From 7bad56b74e057a13c5577d512adf61bc82e2af86 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 10 May 2020 21:43:33 -0700 Subject: [PATCH] [LINT] clang-format the h,cc,m files. (#5557) This PR prepares for our migration to use the clang-format as part of the linter system. --- 3rdparty/bfloat16/bfloat16.cc | 4 +- 3rdparty/cma/cma.h | 17 +- 3rdparty/cma/cma_api_impl.h | 62 +- 3rdparty/compiler-rt/builtin_fp16.h | 82 +- apps/android_camera/app/src/main/jni/tvm_runtime.h | 20 +- apps/android_deploy/app/src/main/jni/tvm_runtime.h | 16 +- apps/android_rpc/app/src/main/jni/tvm_runtime.h | 20 +- apps/bundle_deploy/bundle.cc | 42 +- apps/bundle_deploy/bundle.h | 17 +- apps/bundle_deploy/demo.cc | 55 +- apps/bundle_deploy/runtime.cc | 14 +- apps/bundle_deploy/test.cc | 54 +- apps/cpp_rpc/main.cc | 62 +- apps/cpp_rpc/rpc_env.cc | 78 +- apps/cpp_rpc/rpc_env.h | 5 +- apps/cpp_rpc/rpc_server.cc | 58 +- apps/cpp_rpc/rpc_server.h | 11 +- apps/cpp_rpc/rpc_tracker_client.h | 92 +- apps/cpp_rpc/win32_process.cc | 104 +- apps/cpp_rpc/win32_process.h | 13 +- apps/dso_plugin_module/plugin_module.cc | 32 +- apps/extension/src/tvm_ext.cc | 103 +- apps/howto_deploy/cpp_deploy.cc | 14 +- apps/howto_deploy/tvm_runtime_pack.cc | 10 +- apps/ios_rpc/tvmrpc/AppDelegate.h | 3 +- apps/ios_rpc/tvmrpc/TVMRuntime.h | 9 +- apps/ios_rpc/tvmrpc/TVMRuntime.mm | 99 +- apps/ios_rpc/tvmrpc/ViewController.h | 21 +- apps/ios_rpc/tvmrpc/ViewController.mm | 22 +- apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm | 5 +- apps/rocm_rpc/rocm_runtime_pack.cc | 8 +- golang/src/gotvm.cc | 59 +- golang/src/gotvm.h | 2 +- golang/src/tvm_runtime_pack.cc | 10 +- include/tvm/arith/analyzer.h | 26 +- include/tvm/arith/bound.h | 15 +- include/tvm/arith/int_set.h | 31 +- include/tvm/arith/int_solver.h | 50 +- include/tvm/arith/pattern.h | 8 +- include/tvm/driver/driver_api.h | 59 +- include/tvm/ir/adt.h | 25 +- include/tvm/ir/attrs.h | 210 ++- include/tvm/ir/env_func.h | 22 +- include/tvm/ir/error.h | 16 +- include/tvm/ir/expr.h | 80 +- include/tvm/ir/function.h | 22 +- include/tvm/ir/module.h | 24 +- include/tvm/ir/op.h | 77 +- include/tvm/ir/span.h | 10 +- include/tvm/ir/tensor_type.h | 6 +- include/tvm/ir/transform.h | 31 +- include/tvm/ir/type.h | 68 +- include/tvm/ir/type_functor.h | 31 +- include/tvm/ir/type_relation.h | 37 +- include/tvm/node/container.h | 221 ++- include/tvm/node/functor.h | 45 +- include/tvm/node/node.h | 26 +- include/tvm/node/reflection.h | 128 +- include/tvm/node/repr_printer.h | 1 + include/tvm/node/structural_equal.h | 42 +- include/tvm/node/structural_hash.h | 58 +- include/tvm/relay/adt.h | 34 +- include/tvm/relay/analysis.h | 10 +- include/tvm/relay/attrs/algorithm.h | 45 +- include/tvm/relay/attrs/annotation.h | 13 +- include/tvm/relay/attrs/bitserial.h | 26 +- include/tvm/relay/attrs/debug.h | 5 +- include/tvm/relay/attrs/device_copy.h | 11 +- include/tvm/relay/attrs/image.h | 114 +- include/tvm/relay/attrs/memory.h | 36 +- include/tvm/relay/attrs/nn.h | 1112 ++++++++------- include/tvm/relay/attrs/reduce.h | 14 +- include/tvm/relay/attrs/transform.h | 153 +-- include/tvm/relay/attrs/vision.h | 100 +- include/tvm/relay/base.h | 30 +- include/tvm/relay/expr.h | 62 +- include/tvm/relay/expr_functor.h | 66 +- include/tvm/relay/feature.h | 20 +- include/tvm/relay/function.h | 17 +- include/tvm/relay/interpreter.h | 18 +- include/tvm/relay/op.h | 5 +- include/tvm/relay/op_attr_types.h | 66 +- include/tvm/relay/op_strategy.h | 23 +- include/tvm/relay/pattern_functor.h | 34 +- include/tvm/relay/qnn/attrs.h | 38 +- include/tvm/relay/qnn/transform.h | 2 +- include/tvm/relay/transform.h | 46 +- include/tvm/relay/type.h | 8 +- include/tvm/runtime/c_backend_api.h | 35 +- include/tvm/runtime/c_runtime_api.h | 101 +- include/tvm/runtime/container.h | 83 +- include/tvm/runtime/crt/memory.h | 6 +- include/tvm/runtime/data_type.h | 188 ++- include/tvm/runtime/device_api.h | 77 +- include/tvm/runtime/memory.h | 53 +- include/tvm/runtime/module.h | 36 +- include/tvm/runtime/ndarray.h | 120 +- include/tvm/runtime/object.h | 316 ++--- include/tvm/runtime/packed_func.h | 546 +++----- include/tvm/runtime/registry.h | 42 +- include/tvm/runtime/serializer.h | 12 +- include/tvm/runtime/threading_backend.h | 34 +- include/tvm/runtime/vm.h | 39 +- include/tvm/support/logging.h | 29 +- include/tvm/support/with.h | 12 +- include/tvm/target/codegen.h | 8 +- include/tvm/target/generic_func.h | 33 +- include/tvm/target/target.h | 58 +- include/tvm/target/target_info.h | 1 + include/tvm/te/autodiff.h | 19 +- include/tvm/te/operation.h | 338 ++--- include/tvm/te/schedule.h | 99 +- include/tvm/te/schedule_pass.h | 9 +- include/tvm/te/tensor.h | 68 +- include/tvm/te/tensor_intrin.h | 19 +- include/tvm/tir/analysis.h | 12 +- include/tvm/tir/buffer.h | 35 +- include/tvm/tir/data_layout.h | 39 +- include/tvm/tir/expr.h | 175 +-- include/tvm/tir/expr_functor.h | 27 +- include/tvm/tir/function.h | 17 +- include/tvm/tir/op.h | 101 +- include/tvm/tir/stmt.h | 216 +-- include/tvm/tir/stmt_functor.h | 83 +- include/tvm/tir/transform.h | 43 +- include/tvm/tir/var.h | 67 +- jvm/native/src/main/native/jni_helper_func.h | 54 +- .../src/main/native/org_apache_tvm_native_c_api.cc | 262 ++-- nnvm/include/nnvm/base.h | 8 +- nnvm/include/nnvm/c_api.h | 129 +- nnvm/include/nnvm/graph.h | 95 +- nnvm/include/nnvm/graph_attr_types.h | 9 +- nnvm/include/nnvm/layout.h | 124 +- nnvm/include/nnvm/node.h | 63 +- nnvm/include/nnvm/op.h | 180 ++- nnvm/include/nnvm/op_attr_types.h | 56 +- nnvm/include/nnvm/pass.h | 23 +- nnvm/include/nnvm/pass_functions.h | 50 +- nnvm/include/nnvm/symbolic.h | 20 +- nnvm/include/nnvm/tuple.h | 189 ++- nnvm/src/c_api/c_api_common.h | 31 +- nnvm/src/c_api/c_api_error.cc | 13 +- nnvm/src/c_api/c_api_graph.cc | 39 +- nnvm/src/c_api/c_api_symbolic.cc | 157 +-- nnvm/src/core/graph.cc | 105 +- nnvm/src/core/op.cc | 22 +- nnvm/src/core/pass.cc | 19 +- nnvm/src/core/symbolic.cc | 245 ++-- nnvm/src/pass/correct_layout.cc | 29 +- nnvm/src/pass/gradient.cc | 62 +- nnvm/src/pass/graph_algorithm.h | 23 +- nnvm/src/pass/infer_shape_type.cc | 90 +- nnvm/src/pass/order_mutation.cc | 60 +- nnvm/src/pass/place_device.cc | 55 +- nnvm/src/pass/plan_memory.cc | 82 +- nnvm/src/pass/print_graph_ir.cc | 52 +- nnvm/src/pass/saveload_json.cc | 82 +- nnvm/tests/cpp/op_test.cc | 17 +- nnvm/tests/cpp/tuple_test.cc | 8 +- src/arith/analyzer.cc | 109 +- src/arith/bound_deducer.cc | 54 +- src/arith/canonical_simplify.cc | 149 +- src/arith/compute_expr.h | 28 +- src/arith/const_fold.h | 352 +++-- src/arith/const_int_bound.cc | 83 +- src/arith/detect_linear_equation.cc | 36 +- src/arith/domain_touched.cc | 31 +- src/arith/int_constraints.cc | 41 +- src/arith/int_operator.h | 60 +- src/arith/int_set.cc | 328 ++--- src/arith/interval_set.h | 32 +- src/arith/ir_mutator_with_analyzer.cc | 70 +- src/arith/ir_mutator_with_analyzer.h | 8 +- src/arith/ir_visitor_with_analyzer.h | 13 +- src/arith/modular_set.cc | 101 +- src/arith/pattern_match.h | 340 ++--- src/arith/rewrite_simplify.cc | 912 +++++-------- src/arith/rewrite_simplify.h | 22 +- src/arith/solve_linear_equation.cc | 108 +- src/arith/util.cc | 4 +- src/autotvm/feature_visitor.cc | 10 +- src/autotvm/feature_visitor.h | 28 +- src/autotvm/touch_extractor.cc | 184 ++- src/autotvm/touch_extractor.h | 40 +- src/contrib/hybrid/codegen_hybrid.cc | 94 +- src/contrib/hybrid/codegen_hybrid.h | 88 +- src/driver/driver_api.cc | 133 +- src/ir/adt.cc | 46 +- src/ir/attr_functor.h | 14 +- src/ir/attrs.cc | 30 +- src/ir/env_func.cc | 44 +- src/ir/error.cc | 29 +- src/ir/expr.cc | 179 ++- src/ir/function.cc | 36 +- src/ir/module.cc | 174 +-- src/ir/op.cc | 161 +-- src/ir/span.cc | 36 +- src/ir/tensor_type.cc | 17 +- src/ir/transform.cc | 223 ++- src/ir/type.cc | 120 +- src/ir/type_functor.cc | 60 +- src/ir/type_relation.cc | 36 +- src/node/container.cc | 379 +++-- src/node/reflection.cc | 135 +- src/node/repr_printer.cc | 13 +- src/node/serialization.cc | 114 +- src/node/structural_equal.cc | 42 +- src/node/structural_hash.cc | 46 +- src/printer/doc.cc | 37 +- src/printer/doc.h | 15 +- src/printer/meta_data.h | 15 +- src/printer/relay_text_printer.cc | 136 +- src/printer/text_printer.cc | 17 +- src/printer/text_printer.h | 60 +- src/printer/tir_text_printer.cc | 79 +- src/relay/analysis/annotated_region_set.cc | 29 +- src/relay/analysis/annotated_region_set.h | 82 +- src/relay/analysis/call_graph.cc | 83 +- src/relay/analysis/call_graph.h | 74 +- src/relay/analysis/dependency_graph.cc | 15 +- src/relay/analysis/dependency_graph.h | 6 +- src/relay/analysis/feature.cc | 40 +- src/relay/analysis/kind_check.cc | 69 +- src/relay/analysis/mac_count.cc | 76 +- src/relay/analysis/match_exhaustion.cc | 43 +- src/relay/analysis/type_solver.cc | 174 +-- src/relay/analysis/type_solver.h | 10 +- src/relay/analysis/util.cc | 148 +- src/relay/analysis/well_formed.cc | 21 +- src/relay/backend/build_module.cc | 137 +- src/relay/backend/compile_engine.cc | 211 ++- src/relay/backend/compile_engine.h | 33 +- src/relay/backend/contrib/codegen_c/codegen.cc | 4 +- src/relay/backend/contrib/codegen_c/codegen_c.h | 16 +- src/relay/backend/contrib/dnnl/codegen.cc | 2 +- src/relay/backend/graph_plan_memory.cc | 49 +- src/relay/backend/graph_runtime_codegen.cc | 88 +- src/relay/backend/interpreter.cc | 197 ++- src/relay/backend/param_dict.cc | 127 +- src/relay/backend/param_dict.h | 2 +- src/relay/backend/utils.h | 1 - src/relay/backend/vm/compiler.cc | 250 ++-- src/relay/backend/vm/compiler.h | 19 +- src/relay/backend/vm/inline_primitives.cc | 21 +- src/relay/backend/vm/lambda_lift.cc | 31 +- src/relay/backend/vm/removed_unused_funcs.cc | 20 +- src/relay/ir/adt.cc | 81 +- src/relay/ir/base.cc | 5 +- src/relay/ir/expr.cc | 149 +- src/relay/ir/expr_functor.cc | 112 +- src/relay/ir/function.cc | 36 +- src/relay/ir/op_strategy.cc | 64 +- src/relay/ir/pattern_functor.cc | 32 +- src/relay/ir/transform.cc | 45 +- src/relay/op/algorithm/argsort.cc | 30 +- src/relay/op/algorithm/topk.cc | 30 +- src/relay/op/annotation/annotation.cc | 255 ++-- src/relay/op/debug.cc | 35 +- src/relay/op/device_copy.cc | 34 +- src/relay/op/image/dilation2d.cc | 58 +- src/relay/op/image/resize.cc | 88 +- src/relay/op/memory/memory.cc | 17 +- src/relay/op/nn/bitserial.cc | 48 +- src/relay/op/nn/convolution.cc | 580 +++----- src/relay/op/nn/convolution.h | 189 +-- src/relay/op/nn/nn.cc | 626 ++++----- src/relay/op/nn/nn.h | 8 +- src/relay/op/nn/pad.cc | 135 +- src/relay/op/nn/pooling.cc | 630 ++++----- src/relay/op/nn/sparse.cc | 56 +- src/relay/op/nn/upsampling.cc | 115 +- src/relay/op/op_common.h | 101 +- src/relay/op/tensor/binary.cc | 183 ++- src/relay/op/tensor/reduce.cc | 346 ++--- src/relay/op/tensor/transform.cc | 1444 ++++++++------------ src/relay/op/tensor/transform.h | 45 +- src/relay/op/tensor/unary.cc | 341 ++--- src/relay/op/type_relations.cc | 46 +- src/relay/op/type_relations.h | 19 +- src/relay/op/vision/multibox_op.cc | 86 +- src/relay/op/vision/nms.cc | 66 +- src/relay/op/vision/rcnn_op.cc | 51 +- src/relay/op/vision/yolo.cc | 55 +- src/relay/qnn/op/add.cc | 24 +- src/relay/qnn/op/concatenate.cc | 55 +- src/relay/qnn/op/convolution.cc | 46 +- src/relay/qnn/op/dense.cc | 36 +- src/relay/qnn/op/dequantize.cc | 31 +- src/relay/qnn/op/mul.cc | 15 +- src/relay/qnn/op/op_common.h | 83 +- src/relay/qnn/op/quantize.cc | 36 +- src/relay/qnn/op/requantize.cc | 48 +- src/relay/qnn/op/subtract.cc | 27 +- src/relay/qnn/util.cc | 13 +- src/relay/qnn/util.h | 16 +- src/relay/quantize/annotate.cc | 38 +- src/relay/quantize/calibrate.cc | 47 +- src/relay/quantize/partition.cc | 22 +- src/relay/quantize/quantize.cc | 107 +- src/relay/quantize/quantize.h | 46 +- src/relay/quantize/realize.cc | 139 +- src/relay/transforms/alter_op_layout.cc | 25 +- src/relay/transforms/annotate_target.cc | 3 +- src/relay/transforms/canonicalize_cast.cc | 19 +- src/relay/transforms/canonicalize_ops.cc | 12 +- src/relay/transforms/combine_parallel_conv2d.cc | 44 +- src/relay/transforms/combine_parallel_dense.cc | 21 +- src/relay/transforms/combine_parallel_op.cc | 84 +- src/relay/transforms/combine_parallel_op.h | 24 +- src/relay/transforms/combine_parallel_op_batch.cc | 53 +- src/relay/transforms/combine_parallel_op_batch.h | 25 +- src/relay/transforms/convert_layout.cc | 15 +- src/relay/transforms/convert_sparse_dense.cc | 12 +- src/relay/transforms/de_duplicate.cc | 31 +- src/relay/transforms/dead_code.cc | 37 +- src/relay/transforms/device_annotation.cc | 61 +- src/relay/transforms/eliminate_common_subexpr.cc | 12 +- src/relay/transforms/eta_expand.cc | 30 +- src/relay/transforms/expr_subst.cc | 3 +- src/relay/transforms/expr_subst.h | 4 +- src/relay/transforms/fast_math.cc | 17 +- src/relay/transforms/fold_constant.cc | 54 +- src/relay/transforms/fold_scale_axis.cc | 273 ++-- src/relay/transforms/forward_rewrite.cc | 30 +- src/relay/transforms/fuse_ops.cc | 175 +-- src/relay/transforms/gradient.cc | 251 ++-- src/relay/transforms/infer_layout_util.h | 50 +- src/relay/transforms/inline.cc | 35 +- src/relay/transforms/lazy_gradient_init.cc | 134 +- src/relay/transforms/legalize.cc | 2 +- src/relay/transforms/let_list.h | 22 +- src/relay/transforms/merge_composite.cc | 3 +- src/relay/transforms/partial_eval.cc | 412 +++--- src/relay/transforms/pass_util.h | 35 +- src/relay/transforms/pattern_util.h | 142 +- src/relay/transforms/simplify_fc_transpose.cc | 12 +- src/relay/transforms/simplify_inference.cc | 77 +- src/relay/transforms/to_a_normal_form.cc | 72 +- src/relay/transforms/to_cps.cc | 126 +- src/relay/transforms/to_graph_normal_form.cc | 24 +- src/relay/transforms/transform_layout.h | 15 +- src/relay/transforms/type_infer.cc | 353 ++--- src/runtime/builtin_fp16.cc | 6 +- src/runtime/c_runtime_api.cc | 243 ++-- src/runtime/container.cc | 24 +- src/runtime/contrib/cblas/cblas.cc | 31 +- src/runtime/contrib/cblas/gemm_common.h | 81 +- src/runtime/contrib/coreml/coreml_runtime.h | 24 +- src/runtime/contrib/coreml/coreml_runtime.mm | 86 +- src/runtime/contrib/cublas/cublas.cc | 356 ++--- src/runtime/contrib/cublas/cublas_utils.cc | 11 +- src/runtime/contrib/cublas/cublas_utils.h | 62 +- src/runtime/contrib/cudnn/conv_forward.cc | 433 +++--- src/runtime/contrib/cudnn/cudnn_utils.cc | 68 +- src/runtime/contrib/cudnn/cudnn_utils.h | 24 +- src/runtime/contrib/cudnn/softmax.cc | 100 +- src/runtime/contrib/dnnl/dnnl.cc | 34 +- src/runtime/contrib/dnnl/dnnl_kernel.h | 1 + src/runtime/contrib/edgetpu/edgetpu_runtime.cc | 24 +- src/runtime/contrib/edgetpu/edgetpu_runtime.h | 9 +- .../example_ext_runtime/example_ext_runtime.cc | 30 +- src/runtime/contrib/miopen/conv_forward.cc | 226 ++- src/runtime/contrib/miopen/miopen_utils.cc | 22 +- src/runtime/contrib/miopen/miopen_utils.h | 19 +- src/runtime/contrib/mps/conv.mm | 127 +- src/runtime/contrib/mps/gemm.mm | 64 +- src/runtime/contrib/mps/mps_utils.h | 17 +- src/runtime/contrib/mps/mps_utils.mm | 78 +- src/runtime/contrib/nnpack/convolution.cc | 144 +- src/runtime/contrib/nnpack/fully_connected.cc | 52 +- src/runtime/contrib/nnpack/nnpack_utils.cc | 15 +- src/runtime/contrib/nnpack/nnpack_utils.h | 6 +- src/runtime/contrib/random/mt_random_engine.cc | 77 +- src/runtime/contrib/random/random.cc | 117 +- src/runtime/contrib/rocblas/rocblas.cc | 103 +- src/runtime/contrib/sort/sort.cc | 104 +- src/runtime/contrib/tflite/tflite_runtime.cc | 154 +-- src/runtime/contrib/tflite/tflite_runtime.h | 16 +- src/runtime/cpu_device_api.cc | 47 +- src/runtime/crt/graph_runtime.h | 80 +- src/runtime/crt/load_json.h | 42 +- src/runtime/crt/logging.h | 30 +- src/runtime/crt/module.h | 4 +- src/runtime/crt/ndarray.h | 23 +- src/runtime/crt/packed_func.h | 43 +- src/runtime/cuda/cuda_common.h | 16 +- src/runtime/cuda/cuda_device_api.cc | 119 +- src/runtime/cuda/cuda_module.cc | 118 +- src/runtime/cuda/cuda_module.h | 12 +- src/runtime/dso_library.cc | 39 +- src/runtime/file_util.cc | 32 +- src/runtime/file_util.h | 24 +- src/runtime/graph/debug/graph_runtime_debug.cc | 112 +- src/runtime/graph/graph_runtime.cc | 155 +-- src/runtime/graph/graph_runtime.h | 119 +- src/runtime/hexagon/hexagon_device_api.cc | 45 +- src/runtime/hexagon/hexagon_module.cc | 91 +- src/runtime/hexagon/hexagon_module.h | 18 +- src/runtime/hexagon/hexagon_posix.cc | 6 +- src/runtime/hexagon/sim/hexagon_device_sim.cc | 109 +- .../hexagon/target/fastrpc/src/tvm_remote_imp.cc | 61 +- .../target/fastrpc/src/tvm_remote_nd_imp.cc | 52 +- .../hexagon/target/fastrpc/src/tvm_wrap_pthread.cc | 9 +- .../hexagon/target/hexagon_device_target.cc | 115 +- src/runtime/hexagon/target/hexagon_stubapi.cc | 3 +- src/runtime/hexagon/target/hexagon_stubapi.h | 15 +- src/runtime/hexagon/target/hexagon_target_log.h | 18 +- src/runtime/library_module.cc | 84 +- src/runtime/library_module.h | 9 +- src/runtime/meta_data.h | 12 +- src/runtime/metal/metal_common.h | 35 +- src/runtime/metal/metal_device_api.mm | 177 +-- src/runtime/metal/metal_module.h | 15 +- src/runtime/metal/metal_module.mm | 139 +- src/runtime/micro/host_driven/utvm_runtime.h | 6 +- src/runtime/micro/host_low_level_device.cc | 16 +- src/runtime/micro/low_level_device.h | 8 +- src/runtime/micro/micro_common.cc | 81 +- src/runtime/micro/micro_common.h | 65 +- src/runtime/micro/micro_device_api.cc | 41 +- src/runtime/micro/micro_module.cc | 35 +- src/runtime/micro/micro_section_allocator.h | 28 +- src/runtime/micro/micro_session.cc | 474 +++---- src/runtime/micro/micro_session.h | 161 +-- src/runtime/micro/openocd_low_level_device.cc | 38 +- src/runtime/micro/standalone/minimal_vector.h | 1 - src/runtime/micro/standalone/utvm_graph_runtime.cc | 2 + src/runtime/micro/standalone/utvm_runtime.cc | 8 +- src/runtime/micro/standalone/utvm_runtime_api.cc | 1 + src/runtime/micro/standalone/utvm_runtime_api.h | 1 + src/runtime/micro/target_data_layout_encoder.h | 30 +- src/runtime/micro/tcl_socket.cc | 14 +- src/runtime/micro/tcl_socket.h | 4 +- src/runtime/module.cc | 51 +- src/runtime/ndarray.cc | 125 +- src/runtime/object.cc | 50 +- src/runtime/object_internal.h | 5 +- src/runtime/opencl/aocl/aocl_common.h | 2 +- src/runtime/opencl/aocl/aocl_device_api.cc | 20 +- src/runtime/opencl/aocl/aocl_module.cc | 26 +- src/runtime/opencl/aocl/aocl_module.h | 11 +- src/runtime/opencl/opencl_common.h | 222 +-- src/runtime/opencl/opencl_device_api.cc | 140 +- src/runtime/opencl/opencl_module.cc | 77 +- src/runtime/opencl/opencl_module.h | 11 +- src/runtime/opencl/sdaccel/sdaccel_common.h | 6 +- src/runtime/opencl/sdaccel/sdaccel_device_api.cc | 28 +- src/runtime/opencl/sdaccel/sdaccel_module.cc | 29 +- src/runtime/opencl/sdaccel/sdaccel_module.h | 15 +- src/runtime/opengl/opengl_common.h | 110 +- src/runtime/opengl/opengl_device_api.cc | 158 +-- src/runtime/opengl/opengl_module.cc | 149 +- src/runtime/opengl/opengl_module.h | 24 +- src/runtime/pack_args.h | 71 +- src/runtime/registry.cc | 29 +- src/runtime/rocm/rocm_common.h | 30 +- src/runtime/rocm/rocm_device_api.cc | 73 +- src/runtime/rocm/rocm_module.cc | 127 +- src/runtime/rocm/rocm_module.h | 17 +- src/runtime/rpc/minrpc/minrpc_server.h | 101 +- src/runtime/rpc/minrpc/posix_popen_server.cc | 17 +- src/runtime/rpc/rpc_channel.cc | 3 +- src/runtime/rpc/rpc_channel.h | 1 + src/runtime/rpc/rpc_device_api.cc | 66 +- src/runtime/rpc/rpc_endpoint.cc | 523 +++---- src/runtime/rpc/rpc_endpoint.h | 51 +- src/runtime/rpc/rpc_event_impl.cc | 18 +- src/runtime/rpc/rpc_local_session.cc | 51 +- src/runtime/rpc/rpc_local_session.h | 33 +- src/runtime/rpc/rpc_module.cc | 185 +-- src/runtime/rpc/rpc_pipe_impl.cc | 27 +- src/runtime/rpc/rpc_protocol.h | 108 +- src/runtime/rpc/rpc_server_env.cc | 45 +- src/runtime/rpc/rpc_session.cc | 62 +- src/runtime/rpc/rpc_session.h | 68 +- src/runtime/rpc/rpc_socket_impl.cc | 66 +- src/runtime/runtime_base.h | 18 +- src/runtime/stackvm/stackvm.cc | 304 +++-- src/runtime/stackvm/stackvm.h | 58 +- src/runtime/stackvm/stackvm_module.cc | 45 +- src/runtime/stackvm/stackvm_module.h | 9 +- src/runtime/system_library.cc | 20 +- src/runtime/thread_pool.cc | 107 +- src/runtime/thread_storage_scope.h | 64 +- src/runtime/threading_backend.cc | 75 +- src/runtime/vm/executable.cc | 85 +- src/runtime/vm/memory_manager.cc | 17 +- src/runtime/vm/memory_manager.h | 11 +- src/runtime/vm/naive_allocator.h | 5 +- src/runtime/vm/pooled_allocator.h | 1 + src/runtime/vm/profiler/vm.cc | 34 +- src/runtime/vm/profiler/vm.h | 7 +- src/runtime/vm/serialize_util.h | 17 +- src/runtime/vm/vm.cc | 121 +- src/runtime/vulkan/vulkan.cc | 31 +- src/runtime/vulkan/vulkan_common.h | 3 +- src/runtime/vulkan/vulkan_shader.h | 1 - src/runtime/vulkan/vulkan_stream.h | 6 +- src/runtime/workspace_pool.cc | 15 +- src/runtime/workspace_pool.h | 7 +- src/support/arena.h | 29 +- src/support/base64.h | 100 +- src/support/ffi_testing.cc | 98 +- src/support/pipe.h | 30 +- src/support/ring_buffer.h | 40 +- src/support/socket.h | 187 ++- src/support/str_escape.h | 6 +- src/support/util.h | 13 +- src/target/build_common.h | 20 +- src/target/codegen.cc | 71 +- src/target/datatype/registry.cc | 18 +- src/target/datatype/registry.h | 3 +- src/target/generic_func.cc | 59 +- src/target/intrin_rule.cc | 130 +- src/target/intrin_rule.h | 13 +- src/target/llvm/codegen_amdgpu.cc | 115 +- src/target/llvm/codegen_arm.cc | 45 +- src/target/llvm/codegen_blob.cc | 72 +- src/target/llvm/codegen_blob.h | 9 +- src/target/llvm/codegen_cpu.cc | 489 +++---- src/target/llvm/codegen_cpu.h | 24 +- src/target/llvm/codegen_llvm.cc | 444 +++--- src/target/llvm/codegen_llvm.h | 58 +- src/target/llvm/codegen_nvptx.cc | 100 +- src/target/llvm/codegen_x86_64.cc | 35 +- src/target/llvm/intrin_rule_llvm.cc | 175 ++- src/target/llvm/intrin_rule_llvm.h | 15 +- src/target/llvm/intrin_rule_nvptx.cc | 73 +- src/target/llvm/intrin_rule_rocm.cc | 72 +- src/target/llvm/llvm_common.cc | 54 +- src/target/llvm/llvm_common.h | 45 +- src/target/llvm/llvm_module.cc | 191 ++- src/target/opt/build_aocl_off.cc | 9 +- src/target/opt/build_cuda_off.cc | 12 +- src/target/opt/build_cuda_on.cc | 42 +- src/target/opt/build_hexagon_off.cc | 5 +- src/target/opt/build_metal_off.cc | 8 +- src/target/opt/build_opencl_off.cc | 9 +- src/target/opt/build_opengl_off.cc | 5 +- src/target/opt/build_rocm_off.cc | 15 +- src/target/opt/build_sdaccel_off.cc | 9 +- src/target/source/codegen_aocl.cc | 33 +- src/target/source/codegen_c.cc | 273 ++-- src/target/source/codegen_c.h | 123 +- src/target/source/codegen_c_host.cc | 97 +- src/target/source/codegen_c_host.h | 22 +- src/target/source/codegen_cuda.cc | 267 ++-- src/target/source/codegen_cuda.h | 48 +- src/target/source/codegen_metal.cc | 101 +- src/target/source/codegen_metal.h | 23 +- src/target/source/codegen_opencl.cc | 130 +- src/target/source/codegen_opencl.h | 28 +- src/target/source/codegen_opengl.cc | 43 +- src/target/source/codegen_opengl.h | 12 +- src/target/source/codegen_source_base.cc | 6 +- src/target/source/codegen_source_base.h | 18 +- src/target/source/codegen_vhls.cc | 81 +- src/target/source/codegen_vhls.h | 6 +- src/target/source/intrin_rule_aocl.cc | 68 +- src/target/source/intrin_rule_cuda.cc | 119 +- src/target/source/intrin_rule_metal.cc | 60 +- src/target/source/intrin_rule_opencl.cc | 69 +- src/target/source/intrin_rule_opengl.cc | 48 +- src/target/source/intrin_rule_vhls.cc | 57 +- src/target/source/source_module.cc | 86 +- src/target/spirv/build_vulkan.cc | 52 +- src/target/spirv/codegen_spirv.cc | 179 +-- src/target/spirv/codegen_spirv.h | 27 +- src/target/spirv/intrin_rule_spirv.cc | 62 +- src/target/spirv/ir_builder.cc | 214 ++- src/target/spirv/ir_builder.h | 85 +- src/target/stackvm/codegen_stackvm.cc | 117 +- src/target/stackvm/codegen_stackvm.h | 25 +- src/target/target.cc | 250 ++-- src/target/target_info.cc | 18 +- src/te/autodiff/ad_util.cc | 11 +- src/te/autodiff/ad_util.h | 5 +- src/te/autodiff/adjoint.cc | 107 +- src/te/autodiff/jacobian.cc | 120 +- src/te/operation/compute_op.cc | 265 ++-- src/te/operation/compute_op.h | 34 +- src/te/operation/cross_thread_reduction.cc | 62 +- src/te/operation/extern_op.cc | 103 +- src/te/operation/hybrid_op.cc | 269 ++-- src/te/operation/hybrid_op.h | 27 +- src/te/operation/op_util.cc | 141 +- src/te/operation/op_util.h | 29 +- src/te/operation/placeholder_op.cc | 63 +- src/te/operation/scan_op.cc | 141 +- src/te/operation/tensor_compute_op.cc | 96 +- src/te/operation/tensorize.cc | 225 ++- src/te/schedule/auto_inline_elem_wise.cc | 10 +- src/te/schedule/bound.cc | 58 +- src/te/schedule/graph.cc | 113 +- src/te/schedule/graph.h | 11 +- src/te/schedule/message_passing.cc | 149 +- src/te/schedule/message_passing.h | 46 +- src/te/schedule/operation_inline.cc | 19 +- src/te/schedule/operation_inline.h | 9 +- src/te/schedule/schedule_dataflow_rewrite.cc | 302 ++-- src/te/schedule/schedule_lang.cc | 515 +++---- src/te/schedule/schedule_ops.cc | 144 +- .../schedule_postproc_rewrite_for_tensor_core.cc | 434 +++--- src/te/schedule/schedule_postproc_to_primfunc.cc | 35 +- src/te/schedule/verify_compact_buffer.cc | 7 +- src/te/tensor.cc | 100 +- src/tir/analysis/deep_equal.cc | 18 +- src/tir/analysis/side_effect.cc | 5 +- src/tir/analysis/var_touch.cc | 14 +- src/tir/analysis/verify_gpu_code.cc | 37 +- src/tir/analysis/verify_memory.cc | 90 +- src/tir/analysis/verify_ssa.cc | 20 +- src/tir/ir/buffer.cc | 170 +-- src/tir/ir/data_layout.cc | 148 +- src/tir/ir/expr.cc | 721 +++++----- src/tir/ir/expr_functor.cc | 89 +- src/tir/ir/function.cc | 45 +- src/tir/ir/functor_common.h | 8 +- src/tir/ir/op.cc | 279 ++-- src/tir/ir/stmt.cc | 634 ++++----- src/tir/ir/stmt_functor.cc | 168 +-- src/tir/ir/transform.cc | 33 +- src/tir/pass/hoist_if_then_else.cc | 147 +- src/tir/transforms/arg_binder.cc | 138 +- src/tir/transforms/arg_binder.h | 52 +- src/tir/transforms/bound_checker.cc | 62 +- src/tir/transforms/combine_context_call.cc | 23 +- src/tir/transforms/coproc_sync.cc | 162 +-- src/tir/transforms/decorate_device_scope.cc | 10 +- src/tir/transforms/inject_copy_intrin.cc | 77 +- src/tir/transforms/inject_double_buffer.cc | 61 +- src/tir/transforms/inject_prefetch.cc | 19 +- src/tir/transforms/inject_virtual_thread.cc | 109 +- src/tir/transforms/ir_util.cc | 38 +- src/tir/transforms/ir_util.h | 42 +- src/tir/transforms/lift_attr_scope.cc | 45 +- src/tir/transforms/loop_partition.cc | 124 +- src/tir/transforms/lower_custom_datatypes.cc | 41 +- .../transforms/lower_device_storage_access_info.cc | 42 +- src/tir/transforms/lower_intrin.cc | 57 +- src/tir/transforms/lower_thread_allreduce.cc | 147 +- src/tir/transforms/lower_tvm_builtin.cc | 170 +-- src/tir/transforms/lower_warp_memory.cc | 109 +- src/tir/transforms/make_packed_api.cc | 97 +- src/tir/transforms/narrow_datatype.cc | 119 +- src/tir/transforms/remap_thread_axis.cc | 24 +- src/tir/transforms/remove_no_op.cc | 14 +- src/tir/transforms/rewrite_unsafe_select.cc | 44 +- src/tir/transforms/simplify.cc | 24 +- src/tir/transforms/skip_assert.cc | 11 +- src/tir/transforms/split_host_device.cc | 77 +- src/tir/transforms/storage_access.cc | 19 +- src/tir/transforms/storage_access.h | 32 +- src/tir/transforms/storage_flatten.cc | 189 +-- src/tir/transforms/storage_rewrite.cc | 220 ++- src/tir/transforms/tensorcore_infer_fragment.cc | 31 +- src/tir/transforms/thread_storage_sync.cc | 95 +- src/tir/transforms/unroll_loop.cc | 66 +- src/tir/transforms/vectorize_loop.cc | 207 +-- tests/cpp/arith_simplify_test.cc | 6 +- tests/cpp/attrs_test.cc | 25 +- tests/cpp/build_module_test.cc | 45 +- tests/cpp/container_test.cc | 18 +- tests/cpp/crt_memory_test.cc | 8 +- tests/cpp/expr_test.cc | 4 +- tests/cpp/ir_functor_test.cc | 81 +- tests/cpp/object_protocol_test.cc | 5 +- tests/cpp/packed_func_test.cc | 191 ++- tests/cpp/pattern_match_test.cc | 44 +- tests/cpp/relay_build_module_test.cc | 76 +- tests/cpp/relay_pass_type_infer_test.cc | 16 +- tests/cpp/relay_transform_sequential.cc | 27 +- tests/cpp/simple_passes_test.cc | 3 +- tests/cpp/tensor_test.cc | 12 +- tests/cpp/threading_backend_test.cc | 6 +- tests/cpp/topi_ewise_test.cc | 8 +- tests/cpp/utvm_runtime_standalone_test.cc | 11 +- topi/include/topi/broadcast.h | 82 +- topi/include/topi/contrib/cublas.h | 84 +- topi/include/topi/contrib/rocblas.h | 40 +- topi/include/topi/cuda/dense.h | 65 +- topi/include/topi/cuda/injective.h | 8 +- topi/include/topi/cuda/normalization.h | 12 +- topi/include/topi/cuda/pooling.h | 42 +- topi/include/topi/cuda/reduction.h | 50 +- topi/include/topi/cuda/softmax.h | 8 +- topi/include/topi/detail/array_utils.h | 2 +- topi/include/topi/detail/broadcast.h | 30 +- topi/include/topi/detail/constant_utils.h | 25 +- topi/include/topi/detail/extern.h | 60 +- topi/include/topi/detail/pad_utils.h | 15 +- topi/include/topi/detail/ravel_unravel.h | 34 +- topi/include/topi/detail/tensor_utils.h | 7 +- topi/include/topi/elemwise.h | 516 ++++--- topi/include/topi/generic/default.h | 36 +- topi/include/topi/generic/extern.h | 22 +- topi/include/topi/generic/injective.h | 8 +- topi/include/topi/nn.h | 252 ++-- topi/include/topi/nn/batch_matmul.h | 25 +- topi/include/topi/nn/bias_add.h | 23 +- topi/include/topi/nn/bnn.h | 113 +- topi/include/topi/nn/dense.h | 43 +- topi/include/topi/nn/dilate.h | 89 +- topi/include/topi/nn/flatten.h | 52 +- topi/include/topi/nn/local_response_norm.h | 73 +- topi/include/topi/nn/mapping.h | 68 +- topi/include/topi/nn/pooling.h | 841 ++++++------ topi/include/topi/nn/softmax.h | 93 +- topi/include/topi/reduction.h | 362 +++-- topi/include/topi/rocm/dense.h | 64 +- topi/include/topi/rocm/injective.h | 6 +- topi/include/topi/rocm/normalization.h | 16 +- topi/include/topi/rocm/pooling.h | 42 +- topi/include/topi/rocm/reduction.h | 18 +- topi/include/topi/rocm/softmax.h | 6 +- topi/include/topi/tags.h | 14 +- topi/include/topi/transform.h | 1188 ++++++++-------- topi/include/topi/vision/reorg.h | 41 +- topi/include/topi/x86/bnn.h | 36 +- topi/include/topi/x86/default.h | 57 +- topi/include/topi/x86/injective.h | 22 +- topi/src/broadcast.cc | 49 +- topi/src/elemwise.cc | 181 +-- topi/src/nn.cc | 158 +-- topi/src/reduction.cc | 51 +- topi/src/schedule.cc | 302 ++-- topi/src/transform.cc | 126 +- topi/src/vision.cc | 14 +- vta/runtime/device_api.cc | 50 +- vta/runtime/runtime.cc | 595 +++----- vta/runtime/runtime.h | 67 +- web/emcc/tvmjs_support.cc | 123 +- web/emcc/wasm_runtime.cc | 45 +- web/emcc/webgpu_runtime.cc | 63 +- 734 files changed, 25112 insertions(+), 35054 deletions(-) diff --git a/3rdparty/bfloat16/bfloat16.cc b/3rdparty/bfloat16/bfloat16.cc index 56d05ef..674feb4 100644 --- a/3rdparty/bfloat16/bfloat16.cc +++ b/3rdparty/bfloat16/bfloat16.cc @@ -17,6 +17,7 @@ ==============================================================================*/ #include + #include #include @@ -50,8 +51,7 @@ void BFloat16ToFloat(const uint16_t* src, float* dst, size_t size) { #endif } -void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, - size_t size) { +void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, size_t size) { float a_f, b_f; BFloat16ToFloat(a, &a_f, 1); BFloat16ToFloat(b, &b_f, 1); diff --git a/3rdparty/cma/cma.h b/3rdparty/cma/cma.h index f005b30..2cd5501 100644 --- a/3rdparty/cma/cma.h +++ b/3rdparty/cma/cma.h @@ -27,20 +27,17 @@ #ifndef VTA_DE10_NANO_KERNEL_MODULE_CMA_H_ #define VTA_DE10_NANO_KERNEL_MODULE_CMA_H_ - /* Should be defined in settings.mk file */ #ifndef CMA_IOCTL_MAGIC -#define CMA_IOCTL_MAGIC 0xf2 +#define CMA_IOCTL_MAGIC 0xf2 #endif +#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 1, 4) +#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 2, 4) +#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) +#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 4, 4) +#define CMA_GET_SIZE _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 5, 4) -#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 1, 4) -#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 2, 4) -#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) -#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 4, 4) -#define CMA_GET_SIZE _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 5, 4) - -#define CMA_IOCTL_MAXNR 5 - +#define CMA_IOCTL_MAXNR 5 #endif // VTA_DE10_NANO_KERNEL_MODULE_CMA_H_ diff --git a/3rdparty/cma/cma_api_impl.h b/3rdparty/cma/cma_api_impl.h index 12c0e3b..317be5c 100644 --- a/3rdparty/cma/cma_api_impl.h +++ b/3rdparty/cma/cma_api_impl.h @@ -30,48 +30,47 @@ * \brief Application layer implementation for contigous memory allocation. */ +#include +#include #include #include -#include -#include -#include #include -#include #include #include +#include +#include #include "cma_api.h" #ifndef CMA_IOCTL_MAGIC -#define CMA_IOCTL_MAGIC 0xf2 +#define CMA_IOCTL_MAGIC 0xf2 #endif -#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 1, 4) -#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 2, 4) -#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) -#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 4, 4) -#define CMA_GET_SIZE _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 5, 4) +#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 1, 4) +#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 2, 4) +#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) +#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 4, 4) +#define CMA_GET_SIZE _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 5, 4) -#define CMA_IOCTL_MAXNR 5 +#define CMA_IOCTL_MAXNR 5 #ifndef CMA_DEBUG - #define CMA_DEBUG 0 +#define CMA_DEBUG 0 #endif #ifndef DRIVER_NODE_NAME - #define DRIVER_NODE_NAME "cma" +#define DRIVER_NODE_NAME "cma" #endif #if CMA_DEBUG == 1 - #define __DEBUG(fmt, args...) printf("CMA_API_DEBUG: " fmt, ##args) +#define __DEBUG(fmt, args...) printf("CMA_API_DEBUG: " fmt, ##args) #else - #define __DEBUG(fmt, args...) +#define __DEBUG(fmt, args...) #endif -#define ROUND_UP(N, S) ((((N) + (S) - 1) / (S)) * (S)) - +#define ROUND_UP(N, S) ((((N) + (S)-1) / (S)) * (S)) /* Private functions */ -void *cma_alloc(size_t size, unsigned ioctl_cmd); +void* cma_alloc(size_t size, unsigned ioctl_cmd); /* Global file descriptor */ int cma_fd = 0; @@ -99,23 +98,19 @@ int cma_release(void) { return 0; } -void *cma_alloc_cached(size_t size) { - return cma_alloc(size, CMA_ALLOC_CACHED); -} +void* cma_alloc_cached(size_t size) { return cma_alloc(size, CMA_ALLOC_CACHED); } -void *cma_alloc_noncached(size_t size) { - return cma_alloc(size, CMA_ALLOC_NONCACHED); -} +void* cma_alloc_noncached(size_t size) { return cma_alloc(size, CMA_ALLOC_NONCACHED); } -int cma_free(void *mem) { +int cma_free(void* mem) { __DEBUG("Releasing contigous memory from 0x%x\n", (unsigned)mem); unsigned data, v_addr; /* save user space pointer value */ - data = (unsigned)mem; + data = (unsigned)mem; v_addr = (unsigned)mem; - if ( ioctl(cma_fd, CMA_GET_SIZE, &data) == -1 ) { + if (ioctl(cma_fd, CMA_GET_SIZE, &data) == -1) { __DEBUG("cma_free - ioctl command unsuccsessful - 0\n"); return -1; } @@ -125,7 +120,7 @@ int cma_free(void *mem) { munmap(mem, data); /* free cma entry */ - if ( ioctl(cma_fd, CMA_FREE, &v_addr) == -1 ) { + if (ioctl(cma_fd, CMA_FREE, &v_addr) == -1) { __DEBUG("cma_free - ioctl command unsuccsessful - 1\n"); return -1; } @@ -133,7 +128,7 @@ int cma_free(void *mem) { return 0; } -unsigned cma_get_phy_addr(void *mem) { +unsigned cma_get_phy_addr(void* mem) { unsigned data; __DEBUG("Getting physical address from 0x%x\n", (unsigned)mem); @@ -141,7 +136,7 @@ unsigned cma_get_phy_addr(void *mem) { data = (unsigned)mem; /* get physical address */ - if ( ioctl(cma_fd, CMA_GET_PHY_ADDR, &data) == -1 ) { + if (ioctl(cma_fd, CMA_GET_PHY_ADDR, &data) == -1) { __DEBUG("cma_free - ioctl command unsuccsessful\n"); return 0; } @@ -150,10 +145,9 @@ unsigned cma_get_phy_addr(void *mem) { return data; } - -void *cma_alloc(size_t size, unsigned ioctl_cmd) { +void* cma_alloc(size_t size, unsigned ioctl_cmd) { unsigned data; - void *mem; + void* mem; __DEBUG("Allocating 0x%x bytes of contigous memory\n", size); /* Page align size */ @@ -161,7 +155,7 @@ void *cma_alloc(size_t size, unsigned ioctl_cmd) { /* ioctl cmd to allocate contigous memory */ data = (unsigned)size; - if ( ioctl(cma_fd, ioctl_cmd, &data) == -1 ) { + if (ioctl(cma_fd, ioctl_cmd, &data) == -1) { __DEBUG("cma_alloc - ioctl command unsuccsessful\n"); return NULL; } diff --git a/3rdparty/compiler-rt/builtin_fp16.h b/3rdparty/compiler-rt/builtin_fp16.h index fa8efdd..8048980 100644 --- a/3rdparty/compiler-rt/builtin_fp16.h +++ b/3rdparty/compiler-rt/builtin_fp16.h @@ -29,16 +29,33 @@ static inline uint32_t __clz(uint32_t x) { int n = 32; uint32_t y; - y = x >>16; if (y) { n = n -16; x = y; } - y = x >> 8; if (y) { n = n - 8; x = y; } - y = x >> 4; if (y) { n = n - 4; x = y; } - y = x >> 2; if (y) { n = n - 2; x = y; } - y = x >> 1; if (y) return n - 2; + y = x >> 16; + if (y) { + n = n - 16; + x = y; + } + y = x >> 8; + if (y) { + n = n - 8; + x = y; + } + y = x >> 4; + if (y) { + n = n - 4; + x = y; + } + y = x >> 2; + if (y) { + n = n - 2; + x = y; + } + y = x >> 1; + if (y) return n - 2; return n - x; } -template +template static inline DST_T __truncXfYf2__(SRC_T a) { // Various constants whose values follow from the type parameters. // Any reasonable optimizer will fold and propagate all of these. @@ -71,7 +88,10 @@ static inline DST_T __truncXfYf2__(SRC_T a) { const DST_REP_T dstNaNCode = dstQNaN - 1; // Break a into a sign and representation of the absolute value - union SrcExchangeType { SRC_T f; SRC_REP_T i; }; + union SrcExchangeType { + SRC_T f; + SRC_REP_T i; + }; SrcExchangeType src_rep; src_rep.f = a; const SRC_REP_T aRep = src_rep.i; @@ -88,25 +108,21 @@ static inline DST_T __truncXfYf2__(SRC_T a) { const SRC_REP_T roundBits = aAbs & roundMask; // Round to nearest - if (roundBits > halfway) - absResult++; - // Ties to even + if (roundBits > halfway) absResult++; + // Ties to even else if (roundBits == halfway) absResult += absResult & 1; - } - else if (aAbs > srcInfinity) { + } else if (aAbs > srcInfinity) { // a is NaN. // Conjure the result by beginning with infinity, setting the qNaN // bit and inserting the (truncated) trailing NaN field. absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS; absResult |= dstQNaN; absResult |= ((aAbs & srcNaNCode) >> (SRC_SIG_BITS - DST_SIG_BITS)) & dstNaNCode; - } - else if (aAbs >= overflow) { + } else if (aAbs >= overflow) { // a overflows to infinity. absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS; - } - else { + } else { // a underflows on conversion to the destination type or is an exact // zero. The result may be a denormal or zero. Extract the exponent // to get the shift amount for the denormalization. @@ -124,9 +140,8 @@ static inline DST_T __truncXfYf2__(SRC_T a) { absResult = denormalizedSignificand >> (SRC_SIG_BITS - DST_SIG_BITS); const SRC_REP_T roundBits = denormalizedSignificand & roundMask; // Round to nearest - if (roundBits > halfway) - absResult++; - // Ties to even + if (roundBits > halfway) absResult++; + // Ties to even else if (roundBits == halfway) absResult += absResult & 1; } @@ -134,14 +149,17 @@ static inline DST_T __truncXfYf2__(SRC_T a) { // Apply the signbit to (DST_T)abs(a). const DST_REP_T result = absResult | sign >> (srcBits - dstBits); - union DstExchangeType { DST_T f; DST_REP_T i; }; + union DstExchangeType { + DST_T f; + DST_REP_T i; + }; DstExchangeType dst_rep; dst_rep.i = result; return dst_rep.f; } -template +template static inline DST_T __extendXfYf2__(SRC_T a) { // Various constants whose values follow from the type parameters. // Any reasonable optimizer will fold and propagate all of these. @@ -157,7 +175,7 @@ static inline DST_T __extendXfYf2__(SRC_T a) { const SRC_REP_T srcQNaN = SRC_REP_T(1) << (SRC_SIG_BITS - 1); const SRC_REP_T srcNaNCode = srcQNaN - 1; - const int dstBits = sizeof(DST_T)*8; + const int dstBits = sizeof(DST_T) * 8; const int dstExpBits = dstBits - DST_SIG_BITS - 1; const int dstInfExp = (1 << dstExpBits) - 1; const int dstExpBias = dstInfExp >> 1; @@ -165,7 +183,10 @@ static inline DST_T __extendXfYf2__(SRC_T a) { const DST_REP_T dstMinNormal = DST_REP_T(1) << DST_SIG_BITS; // Break a into a sign and representation of the absolute value - union SrcExchangeType { SRC_T f; SRC_REP_T i; }; + union SrcExchangeType { + SRC_T f; + SRC_REP_T i; + }; SrcExchangeType src_rep; src_rep.f = a; const SRC_REP_T aRep = src_rep.i; @@ -191,8 +212,7 @@ static inline DST_T __extendXfYf2__(SRC_T a) { absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS; absResult |= (DST_REP_T)(aAbs & srcQNaN) << (DST_SIG_BITS - SRC_SIG_BITS); absResult |= (DST_REP_T)(aAbs & srcNaNCode) << (DST_SIG_BITS - SRC_SIG_BITS); - } - else if (aAbs) { + } else if (aAbs) { // a is denormal. // renormalize the significand and clear the leading bit, then insert // the correct adjusted exponent in the destination type. @@ -201,15 +221,17 @@ static inline DST_T __extendXfYf2__(SRC_T a) { absResult ^= dstMinNormal; const int resultExponent = dstExpBias - srcExpBias - scale + 1; absResult |= (DST_REP_T)resultExponent << DST_SIG_BITS; - } - else { + } else { // a is zero. absResult = 0; } // Apply the signbit to (DST_T)abs(a). const DST_REP_T result = absResult | (DST_REP_T)sign << (dstBits - srcBits); - union DstExchangeType { DST_T f; DST_REP_T i; }; + union DstExchangeType { + DST_T f; + DST_REP_T i; + }; DstExchangeType dst_rep; dst_rep.i = result; return dst_rep.f; diff --git a/apps/android_camera/app/src/main/jni/tvm_runtime.h b/apps/android_camera/app/src/main/jni/tvm_runtime.h index a58252e..bc10bda 100644 --- a/apps/android_camera/app/src/main/jni/tvm_runtime.h +++ b/apps/android_camera/app/src/main/jni/tvm_runtime.h @@ -22,6 +22,7 @@ * \brief Pack all tvm runtime source files */ #include + #include /* Enable custom logging - this will cause TVM to pass every log message @@ -38,23 +39,23 @@ #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" +#include "../src/runtime/dso_library.cc" +#include "../src/runtime/file_util.cc" +#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" +#include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_module.cc" +#include "../src/runtime/rpc/rpc_server_env.cc" +#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_socket_impl.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/threading_backend.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" +#include "../src/runtime/workspace_pool.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" @@ -69,7 +70,6 @@ #include "../src/runtime/contrib/sort/sort.cc" #endif - #include void dmlc::CustomLogMessage::Log(const std::string& msg) { diff --git a/apps/android_deploy/app/src/main/jni/tvm_runtime.h b/apps/android_deploy/app/src/main/jni/tvm_runtime.h index 0d038fb..f1a47a6 100644 --- a/apps/android_deploy/app/src/main/jni/tvm_runtime.h +++ b/apps/android_deploy/app/src/main/jni/tvm_runtime.h @@ -22,23 +22,23 @@ * \brief Pack all tvm runtime source files */ #include + #include #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" +#include "../src/runtime/dso_library.cc" +#include "../src/runtime/file_util.cc" +#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" +#include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" -#include "../src/runtime/object.cc" #include "../src/runtime/threading_backend.cc" -#include "../src/runtime/ndarray.cc" - -#include "../src/runtime/graph/graph_runtime.cc" +#include "../src/runtime/workspace_pool.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 5d2bca2..0b713b8 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -22,6 +22,7 @@ * \brief Pack all tvm runtime source files */ #include + #include /* Enable custom logging - this will cause TVM to pass every log message @@ -38,23 +39,23 @@ #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" +#include "../src/runtime/dso_library.cc" +#include "../src/runtime/file_util.cc" +#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" +#include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_module.cc" +#include "../src/runtime/rpc/rpc_server_env.cc" +#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_socket_impl.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/threading_backend.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" +#include "../src/runtime/workspace_pool.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" @@ -69,7 +70,6 @@ #include "../src/runtime/contrib/sort/sort.cc" #endif - #include void dmlc::CustomLogMessage::Log(const std::string& msg) { diff --git a/apps/bundle_deploy/bundle.cc b/apps/bundle_deploy/bundle.cc index 3e50809..d8ff683 100644 --- a/apps/bundle_deploy/bundle.cc +++ b/apps/bundle_deploy/bundle.cc @@ -17,51 +17,47 @@ * under the License. */ -#include #include #include +#include + #define TVM_BUNDLE_FUNCTION __attribute__((visibility("default"))) extern "C" { -TVM_BUNDLE_FUNCTION void *tvm_runtime_create(const char * build_graph_json, - const char * build_params_bin, +TVM_BUNDLE_FUNCTION void* tvm_runtime_create(const char* build_graph_json, + const char* build_params_bin, const uint64_t build_params_bin_len) { const int build_graph_json_len = strlen(build_graph_json); - const std::string json_data(&build_graph_json[0], - &build_graph_json[0] + build_graph_json_len); - tvm::runtime::Module mod_syslib = - (*tvm::runtime::Registry::Get("runtime.SystemLib"))(); + const std::string json_data(&build_graph_json[0], &build_graph_json[0] + build_graph_json_len); + tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("runtime.SystemLib"))(); int device_type = kDLCPU; int device_id = 0; - tvm::runtime::Module mod = - (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))( - json_data, mod_syslib, device_type, device_id); + tvm::runtime::Module mod = (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))( + json_data, mod_syslib, device_type, device_id); TVMByteArray params; - params.data = reinterpret_cast(&build_params_bin[0]); + params.data = reinterpret_cast(&build_params_bin[0]); params.size = build_params_bin_len; mod.GetFunction("load_params")(params); return new tvm::runtime::Module(mod); } -TVM_BUNDLE_FUNCTION void tvm_runtime_destroy(void *handle) { - delete reinterpret_cast(handle); +TVM_BUNDLE_FUNCTION void tvm_runtime_destroy(void* handle) { + delete reinterpret_cast(handle); } -TVM_BUNDLE_FUNCTION void tvm_runtime_set_input(void *handle, const char *name, - void *tensor) { - reinterpret_cast(handle)->GetFunction("set_input")( - name, reinterpret_cast(tensor)); +TVM_BUNDLE_FUNCTION void tvm_runtime_set_input(void* handle, const char* name, void* tensor) { + reinterpret_cast(handle)->GetFunction("set_input")( + name, reinterpret_cast(tensor)); } -TVM_BUNDLE_FUNCTION void tvm_runtime_run(void *handle) { - reinterpret_cast(handle)->GetFunction("run")(); +TVM_BUNDLE_FUNCTION void tvm_runtime_run(void* handle) { + reinterpret_cast(handle)->GetFunction("run")(); } -TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void *handle, int index, - void *tensor) { - reinterpret_cast(handle)->GetFunction("get_output")( - index, reinterpret_cast(tensor)); +TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void* handle, int index, void* tensor) { + reinterpret_cast(handle)->GetFunction("get_output")( + index, reinterpret_cast(tensor)); } } diff --git a/apps/bundle_deploy/bundle.h b/apps/bundle_deploy/bundle.h index aa57faa..80238e1 100644 --- a/apps/bundle_deploy/bundle.h +++ b/apps/bundle_deploy/bundle.h @@ -22,20 +22,15 @@ #include -TVM_DLL void * tvm_runtime_create(const char * json_data, - const char * params_data, - const uint64_t params_size); +TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, + const uint64_t params_size); -TVM_DLL void tvm_runtime_destroy(void * runtime); +TVM_DLL void tvm_runtime_destroy(void* runtime); -TVM_DLL void tvm_runtime_set_input(void * runtime, - const char * name, - DLTensor * tensor); +TVM_DLL void tvm_runtime_set_input(void* runtime, const char* name, DLTensor* tensor); -TVM_DLL void tvm_runtime_run(void * runtime); +TVM_DLL void tvm_runtime_run(void* runtime); -TVM_DLL void tvm_runtime_get_output(void * runtime, - int32_t index, - DLTensor * tensor); +TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tensor); #endif /* TVM_APPS_BUNDLE_DEPLOY_BUNDLE_H_ */ diff --git a/apps/bundle_deploy/demo.cc b/apps/bundle_deploy/demo.cc index 0de10d7..5c210a2 100644 --- a/apps/bundle_deploy/demo.cc +++ b/apps/bundle_deploy/demo.cc @@ -17,44 +17,44 @@ * under the License. */ +#include +#include //dlopen +#include #include -#include -#include //dlopen #include #include #include -#include #include "build/graph.json.c" #include "build/params.bin.c" -template auto getFunc(void *bundle, const char *name) { +template +auto getFunc(void* bundle, const char* name) { dlerror(); - auto *f = - reinterpret_cast::type>(dlsym(bundle, name)); + auto* f = reinterpret_cast::type>(dlsym(bundle, name)); assert(!dlerror()); return f; } -int main(int argc, char **argv) { +int main(int argc, char** argv) { assert(argc == 3 && "Usage: demo "); - auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); + auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); assert(bundle); - char * json_data = reinterpret_cast(build_graph_json); - char * params_data = reinterpret_cast(build_params_bin); + char* json_data = reinterpret_cast(build_graph_json); + char* params_data = reinterpret_cast(build_params_bin); uint64_t params_size = build_params_bin_len; struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto *handle = getFunc(bundle, "tvm_runtime_create")( + auto* handle = getFunc(bundle, "tvm_runtime_create")( json_data, params_data, params_size); gettimeofday(&t1, 0); float input_storage[1 * 3 * 224 * 224]; - FILE * fp = fopen(argv[2], "rb"); + FILE* fp = fopen(argv[2], "rb"); fread(input_storage, 3 * 224 * 224, 4, fp); fclose(fp); @@ -68,12 +68,10 @@ int main(int argc, char **argv) { input.strides = nullptr; input.byte_offset = 0; - getFunc(bundle, "tvm_runtime_set_input")( - handle, "data", &input); + getFunc(bundle, "tvm_runtime_set_input")(handle, "data", &input); gettimeofday(&t2, 0); - auto *ftvm_runtime_run = - (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run"); + auto* ftvm_runtime_run = (auto (*)(void*)->void)dlsym(bundle, "tvm_runtime_run"); assert(!dlerror()); ftvm_runtime_run(handle); gettimeofday(&t3, 0); @@ -89,8 +87,7 @@ int main(int argc, char **argv) { output.strides = nullptr; output.byte_offset = 0; - getFunc(bundle, "tvm_runtime_get_output")( - handle, 0, &output); + getFunc(bundle, "tvm_runtime_get_output")(handle, 0, &output); gettimeofday(&t4, 0); float max_iter = -std::numeric_limits::max(); @@ -102,19 +99,19 @@ int main(int argc, char **argv) { } } - getFunc(bundle, "tvm_runtime_destroy")(handle); + getFunc(bundle, "tvm_runtime_destroy")(handle); gettimeofday(&t5, 0); - printf("The maximum position in output vector is: %d, with max-value %f.\n", - max_index, max_iter); - printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " - "%.2f ms (get_output), %.2f ms (destroy)\n", - (t1.tv_sec-t0.tv_sec)*1000.0f + (t1.tv_usec-t0.tv_usec)/1000.f, - (t2.tv_sec-t1.tv_sec)*1000.0f + (t2.tv_usec-t1.tv_usec)/1000.f, - (t3.tv_sec-t2.tv_sec)*1000.0f + (t3.tv_usec-t2.tv_usec)/1000.f, - (t4.tv_sec-t3.tv_sec)*1000.0f + (t4.tv_usec-t3.tv_usec)/1000.f, - (t5.tv_sec-t4.tv_sec)*1000.0f + (t5.tv_usec-t4.tv_usec)/1000.f); + printf("The maximum position in output vector is: %d, with max-value %f.\n", max_index, max_iter); + printf( + "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec - t0.tv_sec) * 1000.0f + (t1.tv_usec - t0.tv_usec) / 1000.f, + (t2.tv_sec - t1.tv_sec) * 1000.0f + (t2.tv_usec - t1.tv_usec) / 1000.f, + (t3.tv_sec - t2.tv_sec) * 1000.0f + (t3.tv_usec - t2.tv_usec) / 1000.f, + (t4.tv_sec - t3.tv_sec) * 1000.0f + (t4.tv_usec - t3.tv_usec) / 1000.f, + (t5.tv_sec - t4.tv_sec) * 1000.0f + (t5.tv_usec - t4.tv_usec) / 1000.f); dlclose(bundle); - + return 0; } diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc index 7a116e8..8e294a0 100644 --- a/apps/bundle_deploy/runtime.cc +++ b/apps/bundle_deploy/runtime.cc @@ -19,19 +19,19 @@ #include #include -#include #include +#include #include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/cpu_device_api.cc" -#include "../../src/runtime/workspace_pool.cc" +#include "../../src/runtime/file_util.cc" +#include "../../src/runtime/graph/graph_runtime.cc" #include "../../src/runtime/library_module.cc" #include "../../src/runtime/module.cc" -#include "../../src/runtime/registry.cc" -#include "../../src/runtime/file_util.cc" -#include "../../src/runtime/threading_backend.cc" -#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/ndarray.cc" #include "../../src/runtime/object.cc" +#include "../../src/runtime/registry.cc" #include "../../src/runtime/system_library.cc" -#include "../../src/runtime/graph/graph_runtime.cc" +#include "../../src/runtime/thread_pool.cc" +#include "../../src/runtime/threading_backend.cc" +#include "../../src/runtime/workspace_pool.cc" diff --git a/apps/bundle_deploy/test.cc b/apps/bundle_deploy/test.cc index c92400d..882e04b 100644 --- a/apps/bundle_deploy/test.cc +++ b/apps/bundle_deploy/test.cc @@ -17,35 +17,35 @@ * under the License. */ +#include +#include //dlopen +#include +#include #include -#include -#include //dlopen #include #include #include -#include -#include -template auto getFunc(void *bundle, const char *name) { +template +auto getFunc(void* bundle, const char* name) { dlerror(); - auto *f = - reinterpret_cast::type>(dlsym(bundle, name)); + auto* f = reinterpret_cast::type>(dlsym(bundle, name)); assert(!dlerror()); return f; } -int main(int argc, char **argv) { +int main(int argc, char** argv) { assert(argc == 6 && "Usage: test "); - auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); + auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); assert(bundle); struct stat st; - char * json_data; - char * params_data; + char* json_data; + char* params_data; uint64_t params_size; - FILE * fp = fopen(argv[4], "rb"); + FILE* fp = fopen(argv[4], "rb"); stat(argv[4], &st); json_data = (char*)malloc(st.st_size); fread(json_data, st.st_size, 1, fp); @@ -61,7 +61,7 @@ int main(int argc, char **argv) { struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto *handle = getFunc(bundle, "tvm_runtime_create")( + auto* handle = getFunc(bundle, "tvm_runtime_create")( json_data, params_data, params_size); gettimeofday(&t1, 0); @@ -85,12 +85,10 @@ int main(int argc, char **argv) { input.strides = nullptr; input.byte_offset = 0; - getFunc(bundle, "tvm_runtime_set_input")( - handle, "x", &input); + getFunc(bundle, "tvm_runtime_set_input")(handle, "x", &input); gettimeofday(&t2, 0); - auto *ftvm_runtime_run = - (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run"); + auto* ftvm_runtime_run = (auto (*)(void*)->void)dlsym(bundle, "tvm_runtime_run"); assert(!dlerror()); ftvm_runtime_run(handle); gettimeofday(&t3, 0); @@ -106,8 +104,7 @@ int main(int argc, char **argv) { output.strides = nullptr; output.byte_offset = 0; - getFunc(bundle, "tvm_runtime_get_output")( - handle, 0, &output); + getFunc(bundle, "tvm_runtime_get_output")(handle, 0, &output); gettimeofday(&t4, 0); for (auto i = 0; i < 10 * 5; ++i) { @@ -117,20 +114,21 @@ int main(int argc, char **argv) { } } - getFunc(bundle, "tvm_runtime_destroy")(handle); + getFunc(bundle, "tvm_runtime_destroy")(handle); gettimeofday(&t5, 0); - printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " - "%.2f ms (get_output), %.2f ms (destroy)\n", - (t1.tv_sec-t0.tv_sec)*1000.0f + (t1.tv_usec-t0.tv_usec)/1000.f, - (t2.tv_sec-t1.tv_sec)*1000.0f + (t2.tv_usec-t1.tv_usec)/1000.f, - (t3.tv_sec-t2.tv_sec)*1000.0f + (t3.tv_usec-t2.tv_usec)/1000.f, - (t4.tv_sec-t3.tv_sec)*1000.0f + (t4.tv_usec-t3.tv_usec)/1000.f, - (t5.tv_sec-t4.tv_sec)*1000.0f + (t5.tv_usec-t4.tv_usec)/1000.f); + printf( + "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec - t0.tv_sec) * 1000.0f + (t1.tv_usec - t0.tv_usec) / 1000.f, + (t2.tv_sec - t1.tv_sec) * 1000.0f + (t2.tv_usec - t1.tv_usec) / 1000.f, + (t3.tv_sec - t2.tv_sec) * 1000.0f + (t3.tv_usec - t2.tv_usec) / 1000.f, + (t4.tv_sec - t3.tv_sec) * 1000.0f + (t4.tv_usec - t3.tv_usec) / 1000.f, + (t5.tv_sec - t4.tv_sec) * 1000.0f + (t5.tv_usec - t4.tv_usec) / 1000.f); free(json_data); free(params_data); dlclose(bundle); - + return 0; } diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index 5168da3..ae2636d 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -21,20 +21,21 @@ * \file rpc_server.cc * \brief RPC Server for TVM. */ -#include #include #include +#include #if defined(__linux__) || defined(__ANDROID__) #include #endif #include -#include + #include -#include +#include #include +#include -#include "../../src/support/util.h" #include "../../src/support/socket.h" +#include "../../src/support/util.h" #include "rpc_server.h" #if defined(_WIN32) @@ -45,21 +46,21 @@ using namespace std; using namespace tvm::runtime; using namespace tvm::support; -static const string kUsage = \ -"Command line usage\n" \ -" server - Start the server\n" \ -"--host - The hostname of the server, Default=0.0.0.0\n" \ -"--port - The port of the RPC, Default=9090\n" \ -"--port-end - The end search port of the RPC, Default=9199\n" \ -"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \ -"--key - The key used to identify the device type in tracker. Default=\"\"\n" \ -"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \ -"--silent - Whether to run in silent mode. Default=False\n" \ -"\n" \ -" Example\n" \ -" ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " -" --tracker=127.0.0.1:9190 --key=rasp" \ -"\n"; +static const string kUsage = + "Command line usage\n" + " server - Start the server\n" + "--host - The hostname of the server, Default=0.0.0.0\n" + "--port - The port of the RPC, Default=9090\n" + "--port-end - The end search port of the RPC, Default=9199\n" + "--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" + "--key - The key used to identify the device type in tracker. Default=\"\"\n" + "--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" + "--silent - Whether to run in silent mode. Default=False\n" + "\n" + " Example\n" + " ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " + " --tracker=127.0.0.1:9190 --key=rasp" + "\n"; /*! * \brief RpcServerArgs. @@ -95,7 +96,7 @@ void PrintArgs(const RpcServerArgs& args) { LOG(INFO) << "tracker = " << args.tracker; LOG(INFO) << "key = " << args.key; LOG(INFO) << "custom_addr = " << args.custom_addr; - LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); + LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False")); } #if defined(__linux__) || defined(__ANDROID__) @@ -151,7 +152,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { * \param tracker The tracker input. * \return result of operation. */ -bool ValidateTracker(string &tracker) { +bool ValidateTracker(string& tracker) { vector list = Split(tracker, ':'); if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) { return false; @@ -168,7 +169,7 @@ bool ValidateTracker(string &tracker) { * \param argv arg values * \param args the output structure which holds the parsed values */ -void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { +void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) { const string silent = GetCmdOption(argc, argv, "--silent", true); if (!silent.empty()) { args.silent = true; @@ -232,12 +233,11 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { } #if defined(WIN32) const string mmap_path = GetCmdOption(argc, argv, "--child_proc="); - if(!mmap_path.empty()) { + if (!mmap_path.empty()) { args.mmap_path = mmap_path; dmlc::InitLogging("--minloglevel=0"); } #endif - } /*! @@ -246,7 +246,7 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { * \param argv arg values * \return result of operation. */ -int RpcServer(int argc, char * argv[]) { +int RpcServer(int argc, char* argv[]) { RpcServerArgs args; /* parse the command line args */ @@ -260,21 +260,21 @@ int RpcServer(int argc, char * argv[]) { #endif #if defined(WIN32) - if(!args.mmap_path.empty()) { + if (!args.mmap_path.empty()) { int ret = 0; try { - ChildProcSocketHandler(args.mmap_path); + ChildProcSocketHandler(args.mmap_path); } catch (const std::exception&) { - ret = -1; + ret = -1; } return ret; } #endif - RPCServerCreate(args.host, args.port, args.port_end, args.tracker, - args.key, args.custom_addr, args.silent); + RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr, + args.silent); return 0; } @@ -284,7 +284,7 @@ int RpcServer(int argc, char * argv[]) { * \param argv arg values * \return result of operation. */ -int main(int argc, char * argv[]) { +int main(int argc, char* argv[]) { if (argc <= 1) { LOG(INFO) << kUsage; return 0; diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 4a363cb..a690fd8 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -20,8 +20,9 @@ * \file rpc_env.cc * \brief Server environment of the RPC. */ -#include #include + +#include #ifndef _WIN32 #include #include @@ -30,46 +31,45 @@ #include #include namespace { - int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } -} +int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } +} // namespace #endif #include #include #include #include #include -#include -#include "../../src/support/util.h" #include "../../src/runtime/file_util.h" +#include "../../src/support/util.h" #include "rpc_env.h" namespace { - std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { - std::string untar_cmd; - untar_cmd.reserve(512); +std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { + std::string untar_cmd; + untar_cmd.reserve(512); #if defined(__linux__) || defined(__ANDROID__) - untar_cmd += "tar -C "; - untar_cmd += output_dir; - untar_cmd += " -zxf "; - untar_cmd += tar_file; + untar_cmd += "tar -C "; + untar_cmd += output_dir; + untar_cmd += " -zxf "; + untar_cmd += tar_file; #elif defined(_WIN32) - untar_cmd += "python -m tarfile -e "; - untar_cmd += tar_file; - untar_cmd += " "; - untar_cmd += output_dir; + untar_cmd += "python -m tarfile -e "; + untar_cmd += tar_file; + untar_cmd += " "; + untar_cmd += output_dir; #endif - return untar_cmd; - } + return untar_cmd; +} -}// Anonymous namespace +} // Anonymous namespace namespace tvm { namespace runtime { RPCEnv::RPCEnv() { #ifndef _WIN32 char cwd[PATH_MAX]; - if (char *rc = getcwd(cwd, sizeof(cwd))) { + if (char* rc = getcwd(cwd, sizeof(cwd))) { base_ = std::string(cwd) + "/rpc"; } else { base_ = "./rpc"; @@ -172,22 +172,20 @@ std::vector ListDir(const std::string& dirname) { * \param options The compiler options * \param cc The compiler */ -void LinuxShared(const std::string output, - const std::vector &files, - std::string options = "", - std::string cc = "g++") { - std::string cmd = cc; - cmd += " -shared -fPIC "; - cmd += " -o " + output; - for (auto f = files.begin(); f != files.end(); ++f) { - cmd += " " + *f; - } - cmd += " " + options; - std::string err_msg; - auto executed_status = support::Execute(cmd, &err_msg); - if (executed_status) { - LOG(FATAL) << err_msg; - } +void LinuxShared(const std::string output, const std::vector& files, + std::string options = "", std::string cc = "g++") { + std::string cmd = cc; + cmd += " -shared -fPIC "; + cmd += " -o " + output; + for (auto f = files.begin(); f != files.end(); ++f) { + cmd += " " + *f; + } + cmd += " " + options; + std::string err_msg; + auto executed_status = support::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; + } } #endif @@ -199,10 +197,8 @@ void LinuxShared(const std::string output, * \param options The compiler options * \param cc The compiler */ -void WindowsShared(const std::string& output, - const std::vector& files, - const std::string& options = "", - const std::string& cc = "clang") { +void WindowsShared(const std::string& output, const std::vector& files, + const std::string& options = "", const std::string& cc = "clang") { std::string cmd = cc; cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; cmd += " -o " + output; @@ -243,7 +239,7 @@ void CreateShared(const std::string& output, const std::vector& fil * \param fmt The format of file * \return Module The loaded module */ -Module Load(std::string *fileIn, const std::string& fmt) { +Module Load(std::string* fileIn, const std::string& fmt) { const std::string& file = *fileIn; if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll")) { return Module::LoadFromFile(file, fmt); diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index d046f6e..464b10a 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -25,6 +25,7 @@ #define TVM_APPS_CPP_RPC_ENV_H_ #include + #include namespace tvm { @@ -40,13 +41,13 @@ namespace runtime { * \param file The format of file * \return Module The loaded module */ -Module Load(std::string *path, const std::string& fmt = ""); +Module Load(std::string* path, const std::string& fmt = ""); /*! * \brief CleanDir Removes the files from the directory * \param dirname THe name of the directory */ -void CleanDir(const std::string &dirname); +void CleanDir(const std::string& dirname); /*! * \brief RPCEnv The RPC Environment parameters for c++ rpc server diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 2c8bdfa..2628ff7 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -32,9 +32,9 @@ #include #include -#include "../../src/support/socket.h" #include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" +#include "../../src/support/socket.h" #include "rpc_env.h" #include "rpc_server.h" #include "rpc_tracker_client.h" @@ -78,7 +78,7 @@ static std::string getNextString(std::stringstream* iss) { while (end < len && !isspace(str[end])) end++; iss->seekg(end); - return str.substr(start, end-start); + return str.substr(start, end - start); } #endif @@ -96,14 +96,15 @@ class RPCServer { /*! * \brief Constructor. */ - RPCServer(std::string host, int port, int port_end, std::string tracker_addr, - std::string key, std::string custom_addr) : - host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end), - tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), - custom_addr_(std::move(custom_addr)) - { - - } + RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key, + std::string custom_addr) + : host_(std::move(host)), + port_(port), + my_port_(0), + port_end_(port_end), + tracker_addr_(std::move(tracker_addr)), + key_(std::move(key)), + custom_addr_(std::move(custom_addr)) {} /*! * \brief Destructor. @@ -113,8 +114,7 @@ class RPCServer { // Free the resources tracker_sock_.Close(); listen_sock_.Close(); - } catch(...) { - + } catch (...) { } } @@ -213,7 +213,6 @@ class RPCServer { try { SpawnRPCChild(conn.sockfd, seconds(timeout)); } catch (const std::exception&) { - } auto dur = high_resolution_clock::now() - start_time; @@ -233,11 +232,8 @@ class RPCServer { * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient* tracker, - support::TCPSocket* conn_sock, - support::SockAddr* addr, - std::string* opts, - int ping_period = 2) { + void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock, + support::SockAddr* addr, std::string* opts, int ping_period = 2) { std::set old_keyset; std::string matchkey; @@ -249,7 +245,7 @@ class RPCServer { support::TCPSocket conn = listen_sock_.Accept(addr); int code = kRPCMagic; - CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); if (code != kRPCMagic) { conn.Close(); LOG(FATAL) << "Client connected is not TVM RPC server"; @@ -348,9 +344,9 @@ class RPCServer { #if defined(WIN32) /*! -* \brief ServerLoopFromChild The Server loop process. -* \param socket The socket information -*/ + * \brief ServerLoopFromChild The Server loop process. + * \param socket The socket information + */ void ServerLoopFromChild(SOCKET socket) { // Server loop tvm::support::TCPSocket sock(socket); @@ -367,10 +363,10 @@ void ServerLoopFromChild(SOCKET socket) { * \param host The hostname of the server, Default=0.0.0.0 * \param port The port of the RPC, Default=9090 * \param port_end The end search port of the RPC, Default=9199 - * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" - * \param key The key used to identify the device type in tracker. Default="" - * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" - * \param silent Whether run in silent mode. Default=True + * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 + * Default="" \param key The key used to identify the device type in tracker. Default="" \param + * custom_addr Custom IP Address to Report to RPC Tracker. Default="" \param silent Whether run in + * silent mode. Default=True */ void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr, std::string key, std::string custom_addr, bool silent) { @@ -379,13 +375,13 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server - RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr)); + RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), + std::move(custom_addr)); rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc.ServerCreate") -.set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); - }); +TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) { + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); +}); } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index db7c89d..0936c51 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -25,6 +25,7 @@ #define TVM_APPS_CPP_RPC_SERVER_H_ #include + #include "tvm/runtime/c_runtime_api.h" namespace tvm { @@ -49,13 +50,9 @@ void ServerLoopFromChild(SOCKET socket); * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ -void RPCServerCreate(std::string host = "", - int port = 9090, - int port_end = 9099, - std::string tracker_addr = "", - std::string key = "", - std::string custom_addr = "", - bool silent = true); +void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099, + std::string tracker_addr = "", std::string key = "", + std::string custom_addr = "", bool silent = true); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RPC_SERVER_H_ diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index 112f7d2..cdfb647 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -24,12 +24,12 @@ #ifndef TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ #define TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ -#include -#include #include +#include #include -#include +#include #include +#include #include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/support/socket.h" @@ -47,29 +47,28 @@ class TrackerClient { public: /*! * \brief Constructor. - */ - TrackerClient(const std::string& tracker_addr, - const std::string& key, + */ + TrackerClient(const std::string& tracker_addr, const std::string& key, const std::string& custom_addr) - : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr), - gen_(std::random_device{}()), dis_(0.0, 1.0) { - } + : tracker_addr_(tracker_addr), + key_(key), + custom_addr_(custom_addr), + gen_(std::random_device{}()), + dis_(0.0, 1.0) {} /*! * \brief Destructor. - */ + */ ~TrackerClient() { // Free the resources Close(); } /*! * \brief IsValid Check tracker is valid. - */ - bool IsValid() { - return (!tracker_addr_.empty() && !tracker_sock_.IsClosed()); - } + */ + bool IsValid() { return (!tracker_addr_.empty() && !tracker_sock_.IsClosed()); } /*! * \brief TryConnect Connect to tracker if the tracker address is valid. - */ + */ void TryConnect() { if (!tracker_addr_.empty() && (tracker_sock_.IsClosed())) { tracker_sock_ = ConnectWithRetry(); @@ -80,8 +79,8 @@ class TrackerClient { CHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker"; std::ostringstream ss; - ss << "[" << static_cast(TrackerCode::kUpdateInfo) - << ", {\"key\": \"server:"<< key_ << "\"}]"; + ss << "[" << static_cast(TrackerCode::kUpdateInfo) << ", {\"key\": \"server:" << key_ + << "\"}]"; tracker_sock_.SendBytes(ss.str()); // Receive status and validate @@ -91,20 +90,19 @@ class TrackerClient { } /*! * \brief Close Clean up tracker resources. - */ + */ void Close() { // close tracker resource if (!tracker_sock_.IsClosed()) { tracker_sock_.Close(); } } - /*! - * \brief ReportResourceAndGetKey Report resource to tracker. - * \param port listening port. - * \param matchkey Random match key output. - */ - void ReportResourceAndGetKey(int port, - std::string *matchkey) { + /*! + * \brief ReportResourceAndGetKey Report resource to tracker. + * \param port listening port. + * \param matchkey Random match key output. + */ + void ReportResourceAndGetKey(int port, std::string* matchkey) { if (!tracker_sock_.IsClosed()) { *matchkey = RandomKey(key_ + ":", old_keyset_); if (custom_addr_.empty()) { @@ -112,8 +110,8 @@ class TrackerClient { } std::ostringstream ss; - ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" - << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port + << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; tracker_sock_.SendBytes(ss.str()); @@ -121,7 +119,7 @@ class TrackerClient { std::string remote_status = tracker_sock_.RecvBytes(); CHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); } else { - *matchkey = key_; + *matchkey = key_; } } @@ -131,11 +129,9 @@ class TrackerClient { * \param port listening port. * \param ping_period Select wait time. * \param matchkey Random match key output. - */ - void WaitConnectionAndUpdateKey(support::TCPSocket listen_sock, - int port, - int ping_period, - std::string *matchkey) { + */ + void WaitConnectionAndUpdateKey(support::TCPSocket listen_sock, int port, int ping_period, + std::string* matchkey) { int unmatch_period_count = 0; int unmatch_timeout = 4; while (true) { @@ -155,9 +151,9 @@ class TrackerClient { // if match key not in pending key set // it means the key is acquired by a client but not used. if (pending_keys.find(*matchkey) == std::string::npos) { - unmatch_period_count += 1; + unmatch_period_count += 1; } else { - unmatch_period_count = 0; + unmatch_period_count = 0; } // regenerate match key if key is acquired but not used for a while if (unmatch_period_count * ping_period > unmatch_timeout + ping_period) { @@ -166,8 +162,8 @@ class TrackerClient { *matchkey = RandomKey(key_ + ":", old_keyset_); std::ostringstream ss; - ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" - << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port + << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; tracker_sock_.SendBytes(ss.str()); std::string remote_status = tracker_sock_.RecvBytes(); @@ -201,26 +197,25 @@ class TrackerClient { } auto period = (std::chrono::duration_cast( - std::chrono::system_clock::now() - tbegin)).count(); + std::chrono::system_clock::now() - tbegin)) + .count(); CHECK(period < timeout) << "Failed to connect to server" << addr.AsString(); - LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() - << " retry in " << retry_period << " seconds."; + LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() << " retry in " + << retry_period << " seconds."; std::this_thread::sleep_for(std::chrono::seconds(retry_period)); } } /*! - * \brief Random Generate a random number between 0 and 1. - * \return random float value. - */ - float Random() { - return dis_(gen_); - } + * \brief Random Generate a random number between 0 and 1. + * \return random float value. + */ + float Random() { return dis_(gen_); } /*! * \brief Generate a random key. * \param prefix The string prefix. * \return cmap The conflict map set. */ - std::string RandomKey(const std::string& prefix, const std::set &cmap) { + std::string RandomKey(const std::string& prefix, const std::set& cmap) { if (!cmap.empty()) { while (true) { std::string key = prefix + std::to_string(Random()); @@ -236,10 +231,9 @@ class TrackerClient { std::string key_; std::string custom_addr_; support::TCPSocket tracker_sock_; - std::set old_keyset_; + std::set old_keyset_; std::mt19937 gen_; std::uniform_real_distribution dis_; - }; } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/win32_process.cc b/apps/cpp_rpc/win32_process.cc index c6c72d7..bbf8367 100644 --- a/apps/cpp_rpc/win32_process.cc +++ b/apps/cpp_rpc/win32_process.cc @@ -20,15 +20,18 @@ #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif +#include "win32_process.h" + +#include +#include #include #include + #include #include -#include -#include #include -#include -#include "win32_process.h" +#include + #include "rpc_server.h" using namespace std::chrono; @@ -82,36 +85,36 @@ UniqueHandle MakeUniqueHandle(HANDLE handle) { */ SOCKET GetSocket(const std::string& mmap_path) { WSAPROTOCOL_INFO protocol_info; - + const std::string parent_event_name = mmap_path + kParent; const std::string child_event_name = mmap_path + kChild; // Open the events UniqueHandle parent_file_mapping_event; - if ((parent_file_mapping_event = MakeUniqueHandle(OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { + if ((parent_file_mapping_event = MakeUniqueHandle( + OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); } UniqueHandle child_file_mapping_event; - if ((child_file_mapping_event = MakeUniqueHandle(OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { + if ((child_file_mapping_event = MakeUniqueHandle( + OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); } - + // Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read - if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { - LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); + if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != + WAIT_OBJECT_0) { + LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); } - const UniqueHandle file_map = MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, - false, - mmap_path.c_str())); + const UniqueHandle file_map = + MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, false, mmap_path.c_str())); if (!file_map) { - LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); + LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); } - void* map_view = MapViewOfFile(file_map.get(), - FILE_MAP_READ | FILE_MAP_WRITE, - 0, 0, 0); + void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0); SOCKET sock_duplicated = INVALID_SOCKET; @@ -120,12 +123,8 @@ SOCKET GetSocket(const std::string& mmap_path) { UnmapViewOfFile(map_view); // Creates the duplicate socket, that was created in the parent - sock_duplicated = WSASocket(FROM_PROTOCOL_INFO, - FROM_PROTOCOL_INFO, - FROM_PROTOCOL_INFO, - &protocol_info, - 0, - 0); + sock_duplicated = + WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, &protocol_info, 0, 0); // Let the parent know we are finished dupicating the socket SetEvent(child_file_mapping_event.get()); @@ -135,7 +134,7 @@ SOCKET GetSocket(const std::string& mmap_path) { return sock_duplicated; } -}// Anonymous namespace +} // Anonymous namespace namespace tvm { namespace runtime { @@ -146,7 +145,7 @@ namespace runtime { */ void SpawnRPCChild(SOCKET fd, seconds timeout) { STARTUPINFOA startup_info; - + memset(&startup_info, 0, sizeof(startup_info)); startup_info.cb = sizeof(startup_info); @@ -157,13 +156,15 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { // Create an event to let the child know the socket info was set to the mmap file UniqueHandle parent_file_mapping_event; - if ((parent_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { + if ((parent_file_mapping_event = MakeUniqueHandle( + CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { LOG(FATAL) << "CreateEvent for parent file mapping failed"; } UniqueHandle child_file_mapping_event; // An event to let the parent know the socket info was read from the mmap file - if ((child_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { + if ((child_file_mapping_event = MakeUniqueHandle( + CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { LOG(FATAL) << "CreateEvent for child file mapping failed"; } @@ -181,35 +182,22 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { strcpy(command_line_ptr.get(), child_command_line.c_str()); PROCESS_INFORMATION child_process_info; - if (CreateProcessA(nullptr, - command_line_ptr.get(), - nullptr, - nullptr, - false, - CREATE_NO_WINDOW, - nullptr, - nullptr, - &startup_info, - &child_process_info)) { + if (CreateProcessA(nullptr, command_line_ptr.get(), nullptr, nullptr, false, CREATE_NO_WINDOW, + nullptr, nullptr, &startup_info, &child_process_info)) { // Child process and thread handles must be closed, so wrapped in RAII auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess); auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread); WSAPROTOCOL_INFO protocol_info; // Get info needed to duplicate the socket - if (WSADuplicateSocket(fd, - child_process_info.dwProcessId, - &protocol_info) == SOCKET_ERROR) { + if (WSADuplicateSocket(fd, child_process_info.dwProcessId, &protocol_info) == SOCKET_ERROR) { LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError(); } // Create a mmap file to store the info needed for duplicating the SOCKET in the child proc - UniqueHandle file_map = MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, - nullptr, - PAGE_READWRITE, - 0, - sizeof(WSAPROTOCOL_INFO), - file_map_path.c_str())); + UniqueHandle file_map = + MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, 0, + sizeof(WSAPROTOCOL_INFO), file_map_path.c_str())); if (!file_map) { LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); } @@ -225,11 +213,13 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { // Let child proc know the mmap file is ready to be read SetEvent(parent_file_mapping_event.get()); - + // Wait for the child to finish reading mmap file - if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { + if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != + WAIT_OBJECT_0) { TerminateProcess(child_process_handle.get(), 0); - LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child process."; + LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child " + "process."; } } else { TerminateProcess(child_process_handle.get(), 0); @@ -237,9 +227,8 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { } } - const DWORD process_timeout = timeout.count() - ? uint32_t(duration_cast(timeout).count()) - : INFINITE; + const DWORD process_timeout = + timeout.count() ? uint32_t(duration_cast(timeout).count()) : INFINITE; // Wait for child process to exit, or hit configured timeout if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) { @@ -251,8 +240,9 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { } } /*! - * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket - * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client + * socket \param mmap_path The memory mapped file path that will contain the information to + * duplicate the client socket from the parent */ void ChildProcSocketHandler(const std::string& mmap_path) { SOCKET socket; @@ -260,14 +250,12 @@ void ChildProcSocketHandler(const std::string& mmap_path) { // Set high thread priority to avoid the thread scheduler from // interfering with any measurements in the RPC server. SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); - + if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) { tvm::runtime::ServerLoopFromChild(socket); - } - else { + } else { LOG(FATAL) << "GetSocket() failed"; } - } } // namespace runtime } // namespace tvm \ No newline at end of file diff --git a/apps/cpp_rpc/win32_process.h b/apps/cpp_rpc/win32_process.h index 7d1a276..621444e 100644 --- a/apps/cpp_rpc/win32_process.h +++ b/apps/cpp_rpc/win32_process.h @@ -17,10 +17,10 @@ * under the License. */ - /*! - * \file win32_process.h - * \brief Win32 process code to mimic a POSIX fork() - */ +/*! + * \file win32_process.h + * \brief Win32 process code to mimic a POSIX fork() + */ #ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ #define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ #include @@ -34,8 +34,9 @@ namespace runtime { */ void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout); /*! - * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket - * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client + * socket \param mmap_path The memory mapped file path that will contain the information to + * duplicate the client socket from the parent */ void ChildProcSocketHandler(const std::string& mmap_path); } // namespace runtime diff --git a/apps/dso_plugin_module/plugin_module.cc b/apps/dso_plugin_module/plugin_module.cc index 7c3c5ac..eed11f8 100644 --- a/apps/dso_plugin_module/plugin_module.cc +++ b/apps/dso_plugin_module/plugin_module.cc @@ -20,10 +20,10 @@ * \brief Example code that can be compiled and loaded by TVM runtime. * \file plugin_module.cc */ -#include #include -#include #include +#include +#include namespace tvm_dso_plugin { @@ -31,24 +31,16 @@ using namespace tvm::runtime; class MyModuleNode : public ModuleNode { public: - explicit MyModuleNode(int value) - : value_(value) {} + explicit MyModuleNode(int value) : value_(value) {} - virtual const char* type_key() const final { - return "MyModule"; - } + virtual const char* type_key() const final { return "MyModule"; } - virtual PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) final { if (name == "add") { - return TypedPackedFunc([sptr_to_self, this](int value) { - return value_ + value; - }); + return TypedPackedFunc([sptr_to_self, this](int value) { return value_ + value; }); } else if (name == "mul") { - return TypedPackedFunc([sptr_to_self, this](int value) { - return value_ * value; - }); + return TypedPackedFunc([sptr_to_self, this](int value) { return value_ * value; }); } else { LOG(FATAL) << "unknown function " << name; return PackedFunc(); @@ -64,18 +56,14 @@ void CreateMyModule_(TVMArgs args, TVMRetValue* rv) { *rv = Module(make_object(value)); } -int SubOne_(int x) { - return x - 1; -} +int SubOne_(int x) { return x - 1; } // USE TVM_DLL_EXPORT_TYPED_PACKED_FUNC to export a // typed function as packed function. TVM_DLL_EXPORT_TYPED_FUNC(SubOne, SubOne_); // TVM_DLL_EXPORT_TYPED_PACKED_FUNC also works for lambda. -TVM_DLL_EXPORT_TYPED_FUNC(AddOne, [](int x) -> int { - return x + 1; -}); +TVM_DLL_EXPORT_TYPED_FUNC(AddOne, [](int x) -> int { return x + 1; }); // Use TVM_EXPORT_PACKED_FUNC to export a function with TVM_DLL_EXPORT_PACKED_FUNC(CreateMyModule, tvm_dso_plugin::CreateMyModule_); diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index a92d55f..87cb69b 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -17,16 +17,15 @@ * under the License. */ - /*! * \brief Example package that uses TVM. * \file tvm_ext.cc */ -#include +#include #include -#include #include -#include +#include +#include #include using namespace tvm; @@ -50,8 +49,7 @@ class NDSubClass : public tvm::runtime::NDArray { public: class SubContainer : public NDArray::Container { public: - SubContainer(int additional_info) : - additional_info_(additional_info) { + SubContainer(int additional_info) : additional_info_(additional_info) { type_index_ = SubContainer::RuntimeTypeIndex(); } int additional_info_{0}; @@ -74,14 +72,14 @@ class NDSubClass : public tvm::runtime::NDArray { data_ = GetObjectPtr(ptr); } - NDSubClass AddWith(const NDSubClass &other) const { - SubContainer *a = static_cast(get_mutable()); - SubContainer *b = static_cast(other.get_mutable()); + NDSubClass AddWith(const NDSubClass& other) const { + SubContainer* a = static_cast(get_mutable()); + SubContainer* b = static_cast(other.get_mutable()); CHECK(a != nullptr && b != nullptr); return NDSubClass(a->additional_info_ + b->additional_info_); } int get_additional_info() const { - SubContainer *self = static_cast(get_mutable()); + SubContainer* self = static_cast(get_mutable()); CHECK(self != nullptr); return self->additional_info_; } @@ -116,60 +114,48 @@ TVM_REGISTER_OBJECT_TYPE(IntVectorObj); namespace tvm_ext { -TVM_REGISTER_GLOBAL("tvm_ext.ivec_create") -.set_body([](TVMArgs args, TVMRetValue *rv) { - auto n = tvm::runtime::make_object(); - for (int i = 0; i < args.size(); ++i) { - n->vec.push_back(args[i].operator int()); - } - *rv = IntVector(n); - }); - -TVM_REGISTER_GLOBAL("tvm_ext.ivec_get") -.set_body([](TVMArgs args, TVMRetValue *rv) { - IntVector p = args[0]; - *rv = p->vec[args[1].operator int()]; - }); - - -TVM_REGISTER_GLOBAL("tvm_ext.bind_add") -.set_body([](TVMArgs args_, TVMRetValue *rv_) { - PackedFunc pf = args_[0]; - int b = args_[1]; - *rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue *rv) { - *rv = pf(b, args[0]); - }); - }); - -TVM_REGISTER_GLOBAL("tvm_ext.sym_add") -.set_body([](TVMArgs args, TVMRetValue *rv) { - Var a = args[0]; - Var b = args[1]; - *rv = a + b; - }); - -TVM_REGISTER_GLOBAL("device_api.ext_dev") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); - }); - -TVM_REGISTER_GLOBAL("tvm_ext.nd_create") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("tvm_ext.ivec_create").set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = tvm::runtime::make_object(); + for (int i = 0; i < args.size(); ++i) { + n->vec.push_back(args[i].operator int()); + } + *rv = IntVector(n); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.ivec_get").set_body([](TVMArgs args, TVMRetValue* rv) { + IntVector p = args[0]; + *rv = p->vec[args[1].operator int()]; +}); + +TVM_REGISTER_GLOBAL("tvm_ext.bind_add").set_body([](TVMArgs args_, TVMRetValue* rv_) { + PackedFunc pf = args_[0]; + int b = args_[1]; + *rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue* rv) { *rv = pf(b, args[0]); }); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.sym_add").set_body([](TVMArgs args, TVMRetValue* rv) { + Var a = args[0]; + Var b = args[1]; + *rv = a + b; +}); + +TVM_REGISTER_GLOBAL("device_api.ext_dev").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.nd_create").set_body([](TVMArgs args, TVMRetValue* rv) { int additional_info = args[0]; *rv = NDSubClass(additional_info); CHECK_EQ(rv->type_code(), kTVMNDArrayHandle); - }); -TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two").set_body([](TVMArgs args, TVMRetValue* rv) { NDSubClass a = args[0]; NDSubClass b = args[1]; *rv = a.AddWith(b); }); -TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info").set_body([](TVMArgs args, TVMRetValue* rv) { NDSubClass a = args[0]; *rv = a.get_additional_info(); }); @@ -177,17 +163,14 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info") } // namespace tvm_ext // External function exposed to runtime. -extern "C" float TVMTestAddOne(float y) { - return y + 1; -} +extern "C" float TVMTestAddOne(float y) { return y + 1; } // This callback approach allows extension allows tvm to extract // This way can be helpful when we want to use a header only // minimum version of TVM Runtime. extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) { - const PackedFunc& fregister = - *static_cast(pregister); - auto mul = [](TVMArgs args, TVMRetValue *rv) { + const PackedFunc& fregister = *static_cast(pregister); + auto mul = [](TVMArgs args, TVMRetValue* rv) { int x = args[0]; int y = args[1]; *rv = x * y; diff --git a/apps/howto_deploy/cpp_deploy.cc b/apps/howto_deploy/cpp_deploy.cc index a386dff..b7a60f4 100644 --- a/apps/howto_deploy/cpp_deploy.cc +++ b/apps/howto_deploy/cpp_deploy.cc @@ -21,11 +21,12 @@ * \brief Example code on load and run TVM module.s * \file cpp_deploy.cc */ -#include #include #include -#include #include +#include + +#include void Verify(tvm::runtime::Module mod, std::string fname) { // Get the function from the module. @@ -52,10 +53,8 @@ void Verify(tvm::runtime::Module mod, std::string fname) { int device_type = kDLCPU; int device_id = 0; int64_t shape[1] = {10}; - TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, - device_type, device_id, &x); - TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, - device_type, device_id, &y); + TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &x); + TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &y); for (int i = 0; i < shape[0]; ++i) { static_cast(x->data)[i] = i; } @@ -72,8 +71,7 @@ void Verify(tvm::runtime::Module mod, std::string fname) { int main(void) { // Normally we can directly - tvm::runtime::Module mod_dylib = - tvm::runtime::Module::LoadFromFile("lib/test_addone_dll.so"); + tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile("lib/test_addone_dll.so"); LOG(INFO) << "Verify dynamic loading from test_addone_dll.so"; Verify(mod_dylib, "addone"); // For libraries that are directly packed as system lib and linked together with the app diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index 81bab49..37e3968 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -39,15 +39,15 @@ */ #include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/cpu_device_api.cc" -#include "../../src/runtime/workspace_pool.cc" +#include "../../src/runtime/file_util.cc" #include "../../src/runtime/library_module.cc" #include "../../src/runtime/module.cc" -#include "../../src/runtime/registry.cc" -#include "../../src/runtime/file_util.cc" -#include "../../src/runtime/threading_backend.cc" -#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/ndarray.cc" #include "../../src/runtime/object.cc" +#include "../../src/runtime/registry.cc" +#include "../../src/runtime/thread_pool.cc" +#include "../../src/runtime/threading_backend.cc" +#include "../../src/runtime/workspace_pool.cc" // NOTE: all the files after this are optional modules // that you can include remove, depending on how much feature you use. diff --git a/apps/ios_rpc/tvmrpc/AppDelegate.h b/apps/ios_rpc/tvmrpc/AppDelegate.h index 0c54a47..a810aea 100644 --- a/apps/ios_rpc/tvmrpc/AppDelegate.h +++ b/apps/ios_rpc/tvmrpc/AppDelegate.h @@ -25,7 +25,6 @@ @interface AppDelegate : UIResponder -@property (strong, nonatomic) UIWindow *window; - +@property(strong, nonatomic) UIWindow* window; @end diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.h b/apps/ios_rpc/tvmrpc/TVMRuntime.h index 96a5c1b..f6a6dc6 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.h +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,8 @@ #define DMLC_LOG_CUSTOMIZE 1 #define TVM_METAL_RUNTIME 1 -#include #include +#include #include namespace tvm { @@ -52,8 +52,7 @@ using FEventHandler = std::function(data) - maxLength:size]; + ssize_t nbytes = [stream_ write:reinterpret_cast(data) maxLength:size]; if (nbytes < 0) { - NSLog(@"%@",[stream_ streamError].localizedDescription); + NSLog(@"%@", [stream_ streamError].localizedDescription); throw dmlc::Error("Stream error"); } return nbytes; @@ -83,8 +79,8 @@ class NSStreamChannel final : public RPCChannel { NSOutputStream* stream_; }; -FEventHandler CreateServerEventHandler( - NSOutputStream *outputStream, std::string name, std::string remote_key) { +FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string name, + std::string remote_key) { std::unique_ptr ch(new NSStreamChannel(outputStream)); std::shared_ptr sess = RPCSession::Create(std::move(ch), name, remote_key); return [sess](const std::string& in_bytes, int flag) { @@ -103,9 +99,7 @@ struct RPCEnv { } } // Get Path. - std::string GetPath(const std::string& file_name) { - return base_ + file_name; - } + std::string GetPath(const std::string& file_name) { return base_ + file_name; } private: std::string base_; @@ -115,49 +109,44 @@ void LaunchSyncServer() { // only load dylib from frameworks. NSBundle* bundle = [NSBundle mainBundle]; NSString* base = [bundle privateFrameworksPath]; - NSString* path = [base stringByAppendingPathComponent: @"tvm/rpc_config.txt"]; + NSString* path = [base stringByAppendingPathComponent:@"tvm/rpc_config.txt"]; std::string name = [path UTF8String]; std::ifstream fs(name, std::ios::in); std::string url, key; int port; - CHECK(fs >> url >> port >> key) - << "Invalid RPC config file " << name; - RPCConnect(url, port, "server:" + key) - ->ServerLoop(); + CHECK(fs >> url >> port >> key) << "Invalid RPC config file " << name; + RPCConnect(url, port, "server:" + key)->ServerLoop(); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCEnv env; - *rv = env.GetPath(args[0]); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string name = args[0]; - std::string fmt = GetFileFormat(name, ""); - NSString* base; - if (fmt == "dylib") { - // only load dylib from frameworks. - NSBundle* bundle = [NSBundle mainBundle]; - base = [[bundle privateFrameworksPath] - stringByAppendingPathComponent: @"tvm"]; - } else { - // Load other modules in tempdir. - base = NSTemporaryDirectory(); - } - NSString* path = [base stringByAppendingPathComponent: - [NSString stringWithUTF8String:name.c_str()]]; - name = [path UTF8String]; - *rv = Module::LoadFromFile(name, fmt); - LOG(INFO) << "Load module from " << name << " ..."; - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + *rv = env.GetPath(args[0]); +}); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + std::string fmt = GetFileFormat(name, ""); + NSString* base; + if (fmt == "dylib") { + // only load dylib from frameworks. + NSBundle* bundle = [NSBundle mainBundle]; + base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"]; + } else { + // Load other modules in tempdir. + base = NSTemporaryDirectory(); + } + NSString* path = + [base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]]; + name = [path UTF8String]; + *rv = Module::LoadFromFile(name, fmt); + LOG(INFO) << "Load module from " << name << " ..."; +}); } // namespace runtime } // namespace tvm @implementation TVMRuntime -+(void) launchSyncServer { ++ (void)launchSyncServer { tvm::runtime::LaunchSyncServer(); } diff --git a/apps/ios_rpc/tvmrpc/ViewController.h b/apps/ios_rpc/tvmrpc/ViewController.h index 3a3c928..b188a87 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.h +++ b/apps/ios_rpc/tvmrpc/ViewController.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,12 +24,11 @@ #import #include "TVMRuntime.h" -@interface ViewController : UIViewController -{ +@interface ViewController : UIViewController { // input socket stream - NSInputStream *inputStream_; + NSInputStream* inputStream_; // output socket stream - NSOutputStream *outputStream_; + NSOutputStream* outputStream_; // temporal receive buffer. std::string recvBuffer_; // Whether connection is initialized. @@ -46,11 +45,11 @@ tvm::runtime::FEventHandler handler_; } -@property (weak, nonatomic) IBOutlet UITextField *proxyURL; -@property (weak, nonatomic) IBOutlet UITextField *proxyPort; -@property (weak, nonatomic) IBOutlet UITextField *proxyKey; -@property (weak, nonatomic) IBOutlet UILabel *statusLabel; -@property (weak, nonatomic) IBOutlet UITextView *infoText; +@property(weak, nonatomic) IBOutlet UITextField* proxyURL; +@property(weak, nonatomic) IBOutlet UITextField* proxyPort; +@property(weak, nonatomic) IBOutlet UITextField* proxyKey; +@property(weak, nonatomic) IBOutlet UILabel* statusLabel; +@property(weak, nonatomic) IBOutlet UITextView* infoText; - (IBAction)connect:(id)sender; - (IBAction)disconnect:(id)sender; diff --git a/apps/ios_rpc/tvmrpc/ViewController.mm b/apps/ios_rpc/tvmrpc/ViewController.mm index 0f76110..6c618c4 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.mm +++ b/apps/ios_rpc/tvmrpc/ViewController.mm @@ -21,12 +21,12 @@ * \file ViewController.mm */ -#include #import "ViewController.h" +#include @implementation ViewController -- (void)stream:(NSStream *)strm handleEvent:(NSStreamEvent)event { +- (void)stream:(NSStream*)strm handleEvent:(NSStreamEvent)event { std::string buffer; switch (event) { case NSStreamEventOpenCompleted: { @@ -45,7 +45,7 @@ break; } case NSStreamEventErrorOccurred: { - NSLog(@"%@",[strm streamError].localizedDescription); + NSLog(@"%@", [strm streamError].localizedDescription); break; } case NSStreamEventEndEncountered: { @@ -64,8 +64,7 @@ constexpr int kRPCMagic = 0xff271; if (!initialized_) { int code; - size_t nbytes = [inputStream_ read:reinterpret_cast(&code) - maxLength:sizeof(code)]; + size_t nbytes = [inputStream_ read:reinterpret_cast(&code) maxLength:sizeof(code)]; if (nbytes != sizeof(code)) { self.infoText.text = @"Fail to receive remote confirmation code."; [self close]; @@ -115,7 +114,7 @@ - (void)onWriteAvailable { if (initSendPtr_ < initBytes_.length()) { initSendPtr_ += [outputStream_ write:reinterpret_cast(&initBytes_[initSendPtr_]) - maxLength:(initBytes_.length() - initSendPtr_)]; + maxLength:(initBytes_.length() - initSendPtr_)]; } if (initialized_) { try { @@ -148,13 +147,10 @@ // Initialize the network. CFReadStreamRef readStream; CFWriteStreamRef writeStream; - CFStreamCreatePairWithSocketToHost( - NULL, - (__bridge CFStringRef) self.proxyURL.text, - [self.proxyPort.text intValue], - &readStream, &writeStream); - inputStream_ = (__bridge_transfer NSInputStream *)readStream; - outputStream_ = (__bridge_transfer NSOutputStream *)writeStream; + CFStreamCreatePairWithSocketToHost(NULL, (__bridge CFStringRef)self.proxyURL.text, + [self.proxyPort.text intValue], &readStream, &writeStream); + inputStream_ = (__bridge_transfer NSInputStream*)readStream; + outputStream_ = (__bridge_transfer NSOutputStream*)writeStream; [inputStream_ setDelegate:self]; [outputStream_ setDelegate:self]; [inputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; diff --git a/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm b/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm index c4a6f8b..eb538f0 100644 --- a/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm +++ b/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm @@ -32,16 +32,15 @@ @implementation tvmrpcLauncher - (void)setUp { - [super setUp]; + [super setUp]; } - (void)tearDown { - [super tearDown]; + [super tearDown]; } - (void)testRPC { [TVMRuntime launchSyncServer]; } - @end diff --git a/apps/rocm_rpc/rocm_runtime_pack.cc b/apps/rocm_rpc/rocm_runtime_pack.cc index a137a9b..de5c504 100644 --- a/apps/rocm_rpc/rocm_runtime_pack.cc +++ b/apps/rocm_rpc/rocm_runtime_pack.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,7 @@ #define TVM_USE_MIOPEN 1 #define __HIP_PLATFORM_HCC__ 1 -#include "../../src/runtime/rocm/rocm_device_api.cc" -#include "../../src/runtime/rocm/rocm_module.cc" #include "../../src/contrib/miopen/conv_forward.cc" #include "../../src/contrib/miopen/miopen_utils.cc" +#include "../../src/runtime/rocm/rocm_device_api.cc" +#include "../../src/runtime/rocm/rocm_module.cc" diff --git a/golang/src/gotvm.cc b/golang/src/gotvm.cc index af6e430..f599c40 100644 --- a/golang/src/gotvm.cc +++ b/golang/src/gotvm.cc @@ -24,14 +24,17 @@ // Standard includes #include +#include #include #include #include #include -#include // golang string compatible definition -typedef struct { char *p; int n; } _gostring_; +typedef struct { + char* p; + int n; +} _gostring_; #include #ifdef __cplusplus @@ -39,8 +42,8 @@ extern "C" { #endif // TVM runtime C interface -#include #include +#include /*! * \brief Convert native char array to _gostring_ structure. @@ -53,7 +56,7 @@ extern "C" { * \return _gostring_ object corresponding to native char array. * Caller is responsible to free the memory block allocated here. */ -static _gostring_ _native_to_gostring(const char *p, size_t l) { +static _gostring_ _native_to_gostring(const char* p, size_t l) { _gostring_ ret; ret.p = reinterpret_cast(malloc(l)); if (NULL == ret.p) { @@ -72,10 +75,10 @@ static _gostring_ _native_to_gostring(const char *p, size_t l) { * \param off is the offset in the string object. * \param v is the uint64_t value which need to embed into given string. */ -static void putuint64(std::string *s, size_t off, uint64_t v) { - for (int i = 0; i < 8; i++) { - (*s)[off + i] = (v >> (i * 8)) & 0xff; - } +static void putuint64(std::string* s, size_t off, uint64_t v) { + for (int i = 0; i < 8; i++) { + (*s)[off + i] = (v >> (i * 8)) & 0xff; + } } // TVM runtime C interface wrappers @@ -86,7 +89,7 @@ static void putuint64(std::string *s, size_t off, uint64_t v) { * \return char pointer to TVM-VERSION */ const char* _TVM_VERSION(void) { - const char *version = TVM_VERSION; + const char* version = TVM_VERSION; return version; } @@ -101,16 +104,16 @@ const char* _TVM_VERSION(void) { */ int _TVMFuncListGlobalNames(_gostring_* names) { int names_size; - char **names_array; + char** names_array; int result; - result = TVMFuncListGlobalNames(&names_size, (char const ***)&names_array); + result = TVMFuncListGlobalNames(&names_size, (char const***)&names_array); if (result) { return result; } size_t tot = 8; - for (int ii = 0; ii < names_size ; ++ii) { + for (int ii = 0; ii < names_size; ++ii) { tot += 8 + strlen(names_array[ii]); } @@ -118,7 +121,7 @@ int _TVMFuncListGlobalNames(_gostring_* names) { str.resize(tot); putuint64(&str, 0, names_size); size_t off = 8; - for (int64_t ii = 0; ii < names_size ; ++ii) { + for (int64_t ii = 0; ii < names_size; ++ii) { putuint64(&str, off, strlen(names_array[ii])); off += 8; str.replace(off, strlen(names_array[ii]), names_array[ii]); @@ -143,9 +146,9 @@ int _TVMFuncListGlobalNames(_gostring_* names) { * \param array index in native array. */ void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) { - TVMValue *from_p = reinterpret_cast(from_ptr); - TVMValue *to_p = reinterpret_cast(to_ptr); - memcpy(to_p+ind, from_p, sizeof(TVMValue)); + TVMValue* from_p = reinterpret_cast(from_ptr); + TVMValue* to_p = reinterpret_cast(to_ptr); + memcpy(to_p + ind, from_p, sizeof(TVMValue)); } /*! @@ -157,9 +160,9 @@ void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) { * \param array index in native array. */ void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) { - TVMValue *from_p = reinterpret_cast(from_ptr); - TVMValue *to_p = reinterpret_cast(to_ptr); - memcpy(to_p, from_p+ind, sizeof(TVMValue)); + TVMValue* from_p = reinterpret_cast(from_ptr); + TVMValue* to_p = reinterpret_cast(to_ptr); + memcpy(to_p, from_p + ind, sizeof(TVMValue)); } extern int goTVMCallback(void*, void*, int, void*, void*); @@ -175,21 +178,16 @@ extern int goTVMCallback(void*, void*, int, void*, void*); * * \returns the error status as TVM_DLL */ -int _TVMCallback(TVMValue* args, - int* type_codes, - int num_args, - TVMRetValueHandle ret, +int _TVMCallback(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, void* resource_handle) { - return goTVMCallback(args, type_codes, num_args, ret, resource_handle); + return goTVMCallback(args, type_codes, num_args, ret, resource_handle); } /*! * _TVMPackedCFuncFinalizer is finalizer for packed function system. * */ -void _TVMPackedCFuncFinalizer(void* resource_handle) { - return; -} +void _TVMPackedCFuncFinalizer(void* resource_handle) { return; } /*! * /brief _ConvertFunction creates a packed function for with given resource handle. @@ -199,11 +197,8 @@ void _TVMPackedCFuncFinalizer(void* resource_handle) { * * /return is an int indicating the return status. */ -int _ConvertFunction(void* fptr, TVMFunctionHandle *fhandle) { - int ret = TVMFuncCreateFromCFunc(_TVMCallback, - fptr, - _TVMPackedCFuncFinalizer, - fhandle); +int _ConvertFunction(void* fptr, TVMFunctionHandle* fhandle) { + int ret = TVMFuncCreateFromCFunc(_TVMCallback, fptr, _TVMPackedCFuncFinalizer, fhandle); return ret; } diff --git a/golang/src/gotvm.h b/golang/src/gotvm.h index 12b594b..a053e39 100644 --- a/golang/src/gotvm.h +++ b/golang/src/gotvm.h @@ -32,11 +32,11 @@ extern "C" { #endif +#include #include #include #include #include -#include // Some type definitions for golang "C" typedef void* native_voidp; diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc index 416067dc..644249f 100644 --- a/golang/src/tvm_runtime_pack.cc +++ b/golang/src/tvm_runtime_pack.cc @@ -23,15 +23,15 @@ */ #include "src/runtime/c_runtime_api.cc" #include "src/runtime/cpu_device_api.cc" -#include "src/runtime/workspace_pool.cc" +#include "src/runtime/file_util.cc" #include "src/runtime/library_module.cc" #include "src/runtime/module.cc" -#include "src/runtime/registry.cc" -#include "src/runtime/file_util.cc" -#include "src/runtime/threading_backend.cc" -#include "src/runtime/thread_pool.cc" #include "src/runtime/ndarray.cc" #include "src/runtime/object.cc" +#include "src/runtime/registry.cc" +#include "src/runtime/thread_pool.cc" +#include "src/runtime/threading_backend.cc" +#include "src/runtime/workspace_pool.cc" // NOTE: all the files after this are optional modules // that you can include remove, depending on how much feature you use. diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 340da7f..4623b5e 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -24,14 +24,14 @@ #ifndef TVM_ARITH_ANALYZER_H_ #define TVM_ARITH_ANALYZER_H_ -#include -#include #include +#include +#include -#include -#include -#include #include +#include +#include +#include namespace tvm { /*! \brief namespace of arithmetic analysis. */ @@ -130,9 +130,7 @@ class ConstIntBoundAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, - const ConstIntBound& info, - bool override = false); + TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool override = false); /*! * \brief Bind variable to a range. * @@ -221,9 +219,7 @@ class ModularSetAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, - const ModularSet& info, - bool override = false); + TVM_DLL void Update(const Var& var, const ModularSet& info, bool override = false); private: friend class Analyzer; @@ -262,9 +258,7 @@ class RewriteSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false); std::function EnterConstraint(const PrimExpr& constraint); @@ -298,9 +292,7 @@ class CanonicalSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false); private: friend class Analyzer; diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index b1cb779..df1a9e7 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -23,9 +23,9 @@ #ifndef TVM_ARITH_BOUND_H_ #define TVM_ARITH_BOUND_H_ -#include -#include #include +#include +#include #include #include @@ -38,10 +38,10 @@ class Tensor; } namespace arith { -using tir::Var; -using tir::VarNode; using tir::Domain; using tir::Stmt; +using tir::Var; +using tir::VarNode; /*! * \brief Deduce the bound of the target variable in a expression, @@ -58,8 +58,7 @@ using tir::Stmt; * The deduce bound must implies e for all value in relax_map * \return An integer set that always satisfies the condition. */ -IntSet DeduceBound(PrimExpr v, PrimExpr cond, - const Map& hint_map, +IntSet DeduceBound(PrimExpr v, PrimExpr cond, const Map& hint_map, const Map& relax_map); /*! * \brief Same as DeduceBound with unordered_map signature. @@ -83,9 +82,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, * \param consider_stores If stores are considered. * \return The domain that covers all the calls or provides within the given statement. */ -Domain DomainTouched(const Stmt& body, - const tir::Buffer& buffer, - bool consider_loads, +Domain DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads, bool consider_stores); } // namespace arith diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index ab73b07..7cd74d2 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -26,14 +26,15 @@ #include #include + #include namespace tvm { namespace arith { +using tir::IterVar; using tir::Var; using tir::VarNode; -using tir::IterVar; //----------------------------------------------- // Integer set data structure. @@ -44,12 +45,7 @@ using tir::IterVar; /*! * \brief Sign type of an integer expression. */ -enum SignType { - kPositive, - kNegative, - kZero, - kUnknown -}; +enum SignType { kPositive, kNegative, kZero, kUnknown }; /*! * \brief Base class of all Integer set containers. @@ -77,9 +73,7 @@ class IntSet : public ObjectRef { * \brief access the internal node container * \return the pointer to the internal node container */ - const IntSetNode* operator->() const { - return static_cast(get()); - } + const IntSetNode* operator->() const { return static_cast(get()); } /*! * \brief Find a range that covers the region. * \param max_range The range to be covered. @@ -175,8 +169,7 @@ IntSet EvalSet(PrimExpr e, const Map& dom_map); * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map); +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over * all the possible conditional values in dom_map. @@ -185,8 +178,7 @@ IntSet EvalSet(PrimExpr e, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet EvalSet(Range r, - const Map& dom_map); +IntSet EvalSet(Range r, const Map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over @@ -196,8 +188,7 @@ IntSet EvalSet(Range r, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map); +IntSet EvalSet(IntSet s, const std::unordered_map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -205,8 +196,7 @@ IntSet EvalSet(IntSet s, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(Range r, - const std::unordered_map& dom_map); +IntSet EvalSet(Range r, const std::unordered_map& dom_map); /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; /*! @@ -217,9 +207,8 @@ using ExprIntSetMap = std::unordered_map& dom_map); +ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, + const std::unordered_map& dom_map); /*! * \brief Create an union set of all sets diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 57f3af4..ae18cab 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -26,15 +26,16 @@ #include #include + #include #include namespace tvm { namespace arith { +using tir::IterVar; using tir::Var; using tir::VarNode; -using tir::IterVar; /*! * \brief Represent integer constrains including (integer) variables, their ranges and @@ -60,10 +61,8 @@ class IntConstraintsNode : public Object { } bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const { - return - equal(variables, other->variables) && - equal(ranges, other->ranges) && - equal(relations, other->relations); + return equal(variables, other->variables) && equal(ranges, other->ranges) && + equal(relations, other->relations); } void SHashReduce(SHashReducer hash_reduce) const { @@ -90,9 +89,7 @@ class IntConstraints : public ObjectRef { * \param relations The linear relations between the variables * (either equations or inequalities) */ - TVM_DLL IntConstraints(Array variables, - Map ranges, - Array relations); + TVM_DLL IntConstraints(Array variables, Map ranges, Array relations); TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); }; @@ -126,11 +123,8 @@ class IntConstraintsTransformNode : public Object { } bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const { - return - equal(src, other->src) && - equal(dst, other->dst) && - equal(src_to_dst, other->src_to_dst) && - equal(dst_to_src, other->dst_to_src); + return equal(src, other->src) && equal(dst, other->dst) && + equal(src_to_dst, other->src_to_dst) && equal(dst_to_src, other->dst_to_src); } void SHashReduce(SHashReducer hash_reduce) const { @@ -161,10 +155,8 @@ class IntConstraintsTransform : public ObjectRef { * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src, * e.g., {m -> a, n -> -b} */ - TVM_DLL IntConstraintsTransform(IntConstraints src, - IntConstraints dst, - Map src_to_dst, - Map dst_to_src); + TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst, + Map src_to_dst, Map dst_to_src); TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; @@ -176,20 +168,16 @@ class IntConstraintsTransform : public ObjectRef { * NOTE: Although in standard Smith Normal Form the diagonal elements satisfy * s_i | s_{i+1} (| means divides), the implement here does not guarantee it. * TODO(yzhliu): From sergei-grechanik: - * computing the proper Smith normal form may improve stability of automatic differentiation - * (generating the same gradient code for slightly different but equivalent input code - * U_{mxm} and V_{nxn} are invertible matrices. - * This function modifies \p S to be S_{mxn}, \p V to be V_{nxn}, - * \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. - * \param S the original A_{mxn}, it will be modified to S_{mxn} - * \param V an identity matrix, it will be modified to V_{nxn} - * \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} - * \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1} + * computing the proper Smith normal form may improve stability of automatic + * differentiation (generating the same gradient code for slightly different but equivalent input + * code U_{mxm} and V_{nxn} are invertible matrices. This function modifies \p S to be S_{mxn}, \p V + * to be V_{nxn}, \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. \param S the original + * A_{mxn}, it will be modified to S_{mxn} \param V an identity matrix, it will be modified to + * V_{nxn} \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} \param y the y + * in A x = y. it will be modified to U_{mxm} y_{mx1} */ -void SmithNormalFormDiag(std::vector> *S, - std::vector> *V, - std::vector* x, - std::vector *y); +void SmithNormalFormDiag(std::vector>* S, std::vector>* V, + std::vector* x, std::vector* y); /*! * \brief Solve linear equations. @@ -201,7 +189,7 @@ void SmithNormalFormDiag(std::vector> *S, * as well as inequalities inferred from the \p system_to_solve. * You can get the mapping from the original variables to the solution via ret->src_to_dst. */ -IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve); +IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve); } // namespace arith } // namespace tvm diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index d3ba3e9..301d956 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -24,8 +24,8 @@ #ifndef TVM_ARITH_PATTERN_H_ #define TVM_ARITH_PATTERN_H_ -#include #include +#include #include namespace tvm { @@ -38,8 +38,7 @@ namespace arith { * \param vars List of variables to be used in detection. * \return [coeff[i]] if it is possible, empty array if it is not. */ -Array DetectLinearEquation(const PrimExpr& e, - const Array& vars); +Array DetectLinearEquation(const PrimExpr& e, const Array& vars); /*! * \brief Detect if expression corresponds to clip bound of the vars @@ -49,8 +48,7 @@ Array DetectLinearEquation(const PrimExpr& e, * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value * return empty if the e does not match the pattern. */ -Array DetectClipBound(const PrimExpr& e, - const Array& vars); +Array DetectClipBound(const PrimExpr& e, const Array& vars); } // namespace arith } // namespace tvm diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index e6d4427..1d4d493 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -29,47 +29,42 @@ #ifndef TVM_DRIVER_DRIVER_API_H_ #define TVM_DRIVER_DRIVER_API_H_ +#include #include -#include #include -#include +#include #include #include -#include -#include #include #include +#include +#include namespace tvm { /*! -* \brief Build an IRModule given a schedule, args and binds -* \param sch The schedule to lower. -* \param args The arguments to the function. -* \param name The name of the lowered function. -* \param binds Buffer assignments. -* \param config The build configuration. -* \return The result module. -*/ -TVM_DLL IRModule lower( - te::Schedule sch, - const Array& args, - const std::string& name, - const std::unordered_map& binds, - const BuildConfig& config); + * \brief Build an IRModule given a schedule, args and binds + * \param sch The schedule to lower. + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \param config The build configuration. + * \return The result module. + */ +TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, + const BuildConfig& config); /*! -* \brief Build a device and host module for a specific target from an IRModule. -* \param funcs The functions to be built. -* \param target The target device to build for. -* \param target_host The target for building host code. To use the default, pass Target() -* \param config The build configuration. -* \return The built module. -*/ -TVM_DLL runtime::Module build(const IRModule& funcs, - const Target& target, - const Target& target_host, - const BuildConfig& config); + * \brief Build a device and host module for a specific target from an IRModule. + * \param funcs The functions to be built. + * \param target The target device to build for. + * \param target_host The target for building host code. To use the default, pass Target() + * \param config The build configuration. + * \return The built module. + */ +TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target, + const Target& target_host, const BuildConfig& config); /*! * \brief Build a device and host module for a specific target from a map @@ -81,8 +76,7 @@ TVM_DLL runtime::Module build(const IRModule& funcs, * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map& input, - const Target& target_host, +TVM_DLL runtime::Module build(const Map& input, const Target& target_host, const BuildConfig& config); /*! @@ -95,8 +89,7 @@ TVM_DLL runtime::Module build(const Map& input, * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map& input, - const Target& target_host, +TVM_DLL runtime::Module build(const Map& input, const Target& target_host, const BuildConfig& config); } // namespace tvm diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index f9cb622..9d45dc1 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -27,11 +27,12 @@ #ifndef TVM_IR_ADT_H_ #define TVM_IR_ADT_H_ -#include -#include -#include #include #include +#include +#include +#include + #include namespace tvm { @@ -66,9 +67,7 @@ class ConstructorNode : public RelayExprNode { bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const { // Use namehint for now to be consistent with the legacy relay impl // TODO(tvm-team) revisit, need to check the type var. - return - equal(name_hint, other->name_hint) && - equal(inputs, other->inputs); + return equal(name_hint, other->name_hint) && equal(inputs, other->inputs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -92,9 +91,7 @@ class Constructor : public RelayExpr { * \param inputs The input types. * \param belong_to The data type var the constructor will construct. */ - TVM_DLL Constructor(std::string name_hint, - Array inputs, - GlobalTypeVar belong_to); + TVM_DLL Constructor(std::string name_hint, Array inputs, GlobalTypeVar belong_to); TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode); }; @@ -122,10 +119,8 @@ class TypeDataNode : public TypeNode { } bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const { - return - equal.DefEqual(header, other->header) && - equal.DefEqual(type_vars, other->type_vars) && - equal(constructors, other->constructors); + return equal.DefEqual(header, other->header) && equal.DefEqual(type_vars, other->type_vars) && + equal(constructors, other->constructors); } void SHashReduce(SHashReducer hash_reduce) const { @@ -157,9 +152,7 @@ class TypeData : public Type { * \param type_vars type variables. * \param constructors constructors field. */ - TVM_DLL TypeData(GlobalTypeVar header, - Array type_vars, - Array constructors); + TVM_DLL TypeData(GlobalTypeVar header, Array type_vars, Array constructors); TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode); }; diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index d12f1b8..819aafa 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -50,12 +50,12 @@ #include #include -#include -#include #include -#include #include +#include +#include #include +#include namespace tvm { /*! @@ -63,34 +63,30 @@ namespace tvm { * \param ClassName The name of the class. * \param TypeKey The type key to be used by the TVM node system. */ -#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ - static constexpr const char* _type_key = TypeKey; \ - TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ - template \ +#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ + static constexpr const char* _type_key = TypeKey; \ + TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ + template \ void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) - /*! * \brief Declare an attribute field. * \param FieldName The field name. */ -#define TVM_ATTR_FIELD(FieldName) \ - __fvisit__(#FieldName, &FieldName) - +#define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName) /*! * \brief Create a NodeRef type that represents null. * \tparam TNodeRef the type to be created. * \return A instance that will represent None. */ -template +template inline TObjectRef NullValue() { - static_assert(TObjectRef::_type_is_nullable, - "Can only get NullValue for nullable types"); + static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types"); return TObjectRef(ObjectPtr(nullptr)); } -template<> +template <> inline DataType NullValue() { return DataType(DataType::kHandle, 0, 0); } @@ -101,8 +97,7 @@ struct AttrError : public dmlc::Error { * \brief constructor * \param msg error message */ - explicit AttrError(const std::string &msg) - : dmlc::Error(msg) {} + explicit AttrError(const std::string& msg) : dmlc::Error(msg) {} }; /*! @@ -154,13 +149,13 @@ class BaseAttrsNode : public Object { * \param args The postional arguments in the form * [key0, value0, key1, value1, ..., key_n, value_n] */ - template - inline void InitBySeq(Args&& ...args); + template + inline void InitBySeq(Args&&... args); /*! * \brief Print readible docstring to ostream, add newline. * \param os the stream to print the docstring to. */ - inline void PrintDocString(std::ostream &os) const; // NOLINT(*) + inline void PrintDocString(std::ostream& os) const; // NOLINT(*) /*! * \brief Visit attributes that do not equal the default value. * @@ -212,9 +207,7 @@ class DictAttrsNode : public BaseAttrsNode { return equal(dict, other->dict); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dict); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); } // implementations void VisitAttrs(AttrVisitor* v) final; @@ -239,7 +232,6 @@ class DictAttrs : public Attrs { */ TVM_DLL explicit DictAttrs(Map dict); - TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -252,18 +244,16 @@ using runtime::TVMArgValue; struct AttrNopEntry { using TSelf = AttrNopEntry; - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { - return *this; - } - template + TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } + template TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { return *this; } - template + template TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } - template + template TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } @@ -272,10 +262,8 @@ struct AttrNopEntry { // Wrapper for normal visitor. class AttrNormalVisitor { public: - explicit AttrNormalVisitor(AttrVisitor* visitor) - : visitor_(visitor) { - } - template + explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {} + template AttrNopEntry operator()(const char* key, T* value) { visitor_->Visit(key, value); return AttrNopEntry(); @@ -290,16 +278,13 @@ class AttrsSEqualVisitor { bool result_{true}; // constructor AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal) - : lhs_(lhs), rhs_(rhs), equal_(equal) { - } - template + : lhs_(lhs), rhs_(rhs), equal_(equal) {} + template AttrNopEntry operator()(const char* key, T* lhs_value) { if (!result_) return AttrNopEntry(); - const T* rhs_value = - reinterpret_cast( - reinterpret_cast(rhs_) + - (reinterpret_cast(lhs_value) - - reinterpret_cast(lhs_))); + const T* rhs_value = reinterpret_cast( + reinterpret_cast(rhs_) + + (reinterpret_cast(lhs_value) - reinterpret_cast(lhs_))); if (!equal_(*lhs_value, *rhs_value)) { result_ = false; } @@ -314,10 +299,9 @@ class AttrsSEqualVisitor { class AttrsSHashVisitor { public: - explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) - : hash_reducer_(hash_reducer) {} + explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {} - template + template AttrNopEntry operator()(const char* key, T* value) { hash_reducer_(*value); return AttrNopEntry(); @@ -328,7 +312,7 @@ class AttrsSHashVisitor { }; // helper entry that does initialization, set default. -template +template struct AttrInitEntry { // The attributes using TSelf = AttrInitEntry; @@ -344,34 +328,31 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ - << "\' during initialization"; + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization"; throw AttrError(os.str()); } } // override fields. // This function sets the lower bound of the attribute TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { - if (this->value_missing_) return *this; + if (this->value_missing_) return *this; const T& val = *value_; if (begin > val) { std::ostringstream os; os << type_key_ << "." << key_ << ": " - << "value " << val - << " is smaller than the lower bound " << begin; + << "value " << val << " is smaller than the lower bound " << begin; throw AttrError(os.str()); } return *this; } // This function sets the upper bound of the attribute TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { - if (this->value_missing_) return *this; + if (this->value_missing_) return *this; const T& val = *value_; if (val > end) { std::ostringstream os; os << type_key_ << "." << key_ << ": " - << "value " << val - << " is bigger than the upper bound " << end; + << "value " << val << " is bigger than the upper bound " << end; throw AttrError(os.str()); } return *this; @@ -383,19 +364,17 @@ struct AttrInitEntry { value_missing_ = false; return *this; } - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { - return *this; - } + TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } }; // Template function to allow smart conversion // from Expr types into the constants. -template +template inline void SetValue(T* ptr, const TVMArgValue& val) { *ptr = val.operator T(); } -template +template inline void SetIntValue(T* ptr, const TVMArgValue& val) { if (val.type_code() == kDLInt) { *ptr = static_cast(val.value().v_int64); @@ -405,7 +384,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { } } -template<> +template <> inline void SetValue(std::string* ptr, const TVMArgValue& val) { if (val.type_code() == kTVMStr) { *ptr = val.operator std::string(); @@ -414,7 +393,7 @@ inline void SetValue(std::string* ptr, const TVMArgValue& val) { } } -template<> +template <> inline void SetValue(double* ptr, const TVMArgValue& val) { if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { *ptr = val.operator double(); @@ -430,36 +409,34 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { } } } -template<> +template <> inline void SetValue(int* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } -template<> +template <> inline void SetValue(int64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } -template<> +template <> inline void SetValue(uint64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } -template<> +template <> inline void SetValue(bool* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } // Visitor for value initialization -template +template class AttrInitVisitor { public: // Counter of number of matched attributes during visit. // This is used to decide if there is additional unmatched attributes. size_t hit_count_{0}; // constructor - AttrInitVisitor(const char* type_key, FFind ffind) - : type_key_(type_key), ffind_(ffind) { - } + AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {} - template + template AttrInitEntry operator()(const char* key, T* value) { TVMArgValue val; AttrInitEntry opt; @@ -482,10 +459,8 @@ class AttrInitVisitor { FFind ffind_; }; -template -inline AttrInitVisitor CreateInitVisitor( - const char* type_key, - FFind ffind) { +template +inline AttrInitVisitor CreateInitVisitor(const char* type_key, FFind ffind) { return AttrInitVisitor(type_key, ffind); } @@ -493,47 +468,47 @@ inline AttrInitVisitor CreateInitVisitor( * \brief Helper struct to get the type name known to tvm. * \tparam T the type we are interested in. */ -template +template struct TypeName { static constexpr const char* value = T::ContainerType::_type_key; }; -template<> +template <> struct TypeName { static constexpr const char* value = "int"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "int64"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "uint64_t"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "DataType"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "str"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "bool"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "handle"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "double"; }; @@ -542,25 +517,23 @@ class AttrDocEntry { public: using TSelf = AttrDocEntry; - explicit AttrDocEntry(ObjectPtr info) - : info_(info) { - } + explicit AttrDocEntry(ObjectPtr info) : info_(info) {} TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { info_->description = str; return *this; } - template + template TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { std::ostringstream os; os << info_->type_info << ", default=" << value; info_->type_info = os.str(); return *this; } - template + template TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { return *this; } - template + template TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { return *this; } @@ -571,10 +544,9 @@ class AttrDocEntry { class AttrDocVisitor { public: - template + template AttrDocEntry operator()(const char* key, T* v) { - ObjectPtr info - = make_object(); + ObjectPtr info = make_object(); info->name = key; info->type_info = TypeName::value; fields_.push_back(AttrFieldInfo(info)); @@ -589,7 +561,7 @@ class AttrExistVisitor { std::string key_; bool exist_{false}; - template + template AttrNopEntry operator()(const char* key, T* v) { if (exist_) return AttrNopEntry(); if (key == key_) exist_ = true; @@ -597,12 +569,11 @@ class AttrExistVisitor { } }; -template +template struct AttrTriggerNonDefaultEntry { using TSelf = AttrTriggerNonDefaultEntry; // constructor - AttrTriggerNonDefaultEntry( - AttrVisitor* visitor, const char* key, T* data) + AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data) : visitor_(visitor), key_(key), data_(data) {} ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION { @@ -610,37 +581,28 @@ struct AttrTriggerNonDefaultEntry { visitor_->Visit(key_, data_); } } - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { - return *this; - } + TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } TSelf& set_default(const T& value) { if (tvm::StructuralEqual()(value, *data_)) { trigger_ = false; } return *this; } - TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { - return *this; - } - TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { - return *this; - } + TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } + TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } private: AttrVisitor* visitor_; - const char * key_; - T *data_; + const char* key_; + T* data_; bool trigger_{true}; }; class AttrNonDefaultVisitor { public: - explicit AttrNonDefaultVisitor(AttrVisitor* visitor) - : visitor_(visitor) { - } - template - AttrTriggerNonDefaultEntry - operator()(const char* key, T* value) { + explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {} + template + AttrTriggerNonDefaultEntry operator()(const char* key, T* value) { return AttrTriggerNonDefaultEntry(visitor_, key, value); } @@ -655,7 +617,7 @@ class AttrNonDefaultVisitor { * * \tparam DerivedType The final attribute type. */ -template +template class AttrsNode : public BaseAttrsNode { public: void VisitAttrs(AttrVisitor* v) { @@ -695,7 +657,7 @@ class AttrsNode : public BaseAttrsNode { CHECK_EQ(args.type_codes[i], kTVMStr); kwargs[args[i].operator std::string()] = args[i + 1]; } - auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) { + auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) { auto it = kwargs.find(key); if (it != kwargs.end()) { *val = it->second; @@ -715,8 +677,7 @@ class AttrsNode : public BaseAttrsNode { self()->__VisitAttrs__(visitor); if (!visitor.exist_) { std::ostringstream os; - os << DerivedType::_type_key - << ": does not have field \'" << visitor.key_ + os << DerivedType::_type_key << ": does not have field \'" << visitor.key_ << "\', Possible fields:\n"; os << "----------------\n"; this->PrintDocString(os); @@ -746,21 +707,18 @@ class AttrsNode : public BaseAttrsNode { private: DerivedType* self() const { - return const_cast( - static_cast(this)); + return const_cast(static_cast(this)); } }; - -template -inline void BaseAttrsNode::InitBySeq(Args&& ...args) { - runtime::PackedFunc pf([this](const TVMArgs& args, TVMRetValue *rv) { - this->InitByPackedArgs(args); - }); +template +inline void BaseAttrsNode::InitBySeq(Args&&... args) { + runtime::PackedFunc pf( + [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); }); pf(std::forward(args)...); } -inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) +inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*) Array entry = this->ListFieldInfo(); for (AttrFieldInfo info : entry) { os << info->name << " : " << info->type_info << '\n'; diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 67492ab..320d6e3 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -47,9 +47,7 @@ class EnvFuncNode : public Object { /*! \brief constructor */ EnvFuncNode() {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const { // name uniquely identifies the env function. @@ -76,15 +74,13 @@ class EnvFunc : public ObjectRef { EnvFunc() {} explicit EnvFunc(ObjectPtr n) : ObjectRef(n) {} /*! \return The internal global function pointer */ - const EnvFuncNode* operator->() const { - return static_cast(get()); - } + const EnvFuncNode* operator->() const { return static_cast(get()); } /*! * \brief Invoke the function. * \param args The arguments * \returns The return value. */ - template + template runtime::TVMRetValue operator()(Args&&... args) const { const EnvFuncNode* n = operator->(); CHECK(n != nullptr); @@ -104,7 +100,7 @@ class EnvFunc : public ObjectRef { /*! * \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc" */ -template +template class TypedEnvFunc; /*! @@ -116,7 +112,7 @@ class TypedEnvFunc; * \tparam Args The argument signature of the function. * \sa EnvFunc */ -template +template class TypedEnvFunc : public ObjectRef { public: /*! \brief short hand for this function type */ @@ -133,9 +129,7 @@ class TypedEnvFunc : public ObjectRef { return *this; } /*! \return The internal global function pointer */ - const EnvFuncNode* operator->() const { - return static_cast(get()); - } + const EnvFuncNode* operator->() const { return static_cast(get()); } /*! * \brief Invoke the function. * \param args The arguments @@ -144,8 +138,8 @@ class TypedEnvFunc : public ObjectRef { R operator()(Args... args) const { const EnvFuncNode* n = operator->(); CHECK(n != nullptr); - return runtime::detail::typed_packed_call_dispatcher - ::run(n->func, std::forward(args)...); + return runtime::detail::typed_packed_call_dispatcher::run(n->func, + std::forward(args)...); } /*! \brief specify container node */ using ContainerType = EnvFuncNode; diff --git a/include/tvm/ir/error.h b/include/tvm/ir/error.h index 94064ae..c6576c8 100644 --- a/include/tvm/ir/error.h +++ b/include/tvm/ir/error.h @@ -24,13 +24,13 @@ #ifndef TVM_IR_ERROR_H_ #define TVM_IR_ERROR_H_ -#include #include +#include -#include -#include #include +#include #include +#include namespace tvm { /*! @@ -51,7 +51,7 @@ namespace tvm { */ struct ErrorBuilder { public: - template + template ErrorBuilder& operator<<(const T& val) { // NOLINT(*) stream_ << val; return *this; @@ -78,12 +78,12 @@ class Error : public dmlc::Error { * \brief construct error from error builder. * \param err The error builder */ - Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*) + Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*) /*! * \brief copy constructor. * \param other The other ereor. */ - Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*) + Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*) /*! * \brief default constructor. */ Error() : dmlc::Error(""), span(nullptr) {} @@ -173,9 +173,7 @@ class ErrorReporter { */ void RenderErrors(const IRModule& module, bool use_color = true); - inline bool AnyErrors() { - return errors_.size() != 0; - } + inline bool AnyErrors() { return errors_.size() != 0; } private: std::vector errors_; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index fba35a9..717ffb1 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,14 +24,15 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ -#include -#include -#include #include #include -#include +#include +#include +#include + #include #include +#include #include namespace tvm { @@ -111,9 +112,7 @@ class PrimExpr : public BaseExpr { TVM_DLL PrimExpr(float value); // NOLINT(*) /*! \return the data type of this expression. */ - DataType dtype() const { - return static_cast(get())->dtype; - } + DataType dtype() const { return static_cast(get())->dtype; } TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); @@ -160,7 +159,7 @@ class RelayExprNode : public BaseExprNode { * \return The corresponding TTypeNode pointer. * \tparam The specific TypeNode we look for. */ - template + template inline const TTypeNode* type_as() const; static constexpr const char* _type_key = "RelayExpr"; @@ -199,9 +198,7 @@ class GlobalVarNode : public RelayExprNode { bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { // name matters for global var. - return - equal(name_hint, other->name_hint) && - equal.FreeVarEqualImpl(this, other); + return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -322,35 +319,21 @@ class FloatImm : public PrimExpr { */ class Bool : public IntImm { public: - explicit Bool(bool value) - : IntImm(DataType::Bool(), value) { - } - Bool operator!() const { - return Bool((*this)->value == 0); - } - operator bool() const { - return (*this)->value != 0; - } + explicit Bool(bool value) : IntImm(DataType::Bool(), value) {} + Bool operator!() const { return Bool((*this)->value == 0); } + operator bool() const { return (*this)->value != 0; } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); }; // Overload operators to make sure we have the most fine grained types. -inline Bool operator||(const Bool& a, bool b) { - return Bool(a.operator bool() || b); -} -inline Bool operator||(bool a, const Bool& b) { - return Bool(a || b.operator bool()); -} +inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); } +inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); } inline Bool operator||(const Bool& a, const Bool& b) { return Bool(a.operator bool() || b.operator bool()); } -inline Bool operator&&(const Bool& a, bool b) { - return Bool(a.operator bool() && b); -} -inline Bool operator&&(bool a, const Bool& b) { - return Bool(a && b.operator bool()); -} +inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); } +inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); } inline Bool operator&&(const Bool& a, const Bool& b) { return Bool(a.operator bool() && b.operator bool()); } @@ -384,8 +367,7 @@ class Integer : public IntImm { * \tparam Enum The enum type. * \param value The enum value. */ - template::value>::type> + template ::value>::type> explicit Integer(Enum value) : Integer(static_cast(value)) { static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); @@ -402,8 +384,7 @@ class Integer : public IntImm { * \brief convert to int64_t */ operator int64_t() const { - CHECK(data_ != nullptr) - << " Trying to reference a null Integer"; + CHECK(data_ != nullptr) << " Trying to reference a null Integer"; return (*this)->value; } // comparators @@ -411,16 +392,12 @@ class Integer : public IntImm { if (data_ == nullptr) return Bool(false); return Bool((*this)->value == other); } - Bool operator!=(int other) const { - return !(*this == other); - } - template::value>::type> + Bool operator!=(int other) const { return !(*this == other); } + template ::value>::type> Bool operator==(Enum other) const { return *this == static_cast(other); } - template::value>::type> + template ::value>::type> Bool operator!=(Enum other) const { return *this != static_cast(other); } @@ -482,24 +459,21 @@ class Range : public ObjectRef { // implementataions inline const Type& RelayExprNode::checked_type() const { - CHECK(checked_type_.defined()) - << "internal error: the type checker has " - << "not populated the checked_type " - << "field for " - << GetRef(this); + CHECK(checked_type_.defined()) << "internal error: the type checker has " + << "not populated the checked_type " + << "field for " << GetRef(this); return this->checked_type_; } -template +template inline const TTypeNode* RelayExprNode::type_as() const { static_assert(std::is_base_of::value, "TType must be a special case of type"); CHECK(checked_type_.defined()) << "Type inference for this Expr has not completed. Try to call infer_type pass."; const TTypeNode* node = checked_type_.as(); - CHECK(node != nullptr) - << "Expected type to be " << TTypeNode::_type_key - << ", but get " << checked_type_->GetTypeKey(); + CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get " + << checked_type_->GetTypeKey(); return node; } @@ -507,7 +481,7 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -template<> +template <> struct PackedFuncValueConverter { // common rule for both RetValue and ArgValue. static PrimExpr From(const TVMPODValue_& val) { diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index b4a9ed0..00626e6 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -24,12 +24,12 @@ #ifndef TVM_IR_FUNCTION_H_ #define TVM_IR_FUNCTION_H_ -#include #include +#include #include -#include -#include +#include +#include namespace tvm { @@ -96,7 +96,7 @@ class BaseFuncNode : public RelayExprNode { * * \endcode */ - template + template Optional GetAttr( const std::string& attr_key, Optional default_value = Optional(nullptr)) const { @@ -111,9 +111,8 @@ class BaseFuncNode : public RelayExprNode { } } // variant that uses TObjectRef to enable implicit conversion to default value. - template - Optional GetAttr( - const std::string& attr_key, TObjectRef default_value) const { + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { return GetAttr(attr_key, Optional(default_value)); } /*! @@ -180,12 +179,9 @@ class BaseFunc : public RelayExpr { * * \endcode */ -template::value>::type> -inline TFunc WithAttr(TFunc func, - const std::string& attr_key, - ObjectRef attr_value) { +template ::value>::type> +inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = func.CopyOnWrite(); diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index ae78383..ba9a62a 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -24,15 +24,16 @@ #ifndef TVM_IR_MODULE_H_ #define TVM_IR_MODULE_H_ -#include +#include #include #include -#include +#include #include + #include -#include #include #include +#include namespace tvm { class IRModule; @@ -102,8 +103,7 @@ class IRModuleNode : public Object { * * It does not do type checking as AddTypeDef does. */ - TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, - const TypeData& type, + TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false); /*! @@ -303,9 +303,7 @@ class IRModule : public ObjectRef { * * \returns The constructed module */ - static IRModule Empty() { - return IRModule(Map()); - } + static IRModule Empty() { return IRModule(Map()); } /*! * \brief Construct a module from a standalone expression. * @@ -318,10 +316,9 @@ class IRModule : public ObjectRef { * * \returns A module with expr set as the main function. */ - TVM_DLL static IRModule FromExpr( - const RelayExpr& expr, - const Map& global_funcs = {}, - const Map& type_definitions = {}); + TVM_DLL static IRModule FromExpr(const RelayExpr& expr, + const Map& global_funcs = {}, + const Map& type_definitions = {}); /*! * \brief Parse text format source file into an IRModule. @@ -362,8 +359,7 @@ TVM_DLL String PrettyPrint(const ObjectRef& node); * \sa PrettyPrint. * \return The text representation. */ -TVM_DLL String AsText(const ObjectRef& node, - bool show_meta_data = true, +TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 48cf61d..7fafb5a 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -27,10 +27,10 @@ #include #include -#include #include #include #include +#include #include #include @@ -227,8 +227,7 @@ class OpRegistry { * \param description Description of the argument. * \return reference to self. */ - inline OpRegistry& add_argument(const std::string& name, - const std::string& type, + inline OpRegistry& add_argument(const std::string& name, const std::string& type, const std::string& description); /*! * \brief Attach the type function corresponding to the return type. @@ -239,16 +238,14 @@ class OpRegistry { */ inline OpRegistry& add_type_rel( const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func); + runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> + type_rel_func); /*! * \brief Set the the attrs type key and index to be AttrsType. * \tparam AttrsType the attribute type to b set. * \return reference to self. */ - template + template inline OpRegistry& set_attrs_type(); /*! * \brief Set the num_inputs @@ -306,9 +303,7 @@ class OpRegistry { // return internal pointer to op. inline OpNode* get(); // update the attribute OpMap - TVM_DLL void UpdateAttr(const std::string& key, - runtime::TVMRetValue value, - int plevel); + TVM_DLL void UpdateAttr(const std::string& key, runtime::TVMRetValue value, int plevel); }; /*! @@ -410,8 +405,7 @@ class OpMap { #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) // internal macros to make -#define TVM_OP_REGISTER_VAR_DEF \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op +#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op /*! * \def TVM_REGISTER_OP @@ -428,38 +422,28 @@ class OpMap { * * \endcode */ -#define TVM_REGISTER_OP(OpName) \ - TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ - ::tvm::OpRegistry::Registry() \ - ->__REGISTER_OR_GET__(OpName) \ - .set_name() +#define TVM_REGISTER_OP(OpName) \ + TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name() // implementations -inline const OpNode* Op::operator->() const { - return static_cast(get()); -} +inline const OpNode* Op::operator->() const { return static_cast(get()); } template inline OpMap Op::GetAttr(const std::string& key) { return OpMap(Op::GetGenericAttr(key)); } -inline bool Op::HasAttr(const std::string& key) { - return Op::HasGenericAttr(key); -} +inline bool Op::HasAttr(const std::string& key) { return Op::HasGenericAttr(key); } -inline OpNode* OpRegistry::get() { - return const_cast(op_.operator->()); -} +inline OpNode* OpRegistry::get() { return const_cast(op_.operator->()); } -inline OpRegistry& OpRegistry::describe( - const std::string& descr) { // NOLINT(*) +inline OpRegistry& OpRegistry::describe(const std::string& descr) { // NOLINT(*) get()->description = descr; return *this; } -inline OpRegistry& OpRegistry::add_argument(const std::string& name, - const std::string& type, +inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::string& type, const std::string& description) { auto n = make_object(); n->name = name; @@ -471,10 +455,8 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, inline OpRegistry& OpRegistry::add_type_rel( const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func) { + runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> + type_rel_func) { auto func_name = std::string("tvm.relay.type_relation.") + rel_name; TypeRelationFn env_type_rel_func; @@ -482,8 +464,7 @@ inline OpRegistry& OpRegistry::add_type_rel( auto env_func = EnvFunc::Get(func_name); env_type_rel_func = env_func; } else { - runtime::Registry::Register(func_name) - .set_body(type_rel_func.packed()); + runtime::Registry::Register(func_name).set_body(type_rel_func.packed()); auto env_func = EnvFunc::Get(func_name); env_type_rel_func = env_func; } @@ -517,13 +498,9 @@ inline OpRegistry& OpRegistry::add_type_rel( // A common example is sum(x, axis), where the choice of axis // can affect the type of the function. TypeConstraint type_rel = - TypeRelation(env_type_rel_func, - ty_call_args, - arg_types.size(), - Attrs()); + TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs()); - auto func_type = - FuncType(arg_types, out_param, type_params, {type_rel}); + auto func_type = FuncType(arg_types, out_param, type_params, {type_rel}); get()->op_type = func_type; @@ -535,7 +512,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } -template +template inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) get()->attrs_type_key = AttrsType::_type_key; get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); @@ -567,13 +544,11 @@ inline int GenericOpMap::count(const Op& op) const { } } -inline const runtime::TVMRetValue& -GenericOpMap::operator[](const Op& op) const { +inline const runtime::TVMRetValue& GenericOpMap::operator[](const Op& op) const { CHECK(op.defined()); const uint32_t idx = op->index_; CHECK(idx < data_.size() && data_[idx].second != 0) - << "Attribute " << attr_name_ << " has not been registered for Operator " - << op->name; + << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name; return data_[idx].first; } @@ -614,14 +589,12 @@ inline ValueType OpMap::operator[](const Op& op) const { } template -inline ValueType OpMap::get(const Op& op, - ValueType def_value) const { +inline ValueType OpMap::get(const Op& op, ValueType def_value) const { return map_.get(op, def_value); } template -inline ValueType OpMap::get(const RelayExpr& expr, - ValueType def_value) const { +inline ValueType OpMap::get(const RelayExpr& expr, ValueType def_value) const { return map_.get(expr, def_value); } diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 411b733..1ed6848 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -24,8 +24,9 @@ #ifndef TVM_IR_SPAN_H_ #define TVM_IR_SPAN_H_ -#include #include +#include + #include namespace tvm { @@ -92,10 +93,8 @@ class SpanNode : public Object { } bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { - return - equal(source, other->source) && - equal(lineno, other->lineno) && - equal(col_offset, other->col_offset); + return equal(source, other->source) && equal(lineno, other->lineno) && + equal(col_offset, other->col_offset); } TVM_DLL static Span make(SourceName source, int lineno, int col_offset); @@ -104,7 +103,6 @@ class SpanNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); }; - class Span : public ObjectRef { public: TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); diff --git a/include/tvm/ir/tensor_type.h b/include/tvm/ir/tensor_type.h index 489ea64..7a70025 100644 --- a/include/tvm/ir/tensor_type.h +++ b/include/tvm/ir/tensor_type.h @@ -24,8 +24,8 @@ #ifndef TVM_IR_TENSOR_TYPE_H_ #define TVM_IR_TENSOR_TYPE_H_ -#include #include +#include namespace tvm { /*! @@ -75,9 +75,7 @@ class TensorTypeNode : public BaseTensorTypeNode { } bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const { - return - equal(shape, other->shape) && - equal(dtype, other->dtype); + return equal(shape, other->shape) && equal(dtype, other->dtype); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 3680f6d..558d2da 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -56,11 +56,12 @@ #ifndef TVM_IR_TRANSFORM_H_ #define TVM_IR_TRANSFORM_H_ -#include -#include -#include #include #include +#include +#include +#include + #include #include @@ -74,9 +75,7 @@ class PassInfo; * */ using TraceFunc = - runtime::TypedPackedFunc; + runtime::TypedPackedFunc; /*! * \brief PassContextNode contains the information that a pass can rely on, @@ -117,7 +116,6 @@ class PassContextNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); }; - /*! * \brief PassContext that is used to configure the pass behavior. * @@ -226,9 +224,7 @@ class PassInfo : public ObjectRef { * \param name Name of the pass. * \param required The passes that are required to perform the current pass. */ - TVM_DLL PassInfo(int opt_level, - std::string name, - Array required); + TVM_DLL PassInfo(int opt_level, std::string name, Array required); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -264,8 +260,7 @@ class PassNode : public Object { * * \return The transformed module. */ - virtual IRModule operator()(IRModule mod, - const PassContext& pass_ctx) const = 0; + virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; void VisitAttrs(AttrVisitor* v) {} @@ -303,8 +298,7 @@ class Pass : public ObjectRef { * * \return The transformed module. */ - IRModule operator()(IRModule mod, - const PassContext& pass_ctx) const { + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); CHECK(node != nullptr); return node->operator()(std::move(mod), pass_ctx); @@ -352,12 +346,9 @@ class Sequential : public Pass { * * \return The created module pass. */ -TVM_DLL Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const Array& required); - +TVM_DLL Pass +CreateModulePass(const runtime::TypedPackedFunc& pass_func, + int opt_level, const std::string& name, const Array& required); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index a20dbdd..ed64841 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -49,11 +49,12 @@ #ifndef TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_ -#include -#include -#include -#include #include +#include +#include +#include +#include + #include namespace tvm { @@ -109,23 +110,18 @@ class PrimTypeNode : public TypeNode { */ runtime::DataType dtype; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); } bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const { return equal(dtype, other->dtype); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } static constexpr const char* _type_key = "PrimType"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); }; - /* * \brief Managed reference to PrimTypeNode. * \sa PrimTypeNode @@ -141,7 +137,6 @@ class PrimType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); }; - /*! * \brief Low-level raw pointer type. * @@ -159,17 +154,13 @@ class PointerTypeNode : public TypeNode { */ Type element_type; - void VisitAttrs(AttrVisitor* v) { - v->Visit("element_type", &element_type); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("element_type", &element_type); } bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const { return equal(element_type, other->element_type); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(element_type); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(element_type); } static constexpr const char* _type_key = "PointerType"; TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); @@ -190,7 +181,6 @@ class PointerType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); }; - /*! \brief Possible kinds of TypeVars. */ enum TypeKind : int { kType = 0, @@ -238,9 +228,7 @@ class TypeVarNode : public TypeNode { } bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const { - return - equal(kind, other->kind) && - equal.FreeVarEqualImpl(this, other); + return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -290,9 +278,7 @@ class GlobalTypeVarNode : public TypeNode { bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const { // name matters for now in global type var. - return - equal(name_hint, other->name_hint) && - equal.FreeVarEqualImpl(this, other); + return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -340,9 +326,7 @@ class TupleTypeNode : public TypeNode { return equal(fields, other->fields); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(fields); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } static constexpr const char* _type_key = "TupleType"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); @@ -372,9 +356,7 @@ class TupleType : public Type { /*! * \return a type that represents void. */ -inline Type VoidType() { - return TupleType::Empty(); -} +inline Type VoidType() { return TupleType::Empty(); } /*! * \brief Check whether the tyep represents void. @@ -439,11 +421,8 @@ class FuncTypeNode : public TypeNode { bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const { // type params first as they defines type vars. - return - equal.DefEqual(type_params, other->type_params) && - equal(arg_types, other->arg_types) && - equal(ret_type, other->ret_type) && - equal(type_constraints, other->type_constraints); + return equal.DefEqual(type_params, other->type_params) && equal(arg_types, other->arg_types) && + equal(ret_type, other->ret_type) && equal(type_constraints, other->type_constraints); } void SHashReduce(SHashReducer hash_reduce) const { @@ -471,9 +450,7 @@ class FuncType : public Type { * \param type_constraints The type constraints. * \sa FuncTypeNode for more docs about these fields. */ - TVM_DLL FuncType(Array arg_types, - Type ret_type, - Array type_params, + TVM_DLL FuncType(Array arg_types, Type ret_type, Array type_params, Array type_constraints); TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); @@ -500,14 +477,10 @@ class IncompleteTypeNode : public TypeNode { } bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const { - return - equal(kind, other->kind) && - equal.FreeVarEqualImpl(this, other); + return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(kind); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(kind); } static constexpr const char* _type_key = "IncompleteType"; TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); @@ -528,7 +501,6 @@ class IncompleteType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); }; - /*! * \brief Reference Type High-level Relay IR. * @@ -550,9 +522,7 @@ class RelayRefTypeNode : public TypeNode { return equal(value, other->value); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } // Keep the relay prefix in the type as this type is specific // to the relay itself. diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 5507191..2a6314c 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -25,11 +25,12 @@ #define TVM_IR_TYPE_FUNCTOR_H_ #include -#include #include +#include + #include -#include #include +#include namespace tvm { @@ -37,16 +38,13 @@ template class TypeFunctor; // functions to be overriden. -#define TYPE_FUNCTOR_DEFAULT \ +#define TYPE_FUNCTOR_DEFAULT \ { return VisitTypeDefault_(op, std::forward(args)...); } - -#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitType_(static_cast(n.get()), \ - std::forward(args)...); \ - }); +#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitType_(static_cast(n.get()), std::forward(args)...); \ + }); template class TypeFunctor { @@ -65,9 +63,7 @@ class TypeFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Type& n, Args... args) { - return VisitType(n, std::forward(args)...); - } + R operator()(const Type& n, Args... args) { return VisitType(n, std::forward(args)...); } /*! * \brief The functor call. * \param n The expression node. @@ -80,8 +76,7 @@ class TypeFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitType_(const TensorTypeNode* op, - Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; @@ -126,8 +121,7 @@ class TypeFunctor { /*! * \brief A type visitor that recursively visit types. */ -class TVM_DLL TypeVisitor : - public TypeFunctor { +class TVM_DLL TypeVisitor : public TypeFunctor { public: void VisitType_(const TypeVarNode* op) override; void VisitType_(const IncompleteTypeNode* op) override; @@ -146,8 +140,7 @@ class TVM_DLL TypeVisitor : /*! * \brief TypeMutator that mutates expressions. */ -class TVM_DLL TypeMutator : - public TypeFunctor { +class TVM_DLL TypeMutator : public TypeFunctor { public: Type VisitType(const Type& t) override; Type VisitType_(const TypeVarNode* op) override; diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index 06bcb72..dbd241a 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -24,10 +24,10 @@ #ifndef TVM_IR_TYPE_RELATION_H_ #define TVM_IR_TYPE_RELATION_H_ -#include -#include -#include #include +#include +#include +#include namespace tvm { @@ -51,9 +51,7 @@ class TypeCallNode : public TypeNode { } bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(args, other->args); + return equal(func, other->func) && equal(args, other->args); } void SHashReduce(SHashReducer hash_reduce) const { @@ -105,7 +103,7 @@ class TypeReporterNode : public Object { * \return false if assertation can be proven to have failed * true if solver can still proceed. */ - TVM_DLL virtual bool Assert(const PrimExpr& cond)= 0; + TVM_DLL virtual bool Assert(const PrimExpr& cond) = 0; /*! * \brief assert shape expression equals each other. * \param lhs The left operand. @@ -141,11 +139,9 @@ class TypeReporterNode : public Object { class TypeReporter : public ObjectRef { public: TypeReporter() {} - explicit TypeReporter(ObjectPtr n) : ObjectRef(n) { - } + explicit TypeReporter(ObjectPtr n) : ObjectRef(n) {} TypeReporterNode* operator->() const { - return const_cast( - static_cast(get())); + return const_cast(static_cast(get())); } using ContainerType = TypeReporterNode; }; @@ -169,11 +165,8 @@ class TypeReporter : public ObjectRef { * \return false if This relation cannot be resolved. * true if this relation has been resolved. */ -using TypeRelationFn = - TypedEnvFunc& args, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter)>; +using TypeRelationFn = TypedEnvFunc& args, int num_inputs, + const Attrs& attrs, const TypeReporter& reporter)>; /*! * \brief User defined type relation, it is an input-output relation on types. @@ -207,11 +200,8 @@ class TypeRelationNode : public TypeConstraintNode { } bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(args, other->args) && - equal(num_inputs, other->num_inputs) && - equal(attrs, other->attrs); + return equal(func, other->func) && equal(args, other->args) && + equal(num_inputs, other->num_inputs) && equal(attrs, other->attrs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -239,10 +229,7 @@ class TypeRelation : public TypeConstraint { * \param attrs Attributes to the relation function. * \sa TypeRelationNode for more docs about these fields. */ - TVM_DLL TypeRelation(TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs); + TVM_DLL TypeRelation(TypeRelationFn func, Array args, int num_inputs, Attrs attrs); TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); }; diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index ba1edf8..2b6645f 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -23,28 +23,28 @@ #ifndef TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_ -#include +#include #include +#include #include -#include -#include -#include #include +#include +#include #include #include -#include +#include namespace tvm { -using runtime::String; -using runtime::StringObj; +using runtime::make_object; using runtime::Object; +using runtime::ObjectEqual; +using runtime::ObjectHash; using runtime::ObjectPtr; using runtime::ObjectRef; -using runtime::make_object; -using runtime::ObjectHash; -using runtime::ObjectEqual; +using runtime::String; +using runtime::StringObj; /*! \brief array node content in array */ class ArrayNode : public Object { @@ -60,10 +60,7 @@ class ArrayNode : public Object { class MapNode : public Object { public: /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map< - ObjectRef, - ObjectRef, - ObjectHash, ObjectEqual>; + using ContainerType = std::unordered_map; /*! \brief the data content */ ContainerType data; @@ -72,7 +69,6 @@ class MapNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); }; - /*! \brief specialized map node with string as key */ class StrMapNode : public Object { public: @@ -91,14 +87,13 @@ class StrMapNode : public Object { * \tparam Converter a struct that contains converting function * \tparam TIter the content iterator type. */ -template +template class IterAdapter { public: using difference_type = typename std::iterator_traits::difference_type; using value_type = typename Converter::ResultType; using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) + using reference = typename Converter::ResultType&; // NOLINT(*) using iterator_category = typename std::iterator_traits::iterator_category; explicit IterAdapter(TIter iter) : iter_(iter) {} @@ -106,26 +101,18 @@ class IterAdapter { ++iter_; return *this; } - inline IterAdapter operator+(difference_type offset) const { - return IterAdapter(iter_ + offset); - } + inline IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } - template + template typename std::enable_if::value, - typename T::difference_type>::type - inline operator-(const IterAdapter& rhs) const { + typename T::difference_type>::type inline + operator-(const IterAdapter& rhs) const { return iter_ - rhs.iter_; } - inline bool operator==(IterAdapter other) const { - return iter_ == other.iter_; - } - inline bool operator!=(IterAdapter other) const { - return !(*this == other); - } - inline const value_type operator*() const { - return Converter::convert(*iter_); - } + inline bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + inline bool operator!=(IterAdapter other) const { return !(*this == other); } + inline const value_type operator*() const { return Converter::convert(*iter_); } private: TIter iter_; @@ -139,28 +126,26 @@ class IterAdapter { * operator[] only provide const acces, use Set to mutate the content. * \tparam T The content NodeRef type. */ -template::value>::type > +template ::value>::type> class Array : public ObjectRef { public: /*! * \brief default constructor */ - Array() { - data_ = make_object(); - } + Array() { data_ = make_object(); } /*! * \brief move constructor * \param other source */ - Array(Array && other) : ObjectRef() { // NOLINT(*) + Array(Array&& other) : ObjectRef() { // NOLINT(*) data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Array(const Array &other) : ObjectRef() { // NOLINT(*) + Array(const Array& other) : ObjectRef() { // NOLINT(*) data_ = std::move(other.data_); } /*! @@ -174,7 +159,7 @@ class Array : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template Array(IterType begin, IterType end) { assign(begin, end); } @@ -182,14 +167,14 @@ class Array : public ObjectRef { * \brief constructor from initializer list * \param init The initalizer list */ - Array(std::initializer_list init) { // NOLINT(*) + Array(std::initializer_list init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! * \brief constructor from vector * \param init The vector */ - Array(const std::vector& init) { // NOLINT(*) + Array(const std::vector& init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! @@ -209,7 +194,7 @@ class Array : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Array& operator=(Array && other) { + Array& operator=(Array&& other) { data_ = std::move(other.data_); return *this; } @@ -218,7 +203,7 @@ class Array : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Array& operator=(const Array & other) { + Array& operator=(const Array& other) { data_ = other.data_; return *this; } @@ -228,7 +213,7 @@ class Array : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template void assign(IterType begin, IterType end) { auto n = make_object(); for (IterType it = begin; it != end; ++it) { @@ -242,8 +227,7 @@ class Array : public ObjectRef { * \return the i-th element. */ inline const T operator[](size_t i) const { - return DowncastNoCheck( - static_cast(data_.get())->data[i]); + return DowncastNoCheck(static_cast(data_.get())->data[i]); } /*! \return The size of the array */ inline size_t size() const { @@ -259,7 +243,7 @@ class Array : public ObjectRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline ArrayNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); @@ -292,16 +276,14 @@ class Array : public ObjectRef { n->data[i] = value; } /*! \return whether array is empty */ - inline bool empty() const { - return size() == 0; - } + inline bool empty() const { return size() == 0; } /*! * \brief Helper function to apply fmutate to mutate an array. * \param fmutate The transformation function T -> T. * \tparam F the type of the mutation function. * \note This function performs copy on write optimization. */ - template + template inline void MutateByApply(F fmutate) { ArrayNode* ptr = static_cast(data_.get()); if (ptr == nullptr) return; @@ -342,16 +324,12 @@ class Array : public ObjectRef { struct ValueConverter { using ResultType = T; - static inline T convert(const ObjectRef& n) { - return DowncastNoCheck(n); - } + static inline T convert(const ObjectRef& n) { return DowncastNoCheck(n); } }; - using iterator = IterAdapter::const_iterator>; + using iterator = IterAdapter::const_iterator>; - using reverse_iterator = IterAdapter< - ValueConverter, - std::vector::const_reverse_iterator>; + using reverse_iterator = + IterAdapter::const_reverse_iterator>; /*! \return begin iterator */ inline iterator begin() const { @@ -380,32 +358,28 @@ class Array : public ObjectRef { * \tparam K The key NodeRef type. * \tparam V The value NodeRef type. */ -template::value || - std::is_base_of::value >::type, - typename = typename std::enable_if::value>::type> +template ::value || + std::is_base_of::value>::type, + typename = typename std::enable_if::value>::type> class Map : public ObjectRef { public: /*! * \brief default constructor */ - Map() { - data_ = make_object(); - } + Map() { data_ = make_object(); } /*! * \brief move constructor * \param other source */ - Map(Map && other) { // NOLINT(*) + Map(Map&& other) { // NOLINT(*) data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) + Map(const Map& other) : ObjectRef(other.data_) { // NOLINT(*) } /*! * \brief constructor from pointer @@ -418,7 +392,7 @@ class Map : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template Map(IterType begin, IterType end) { assign(begin, end); } @@ -426,15 +400,15 @@ class Map : public ObjectRef { * \brief constructor from initializer list * \param init The initalizer list */ - Map(std::initializer_list > init) { // NOLINT(*) + Map(std::initializer_list > init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! * \brief constructor from vector * \param init The vector */ - template - Map(const std::unordered_map& init) { // NOLINT(*) + template + Map(const std::unordered_map& init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! @@ -442,7 +416,7 @@ class Map : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Map& operator=(Map && other) { + Map& operator=(Map&& other) { data_ = std::move(other.data_); return *this; } @@ -451,7 +425,7 @@ class Map : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Map& operator=(const Map & other) { + Map& operator=(const Map& other) { data_ = other.data_; return *this; } @@ -461,7 +435,7 @@ class Map : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template void assign(IterType begin, IterType end) { ObjectPtr n = make_object(); for (IterType i = begin; i != end; ++i) { @@ -475,8 +449,7 @@ class Map : public ObjectRef { * \return the corresonding element. */ inline const V operator[](const K& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } /*! * \brief Read element from map. @@ -484,8 +457,7 @@ class Map : public ObjectRef { * \return the corresonding element. */ inline const V at(const K& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } /*! \return The size of the array */ inline size_t size() const { @@ -506,7 +478,7 @@ class Map : public ObjectRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline MapNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); @@ -524,24 +496,18 @@ class Map : public ObjectRef { } /*! \return whether array is empty */ - inline bool empty() const { - return size() == 0; - } + inline bool empty() const { return size() == 0; } /*! \brief specify container node */ using ContainerType = MapNode; struct ValueConverter { using ResultType = std::pair; - static inline ResultType convert(const std::pair< - ObjectRef, - ObjectRef>& n) { - return std::make_pair(DowncastNoCheck(n.first), - DowncastNoCheck(n.second)); + static inline ResultType convert(const std::pair& n) { + return std::make_pair(DowncastNoCheck(n.first), DowncastNoCheck(n.second)); } }; - using iterator = IterAdapter< - ValueConverter, MapNode::ContainerType::const_iterator>; + using iterator = IterAdapter; /*! \return begin iterator */ inline iterator begin() const { @@ -553,46 +519,43 @@ class Map : public ObjectRef { } /*! \return begin iterator */ inline iterator find(const K& key) const { - return iterator( - static_cast(data_.get())->data.find(key)); + return iterator(static_cast(data_.get())->data.find(key)); } }; // specialize of string map -template +template class Map : public ObjectRef { public: // for code reuse - Map() { - data_ = make_object(); - } - Map(Map && other) { // NOLINT(*) + Map() { data_ = make_object(); } + Map(Map&& other) { // NOLINT(*) data_ = std::move(other.data_); } - Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) + Map(const Map& other) : ObjectRef(other.data_) { // NOLINT(*) } explicit Map(ObjectPtr n) : ObjectRef(n) {} - template + template Map(IterType begin, IterType end) { assign(begin, end); } - Map(std::initializer_list > init) { // NOLINT(*) + Map(std::initializer_list > init) { // NOLINT(*) assign(init.begin(), init.end()); } - template - Map(const std::unordered_map& init) { // NOLINT(*) + template + Map(const std::unordered_map& init) { // NOLINT(*) assign(init.begin(), init.end()); } - Map& operator=(Map && other) { + Map& operator=(Map&& other) { data_ = std::move(other.data_); return *this; } - Map& operator=(const Map & other) { + Map& operator=(const Map& other) { data_ = other.data_; return *this; } - template + template void assign(IterType begin, IterType end) { auto n = make_object(); for (IterType i = begin; i != end; ++i) { @@ -601,12 +564,10 @@ class Map : public ObjectRef { data_ = std::move(n); } inline const V operator[](const std::string& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } inline const V at(const std::string& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } inline size_t size() const { if (data_.get() == nullptr) return 0; @@ -617,7 +578,7 @@ class Map : public ObjectRef { return static_cast(data_.get())->data.count(key); } inline StrMapNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); @@ -628,22 +589,17 @@ class Map : public ObjectRef { StrMapNode* n = this->CopyOnWrite(); n->data[key] = value; } - inline bool empty() const { - return size() == 0; - } + inline bool empty() const { return size() == 0; } using ContainerType = StrMapNode; struct ValueConverter { using ResultType = std::pair; - static inline ResultType convert(const std::pair< - std::string, - ObjectRef>& n) { + static inline ResultType convert(const std::pair& n) { return std::make_pair(n.first, DowncastNoCheck(n.second)); } }; - using iterator = IterAdapter< - ValueConverter, StrMapNode::ContainerType::const_iterator>; + using iterator = IterAdapter; /*! \return begin iterator */ inline iterator begin() const { @@ -663,7 +619,7 @@ class Map : public ObjectRef { namespace tvm { namespace runtime { // Additional overloads for PackedFunc checking. -template +template struct ObjectTypeChecker > { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; @@ -676,12 +632,10 @@ struct ObjectTypeChecker > { } return true; } - static std::string TypeName() { - return "List[" + ObjectTypeChecker::TypeName() + "]"; - } + static std::string TypeName() { return "List[" + ObjectTypeChecker::TypeName() + "]"; } }; -template +template struct ObjectTypeChecker > { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; @@ -692,13 +646,10 @@ struct ObjectTypeChecker > { } return true; } - static std::string TypeName() { - return "Map[str, " + - ObjectTypeChecker::TypeName()+ ']'; - } + static std::string TypeName() { return "Map[str, " + ObjectTypeChecker::TypeName() + ']'; } }; -template +template struct ObjectTypeChecker > { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; @@ -711,10 +662,8 @@ struct ObjectTypeChecker > { return true; } static std::string TypeName() { - return "Map[" + - ObjectTypeChecker::TypeName() + - ", " + - ObjectTypeChecker::TypeName()+ ']'; + return "Map[" + ObjectTypeChecker::TypeName() + ", " + ObjectTypeChecker::TypeName() + + ']'; } }; } // namespace runtime diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h index e11fda8..0837f35 100644 --- a/include/tvm/node/functor.h +++ b/include/tvm/node/functor.h @@ -26,9 +26,9 @@ #include #include -#include #include #include +#include namespace tvm { @@ -60,16 +60,16 @@ using runtime::ObjectRef; * \tparam FType function signiture * This type if only defined for FType with function signature */ -template +template class NodeFunctor; -template +template class NodeFunctor { private: /*! \brief internal function pointer type */ - typedef R (*FPointer)(const ObjectRef&n, Args...); + typedef R (*FPointer)(const ObjectRef& n, Args...); /*! \brief refer to itself. */ - using TSelf = NodeFunctor; + using TSelf = NodeFunctor; /*! \brief internal function table */ std::vector func_; @@ -92,9 +92,8 @@ class NodeFunctor { * \return The result. */ R operator()(const ObjectRef& n, Args... args) const { - CHECK(can_dispatch(n)) - << "NodeFunctor calls un-registered function on type " - << n->GetTypeKey(); + CHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type " + << n->GetTypeKey(); return (*func_[n->type_index()])(n, std::forward(args)...); } /*! @@ -103,37 +102,32 @@ class NodeFunctor { * \tparam TNode the type of Node to be dispatched. * \return reference to self. */ - template + template TSelf& set_dispatch(FPointer f) { // NOLINT(*) uint32_t tindex = TNode::RuntimeTypeIndex(); if (func_.size() <= tindex) { func_.resize(tindex + 1, nullptr); } - CHECK(func_[tindex] == nullptr) - << "Dispatch for " << TNode::_type_key - << " is already set"; + CHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set"; func_[tindex] = f; return *this; } /*! - * \brief unset the dispacher for type TNode - * - * \tparam TNode the type of Node to be dispatched. - * \return reference to self. - */ - template + * \brief unset the dispacher for type TNode + * + * \tparam TNode the type of Node to be dispatched. + * \return reference to self. + */ + template TSelf& clear_dispatch() { // NOLINT(*) uint32_t tindex = TNode::RuntimeTypeIndex(); - CHECK_LT(tindex, func_.size()) - << "clear_dispatch: index out of range"; + CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; func_[tindex] = nullptr; return *this; } }; - -#define TVM_REG_FUNC_VAR_DEF(ClsName) \ - static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName +#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName /*! * \brief Useful macro to set NodeFunctor dispatch in a global static field. @@ -176,8 +170,7 @@ class NodeFunctor { * \param ClsName The name of the class * \param FField The static function that returns a singleton of NodeFunctor. */ -#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ - TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = \ - ClsName::FField() +#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ + TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField() } // namespace tvm #endif // TVM_NODE_FUNCTOR_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 471a0de..b622fc7 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -34,35 +34,35 @@ #ifndef TVM_NODE_NODE_H_ #define TVM_NODE_NODE_H_ -#include -#include -#include -#include +#include #include #include -#include #include #include +#include +#include +#include +#include #include -#include -#include #include +#include +#include namespace tvm { -using runtime::TypeIndex; +using runtime::Downcast; +using runtime::GetRef; +using runtime::make_object; using runtime::Object; +using runtime::ObjectEqual; +using runtime::ObjectHash; using runtime::ObjectPtr; using runtime::ObjectRef; -using runtime::GetRef; -using runtime::Downcast; -using runtime::ObjectHash; -using runtime::ObjectEqual; -using runtime::make_object; using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +using runtime::TypeIndex; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 9ed87df..643b638 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -23,18 +23,18 @@ #ifndef TVM_NODE_REFLECTION_H_ #define TVM_NODE_REFLECTION_H_ +#include +#include #include -#include +#include #include -#include #include -#include -#include -#include +#include +#include -#include #include #include +#include namespace tvm { @@ -51,7 +51,7 @@ using runtime::ObjectRef; */ class AttrVisitor { public: -//! \cond Doxygen_Suppress + //! \cond Doxygen_Suppress TVM_DLL virtual ~AttrVisitor() = default; TVM_DLL virtual void Visit(const char* key, double* value) = 0; TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0; @@ -63,14 +63,13 @@ class AttrVisitor { TVM_DLL virtual void Visit(const char* key, DataType* value) = 0; TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0; TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; - template::value>::type> + template ::value>::type> void Visit(const char* key, ENum* ptr) { static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); this->Visit(key, reinterpret_cast(ptr)); } -//! \endcond + //! \endcond }; /*! @@ -166,7 +165,7 @@ class ReflectionVTable { TVM_DLL static ReflectionVTable* Global(); class Registry; - template + template inline Registry Register(); private: @@ -174,7 +173,7 @@ class ReflectionVTable { std::vector fvisit_attrs_; /*! \brief Structural equal function. */ std::vector fsequal_reduce_; - /*! \brief Structural hash function. */ + /*! \brief Structural hash function. */ std::vector fshash_reduce_; /*! \brief Creation function. */ std::vector fcreate_; @@ -186,7 +185,7 @@ class ReflectionVTable { class ReflectionVTable::Registry { public: Registry(ReflectionVTable* parent, uint32_t type_index) - : parent_(parent), type_index_(type_index) { } + : parent_(parent), type_index_(type_index) {} /*! * \brief Set fcreate function. * \param f The creator function. @@ -213,10 +212,8 @@ class ReflectionVTable::Registry { uint32_t type_index_; }; - -#define TVM_REFLECTION_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry \ - __make_reflectiion +#define TVM_REFLECTION_REG_VAR_DEF \ + static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflectiion /*! * \brief Directly register reflection VTable. @@ -247,122 +244,108 @@ class ReflectionVTable::Registry { * \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE. * And can be used to register the related reflection functions for runtime objects. */ -#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ - TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::ReflectionVTable::Global()->Register() \ +#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ + TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ + ::tvm::ReflectionVTable::Global()->Register() /*! * \brief Register a node type to object registry and reflection registry. * \param TypeName The name of the type. * \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well. */ -#define TVM_REGISTER_NODE_TYPE(TypeName) \ - TVM_REGISTER_OBJECT_TYPE(TypeName); \ +#define TVM_REGISTER_NODE_TYPE(TypeName) \ + TVM_REGISTER_OBJECT_TYPE(TypeName); \ TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait) \ - .set_creator([](const std::string&) -> ObjectPtr { \ - return ::tvm::runtime::make_object(); \ - }) - + .set_creator([](const std::string&) -> ObjectPtr { \ + return ::tvm::runtime::make_object(); \ + }) // Implementation details namespace detail { -template +template struct ImplVisitAttrs { static constexpr const std::nullptr_t VisitAttrs = nullptr; }; -template +template struct ImplVisitAttrs { - static void VisitAttrs(T* self, AttrVisitor* v) { - self->VisitAttrs(v); - } + static void VisitAttrs(T* self, AttrVisitor* v) { self->VisitAttrs(v); } }; -template +template struct ImplSEqualReduce { static constexpr const std::nullptr_t SEqualReduce = nullptr; }; -template +template struct ImplSEqualReduce { static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) { return self->SEqualReduce(other, equal); } }; -template +template struct ImplSHashReduce { static constexpr const std::nullptr_t SHashReduce = nullptr; }; -template +template struct ImplSHashReduce { static void SHashReduce(const T* self, SHashReducer hash_reduce) { self->SHashReduce(hash_reduce); } }; -template -struct ReflectionTrait : - public ImplVisitAttrs, - public ImplSEqualReduce, - public ImplSHashReduce { -}; +template +struct ReflectionTrait : public ImplVisitAttrs, + public ImplSEqualReduce, + public ImplSHashReduce {}; -template::value> +template ::value> struct SelectVisitAttrs { static constexpr const std::nullptr_t VisitAttrs = nullptr; }; -template +template struct SelectVisitAttrs { static void VisitAttrs(Object* self, AttrVisitor* v) { TraitName::VisitAttrs(static_cast(self), v); } }; -template::value> +template ::value> struct SelectSEqualReduce { static constexpr const std::nullptr_t SEqualReduce = nullptr; }; -template +template struct SelectSEqualReduce { - static bool SEqualReduce(const Object* self, - const Object* other, - SEqualReducer equal) { - return TraitName::SEqualReduce(static_cast(self), - static_cast(other), + static bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) { + return TraitName::SEqualReduce(static_cast(self), static_cast(other), equal); } }; -template::value> +template ::value> struct SelectSHashReduce { static constexpr const std::nullptr_t SHashReduce = nullptr; }; -template +template struct SelectSHashReduce { - static void SHashReduce(const Object* self, - SHashReducer hash_reduce) { - return TraitName::SHashReduce(static_cast(self), - hash_reduce); + static void SHashReduce(const Object* self, SHashReducer hash_reduce) { + return TraitName::SHashReduce(static_cast(self), hash_reduce); } }; } // namespace detail -template -inline ReflectionVTable::Registry -ReflectionVTable::Register() { +template +inline ReflectionVTable::Registry ReflectionVTable::Register() { uint32_t tindex = T::RuntimeTypeIndex(); if (tindex >= fvisit_attrs_.size()) { fvisit_attrs_.resize(tindex + 1, nullptr); @@ -372,20 +355,16 @@ ReflectionVTable::Register() { fshash_reduce_.resize(tindex + 1, nullptr); } // functor that implemnts the redirection. - fvisit_attrs_[tindex] = - ::tvm::detail::SelectVisitAttrs::VisitAttrs; + fvisit_attrs_[tindex] = ::tvm::detail::SelectVisitAttrs::VisitAttrs; - fsequal_reduce_[tindex] = - ::tvm::detail::SelectSEqualReduce::SEqualReduce; + fsequal_reduce_[tindex] = ::tvm::detail::SelectSEqualReduce::SEqualReduce; - fshash_reduce_[tindex] = - ::tvm::detail::SelectSHashReduce::SHashReduce; + fshash_reduce_[tindex] = ::tvm::detail::SelectSHashReduce::SHashReduce; return Registry(this, tindex); } -inline void ReflectionVTable:: -VisitAttrs(Object* self, AttrVisitor* visitor) const { +inline void ReflectionVTable::VisitAttrs(Object* self, AttrVisitor* visitor) const { uint32_t tindex = self->type_index(); if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { LOG(FATAL) << "TypeError: " << self->GetTypeKey() @@ -394,8 +373,7 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const { fvisit_attrs_[tindex](self, visitor); } -inline bool ReflectionVTable::GetReprBytes(const Object* self, - std::string* repr_bytes) const { +inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr_bytes) const { uint32_t tindex = self->type_index(); if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) { if (repr_bytes != nullptr) { diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 5782430..532425a 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -24,6 +24,7 @@ #define TVM_NODE_REPR_PRINTER_H_ #include + #include namespace tvm { diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index f719e24..9424f6d 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -23,9 +23,10 @@ #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_ #define TVM_NODE_STRUCTURAL_EQUAL_H_ -#include -#include #include +#include +#include + #include namespace tvm { @@ -43,26 +44,13 @@ class BaseValueEqual { return diff > -atol && diff < atol; } - bool operator()(const int64_t& lhs, const int64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const int& lhs, const int& rhs) const { - return lhs == rhs; - } - bool operator()(const bool& lhs, const bool& rhs) const { - return lhs == rhs; - } - bool operator()(const std::string& lhs, const std::string& rhs) const { - return lhs == rhs; - } - bool operator()(const DataType& lhs, const DataType& rhs) const { - return lhs == rhs; - } - template::value>::type> + bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } + bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; } + bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; } + bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; } + bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; } + bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; } + template ::value>::type> bool operator()(const ENum& lhs, const ENum& rhs) const { return lhs == rhs; } @@ -127,9 +115,7 @@ class SEqualReducer : public BaseValueEqual { * \note This function may save the equality condition of (lhs == rhs) in an internal * stack and try to resolve later. */ - virtual bool SEqualReduce(const ObjectRef& lhs, - const ObjectRef& rhs, - bool map_free_vars) = 0; + virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0; /*! * \brief Lookup the graph node equal map for vars that are already mapped. * @@ -185,7 +171,7 @@ class SEqualReducer : public BaseValueEqual { * \param rhs The right operand. * \return the immediate check result. */ - template + template bool operator()(const Array& lhs, const Array& rhs) const { // quick specialization for Array to reduce amount of recursion // depth as array comparison is pretty common. @@ -210,9 +196,7 @@ class SEqualReducer : public BaseValueEqual { } /*! \return Get the internal handler. */ - Handler* operator->() const { - return handler_; - } + Handler* operator->() const { return handler_; } private: /*! \brief Internal class pointer. */ diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index affc5f4..ed89d84 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -23,11 +23,12 @@ #ifndef TVM_NODE_STRUCTURAL_HASH_H_ #define TVM_NODE_STRUCTURAL_HASH_H_ -#include -#include #include -#include +#include +#include + #include +#include namespace tvm { @@ -36,39 +37,25 @@ namespace tvm { */ class BaseValueHash { public: - size_t operator()(const double& key) const { - return std::hash()(key); - } + size_t operator()(const double& key) const { return std::hash()(key); } - size_t operator()(const int64_t& key) const { - return std::hash()(key); - } + size_t operator()(const int64_t& key) const { return std::hash()(key); } - size_t operator()(const uint64_t& key) const { - return std::hash()(key); - } + size_t operator()(const uint64_t& key) const { return std::hash()(key); } - size_t operator()(const int& key) const { - return std::hash()(key); - } + size_t operator()(const int& key) const { return std::hash()(key); } - size_t operator()(const bool& key) const { - return std::hash()(key); - } + size_t operator()(const bool& key) const { return std::hash()(key); } - size_t operator()(const std::string& key) const { - return std::hash()(key); - } + size_t operator()(const std::string& key) const { return std::hash()(key); } size_t operator()(const runtime::DataType& key) const { - return std::hash()( - static_cast(key.code()) | - (static_cast(key.bits()) << 8) | - (static_cast(key.lanes()) << 16)); + return std::hash()(static_cast(key.code()) | + (static_cast(key.bits()) << 8) | + (static_cast(key.lanes()) << 16)); } - template::value>::type> + template ::value>::type> bool operator()(const ENum& key) const { return std::hash()(static_cast(key)); } @@ -173,9 +160,8 @@ class SHashReducer { * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. */ - template::value>::type> + template ::value>::type> void operator()(const T& key) const { // handle normal values. handler_->SHashReduceHashedValue(BaseValueHash()(key)); @@ -184,17 +170,13 @@ class SHashReducer { * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. */ - void operator()(const ObjectRef& key) const { - return handler_->SHashReduce(key, map_free_vars_); - } + void operator()(const ObjectRef& key) const { return handler_->SHashReduce(key, map_free_vars_); } /*! * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. * \note This function indicate key could contain var defintions. */ - void DefHash(const ObjectRef& key) const { - return handler_->SHashReduce(key, true); - } + void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); } /*! * \brief Implementation for hash for a free var. * \param var The variable. @@ -205,9 +187,7 @@ class SHashReducer { } /*! \return Get the internal handler. */ - Handler* operator->() const { - return handler_; - } + Handler* operator->() const { return handler_; } private: /*! \brief Internal class pointer. */ diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 1ee7c9c..b2164ba 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -24,13 +24,14 @@ #ifndef TVM_RELAY_ADT_H_ #define TVM_RELAY_ADT_H_ -#include #include +#include #include #include #include -#include + #include +#include #include namespace tvm { @@ -72,16 +73,11 @@ class PatternWildcard; /*! \brief PatternWildcard container node */ class PatternWildcardNode : public PatternNode { public: - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("span", &span); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } - bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { - return true; - } + bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; } - void SHashReduce(SHashReducer hash_reduce) const { - } + void SHashReduce(SHashReducer hash_reduce) const {} static constexpr const char* _type_key = "relay.PatternWildcard"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode); @@ -131,9 +127,7 @@ class PatternVarNode : public PatternNode { return equal.DefEqual(var, other->var); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(var); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); } static constexpr const char* _type_key = "relay.PatternVar"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode); @@ -167,9 +161,7 @@ class PatternConstructorNode : public PatternNode { } bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const { - return - equal(constructor, other->constructor) && - equal(patterns, other->patterns); + return equal(constructor, other->constructor) && equal(patterns, other->patterns); } void SHashReduce(SHashReducer hash_reduce) const { @@ -210,9 +202,7 @@ class PatternTupleNode : public PatternNode { return equal(patterns, other->patterns); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(patterns); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); } static constexpr const char* _type_key = "relay.PatternTuple"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode); @@ -297,10 +287,8 @@ class MatchNode : public ExprNode { bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal(data, other->data) && - equal(clauses, other->clauses) && - equal(complete, other->complete); + return equal(data, other->data) && equal(clauses, other->clauses) && + equal(complete, other->complete); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index a2c0c75..b4b1b9d 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -24,11 +24,12 @@ #ifndef TVM_RELAY_ANALYSIS_H_ #define TVM_RELAY_ANALYSIS_H_ +#include #include #include #include -#include #include + #include #include @@ -73,9 +74,9 @@ TVM_DLL bool ConstantCheck(const Expr& e); * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, * although x is not shadowed. * - * \param expr the expression to check. + * \param expr the expression to check. * - * \return true iff all Var in expr is bound at most once. + * \return true iff all Var in expr is bound at most once. */ TVM_DLL bool WellFormed(const Expr& expr); @@ -233,8 +234,7 @@ TVM_DLL Array UnmatchedCases(const Match& match, const IRModule& mod); * * \return The reference count mapping. */ -TVM_DLL std::unordered_map -GetExprRefCount(const Expr& body); +TVM_DLL std::unordered_map GetExprRefCount(const Expr& body); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 2d1b902..a7d4708 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -38,14 +39,15 @@ struct ArgsortAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Axis along which to sort the input tensor." - "If not given, the flattened array is used."); - TVM_ATTR_FIELD(is_ascend).set_default(true) - .describe("Whether to sort in ascending or descending order." - "By default, sort in ascending order"); - TVM_ATTR_FIELD(dtype).set_default(NullValue()) - .describe("DType of the output indices."); + TVM_ATTR_FIELD(axis).set_default(-1).describe( + "Axis along which to sort the input tensor." + "If not given, the flattened array is used."); + TVM_ATTR_FIELD(is_ascend).set_default(true).describe( + "Whether to sort in ascending or descending order." + "By default, sort in ascending order"); + TVM_ATTR_FIELD(dtype) + .set_default(NullValue()) + .describe("DType of the output indices."); } }; @@ -57,20 +59,19 @@ struct TopKAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") { - TVM_ATTR_FIELD(k).set_default(1) - .describe("Number of top elements to select"); - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Axis along which to sort the input tensor."); - TVM_ATTR_FIELD(ret_type).set_default("both") - .describe("The return type [both, values, indices]." - "both - return both top k data and indices." - "values - return top k data only." - "indices - return top k indices only."); - TVM_ATTR_FIELD(is_ascend).set_default(false) - .describe("Whether to sort in ascending or descending order." - "By default, sort in descending order"); - TVM_ATTR_FIELD(dtype).set_default(NullValue()) - .describe("Data type of the output indices."); + TVM_ATTR_FIELD(k).set_default(1).describe("Number of top elements to select"); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor."); + TVM_ATTR_FIELD(ret_type).set_default("both").describe( + "The return type [both, values, indices]." + "both - return both top k data and indices." + "values - return top k data only." + "indices - return top k indices only."); + TVM_ATTR_FIELD(is_ascend).set_default(false).describe( + "Whether to sort in ascending or descending order." + "By default, sort in descending order"); + TVM_ATTR_FIELD(dtype) + .set_default(NullValue()) + .describe("Data type of the output indices."); } }; diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index cc21e34..4a2eb63 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_ANNOTATION_H_ #include + #include namespace tvm { @@ -38,9 +39,8 @@ struct OnDeviceAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { TVM_ATTR_FIELD(device_type) - .describe( - "The virutal device/context type that an expression is annotated with.") - .set_default(0); + .describe("The virutal device/context type that an expression is annotated with.") + .set_default(0); } }; @@ -51,9 +51,7 @@ struct CastHintAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") { - TVM_ATTR_FIELD(dtype) - .describe( - "The data type denoted to be cast."); + TVM_ATTR_FIELD(dtype).describe("The data type denoted to be cast."); } }; @@ -65,8 +63,7 @@ struct CompilerAttrs : public tvm::AttrsNode { std::string compiler; TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") { - TVM_ATTR_FIELD(compiler) - .describe("A 3rd party compiler used for code generation."); + TVM_ATTR_FIELD(compiler).describe("A 3rd party compiler used for code generation."); } }; diff --git a/include/tvm/relay/attrs/bitserial.h b/include/tvm/relay/attrs/bitserial.h index 962afc2..ed04c59 100644 --- a/include/tvm/relay/attrs/bitserial.h +++ b/include/tvm/relay/attrs/bitserial.h @@ -27,6 +27,7 @@ #include #include + #include namespace tvm { @@ -112,23 +113,18 @@ struct BinaryDenseAttrs : public tvm::AttrsNode { bool unipolar; TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") { - TVM_ATTR_FIELD(units) - .describe("Number of hidden units of the dense transformation."); - TVM_ATTR_FIELD(data_bits) - .set_default(1) - .describe("Number of bits to pack for incoming tensor."); + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + TVM_ATTR_FIELD(data_bits).set_default(1).describe( + "Number of bits to pack for incoming tensor."); TVM_ATTR_FIELD(weight_bits) - .set_default(1) - .describe("Number of bits to pack for weight tensor."); + .set_default(1) + .describe("Number of bits to pack for weight tensor."); TVM_ATTR_FIELD(pack_dtype) - .set_default(NullValue()) - .describe("Datatype to pack bits into before computation."); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type."); - TVM_ATTR_FIELD(unipolar) - .set_default(true) - .describe("Whether to use unipolar or bipolar quantization for inputs."); + .set_default(NullValue()) + .describe("Datatype to pack bits into before computation."); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); + TVM_ATTR_FIELD(unipolar).set_default(true).describe( + "Whether to use unipolar or bipolar quantization for inputs."); } }; diff --git a/include/tvm/relay/attrs/debug.h b/include/tvm/relay/attrs/debug.h index ed9ed4e..112228b 100644 --- a/include/tvm/relay/attrs/debug.h +++ b/include/tvm/relay/attrs/debug.h @@ -25,6 +25,8 @@ #define TVM_RELAY_ATTRS_DEBUG_H_ #include +#include + #include namespace tvm { @@ -37,8 +39,7 @@ struct DebugAttrs : public tvm::AttrsNode { EnvFunc debug_func; TVM_DECLARE_ATTRS(DebugAttrs, "relay.attrs.DebugAttrs") { - TVM_ATTR_FIELD(debug_func) - .describe("The function to use when debugging."); + TVM_ATTR_FIELD(debug_func).describe("The function to use when debugging."); } }; diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index 2486fcd..7da92b3 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_DEVICE_COPY_H_ #include + #include namespace tvm { @@ -39,13 +40,11 @@ struct DeviceCopyAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") { TVM_ATTR_FIELD(src_dev_type) - .describe( - "The virtual device/context type where the op copies data from.") - .set_default(0); + .describe("The virtual device/context type where the op copies data from.") + .set_default(0); TVM_ATTR_FIELD(dst_dev_type) - .describe( - "The virtual device/context type where the op copies data to.") - .set_default(0); + .describe("The virtual device/context type where the op copies data to.") + .set_default(0); } }; diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index 52bb2ef..b927c98 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -40,26 +41,27 @@ struct ResizeAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { - TVM_ATTR_FIELD(size).set_default(NullValue >()) - .describe("Output Size."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Resize is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("bilinear") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation" - "bicubic - Bicubic Interpolation"); - TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") - .describe("Describes how to transform the coordinate in the resized tensor" - "to the coordinate in the original tensor." - "Refer to the ONNX Resize operator specification for details" - "Available options are half_pixel, align_corners and asymmetric"); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type."); + TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("bilinear") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation" + "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; @@ -72,22 +74,22 @@ struct CropAndResizeAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(CropAndResizeAttrs, "relay.attrs.CropAndResizeAttrs") { - TVM_ATTR_FIELD(crop_size).set_default(NullValue >()) - .describe("Target Size."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Resize is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("bilinear") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation"); - TVM_ATTR_FIELD(extrapolation_value).set_default(0.0) + TVM_ATTR_FIELD(crop_size).set_default(NullValue >()).describe("Target Size."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("bilinear") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation"); + TVM_ATTR_FIELD(extrapolation_value) + .set_default(0.0) .describe("Specify value for extrapolation."); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type."); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; @@ -101,25 +103,33 @@ struct Dilation2DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Dilation2DAttrs, "relay.attrs.Dilation2DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the sliding window. [stride_height, stride_width]."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilations).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilations) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use. [dilation_height, dilation_width]"); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("IHW") - .describe("Dimension ordering of weight. Can be 'IHW', 'HWI', etc." - "'I', 'H', 'W' stands for input_channel, height, and width" - "dimensions respectively."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("IHW") + .describe( + "Dimension ordering of weight. Can be 'IHW', 'HWI', etc." + "'I', 'H', 'W' stands for input_channel, height, and width" + "dimensions respectively."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h index d232f86..7429c39 100644 --- a/include/tvm/relay/attrs/memory.h +++ b/include/tvm/relay/attrs/memory.h @@ -26,6 +26,7 @@ #include #include + #include #include @@ -46,15 +47,10 @@ struct AllocStorageAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") { TVM_ATTR_FIELD(dtype) - .describe( - "The dtype of the tensor to allocate.") - .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(device_id) - .describe( - "The device id on which to allocate memory."); - TVM_ATTR_FIELD(device_type) - .describe( - "The device type on which to allocate memory."); + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory."); + TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory."); } }; @@ -68,16 +64,13 @@ struct AllocTensorAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") { TVM_ATTR_FIELD(dtype) - .describe( - "The dtype of the tensor to allocate.") - .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(const_shape) - .describe( - "The shape of constant used to aid in type inference."); + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + TVM_ATTR_FIELD(const_shape).describe("The shape of constant used to aid in type inference."); TVM_ATTR_FIELD(assert_shape) - .describe( - "The shape to cast the return type of the allocation to, "\ - "used to specify the shape obtained via further analysis."); + .describe( + "The shape to cast the return type of the allocation to, " + "used to specify the shape obtained via further analysis."); } }; @@ -88,10 +81,9 @@ struct ShapeFuncAttrs : public tvm::AttrsNode { Array is_input; TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") { - TVM_ATTR_FIELD(is_input) - .describe( - "A bool indicating whether the shape function should"\ - "expect shape or input in each position."); + TVM_ATTR_FIELD(is_input).describe( + "A bool indicating whether the shape function should" + "expect shape or input in each position."); } }; diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index fdf56a7..a9c3059 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -42,13 +43,10 @@ struct BiasAddAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") { - TVM_ATTR_FIELD(axis) - .describe("The axis to add the bias") - .set_default(1); + TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1); } }; - /*! \brief Attributes used in 1D convolution operators */ struct Conv1DAttrs : public tvm::AttrsNode { Array strides; @@ -63,31 +61,44 @@ struct Conv1DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv1DAttrs, "relay.attrs.Conv1DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, })) + TVM_ATTR_FIELD(strides) + .set_default(Array({ + 1, + })) .describe("Specifies the stride of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, })) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({ + 1, + })) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Currently unused but may be added in the future."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Currently unused but may be added in the future."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCW") - .describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Convolution is applied on the 'W'" - "dimension."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIW") - .describe("Dimension ordering of weight. Can be 'OIW', or 'WIO', etc." - "'O', 'I', 'W' stands for num_filter, input_channel, and width" - "dimensions respectively."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCW") + .describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Convolution is applied on the 'W'" + "dimension."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIW") + .describe( + "Dimension ordering of weight. Can be 'OIW', or 'WIO', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -96,7 +107,6 @@ struct Conv1DAttrs : public tvm::AttrsNode { } }; - /*! \brief Attributes used in convolution operators */ struct Conv2DAttrs : public tvm::AttrsNode { Array strides; @@ -111,42 +121,53 @@ struct Conv2DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -156,14 +177,13 @@ struct Conv2DAttrs : public tvm::AttrsNode { }; /*! \brief Attributes used in winograd weight transformation operators */ -struct ConvWinogradWeightTransformAttrs : - public tvm::AttrsNode { +struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode { int tile_size; TVM_DECLARE_ATTRS(ConvWinogradWeightTransformAttrs, - "relay.attrs.ConvWinogradWeightTransformAttrs") { - TVM_ATTR_FIELD(tile_size) - .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + "relay.attrs.ConvWinogradWeightTransformAttrs") { + TVM_ATTR_FIELD(tile_size).describe( + "Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); } }; @@ -182,44 +202,55 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") { - TVM_ATTR_FIELD(tile_size) - .describe("The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(tile_size).describe( + "The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -261,43 +292,54 @@ struct Conv3DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom," - "right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom," + "right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Convolution is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW") - .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." - "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," - "and width dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCDHW") + .describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIDHW") + .describe( + "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," + "and width dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -321,45 +363,56 @@ struct Conv3DWinogradAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") { - TVM_ATTR_FIELD(tile_size) - .describe("The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)"); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(tile_size).describe( + "The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)"); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom," - "right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom," + "right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Convolution is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW") - .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." - "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," - "and width dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCDHW") + .describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIDHW") + .describe( + "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," + "and width dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -368,14 +421,12 @@ struct Conv3DWinogradAttrs : public tvm::AttrsNode { } }; - /*! \brief Attributes used in softmax operators */ struct SoftmaxAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("The axis to sum over when computing softmax."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("The axis to sum over when computing softmax."); } }; @@ -395,47 +446,60 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") { TVM_ATTR_FIELD(channels) - .set_default(NullValue()) - .describe("The dimensionality of the output space" - "i.e. the number of output channels in the convolution."); + .set_default(NullValue()) + .describe( + "The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); TVM_ATTR_FIELD(kernel_size) - .describe("The dimensions of the convolution window.") - .set_default(NullValue >()); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("The strides of the convolution."); - TVM_ATTR_FIELD(output_padding).set_default(Array({0, 0})) - .describe("Zero-padding added to one side of the output." - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding) + .set_default(Array({0, 0})) + .describe( + "Zero-padding added to one side of the output." + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); @@ -447,8 +511,9 @@ struct DilateAttrs : public tvm::AttrsNode { Array strides; TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("Dilation stride on each dimension, 1 means no dilation."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Dilation stride on each dimension, 1 means no dilation."); } }; @@ -468,42 +533,54 @@ struct Conv1DTransposeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(Conv1DTransposeAttrs, "relay.attrs.Conv1DTransposeAttrs") { TVM_ATTR_FIELD(channels) - .set_default(NullValue()) - .describe("The dimensionality of the output space" - "i.e. the number of output channels in the convolution."); + .set_default(NullValue()) + .describe( + "The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); TVM_ATTR_FIELD(kernel_size) - .describe("The dimensions of the convolution window.") - .set_default(NullValue >()); - TVM_ATTR_FIELD(strides).set_default(Array({1})) - .describe("The strides of the convolution."); - TVM_ATTR_FIELD(output_padding).set_default(Array({0})) - .describe("Zero-padding added to one side of the output."); - TVM_ATTR_FIELD(padding).set_default(Array({0})) - .describe("Symmetric or asymmetric padding." - "Single value: the input is implicitly zero-padded on both sides." - "Two values: padding[0] is used for left input padding, " - "padding[1] is used for right input padding,"); - TVM_ATTR_FIELD(dilation).set_default(Array({1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - TVM_ATTR_FIELD(data_layout).set_default("NCW") - .describe("Dimension ordering of data. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Convolution is applied on the" - "'W' dimension."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIW") - .describe("Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc." - "'O', 'I', 'W' stands for num_filter, input_channel, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Default to be same as input layout."); + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding) + .set_default(Array({0})) + .describe("Zero-padding added to one side of the output."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0})) + .describe( + "Symmetric or asymmetric padding." + "Single value: the input is implicitly zero-padded on both sides." + "Two values: padding[0] is used for left input padding, " + "padding[1] is used for right input padding,"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCW") + .describe( + "Dimension ordering of data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Convolution is applied on the" + "'W' dimension."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIW") + .describe( + "Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Default to be same as input layout."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); @@ -519,23 +596,25 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); } }; @@ -549,25 +628,28 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { bool count_include_pad; TVM_DECLARE_ATTRS(AvgPool2DAttrs, "relay.attrs.AvgPool2DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - TVM_ATTR_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad) + .set_default(false) + .describe("When true, will include padding to compute the average"); } }; @@ -576,11 +658,11 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode { std::string layout; TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") { - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -590,13 +672,14 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode { std::string layout; TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") { - TVM_ATTR_FIELD(output_size).set_default(Array({})) - .describe("Output height and width."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); + TVM_ATTR_FIELD(output_size) + .set_default(Array({})) + .describe("Output height and width."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -605,17 +688,17 @@ struct AdaptivePool3DAttrs : public tvm::AttrsNode { std::string layout; TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") { - TVM_ATTR_FIELD(output_size).set_default(Array({})) - .describe("Output depth, height and width."); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Pooling is applied on 'D', 'H' and" - "'W' dimensions."); + TVM_ATTR_FIELD(output_size) + .set_default(Array({})) + .describe("Output depth, height and width."); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on 'D', 'H' and" + "'W' dimensions."); } }; - /*! \brief Attributes for 1D max pool operator */ struct MaxPool1DAttrs : public tvm::AttrsNode { Array pool_size; @@ -625,22 +708,24 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCW") - .describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Pooling is applied on the 'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); } }; @@ -654,28 +739,30 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { bool count_include_pad; TVM_DECLARE_ATTRS(AvgPool1DAttrs, "relay.attrs.AvgPool1DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCW") - .describe("Dimension ordering of input data. Can be 'NCW', 'NHC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Pooling is applied on the 'W' dimension."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - TVM_ATTR_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NHC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimension."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad) + .set_default(false) + .describe("When true, will include padding to compute the average"); } }; - /*! \brief Attributes for 3D max pool operator */ struct MaxPool3DAttrs : public tvm::AttrsNode { Array pool_size; @@ -685,23 +772,25 @@ struct MaxPool3DAttrs : public tvm::AttrsNode { bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Pooling is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); } }; @@ -715,37 +804,38 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { bool count_include_pad; TVM_DECLARE_ATTRS(AvgPool3DAttrs, "relay.attrs.AvgPool3DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Pooling is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - TVM_ATTR_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad) + .set_default(false) + .describe("When true, will include padding to compute the average"); } }; - /*! \brief Attributes for dense operator */ struct DenseAttrs : public tvm::AttrsNode { IndexExpr units; DataType out_dtype; TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") { - TVM_ATTR_FIELD(units) - .describe("Number of hidden units of the dense transformation."); + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -782,21 +872,22 @@ struct UpSamplingAttrs : public tvm::AttrsNode { bool align_corners; TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { - TVM_ATTR_FIELD(scale_h) - .describe("The upsampling factor for height"); - TVM_ATTR_FIELD(scale_w) - .describe("The upsampling factor for width"); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Upsampling is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("nearest_neighbor") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation" - "bicubic - Bicubic Interpolation"); - TVM_ATTR_FIELD(align_corners).set_default(false) + TVM_ATTR_FIELD(scale_h).describe("The upsampling factor for height"); + TVM_ATTR_FIELD(scale_w).describe("The upsampling factor for width"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Upsampling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("nearest_neighbor") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation" + "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(align_corners) + .set_default(false) .describe("Should be true to preserve the values at the corner pixels"); } }; @@ -811,26 +902,27 @@ struct UpSampling3DAttrs : public tvm::AttrsNode { std::string coordinate_transformation_mode; TVM_DECLARE_ATTRS(UpSampling3DAttrs, "relay.attrs.UpSampling3DAttrs") { - TVM_ATTR_FIELD(scale_d) - .describe("The upsampling factor for depth"); - TVM_ATTR_FIELD(scale_h) - .describe("The upsampling factor for height"); - TVM_ATTR_FIELD(scale_w) - .describe("The upsampling factor for width"); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Upsampling is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("nearest_neighbor") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "trilinear - Trilinear Interpolation"); - TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") - .describe("Describes how to transform the coordinate in the resized tensor" - "to the coordinate in the original tensor." - "Refer to the ONNX Resize operator specification for details" - "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(scale_d).describe("The upsampling factor for depth"); + TVM_ATTR_FIELD(scale_h).describe("The upsampling factor for height"); + TVM_ATTR_FIELD(scale_w).describe("The upsampling factor for width"); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Upsampling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("nearest_neighbor") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "trilinear - Trilinear Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); } }; @@ -841,15 +933,17 @@ struct PadAttrs : public tvm::AttrsNode { std::string pad_mode; TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { - TVM_ATTR_FIELD(pad_value).set_default(0.0) - .describe("The value used for padding when mode is 'constant'."); - TVM_ATTR_FIELD(pad_width) - .describe("Number of values padded to the edges of each axis, " - "in the format of ((before_1, after_1), ..., (before_N, after_N))"); - TVM_ATTR_FIELD(pad_mode).set_default("constant") - .describe("Padding type to use. \"constant\" pads with constant_value, " - "\"edge\" pads using the edge values of the input array, " - "\"reflect\" pads by reflecting values with respect to the edges."); + TVM_ATTR_FIELD(pad_value).set_default(0.0).describe( + "The value used for padding when mode is 'constant'."); + TVM_ATTR_FIELD(pad_width).describe( + "Number of values padded to the edges of each axis, " + "in the format of ((before_1, after_1), ..., (before_N, after_N))"); + TVM_ATTR_FIELD(pad_mode) + .set_default("constant") + .describe( + "Padding type to use. \"constant\" pads with constant_value, " + "\"edge\" pads using the edge values of the input array, " + "\"reflect\" pads by reflecting values with respect to the edges."); } }; @@ -859,11 +953,12 @@ struct MirrorPadAttrs : public tvm::AttrsNode { Array > pad_width; TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") { - TVM_ATTR_FIELD(mode).set_default("SYMMETRIC") - .describe("Specifies how mirroring should be performed."); - TVM_ATTR_FIELD(pad_width) - .describe("Number of values padded to the edges of each axis, " - "in the format of ((before_1, after_1), ..., (before_N, after_N))"); + TVM_ATTR_FIELD(mode) + .set_default("SYMMETRIC") + .describe("Specifies how mirroring should be performed."); + TVM_ATTR_FIELD(pad_width).describe( + "Number of values padded to the edges of each axis, " + "in the format of ((before_1, after_1), ..., (before_N, after_N))"); } }; @@ -872,30 +967,28 @@ struct LeakyReluAttrs : public tvm::AttrsNode { double alpha; TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") { - TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25) - .describe("Slope coefficient for the negative half axis."); + TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25).describe( + "Slope coefficient for the negative half axis."); } }; - /*! \brief Attributes for prelu operator */ struct PReluAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(PReluAttrs, "relay.attrs.PReluAttrs") { - TVM_ATTR_FIELD(axis).set_default(1) - .describe("Specify which shape axis the channel is specified."); + TVM_ATTR_FIELD(axis).set_default(1).describe( + "Specify which shape axis the channel is specified."); } }; - /*! \brief Attributes used in dropout operator */ struct DropoutAttrs : public tvm::AttrsNode { double rate; TVM_DECLARE_ATTRS(DropoutAttrs, "relay.attrs.DropoutAttrs") { TVM_ATTR_FIELD(rate) - .describe("Fraction of the input that gets dropped out during training time") - .set_default(0.5); + .describe("Fraction of the input that gets dropped out during training time") + .set_default(0.5); } }; // struct DropoutAttrs @@ -907,24 +1000,22 @@ struct BatchNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(BatchNormAttrs, "relay.attrs.BatchNormAttrs") { - TVM_ATTR_FIELD(axis) - .describe("Specify which shape axis denotes the channel.") - .set_default(1); + TVM_ATTR_FIELD(axis).describe("Specify which shape axis denotes the channel.").set_default(1); TVM_ATTR_FIELD(epsilon) - .describe("Small float added to variance to avoid dividing by zero") - .set_default(1e-5); + .describe("Small float added to variance to avoid dividing by zero") + .set_default(1e-5); TVM_ATTR_FIELD(center) - .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored") - .set_default(true); + .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored") + .set_default(true); TVM_ATTR_FIELD(scale) - .describe("If True, multiply by gamma. If False, gamma is not used. " - "When the next layer is piecewise linear (also, e.g., nn.relu), " - "this can be disabled since the scaling will be done by the next layer.") - .set_default(true); + .describe( + "If True, multiply by gamma. If False, gamma is not used. " + "When the next layer is piecewise linear (also, e.g., nn.relu), " + "this can be disabled since the scaling will be done by the next layer.") + .set_default(true); } }; // struct BatchNormAttrs - /*! \brief Attributes used in instance_norm operator */ struct InstanceNormAttrs : public tvm::AttrsNode { int axis; @@ -933,21 +1024,18 @@ struct InstanceNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") { - TVM_ATTR_FIELD(axis) - .describe("Specify which shape axis denotes the channel.") - .set_default(1); + TVM_ATTR_FIELD(axis).describe("Specify which shape axis denotes the channel.").set_default(1); TVM_ATTR_FIELD(epsilon) - .describe("Small float added to variance to avoid dividing by zero") - .set_default(1e-5); - TVM_ATTR_FIELD(center).set_default(true) - .describe("If true, add offset of beta to normalized tensor; " - "otherwise, beta is ignored."); - TVM_ATTR_FIELD(scale).set_default(true) - .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + .describe("Small float added to variance to avoid dividing by zero") + .set_default(1e-5); + TVM_ATTR_FIELD(center).set_default(true).describe( + "If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true).describe( + "If true, multiply by gamma; otherwise, gamma is ignored."); } }; // struct InstanceNormAttrs - /*! \brief Attributes used in layer_norm operator */ struct LayerNormAttrs : public tvm::AttrsNode { int axis; @@ -956,19 +1044,17 @@ struct LayerNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(LayerNormAttrs, "relay.attrs.LayerNormAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Specify which shape axis denotes the channel."); - TVM_ATTR_FIELD(epsilon).set_default(1e-5) - .describe("Small float added to variance to avoid dividing by zero"); - TVM_ATTR_FIELD(center).set_default(true) - .describe("If true, add offset of beta to normalized tensor; " - "otherwise, beta is ignored."); - TVM_ATTR_FIELD(scale).set_default(true) - .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Specify which shape axis denotes the channel."); + TVM_ATTR_FIELD(epsilon).set_default(1e-5).describe( + "Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).set_default(true).describe( + "If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true).describe( + "If true, multiply by gamma; otherwise, gamma is ignored."); } }; // struct LayerNormAttrs - /*! \brief Attributes used in group_norm operator */ struct GroupNormAttrs : public tvm::AttrsNode { int num_groups; @@ -978,21 +1064,20 @@ struct GroupNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(GroupNormAttrs, "relay.attrs.GroupNormAttrs") { - TVM_ATTR_FIELD(num_groups).set_default(0) - .describe("Specify number of groups to separate the channels into."); - TVM_ATTR_FIELD(axis).set_default(1) - .describe("Specify which shape axis denotes the channel."); - TVM_ATTR_FIELD(epsilon).set_default(1e-5) - .describe("Small float added to variance to avoid dividing by zero"); - TVM_ATTR_FIELD(center).set_default(true) - .describe("If true, add offset of beta to normalized tensor; " - "otherwise, beta is ignored."); - TVM_ATTR_FIELD(scale).set_default(true) - .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + TVM_ATTR_FIELD(num_groups) + .set_default(0) + .describe("Specify number of groups to separate the channels into."); + TVM_ATTR_FIELD(axis).set_default(1).describe("Specify which shape axis denotes the channel."); + TVM_ATTR_FIELD(epsilon).set_default(1e-5).describe( + "Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).set_default(true).describe( + "If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true).describe( + "If true, multiply by gamma; otherwise, gamma is ignored."); } }; // struct GroupNormAttrs - /*! \brief Attributes for LRN operator */ struct LRNAttrs : public tvm::AttrsNode { int size; @@ -1002,34 +1087,26 @@ struct LRNAttrs : public tvm::AttrsNode { double beta; TVM_DECLARE_ATTRS(LRNAttrs, "relay.attrs.LRNAttrs") { - TVM_ATTR_FIELD(size).set_default(5) - .describe("The size of the local region to be considered for normalization."); - TVM_ATTR_FIELD(axis).set_default(1) - .describe("Axis of input data layout channel."); - TVM_ATTR_FIELD(bias).set_default(2) - .describe("The offset parameter to avoid division by 0."); - TVM_ATTR_FIELD(alpha).set_default(0.0001) - .describe("The scaling parameter."); - TVM_ATTR_FIELD(beta).set_default(0.75) - .describe("The exponent parameter."); + TVM_ATTR_FIELD(size).set_default(5).describe( + "The size of the local region to be considered for normalization."); + TVM_ATTR_FIELD(axis).set_default(1).describe("Axis of input data layout channel."); + TVM_ATTR_FIELD(bias).set_default(2).describe("The offset parameter to avoid division by 0."); + TVM_ATTR_FIELD(alpha).set_default(0.0001).describe("The scaling parameter."); + TVM_ATTR_FIELD(beta).set_default(0.75).describe("The exponent parameter."); } }; - /*! \brief Attributes for L2Normalize operator */ struct L2NormalizeAttrs : public tvm::AttrsNode { double eps; Array axis; TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") { - TVM_ATTR_FIELD(eps) - .describe("A lower bound value for the norm, to avoid division by 0."); - TVM_ATTR_FIELD(axis) - .describe("Axis over the normalization applied."); + TVM_ATTR_FIELD(eps).describe("A lower bound value for the norm, to avoid division by 0."); + TVM_ATTR_FIELD(axis).describe("Axis over the normalization applied."); } }; - /*! \brief Attributes for DeformableConv2D operator */ struct DeformableConv2DAttrs : public tvm::AttrsNode { Array strides; @@ -1045,46 +1122,59 @@ struct DeformableConv2DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(DeformableConv2DAttrs, "relay.attrs.DeformableConv2DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(deformable_groups).set_default(1) - .describe("Controls the connections between inputs and offsets." - "Input channels are partitioned into multiple deformable groups. Offsets" - "are shared across input channels in the same deformable group."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(deformable_groups) + .set_default(1) + .describe( + "Controls the connections between inputs and offsets." + "Input channels are partitioned into multiple deformable groups. Offsets" + "are shared across input channels in the same deformable group."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 443efb5..f57c1f4 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_REDUCE_H_ #include + #include namespace tvm { @@ -37,7 +38,8 @@ struct ReduceAttrs : public tvm::AttrsNode { bool exclude; TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue>()) + TVM_ATTR_FIELD(axis) + .set_default(NullValue>()) .describe(R"code(The axis or axes along which to perform the reduction. The default, `axis=()`, will compute over all elements into a @@ -51,11 +53,11 @@ struct ReduceAttrs : public tvm::AttrsNode { If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead.)code"); - TVM_ATTR_FIELD(keepdims).set_default(false) - .describe("If this is set to `True`, the reduced axes are left " - "in the result as dimension with size one."); - TVM_ATTR_FIELD(exclude).set_default(false) - .describe("Whether to perform reduction on axis that are NOT in axis instead."); + TVM_ATTR_FIELD(keepdims).set_default(false).describe( + "If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + TVM_ATTR_FIELD(exclude).set_default(false).describe( + "Whether to perform reduction on axis that are NOT in axis instead."); } }; } // namespace relay diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index ae2ac11..84dda6f 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -27,6 +27,7 @@ #include #include #include + #include namespace tvm { @@ -37,8 +38,7 @@ struct CastAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") { - TVM_ATTR_FIELD(dtype) - .describe("Target data type"); + TVM_ATTR_FIELD(dtype).describe("Target data type"); } }; // struct CastAttrs. @@ -48,11 +48,11 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { int num_newaxis; TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs") { - TVM_ATTR_FIELD(axis) - .describe("The axis at which the input array is expanded." - "Should lie in range `[-data.ndim - 1, data.ndim]`." - "If `axis < 0`, it is the first axis inserted;" - "If `axis >= 0`, it is the last axis inserted in Python's negative indexing."); + TVM_ATTR_FIELD(axis).describe( + "The axis at which the input array is expanded." + "Should lie in range `[-data.ndim - 1, data.ndim]`." + "If `axis < 0`, it is the first axis inserted;" + "If `axis >= 0`, it is the last axis inserted in Python's negative indexing."); TVM_ATTR_FIELD(num_newaxis) .describe("Number of axises to be inserted. Should be >= 0.") .set_lower_bound(0) @@ -65,8 +65,9 @@ struct ConcatenateAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") { TVM_ATTR_FIELD(axis) - .describe("The axis at which the input arrays are concatenated." - "Should lie in range `[-ndim, ndim)`.") + .describe( + "The axis at which the input arrays are concatenated." + "Should lie in range `[-ndim, ndim)`.") .set_default(0); } }; // struct ConcatenateAttrs @@ -75,8 +76,7 @@ struct ConcatenateAttrs : public tvm::AttrsNode { struct TransposeAttrs : public tvm::AttrsNode { Array axes; TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") { - TVM_ATTR_FIELD(axes) - .describe("The target axes order, reverse order if not specified."); + TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified."); } }; // struct TransposeAttrs @@ -85,8 +85,8 @@ struct ReshapeAttrs : public tvm::AttrsNode { Array newshape; bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { - TVM_ATTR_FIELD(newshape) - .describe("The new shape. Should be compatible with the original shape."); + TVM_ATTR_FIELD(newshape).describe( + "The new shape. Should be compatible with the original shape."); TVM_ATTR_FIELD(reverse) .describe("Infer the special values from right to left if true") .set_default(false); @@ -98,13 +98,14 @@ struct TakeAttrs : public tvm::AttrsNode { std::string mode; TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) .describe("The axis over which to select values."); - TVM_ATTR_FIELD(mode).set_default("clip") - .describe("Specify how out-of-bound indices will behave." - "clip - clip to the range (default)" - "wrap - wrap around the indices" - "fast - no clip or wrap around (user must make sure indices are in-bound)"); + TVM_ATTR_FIELD(mode).set_default("clip").describe( + "Specify how out-of-bound indices will behave." + "clip - clip to the range (default)" + "wrap - wrap around the indices" + "fast - no clip or wrap around (user must make sure indices are in-bound)"); } }; @@ -114,11 +115,8 @@ struct InitOpAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") { - TVM_ATTR_FIELD(shape) - .describe("Target shape."); - TVM_ATTR_FIELD(dtype) - .describe("Target data type.") - .set_default(NullValue()); + TVM_ATTR_FIELD(shape).describe("Target shape."); + TVM_ATTR_FIELD(dtype).describe("Target data type.").set_default(NullValue()); } }; // struct InitOpAttrs @@ -130,14 +128,10 @@ struct ArangeAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") { - TVM_ATTR_FIELD(start) - .describe("Start of interval. The interval includes this value."); - TVM_ATTR_FIELD(stop) - .describe("Stop of interval. The interval does not include this value."); - TVM_ATTR_FIELD(step) - .describe("Spacing between values."); - TVM_ATTR_FIELD(dtype) - .describe("Target data type."); + TVM_ATTR_FIELD(start).describe("Start of interval. The interval includes this value."); + TVM_ATTR_FIELD(stop).describe("Stop of interval. The interval does not include this value."); + TVM_ATTR_FIELD(step).describe("Spacing between values."); + TVM_ATTR_FIELD(dtype).describe("Target data type."); } }; // struct ArangeAttrs @@ -145,8 +139,8 @@ struct ArangeAttrs : public tvm::AttrsNode { struct StackAttrs : public tvm::AttrsNode { Integer axis; TVM_DECLARE_ATTRS(StackAttrs, "relay.attrs.StackAttrs") { - TVM_ATTR_FIELD(axis).set_default(0) - .describe("The axis in the result array along which the input arrays are stacked."); + TVM_ATTR_FIELD(axis).set_default(0).describe( + "The axis in the result array along which the input arrays are stacked."); } }; // struct StackAttrs @@ -155,9 +149,9 @@ struct RepeatAttrs : public tvm::AttrsNode { Integer repeats; Integer axis; TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") { - TVM_ATTR_FIELD(repeats) - .describe("The number of repetitions for each element."); - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(repeats).describe("The number of repetitions for each element."); + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) .describe(" The axis along which to repeat values."); } }; // struct RepeatAttrs @@ -166,9 +160,9 @@ struct RepeatAttrs : public tvm::AttrsNode { struct TileAttrs : public tvm::AttrsNode { Array reps; TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") { - TVM_ATTR_FIELD(reps) - .describe("The number of times for repeating the tensor a." - "Each dim sizeof reps must be a positive integer."); + TVM_ATTR_FIELD(reps).describe( + "The number of times for repeating the tensor a." + "Each dim sizeof reps must be a positive integer."); } }; // struct TileAttrs @@ -176,7 +170,8 @@ struct TileAttrs : public tvm::AttrsNode { struct ReverseAttrs : public tvm::AttrsNode { Integer axis; TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) .describe("The axis along which to reverse elements."); } }; // struct ReverseAttrs @@ -188,10 +183,11 @@ struct SqueezeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") { TVM_ATTR_FIELD(axis) - .describe("The axis to squeeze in the input tensor." - "If `axis = None`, all axis of dimension 1 get squeezed;" - "Else, the dimension in axes get squeezed." - "It is an error if an axis does not has dimension 1.") + .describe( + "The axis to squeeze in the input tensor." + "If `axis = None`, all axis of dimension 1 get squeezed;" + "Else, the dimension in axes get squeezed." + "It is an error if an axis does not has dimension 1.") .set_default(NullValue >()); } }; // struct SqueezeAttrs @@ -202,13 +198,13 @@ struct SplitAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { TVM_ATTR_FIELD(indices_or_sections) - .describe("Indices or sections to split into. Accepts an int or a tuple" - "If indices_or_sections is an integer, the input will be divided equally" - "along given axis. If such a split is not possible, an error is raised." - "If indices_or_sections is a tuple of sorted integers," - "the entries indicate where along axis the array is split."); - TVM_ATTR_FIELD(axis).set_default(0) - .describe("the axis to be splitted."); + .describe( + "Indices or sections to split into. Accepts an int or a tuple" + "If indices_or_sections is an integer, the input will be divided equally" + "along given axis. If such a split is not possible, an error is raised." + "If indices_or_sections is a tuple of sorted integers," + "the entries indicate where along axis the array is split."); + TVM_ATTR_FIELD(axis).set_default(0).describe("the axis to be splitted."); } }; @@ -219,12 +215,9 @@ struct StridedSliceAttrs : public tvm::AttrsNode { Array strides; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { - TVM_ATTR_FIELD(begin) - .describe("Indices for begin of slice, begin index is also inclusive"); - TVM_ATTR_FIELD(end) - .describe("Indices for end of slice, end index is exclusive"); - TVM_ATTR_FIELD(strides).set_default(Array({})) - .describe("Stride values of the slice"); + TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive"); + TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive"); + TVM_ATTR_FIELD(strides).set_default(Array({})).describe("Stride values of the slice"); } }; @@ -232,10 +225,10 @@ struct SliceLikeAttrs : public tvm::AttrsNode { Array axes; TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") { - TVM_ATTR_FIELD(axes) - .describe("List of axes on which input data will be sliced according to the " - "corresponding size of the second input. By default will slice " - "on all axes. Negative axes mean counting in reverse."); + TVM_ATTR_FIELD(axes).describe( + "List of axes on which input data will be sliced according to the " + "corresponding size of the second input. By default will slice " + "on all axes. Negative axes mean counting in reverse."); } }; @@ -245,10 +238,8 @@ struct ClipAttrs : public tvm::AttrsNode { double a_max; TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { - TVM_ATTR_FIELD(a_min) - .describe("The minimum clip value."); - TVM_ATTR_FIELD(a_max) - .describe("The maximum clip value."); + TVM_ATTR_FIELD(a_min).describe("The minimum clip value."); + TVM_ATTR_FIELD(a_max).describe("The maximum clip value."); } }; @@ -258,10 +249,8 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { std::string dst_layout; TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") { - TVM_ATTR_FIELD(src_layout) - .describe("The source layout of the tensor. (e.g. NCHW)"); - TVM_ATTR_FIELD(dst_layout) - .describe("The destination layout of the tensor. (e.g. NCHW16c)"); + TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. NCHW)"); + TVM_ATTR_FIELD(dst_layout).describe("The destination layout of the tensor. (e.g. NCHW16c)"); } }; @@ -270,9 +259,7 @@ struct ShapeOfAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") { - TVM_ATTR_FIELD(dtype) - .describe("Target data type") - .set_default(NullValue()); + TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue()); } }; @@ -281,10 +268,9 @@ struct SequenceMaskAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(SequenceMaskAttrs, "relay.attrs.SequenceMaskAttrs") { - TVM_ATTR_FIELD(mask_value).set_default(0) - .describe("The masking value."); - TVM_ATTR_FIELD(axis).set_default(0) - .describe("The axis of the length dimension. Can only be 0 or 1."); + TVM_ATTR_FIELD(mask_value).set_default(0).describe("The masking value."); + TVM_ATTR_FIELD(axis).set_default(0).describe( + "The axis of the length dimension. Can only be 0 or 1."); } }; // struct SequenceMaskAttrs. @@ -293,9 +279,7 @@ struct NdarraySizeAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs") { - TVM_ATTR_FIELD(dtype) - .describe("Target data type") - .set_default(NullValue()); + TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue()); } }; @@ -306,12 +290,9 @@ struct OneHotAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") { - TVM_ATTR_FIELD(depth).set_default(1) - .describe("Depth of the one hot dimension."); - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Axis to fill."); - TVM_ATTR_FIELD(dtype).set_default(NullValue()) - .describe("Output data type."); + TVM_ATTR_FIELD(depth).set_default(1).describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill."); + TVM_ATTR_FIELD(dtype).set_default(NullValue()).describe("Output data type."); } }; // struct OneHotAttrs diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index c4a30ce..e7e24b1 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -41,39 +42,32 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(MultiBoxPriorAttrs, "relay.attrs.MultiBoxPriorAttrs") { TVM_ATTR_FIELD(sizes) - .set_default(Array({static_cast(1.0)})) - .describe("List of sizes of generated MultiBoxPriores."); + .set_default(Array({static_cast(1.0)})) + .describe("List of sizes of generated MultiBoxPriores."); TVM_ATTR_FIELD(ratios) - .set_default(Array({static_cast(1.0)})) - .describe("List of aspect ratios of generated MultiBoxPriores."); + .set_default(Array({static_cast(1.0)})) + .describe("List of aspect ratios of generated MultiBoxPriores."); TVM_ATTR_FIELD(steps) - .set_default(Array({static_cast(-1.0), - static_cast(-1.0)})) - .describe("Priorbox step across y and x, -1 for auto calculation."); + .set_default(Array({static_cast(-1.0), static_cast(-1.0)})) + .describe("Priorbox step across y and x, -1 for auto calculation."); TVM_ATTR_FIELD(offsets) - .set_default(Array({static_cast(0.5), - static_cast(0.5)})) - .describe("Priorbox center offsets, y and x respectively."); - TVM_ATTR_FIELD(clip).set_default(false) - .describe("Whether to clip out-of-boundary boxes."); + .set_default(Array({static_cast(0.5), static_cast(0.5)})) + .describe("Priorbox center offsets, y and x respectively."); + TVM_ATTR_FIELD(clip).set_default(false).describe("Whether to clip out-of-boundary boxes."); } }; -struct MultiBoxTransformLocAttrs - : public tvm::AttrsNode { +struct MultiBoxTransformLocAttrs : public tvm::AttrsNode { bool clip; double threshold; Array variances; - TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, - "relay.attrs.MultiBoxTransformLocAttrs") { - TVM_ATTR_FIELD(clip).set_default(true) - .describe("Clip out-of-boundary boxes."); - TVM_ATTR_FIELD(threshold).set_default(0.01) - .describe("Threshold to be a positive prediction."); + TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, "relay.attrs.MultiBoxTransformLocAttrs") { + TVM_ATTR_FIELD(clip).set_default(true).describe("Clip out-of-boundary boxes."); + TVM_ATTR_FIELD(threshold).set_default(0.01).describe("Threshold to be a positive prediction."); TVM_ATTR_FIELD(variances) - .set_default(Array({0.1f, 0.1f , 0.2f, 0.2f})) - .describe("Variances to be decoded from box regression output."); + .set_default(Array({0.1f, 0.1f, 0.2f, 0.2f})) + .describe("Variances to be decoded from box regression output."); } }; @@ -84,12 +78,11 @@ struct GetValidCountsAttrs : public tvm::AttrsNode { int score_index; TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { - TVM_ATTR_FIELD(score_threshold).set_default(0.0) - .describe("Lower limit of score for valid bounding boxes."); - TVM_ATTR_FIELD(id_index).set_default(0) - .describe("Axis index of id."); - TVM_ATTR_FIELD(score_index).set_default(1) - .describe("Index of the scores/confidence of boxes."); + TVM_ATTR_FIELD(score_threshold) + .set_default(0.0) + .describe("Lower limit of score for valid bounding boxes."); + TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id."); + TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the scores/confidence of boxes."); } }; @@ -106,25 +99,28 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { Integer stride; TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") { - TVM_ATTR_FIELD(stride) - .set_default(1) - .describe("Stride value for yolo reorg"); + TVM_ATTR_FIELD(stride).set_default(1).describe("Stride value for yolo reorg"); } }; @@ -206,10 +200,8 @@ struct ProposalAttrs : public tvm::AttrsNode { .describe( "The size of the receptive field each unit in the convolution layer of the rpn," "for example the product of all stride's prior to this layer."); - TVM_ATTR_FIELD(threshold) - .set_default(0.7) - .describe( - "IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)"); + TVM_ATTR_FIELD(threshold).set_default(0.7).describe( + "IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)"); TVM_ATTR_FIELD(rpn_pre_nms_top_n) .set_default(6000) .describe("Number of top scoring boxes to apply NMS. -1 to use all boxes"); diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 1d01206..c78ab75 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -24,10 +24,10 @@ #ifndef TVM_RELAY_BASE_H_ #define TVM_RELAY_BASE_H_ - #include -#include #include +#include + #include #include @@ -42,17 +42,19 @@ namespace tvm { */ namespace relay { -#define RELAY_DEBUG(...) \ -{ auto fdebug = runtime::Registry::Get("relay.debug"); \ - CHECK(fdebug) << "Could not find Relay Python debugger function."; \ - (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ -} +#define RELAY_DEBUG(...) \ + { \ + auto fdebug = runtime::Registry::Get("relay.debug"); \ + CHECK(fdebug) << "Could not find Relay Python debugger function."; \ + (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ + } -#define RELAY_DEBUG_INTERP(...) \ -{ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \ - CHECK(fdebug) << "Could not find Relay Python debugger function."; \ - (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ -} +#define RELAY_DEBUG_INTERP(...) \ + { \ + auto fdebug = runtime::Registry::Get("relay.debug_interp"); \ + CHECK(fdebug) << "Could not find Relay Python debugger function."; \ + (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ + } /*! * \brief Symbolic expression for tensor shape. @@ -93,9 +95,7 @@ class IdNode : public Object { */ std::string name_hint; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name_hint", &name_hint); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); } static constexpr const char* _type_key = "relay.Id"; TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 8c50260..69a60a7 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -26,10 +26,12 @@ #include #include -#include #include -#include +#include + #include +#include + #include "./base.h" #include "./type.h" @@ -63,9 +65,7 @@ class ConstantNode : public ExprNode { TensorType tensor_type() const; /*! \return Whether it is scalar(rank-0 tensor) */ - bool is_scalar() const { - return data->ndim == 0; - } + bool is_scalar() const { return data->ndim == 0; } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); @@ -77,9 +77,7 @@ class ConstantNode : public ExprNode { return equal(data, other->data); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(data); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); } static constexpr const char* _type_key = "relay.Constant"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode); @@ -172,9 +170,7 @@ class VarNode : public ExprNode { Type type_annotation; /*! \return The name hint of the variable */ - const std::string& name_hint() const { - return vid->name_hint; - } + const std::string& name_hint() const { return vid->name_hint; } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("vid", &vid); @@ -184,9 +180,7 @@ class VarNode : public ExprNode { } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { - return - equal(type_annotation, other->type_annotation) && - equal.FreeVarEqualImpl(this, other); + return equal(type_annotation, other->type_annotation) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -194,11 +188,9 @@ class VarNode : public ExprNode { hash_reduce.FreeVarHashImpl(this); } - TVM_DLL static Var make(std::string name_hint, - Type type_annotation); + TVM_DLL static Var make(std::string name_hint, Type type_annotation); - TVM_DLL static Var make(Id vid, - Type type_annotation); + TVM_DLL static Var make(Id vid, Type type_annotation); static constexpr const char* _type_key = "relay.Var"; TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); @@ -211,8 +203,7 @@ class Var : public Expr { * \param name_hint The name hint of a variable. * \param type_annotation The type annotation of a variable. */ - TVM_DLL Var(std::string name_hint, Type type_annotation) : - Var(Id(name_hint), type_annotation) {} + TVM_DLL Var(std::string name_hint, Type type_annotation) : Var(Id(name_hint), type_annotation) {} /*! * \brief The constructor @@ -278,11 +269,8 @@ class CallNode : public ExprNode { bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { // skip type_args check for primitive ops. equal->MarkGraphNode(); - return - equal(op, other->op) && - equal(args, other->args) && - equal(attrs, other->attrs) && - (IsPrimitiveOp(op) || equal(type_args, other->type_args)); + return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && + (IsPrimitiveOp(op) || equal(type_args, other->type_args)); } void SHashReduce(SHashReducer hash_reduce) const { @@ -308,9 +296,7 @@ class Call : public Expr { * \param attrs The attributes of the call node. * \param type_args The type arguments passed to a polymorphic function. */ - TVM_DLL Call(Expr op, - Array args, - Attrs attrs = Attrs(), + TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), Array type_args = Array()); TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); @@ -348,10 +334,8 @@ class LetNode : public ExprNode { bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal.DefEqual(var, other->var) && - equal(value, other->value) && - equal(body, other->body); + return equal.DefEqual(var, other->var) && equal(value, other->value) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -410,10 +394,8 @@ class IfNode : public ExprNode { bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal(cond, other->cond) && - equal(true_branch, other->true_branch) && - equal(false_branch, other->false_branch); + return equal(cond, other->cond) && equal(true_branch, other->true_branch) && + equal(false_branch, other->false_branch); } void SHashReduce(SHashReducer hash_reduce) const { @@ -457,9 +439,7 @@ class TupleGetItemNode : public ExprNode { } bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { - return - equal(tuple, other->tuple) && - equal(index, other->index); + return equal(tuple, other->tuple) && equal(index, other->index); } void SHashReduce(SHashReducer hash_reduce) const { @@ -576,9 +556,7 @@ class RefWriteNode : public ExprNode { bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal(ref, other->ref) && - equal(value, other->value); + return equal(ref, other->ref) && equal(value, other->value); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 04b2754..559f9b8 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -25,16 +25,16 @@ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_ -#include #include +#include +#include #include #include -#include #include #include -#include #include +#include namespace tvm { namespace relay { @@ -54,15 +54,13 @@ template class ExprFunctor; // functions to be overriden. -#define EXPR_FUNCTOR_DEFAULT \ +#define EXPR_FUNCTOR_DEFAULT \ { return VisitExprDefault_(op, std::forward(args)...); } -#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); +#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); template class ExprFunctor { @@ -81,9 +79,7 @@ class ExprFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Expr& n, Args... args) { - return VisitExpr(n, std::forward(args)...); - } + R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward(args)...); } /*! * \brief The functor call. * \param n The expression node. @@ -96,22 +92,15 @@ class ExprFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitExpr_(const ConstantNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const TupleNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const VarNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const GlobalVarNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FunctionNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const IfNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const OpNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -154,8 +143,7 @@ class ExprFunctor { * ExprVisitor treats Expr as dataflow graph, * and only visit each Expr node once. */ -class ExprVisitor - : public ::tvm::relay::ExprFunctor { +class ExprVisitor : public ::tvm::relay::ExprFunctor { public: void VisitExpr(const Expr& expr) override; void VisitExpr_(const VarNode* op) override; @@ -189,16 +177,13 @@ class ExprVisitor * The mutated results are memoized in a map and reused so that * local transformation on the dataflow preserves the graph structure. */ -class ExprMutator - : public ::tvm::relay::ExprFunctor { +class ExprMutator : public ::tvm::relay::ExprFunctor { public: /*! * \brief Mutate is alias for VisitExpr * \return expr. */ - Expr Mutate(const Expr& expr) { - return this->VisitExpr(expr); - } + Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); } Expr VisitExpr(const Expr& expr) override; Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const ConstantNode* op) override; @@ -283,7 +268,8 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions * of the graph and processes them iteratatively to prevent stack overflows * - * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive behavior. + * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive + * behavior. */ class MixedModeMutator : public ::tvm::relay::ExprMutator { public: @@ -293,14 +279,14 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); }; Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); }; /*! - * \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be - * able to rewrite the op only with data about the original node `pre` and the same node with + * \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will + * be able to rewrite the op only with data about the original node `pre` and the same node with * modified inputs `post` and should not recurse. * * \param pre The expression node before rewriting. * \param post The expression with rewritten inputs. */ - virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;} + virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post; } virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; } virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; } @@ -350,9 +336,7 @@ class ExprRewriter { * \param post The expression node with rewritten inputs. * \return The result of the call */ - Expr operator()(const Expr& pre, const Expr& post) { - return Rewrite(pre, post); - } + Expr operator()(const Expr& pre, const Expr& post) { return Rewrite(pre, post); } /*! * \brief The functor call. * \param pre The expression node before rewriting. diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 744d7c4..3783e32 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -24,9 +24,9 @@ #ifndef TVM_RELAY_FEATURE_H_ #define TVM_RELAY_FEATURE_H_ +#include #include #include -#include #include @@ -65,9 +65,7 @@ class FeatureSet { public: FeatureSet(const FeatureSet&) = default; /*! \brief A singleton set containing a single Feature. */ - explicit FeatureSet(Feature ft) { - bs_.set(static_cast(ft)); - } + explicit FeatureSet(Feature ft) { bs_.set(static_cast(ft)); } explicit FeatureSet(const tvm::Array& ft) { for (Integer i : ft) { (*this) += Feature(static_cast(i)); @@ -93,25 +91,25 @@ class FeatureSet { FeatureSet fs; return fs; } - template + template FeatureSet& operator+=(const T& rhs) { bs_ |= FeatureSet(rhs).bs_; return *this; } /*! \brief Set union. */ - template + template FeatureSet operator+(const T& rhs) const { FeatureSet fs(*this); fs += rhs; return fs; } - template + template FeatureSet& operator-=(const T& rhs) { bs_ &= ~(FeatureSet(rhs)).bs_; return *this; } /*! \brief Set difference. */ - template + template FeatureSet operator-(const T& rhs) const { FeatureSet fs(*this); fs -= rhs; @@ -124,14 +122,12 @@ class FeatureSet { * * \return true only if this is a subset of rhs. */ - bool is_subset_of(const FeatureSet& rhs) const { - return ((*this) - rhs).bs_.none(); - } + bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); } private: std::bitset bs_; FeatureSet() = default; - explicit FeatureSet(const std::bitset& bs) : bs_(bs) { } + explicit FeatureSet(const std::bitset& bs) : bs_(bs) {} }; /*! diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 33b813b..ab9111b 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -26,8 +26,8 @@ #include #include -#include +#include namespace tvm { namespace relay { @@ -71,12 +71,9 @@ class FunctionNode : public BaseFuncNode { bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { // Important to make def equal first. equal->MarkGraphNode(); - return - equal.DefEqual(params, other->params) && - equal.DefEqual(type_params, other->type_params) && - equal(ret_type, other->ret_type) && - equal(attrs, other->attrs) && - equal(body, other->body); + return equal.DefEqual(params, other->params) && + equal.DefEqual(type_params, other->type_params) && equal(ret_type, other->ret_type) && + equal(attrs, other->attrs) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -100,7 +97,6 @@ class FunctionNode : public BaseFuncNode { TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); }; - /*! * \brief Managed reference to FunctionNode. * \sa FunctionNode @@ -115,10 +111,7 @@ class Function : public BaseFunc { * \param ty_params The type parameters. * \param attrs Additional function attributes. */ - TVM_DLL Function(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array ty_params, + TVM_DLL Function(tvm::Array params, Expr body, Type ret_type, tvm::Array ty_params, tvm::DictAttrs attrs = NullValue()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index ae1f84a..bda73ed 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -36,12 +36,11 @@ #include #include -#include #include +#include #include #include - namespace tvm { namespace relay { @@ -64,8 +63,8 @@ namespace relay { * \param target Compiler target flag to compile the functions on the context. * \return A function that takes in an expression and returns a value. */ -runtime::TypedPackedFunc -CreateInterpreter(IRModule mod, DLContext context, Target target); +runtime::TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, + Target target); /*! \brief The container type of Closures used by the interpreter. */ class InterpreterClosureObj : public runtime::vm::ClosureObj { @@ -96,8 +95,7 @@ class InterpreterClosureObj : public runtime::vm::ClosureObj { class InterpreterClosure : public runtime::vm::Closure { public: TVM_DLL InterpreterClosure(tvm::Map env, Function func); - TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, - InterpreterClosureObj); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, InterpreterClosureObj); }; /*! \brief The container type of RecClosure. */ @@ -130,9 +128,7 @@ struct RefValueObj : Object { RefValueObj() {} - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("value", &value); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); } static constexpr const char* _type_key = "relay.RefValue"; TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object); @@ -164,9 +160,7 @@ struct ConstructorValueObj : Object { class ConstructorValue : public ObjectRef { public: - TVM_DLL ConstructorValue(int32_t tag, - tvm::Array fields, - Constructor construtor = {}); + TVM_DLL ConstructorValue(int32_t tag, tvm::Array fields, Constructor construtor = {}); TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj); }; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index fa47da2..1284515 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -25,8 +25,8 @@ #define TVM_RELAY_OP_H_ #include -#include #include +#include namespace tvm { namespace relay { @@ -34,8 +34,7 @@ namespace relay { using Op = tvm::Op; using OpNode = tvm::OpNode; -#define RELAY_REGISTER_OP(OpName) \ - TVM_REGISTER_OP(OpName) +#define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName) } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 5b2fdd3..b3e70f5 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -24,21 +24,22 @@ #ifndef TVM_RELAY_OP_ATTR_TYPES_H_ #define TVM_RELAY_OP_ATTR_TYPES_H_ -#include -#include -#include #include -#include +#include #include +#include +#include +#include #include + #include namespace tvm { namespace relay { +using tir::BijectiveLayoutNode; using tir::Layout; using tir::LayoutAxis; -using tir::BijectiveLayoutNode; /*! \brief operator pattern used in graph fusion */ enum OpPatternKind { @@ -104,10 +105,8 @@ using TShapeDataDependant = bool; & these are always placeholders. * \return The output compute description of the operator. */ -using FTVMCompute = runtime::TypedPackedFunc< - Array(const Attrs& attrs, - const Array& inputs, - const Type& out_type)>; +using FTVMCompute = runtime::TypedPackedFunc( + const Attrs& attrs, const Array& inputs, const Type& out_type)>; /*! * \brief Build the computation schedule for @@ -118,10 +117,8 @@ using FTVMCompute = runtime::TypedPackedFunc< * \param target The build target. * \return schedule The computation schedule. */ -using FTVMSchedule = runtime::TypedPackedFunc< - te::Schedule(const Attrs& attrs, - const Array& outs, - const Target& target)>; +using FTVMSchedule = runtime::TypedPackedFunc& outs, const Target& target)>; /*! * \brief Generate the strategy of operators. This function is a generic @@ -143,11 +140,9 @@ using FTVMStrategy = GenericFunc; * and dtype of the inputs. * \return new_expr The modified expression. */ -using FTVMAlterOpLayout = runtime::TypedPackedFunc< - Expr(const Attrs& attrs, - const Array& args, - const Array& tinfos, - const Type& out_type)>; +using FTVMAlterOpLayout = + runtime::TypedPackedFunc& args, + const Array& tinfos, const Type& out_type)>; /*! * \brief Convert the layout of operators or replace the @@ -160,11 +155,9 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc< * \param desired_layout The desired layout. * \return new_expr The modified expression. */ -using FTVMConvertOpLayout = runtime::TypedPackedFunc< - Expr(const Attrs& attrs, - const Array& args, - const Array& tinfos, - const std::string& desired_layout)>; +using FTVMConvertOpLayout = runtime::TypedPackedFunc& args, const Array& tinfos, + const std::string& desired_layout)>; /*! * \brief Legalizes an expression with another expression. This function will be * invoked in Legalize pass. It is a target-dependent pass. @@ -174,10 +167,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc< * and dtype of the inputs. * \return new_expr The modified expression. */ -using FTVMLegalize = runtime::TypedPackedFunc< - Expr(const Attrs& attrs, - const Array& args, - const Array& arg_types)>; +using FTVMLegalize = runtime::TypedPackedFunc& args, + const Array& arg_types)>; /*! * \brief Annotates an expression to indicate if an op should be compiled using @@ -189,9 +180,8 @@ using FTVMLegalize = runtime::TypedPackedFunc< * \return true if this op should be registered to invoke a specific compiler * for codegen, otherwise, false. */ -using FTVMAnnotateTarget = runtime::TypedPackedFunc< - bool(const Attrs& attrs, // NOLINT(*) - const Array& args)>; +using FTVMAnnotateTarget = runtime::TypedPackedFunc& args)>; /*! * \brief Forward rewriting rule for a specific op. @@ -207,10 +197,8 @@ using FTVMAnnotateTarget = runtime::TypedPackedFunc< * \note When we register the function, we can register * a different signature with ctx to be a specific node type. */ -using FForwardRewrite = runtime::TypedPackedFunc< - Expr(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx)>; +using FForwardRewrite = runtime::TypedPackedFunc& new_args, const ObjectRef& ctx)>; /*! * \brief Gradient for a specific op. @@ -219,8 +207,8 @@ using FForwardRewrite = runtime::TypedPackedFunc< * \param output_grad the gradient of the Expr. * \return the gradient for each parameters. */ -using FPrimalGradient = runtime::TypedPackedFunc(const Expr& orig_call, - const Expr& output_grad)>; +using FPrimalGradient = + runtime::TypedPackedFunc(const Expr& orig_call, const Expr& output_grad)>; /*! * \brief The codegeneration strategy for dynamic dimensions. @@ -233,10 +221,8 @@ enum AnyCodegenStrategy { /*! \brief A runtime representation of shape. */ using Shape = Array; -using FShapeFunc = runtime::TypedPackedFunc< - Array(const Attrs& attrs, - const Array& inputs, - const Array& out_ndims)>; +using FShapeFunc = runtime::TypedPackedFunc( + const Attrs& attrs, const Array& inputs, const Array& out_ndims)>; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/op_strategy.h b/include/tvm/relay/op_strategy.h index a4da95a..3f5876d 100644 --- a/include/tvm/relay/op_strategy.h +++ b/include/tvm/relay/op_strategy.h @@ -25,11 +25,12 @@ #ifndef TVM_RELAY_OP_STRATEGY_H_ #define TVM_RELAY_OP_STRATEGY_H_ -#include -#include #include #include #include +#include +#include + #include namespace tvm { @@ -70,8 +71,7 @@ class OpImplementation : public ObjectRef { * \param out_type The output type information. * \return The output compute description of the operator. */ - TVM_DLL Array Compute(const Attrs& attrs, - const Array& inputs, + TVM_DLL Array Compute(const Attrs& attrs, const Array& inputs, const Type& out_type); /*! * \brief Build the computation schedule. @@ -80,8 +80,7 @@ class OpImplementation : public ObjectRef { * \param target The build target. * \return The computation schedule. */ - TVM_DLL te::Schedule Schedule(const Attrs& attrs, - const Array& outs, + TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array& outs, const Target& target); TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode); @@ -119,8 +118,8 @@ class OpSpecialization : public ObjectRef { * \param name Name of the implementation * \param plevel Priority level of the implementation */ - TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, - std::string name, int plevel); + TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name, + int plevel); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode); }; @@ -133,9 +132,7 @@ class OpStrategyNode : public Object { /*! \brief List of operator specializations. */ Array specializations; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("specializations", &specializations); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); } static constexpr const char* _type_key = "relay.OpStrategy"; TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode); @@ -153,8 +150,8 @@ class OpStrategy : public ObjectRef { * \param name Name of the implementation * \param plevel Priority level of the implementation */ - TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, - std::string name, int plevel); + TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name, + int plevel); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode); }; diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 6e0fb17..ada69c6 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -25,16 +25,16 @@ #ifndef TVM_RELAY_PATTERN_FUNCTOR_H_ #define TVM_RELAY_PATTERN_FUNCTOR_H_ -#include #include +#include #include -#include #include +#include +#include "./adt.h" #include "./expr.h" #include "./op.h" -#include "./adt.h" namespace tvm { namespace relay { @@ -54,15 +54,13 @@ template class PatternFunctor; // functions to be overriden. -#define PATTERN_FUNCTOR_DEFAULT \ +#define PATTERN_FUNCTOR_DEFAULT \ { return VisitPatternDefault_(op, std::forward(args)...); } -#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitPattern_(static_cast(n.get()), \ - std::forward(args)...); \ - }); +#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.get()), std::forward(args)...); \ + }); template class PatternFunctor { @@ -96,14 +94,10 @@ class PatternFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitPattern_(const PatternWildcardNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPattern_(const PatternVarNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPattern_(const PatternConstructorNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPattern_(const PatternTupleNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternWildcardNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternVarNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternConstructorNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPatternDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; @@ -144,8 +138,7 @@ class PatternVisitor : public ::tvm::relay::PatternFunctor { +class PatternMutator : public ::tvm::relay::PatternFunctor { public: Pattern Mutate(const Pattern& pat); Pattern VisitPattern_(const PatternWildcardNode* op) override; @@ -163,6 +156,7 @@ class PatternMutator virtual Var VisitVar(const Var& v); /*! \brief Used to visit the vars inside of patterns. */ virtual Constructor VisitConstructor(const Constructor& c); + private: std::unordered_map var_map_; }; diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 3c1c4a3..4b5cd89 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -25,6 +25,7 @@ #define TVM_RELAY_QNN_ATTRS_H_ #include + #include namespace tvm { @@ -39,19 +40,20 @@ struct RequantizeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { TVM_ATTR_FIELD(axis) - .describe("The output channel axis for channel wise quantization. Default value is -1," - "which corresponds to the last axis.") - .set_default(-1); - TVM_ATTR_FIELD(rounding).set_default("UPWARD") - .describe("Defines the rounding direction when the value is midway between" - "two representable values. There are two supported modes - UPWARD" - "or TONEAREST. Both modes behave exactly same except at the" - "midpoints between the two representable values. At the midpoint," - "UPWARD rounds towards positive infinity (for example -1.5 will be" - "rounded to -1). TONEAREST is the standard rounding where the" - "value is rounded away from zero at midpoints (for example, -1.5" - "rounds to -2). More context can be found at following gblic manual" - "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); + .describe( + "The output channel axis for channel wise quantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); + TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe( + "Defines the rounding direction when the value is midway between" + "two representable values. There are two supported modes - UPWARD" + "or TONEAREST. Both modes behave exactly same except at the" + "midpoints between the two representable values. At the midpoint," + "UPWARD rounds towards positive infinity (for example -1.5 will be" + "rounded to -1). TONEAREST is the standard rounding where the" + "value is rounded away from zero at midpoints (for example, -1.5" + "rounds to -2). More context can be found at following gblic manual" + "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); @@ -64,12 +66,12 @@ struct QuantizeAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { - TVM_ATTR_FIELD(out_dtype) - .describe("Output data type, can be one of [int8 or uint8]."); + TVM_ATTR_FIELD(out_dtype).describe("Output data type, can be one of [int8 or uint8]."); TVM_ATTR_FIELD(axis) - .describe("The output channel axis for channel wise quantization. Default value is -1," - "which corresponds to the last axis.") - .set_default(-1); + .describe( + "The output channel axis for channel wise quantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); } }; diff --git a/include/tvm/relay/qnn/transform.h b/include/tvm/relay/qnn/transform.h index 10cd19a..d1f07c9 100644 --- a/include/tvm/relay/qnn/transform.h +++ b/include/tvm/relay/qnn/transform.h @@ -25,8 +25,8 @@ #ifndef TVM_RELAY_QNN_TRANSFORM_H_ #define TVM_RELAY_QNN_TRANSFORM_H_ -#include #include +#include namespace tvm { namespace relay { diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index dc4097a..461276b 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -24,13 +24,13 @@ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ -#include -#include #include +#include #include #include -#include #include +#include +#include #include @@ -56,11 +56,9 @@ using Sequential = tvm::transform::Sequential; * * \return The created function pass. */ -TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< - Function(Function, IRModule, PassContext)>& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); +TVM_DLL Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, const std::string& name, const tvm::Array& required); /*! \brief Remove expressions which does not effect the program result. * @@ -79,17 +77,17 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< TVM_DLL Pass DeadCodeElimination(bool inline_once = false); /*! -* \brief Convert all expressions of TensorType into GradCell, -* an algebraic data type defined in gradient.rly. -* -* This will delay or decrease memory usage. All calls to -* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory, -* rather only instantiate if needed. It also defines + and * operation -* between GradCell types which can increase performance when using -* zero-filled or one-filled tensors, which is the case in reverse mode ad. -* -* \return the pass -*/ + * \brief Convert all expressions of TensorType into GradCell, + * an algebraic data type defined in gradient.rly. + * + * This will delay or decrease memory usage. All calls to + * ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory, + * rather only instantiate if needed. It also defines + and * operation + * between GradCell types which can increase performance when using + * zero-filled or one-filled tensors, which is the case in reverse mode ad. + * + * \return the pass + */ TVM_DLL Pass LazyGradientInit(); /*! @@ -373,9 +371,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); * \return A type checked Function with its checked_type field populated. * \note this function mutates mod and is not thread-safe. */ -TVM_DLL Function InferType(const Function& f, - const IRModule& mod, - const GlobalVar& var); +TVM_DLL Function InferType(const Function& f, const IRModule& mod, const GlobalVar& var); /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. This @@ -389,8 +385,7 @@ TVM_DLL Function InferType(const Function& f, * an Expr consumed by multiple callers. * \return The rewritten expression. */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_attr_name, +TVM_DLL Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_attr_name, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); @@ -406,8 +401,7 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr, * * \return The rewritten expression. */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, +TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index e8f402a..105f74e 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -24,18 +24,18 @@ #ifndef TVM_RELAY_TYPE_H_ #define TVM_RELAY_TYPE_H_ -#include +#include +#include #include +#include #include -#include #include -#include #include + #include #include "base.h" - namespace tvm { namespace relay { diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index abfc792..741b280 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -45,11 +45,8 @@ extern "C" { * * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. */ -typedef int (*TVMBackendPackedCFunc)(TVMValue* args, - int* type_codes, - int num_args, - TVMValue* out_ret_value, - int* out_ret_tcode); +typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args, + TVMValue* out_ret_value, int* out_ret_tcode); /*! * \brief Backend function for modules to get function @@ -61,9 +58,7 @@ typedef int (*TVMBackendPackedCFunc)(TVMValue* args, * \param out The result function. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, - const char* func_name, - TVMFunctionHandle *out); +TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out); /*! * \brief Backend function to register system-wide library symbol. * @@ -87,11 +82,8 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); * certain backends such as OpenGL. * \return nullptr when error is thrown, a valid ptr if success */ -TVM_DLL void* TVMBackendAllocWorkspace(int device_type, - int device_id, - uint64_t nbytes, - int dtype_code_hint, - int dtype_bits_hint); +TVM_DLL void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, + int dtype_code_hint, int dtype_bits_hint); /*! * \brief Backend function to free temporal workspace. @@ -103,9 +95,7 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type, * * \sa TVMBackendAllocWorkspace */ -TVM_DLL int TVMBackendFreeWorkspace(int device_type, - int device_id, - void* ptr); +TVM_DLL int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr); /*! * \brief Environment for TVM parallel task. @@ -125,8 +115,7 @@ typedef struct { * \param penv The parallel environment backs the execution. * \param cdata The supporting closure data. */ -typedef int (*FTVMParallelLambda)( - int task_id, TVMParallelGroupEnv* penv, void* cdata); +typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata); /*! * \brief Backend function for running parallel jobs. @@ -138,9 +127,7 @@ typedef int (*FTVMParallelLambda)( * * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, - void* cdata, - int num_task); +TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task); /*! * \brief BSP barrrier between parallel threads @@ -150,7 +137,6 @@ TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, */ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); - /*! * \brief Simple static initialization function. * Run f once and set handle to be not null. @@ -162,10 +148,7 @@ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); * \param nbytes Number of bytes in the closure data. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendRunOnce(void** handle, - int (*f)(void*), - void *cdata, - int nbytes); +TVM_DLL int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes); #ifdef __cplusplus } // TVM_EXTERN_C diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 5d371ee..bb38ad8 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -63,15 +63,14 @@ // TVM version #define TVM_VERSION "0.7.dev1" - // TVM Runtime is DLPack compatible. #include #ifdef __cplusplus extern "C" { #endif -#include #include +#include /*! \brief type of array index. */ typedef int64_t tvm_index_t; @@ -180,7 +179,7 @@ TVM_DLL void TVMAPISetLastError(const char* msg); * this function is threadsafe and can be called by different thread * \return error info */ -TVM_DLL const char *TVMGetLastError(void); +TVM_DLL const char* TVMGetLastError(void); /*! * \brief Load module from file. * \param file_name The file name to load the module from. @@ -191,9 +190,7 @@ TVM_DLL const char *TVMGetLastError(void); * \note The resulting module do not contain import relation. * It can be reconstructed by TVMModImport. */ -TVM_DLL int TVMModLoadFromFile(const char* file_name, - const char* format, - TVMModuleHandle* out); +TVM_DLL int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out); /*! * \brief Add dep to mod's dependency. @@ -203,8 +200,7 @@ TVM_DLL int TVMModLoadFromFile(const char* file_name, * \param dep The dependent module to be imported. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMModImport(TVMModuleHandle mod, - TVMModuleHandle dep); +TVM_DLL int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep); /*! * \brief Get function from the module. @@ -214,10 +210,8 @@ TVM_DLL int TVMModImport(TVMModuleHandle mod, * \param out The result function, can be NULL if it is not available. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, - const char* func_name, - int query_imports, - TVMFunctionHandle *out); +TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, + TVMFunctionHandle* out); /*! * \brief Free the Module @@ -259,12 +253,8 @@ TVM_DLL int TVMFuncFree(TVMFunctionHandle func); * The front-end need to call free function (e.g. TVMFuncFree) * to free these handles. */ -TVM_DLL int TVMFuncCall(TVMFunctionHandle func, - TVMValue* arg_values, - int* type_codes, - int num_args, - TVMValue* ret_val, - int* ret_type_code); +TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args, + TVMValue* ret_val, int* ret_type_code); /*! * \brief Set the return value of TVMPackedCFunc. @@ -277,10 +267,7 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func, * \param type_code The type of the value to be returned. * \param num_ret Number of return values, for now only 1 is supported. */ -TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, - TVMValue* value, - int* type_code, - int num_ret); +TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret); /*! * \brief Inplace translate callback argument value to return value. @@ -305,12 +292,8 @@ TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code); * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. * \sa TVMCFuncSetReturn */ -typedef int (*TVMPackedCFunc)( - TVMValue* args, - int* type_codes, - int num_args, - TVMRetValueHandle ret, - void* resource_handle); +typedef int (*TVMPackedCFunc)(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, + void* resource_handle); /*! * \brief C callback to free the resource handle in C packed function. @@ -340,10 +323,8 @@ typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle); * \param out the result function handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, - void* resource_handle, - TVMPackedCFuncFinalizer fin, - TVMFunctionHandle *out); +TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, + TVMPackedCFuncFinalizer fin, TVMFunctionHandle* out); /*! * \brief Register the function to runtime's global table. @@ -354,8 +335,7 @@ TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, * \param f The function to be registered. * \param override Whether allow override already registered function. */ -TVM_DLL int TVMFuncRegisterGlobal( - const char* name, TVMFunctionHandle f, int override); +TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override); /*! * \brief Get a global function. @@ -374,8 +354,7 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); * \param out_array The array of function names. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMFuncListGlobalNames(int* out_size, - const char*** out_array); +TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array); // Array related apis for quick proptyping /*! @@ -392,14 +371,8 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size, * \param out The output handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, - int ndim, - int dtype_code, - int dtype_bits, - int dtype_lanes, - int device_type, - int device_id, - TVMArrayHandle* out); +TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, + int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out); /*! * \brief Free the TVM Array. @@ -415,9 +388,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle); * \param nbytes The number of bytes to copy. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, - void* data, - size_t nbytes); +TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes); /*! * \brief Copy array data to CPU byte array. @@ -426,9 +397,7 @@ TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, * \param nbytes The number of bytes to copy. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, - void* data, - size_t nbytes); +TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes); /*! * \brief Copy the array, both from and to must be valid during the copy. @@ -437,9 +406,7 @@ TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, * \param stream The stream where the copy happens, can be NULL. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, - TVMArrayHandle to, - TVMStreamHandle stream); +TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream); /*! * \brief Produce an array from the DLManagedTensor that shares data memory @@ -448,8 +415,7 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, * \param out The output array handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, - TVMArrayHandle* out); +TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out); /*! * \brief Produce a DLMangedTensor from the array that shares data memory with @@ -458,8 +424,7 @@ TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, * \param out The DLManagedTensor handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, - DLManagedTensor** out); +TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out); /*! * \brief Delete (free) a DLManagedTensor's data. @@ -519,9 +484,7 @@ TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle strea * \param dst The destination stream to synchronize. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMStreamStreamSynchronize(int device_type, - int device_id, - TVMStreamHandle src, +TVM_DLL int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, TVMStreamHandle dst); /*! @@ -561,11 +524,8 @@ TVM_DLL int TVMObjectFree(TVMObjectHandle obj); * \param out_data The allocated device pointer. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx, - size_t nbytes, - size_t alignment, - DLDataType type_hint, - void** out_data); +TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint, void** out_data); /*! * \brief Free a data space on device. @@ -589,14 +549,9 @@ TVM_DLL int TVMDeviceFreeDataSpace(TVMContext ctx, void* ptr); * \param stream Optional stream object. * \return 0 when success, -1 when failure happens. */ -TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t num_bytes, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, +TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t num_bytes, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream); #ifdef __cplusplus diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index cdb92ba..49c005e 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -39,8 +39,7 @@ // string_view: // https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations // https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros -#if defined(__cpp_lib_experimental_string_view) && \ - __cpp_lib_experimental_string_view >= 201411 +#if defined(__cpp_lib_experimental_string_view) && __cpp_lib_experimental_string_view >= 201411 #define TVM_USE_CXX14_STRING_VIEW_HASH 1 #else #define TVM_USE_CXX14_STRING_VIEW_HASH 0 @@ -135,8 +134,7 @@ class InplaceArrayBase { * \brief Destroy the Inplace Array Base object */ ~InplaceArrayBase() { - if (!(std::is_standard_layout::value && - std::is_trivial::value)) { + if (!(std::is_standard_layout::value && std::is_trivial::value)) { size_t size = Self()->GetSize(); for (size_t i = 0; i < size; ++i) { ElemType* fp = reinterpret_cast(AddressOf(i)); @@ -179,10 +177,10 @@ class InplaceArrayBase { * \return Raw pointer to the element. */ void* AddressOf(size_t idx) const { - static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && - sizeof(ArrayType) % alignof(ElemType) == 0, - "The size and alignment of ArrayType should respect " - "ElemType's alignment."); + static_assert( + alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); size_t kDataStart = sizeof(ArrayType); ArrayType* self = Self(); @@ -242,8 +240,7 @@ class ADT : public ObjectRef { * \param fields The fields of the ADT object. * \return The constructed ADT object reference. */ - ADT(int32_t tag, std::vector fields) - : ADT(tag, fields.begin(), fields.end()){}; + ADT(int32_t tag, std::vector fields) : ADT(tag, fields.begin(), fields.end()){}; /*! * \brief construct an ADT object reference. @@ -267,8 +264,7 @@ class ADT : public ObjectRef { * \param init The initializer list of fields. * \return The constructed ADT object reference. */ - ADT(int32_t tag, std::initializer_list init) - : ADT(tag, init.begin(), init.end()){}; + ADT(int32_t tag, std::initializer_list init) : ADT(tag, init.begin(), init.end()){}; /*! * \brief Access element at index. @@ -276,9 +272,7 @@ class ADT : public ObjectRef { * \param idx The array index * \return const ObjectRef */ - const ObjectRef& operator[](size_t idx) const { - return operator->()->operator[](idx); - } + const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); } /*! * \brief Return the ADT tag. @@ -390,9 +384,7 @@ class String : public ObjectRef { * * \return the comparison result */ - bool operator==(const std::string& other) const { - return this->compare(other) == 0; - } + bool operator==(const std::string& other) const { return this->compare(other) == 0; } /*! * \brief Compare is not equal to other std::string @@ -512,11 +504,9 @@ class String : public ObjectRef { // This function falls back to string copy with c++11 compiler and is // recommended to be compiled with c++14 #if TVM_USE_CXX17_STRING_VIEW_HASH - return std::hash()( - std::string_view(data, size)); + return std::hash()(std::string_view(data, size)); #elif TVM_USE_CXX14_STRING_VIEW_HASH - return std::hash()( - std::experimental::string_view(data, size)); + return std::hash()(std::experimental::string_view(data, size)); #else return std::hash()(std::string(data, size)); #endif @@ -538,8 +528,7 @@ class String : public ObjectRef { * \return int zero if both char sequences compare equal. negative if this * appear before other, positive otherwise. */ - static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, - size_t rhs_count); + static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); }; /*! \brief An object representing string moved from std::string. */ @@ -575,8 +564,7 @@ inline String String::operator=(std::string other) { return Downcast(*this); } -inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, - size_t rhs_count) { +inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { if (lhs == rhs && lhs_count == rhs_count) return 0; for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { @@ -592,7 +580,7 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, } } -template<> +template <> struct PackedFuncValueConverter<::tvm::runtime::String> { static String From(const TVMArgValue& val) { if (val.IsObjectRef()) { @@ -612,8 +600,7 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { }; /*! \brief Helper to represent nullptr for optional. */ -struct NullOptType { -}; +struct NullOptType {}; /*! * \brief Optional container that to represent to a Nullable variant of T. @@ -628,12 +615,11 @@ struct NullOptType { * * \endcode */ -template +template class Optional : public ObjectRef { public: using ContainerType = typename T::ContainerType; - static_assert(std::is_base_of::value, - "Optional is only defined for ObjectRef."); + static_assert(std::is_base_of::value, "Optional is only defined for ObjectRef."); // default constructors. Optional() = default; Optional(const Optional&) = default; @@ -656,9 +642,8 @@ class Optional : public ObjectRef { return *this; } // normal value handling. - Optional(T other) // NOLINT(*) - : ObjectRef(std::move(other)) { - } + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} Optional& operator=(T other) { ObjectRef::operator=(std::move(other)); return *this; @@ -680,20 +665,12 @@ class Optional : public ObjectRef { * \return The contained value if the Optional is not null * otherwise return the default_value. */ - T value_or(T default_value) const { - return data_ != nullptr ? T(data_) : default_value; - } + T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } /*! \return Whether the container is not nullptr.*/ - explicit operator bool() const { - return *this != nullptr; - } + explicit operator bool() const { return *this != nullptr; } // operator overloadings - bool operator==(std::nullptr_t) const { - return data_ == nullptr; - } - bool operator!=(std::nullptr_t) const { - return data_ != nullptr; - } + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } auto operator==(const Optional& other) const { // support case where sub-class returns a symbolic ref type. using RetType = decltype(value() == other.value()); @@ -722,16 +699,14 @@ class Optional : public ObjectRef { if (*this != nullptr) return value() == other; return RetType(false); } - auto operator!=(const T& other) const { - return !(*this == other); - } - template + auto operator!=(const T& other) const { return !(*this == other); } + template auto operator==(const U& other) const { using RetType = decltype(value() == other); if (*this == nullptr) return RetType(false); return value() == other; } - template + template auto operator!=(const U& other) const { using RetType = decltype(value() != other); if (*this == nullptr) return RetType(true); @@ -740,7 +715,7 @@ class Optional : public ObjectRef { static constexpr bool _type_is_nullable = true; }; -template +template struct PackedFuncValueConverter> { static Optional From(const TVMArgValue& val) { if (val.type_code() == kTVMNullptr) return Optional(nullptr); @@ -755,8 +730,8 @@ struct PackedFuncValueConverter> { } // namespace runtime // expose the functions to the root namespace. -using runtime::String; using runtime::Optional; +using runtime::String; constexpr runtime::NullOptType NullOpt{}; } // namespace tvm diff --git a/include/tvm/runtime/crt/memory.h b/include/tvm/runtime/crt/memory.h index 3e47060..7b88b31 100644 --- a/include/tvm/runtime/crt/memory.h +++ b/include/tvm/runtime/crt/memory.h @@ -32,7 +32,7 @@ static int vleak_size = 0; * \param size The size of memory * \return The virtual address */ -void * vmalloc(size_t size); +void* vmalloc(size_t size); /*! * \brief Reallocate memory from manager @@ -40,13 +40,13 @@ void * vmalloc(size_t size); * \param size The size of memory * \return The virtual address */ -void * vrealloc(void * ptr, size_t size); +void* vrealloc(void* ptr, size_t size); /*! * \brief Free the memory. * \param ptr The pointer to the memory to deallocate * \return The virtual address */ -void vfree(void * ptr); +void vfree(void* ptr); #endif // TVM_RUNTIME_CRT_MEMORY_H_ diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 940818a..a10b83f 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -24,10 +24,11 @@ #ifndef TVM_RUNTIME_DATA_TYPE_H_ #define TVM_RUNTIME_DATA_TYPE_H_ -#include #include -#include +#include + #include +#include namespace tvm { namespace runtime { @@ -52,8 +53,7 @@ class DataType { * \brief Constructor * \param dtype The DLDataType */ - explicit DataType(DLDataType dtype) - : data_(dtype) {} + explicit DataType(DLDataType dtype) : data_(dtype) {} /*! * \brief Constructor * \param code The type code. @@ -66,110 +66,70 @@ class DataType { data_.lanes = static_cast(lanes); } /*! \return The type code. */ - int code() const { - return static_cast(data_.code); - } + int code() const { return static_cast(data_.code); } /*! \return number of bits in the data. */ - int bits() const { - return static_cast(data_.bits); - } + int bits() const { return static_cast(data_.bits); } /*! \return number of bytes to store each scalar. */ - int bytes() const { - return (bits() + 7) / 8; - } + int bytes() const { return (bits() + 7) / 8; } /*! \return number of lanes in the data. */ - int lanes() const { - return static_cast(data_.lanes); - } + int lanes() const { return static_cast(data_.lanes); } /*! \return whether type is a scalar type. */ - bool is_scalar() const { - return lanes() == 1; - } + bool is_scalar() const { return lanes() == 1; } /*! \return whether type is a scalar type. */ - bool is_bool() const { - return code() == DataType::kUInt && bits() == 1; - } + bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } /*! \return whether type is a float type. */ - bool is_float() const { - return code() == DataType::kFloat; - } + bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a float16 type. */ - bool is_float16() const { - return is_float() && bits() == 16; - } + bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is an int type. */ - bool is_int() const { - return code() == DataType::kInt; - } + bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ - bool is_uint() const { - return code() == DataType::kUInt; - } + bool is_uint() const { return code() == DataType::kUInt; } /*! \return whether type is a handle type. */ - bool is_handle() const { - return code() == DataType::kHandle && !is_void(); - } + bool is_handle() const { return code() == DataType::kHandle && !is_void(); } /*! \return whether type is a vector type. */ - bool is_vector() const { - return lanes() > 1; - } + bool is_vector() const { return lanes() > 1; } /*! \return whether type is a bool vector type. */ - bool is_vector_bool() const { - return is_vector() && bits() == 1; - } + bool is_vector_bool() const { return is_vector() && bits() == 1; } /*! \return whether type is a Void type. */ - bool is_void() const { - return code() == DataType::kHandle && bits() == 0 && lanes() == 0; - } + bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. * \return the result type. */ - DataType with_lanes(int lanes) const { - return DataType(data_.code, data_.bits, lanes); - } + DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); } /*! * \brief Create a new data type by change bits to a specified value. * \param bits The target number of bits. * \return the result type. */ - DataType with_bits(int bits) const { - return DataType(data_.code, bits, data_.lanes); - } + DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); } /*! * \brief Get the scalar version of the type. * \return the result type. */ - DataType element_of() const { - return with_lanes(1); - } + DataType element_of() const { return with_lanes(1); } /*! * \brief Equal comparator. * \param other The data type to compre against. * \return The comparison resilt. */ bool operator==(const DataType& other) const { - return - data_.code == other.data_.code && - data_.bits == other.data_.bits && - data_.lanes == other.data_.lanes; + return data_.code == other.data_.code && data_.bits == other.data_.bits && + data_.lanes == other.data_.lanes; } /*! * \brief NotEqual comparator. * \param other The data type to compre against. * \return The comparison resilt. */ - bool operator!=(const DataType& other) const { - return !operator==(other); - } + bool operator!=(const DataType& other) const { return !operator==(other); } /*! * \brief Converter to DLDataType * \return the result. */ - operator DLDataType () const { - return data_; - } + operator DLDataType() const { return data_; } /*! * \brief Construct an int type. @@ -177,51 +137,39 @@ class DataType { * \param lanes The number of lanes. * \return The constructed data type. */ - static DataType Int(int bits, int lanes = 1) { - return DataType(kDLInt, bits, lanes); - } + static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); } /*! * \brief Construct an uint type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType UInt(int bits, int lanes = 1) { - return DataType(kDLUInt, bits, lanes); - } + static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } /*! * \brief Construct an uint type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Float(int bits, int lanes = 1) { - return DataType(kDLFloat, bits, lanes); - } + static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Bool(int lanes = 1) { - return DataType::UInt(1, lanes); - } + static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); } /*! * \brief Construct a handle type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Handle(int bits = 64, int lanes = 1) { - return DataType(kHandle, bits, lanes); - } + static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); } /*! * \brief Construct a Void type. * \return The constructed data type. */ - static DataType Void() { - return DataType(kHandle, 0, 0); - } + static DataType Void() { return DataType(kHandle, 0, 0); } /*! * \brief Get the corresponding type of TVMShapeIndex. * \return The type of TVM shape index. @@ -246,14 +194,11 @@ class DataType { inline int GetVectorBytes(DataType dtype) { int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist - if (dtype == DataType::Bool() || - dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || + if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { return 1; } - CHECK_EQ(data_bits % 8, 0U) - << "Need to load/store by multiple of bytes"; + CHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; return data_bits / 8; } @@ -322,29 +267,46 @@ inline std::string DLDataType2String(DLDataType t); // implementation details inline const char* TypeCode2Str(int type_code) { switch (type_code) { - case kDLInt: return "int"; - case kDLUInt: return "uint"; - case kDLFloat: return "float"; - case kTVMStr: return "str"; - case kTVMBytes: return "bytes"; - case kTVMOpaqueHandle: return "handle"; - case kTVMNullptr: return "NULL"; - case kTVMDLTensorHandle: return "ArrayHandle"; - case kTVMDataType: return "DLDataType"; - case kTVMContext: return "TVMContext"; - case kTVMPackedFuncHandle: return "FunctionHandle"; - case kTVMModuleHandle: return "ModuleHandle"; - case kTVMNDArrayHandle: return "NDArrayContainer"; - case kTVMObjectHandle: return "Object"; - case kTVMObjectRValueRefArg: return "ObjectRValueRefArg"; - default: LOG(FATAL) << "unknown type_code=" - << static_cast(type_code); return ""; + case kDLInt: + return "int"; + case kDLUInt: + return "uint"; + case kDLFloat: + return "float"; + case kTVMStr: + return "str"; + case kTVMBytes: + return "bytes"; + case kTVMOpaqueHandle: + return "handle"; + case kTVMNullptr: + return "NULL"; + case kTVMDLTensorHandle: + return "ArrayHandle"; + case kTVMDataType: + return "DLDataType"; + case kTVMContext: + return "TVMContext"; + case kTVMPackedFuncHandle: + return "FunctionHandle"; + case kTVMModuleHandle: + return "ModuleHandle"; + case kTVMNDArrayHandle: + return "NDArrayContainer"; + case kTVMObjectHandle: + return "Object"; + case kTVMObjectRValueRefArg: + return "ObjectRValueRefArg"; + default: + LOG(FATAL) << "unknown type_code=" << static_cast(type_code); + return ""; } } inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { - os << "bool"; return os; + os << "bool"; + return os; } if (DataType(t).is_void()) { return os << "void"; @@ -362,7 +324,7 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) return os; } -inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) +inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) return os << dtype.operator DLDataType(); } @@ -380,14 +342,18 @@ inline DLDataType String2DLDataType(std::string s) { t = DataType::Void(); return t; } - t.bits = 32; t.lanes = 1; + t.bits = 32; + t.lanes = 1; const char* scan; if (s.substr(0, 3) == "int") { - t.code = kDLInt; scan = s.c_str() + 3; + t.code = kDLInt; + scan = s.c_str() + 3; } else if (s.substr(0, 4) == "uint") { - t.code = kDLUInt; scan = s.c_str() + 4; + t.code = kDLUInt; + scan = s.c_str() + 4; } else if (s.substr(0, 5) == "float") { - t.code = kDLFloat; scan = s.c_str() + 5; + t.code = kDLFloat; + scan = s.c_str() + 5; } else if (s.substr(0, 6) == "handle") { t.code = kTVMOpaqueHandle; t.bits = 64; // handle uses 64 bit by default. diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 4ccaa3c..7fb2f9d 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -85,9 +86,7 @@ class TVM_DLL DeviceAPI { * as OpenGL, as nbytes & alignment are sufficient for most backends. * \return The allocated device pointer. */ - virtual void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + virtual void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) = 0; /*! * \brief Free a data space on device. @@ -108,16 +107,10 @@ class TVM_DLL DeviceAPI { * can be useful for cross device endian converison. * \param stream Optional stream object. */ - virtual void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t num_bytes, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, - TVMStreamHandle stream) = 0; - /*! + virtual void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) = 0; + /*! * \brief Create a new stream of execution. * * \param ctx The context of allocation. @@ -156,10 +149,9 @@ class TVM_DLL DeviceAPI { * \param event_src The source stream to synchronize. * \param event_dst The destination stream to synchronize. */ - virtual void SyncStreamFromTo(TVMContext ctx, - TVMStreamHandle event_src, + virtual void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst); - /*! + /*! * \brief Allocate temporal workspace for backend execution. * * \note We have the following assumption about backend temporal @@ -175,9 +167,7 @@ class TVM_DLL DeviceAPI { * \param type_hint The type of elements. Only needed by certain backends such * as OpenGL, as nbytes is sufficient for most backends. */ - virtual void* AllocWorkspace(TVMContext ctx, - size_t nbytes, - DLDataType type_hint = {}); + virtual void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}); /*! * \brief Free temporal workspace in backend execution. * @@ -214,22 +204,39 @@ constexpr int kRPCSessMask = 128; */ inline const char* DeviceName(int type) { switch (type) { - case kDLCPU: return "cpu"; - case kDLGPU: return "gpu"; - case kDLCPUPinned: return "cpu_pinned"; - case kDLOpenCL: return "opencl"; - case kDLSDAccel: return "sdaccel"; - case kDLAOCL: return "aocl"; - case kDLVulkan: return "vulkan"; - case kDLMetal: return "metal"; - case kDLVPI: return "vpi"; - case kDLROCM: return "rocm"; - case kOpenGL: return "opengl"; - case kDLExtDev: return "ext_dev"; - case kDLWebGPU: return "webgpu"; - case kDLMicroDev: return "micro_dev"; - case kDLHexagon: return "hexagon"; - default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; + case kDLCPU: + return "cpu"; + case kDLGPU: + return "gpu"; + case kDLCPUPinned: + return "cpu_pinned"; + case kDLOpenCL: + return "opencl"; + case kDLSDAccel: + return "sdaccel"; + case kDLAOCL: + return "aocl"; + case kDLVulkan: + return "vulkan"; + case kDLMetal: + return "metal"; + case kDLVPI: + return "vpi"; + case kDLROCM: + return "rocm"; + case kOpenGL: + return "opengl"; + case kDLExtDev: + return "ext_dev"; + case kDLWebGPU: + return "webgpu"; + case kDLMicroDev: + return "micro_dev"; + case kDLHexagon: + return "hexagon"; + default: + LOG(FATAL) << "unknown type =" << type; + return "Unknown"; } } diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index b9b420a..1199c42 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -24,9 +24,10 @@ #define TVM_RUNTIME_MEMORY_H_ #include + #include -#include #include +#include namespace tvm { namespace runtime { @@ -36,7 +37,7 @@ namespace runtime { * \tparam T the node type. * \return The ObjectPtr to the allocated object. */ -template +template inline ObjectPtr make_object(Args&&... args); // Detail implementations after this @@ -55,7 +56,7 @@ inline ObjectPtr make_object(Args&&... args); * * \tparam Derived The derived class. */ -template +template class ObjAllocatorBase { public: /*! @@ -64,13 +65,11 @@ class ObjAllocatorBase { * \tparam Args The constructor signature. * \param args The arguments. */ - template + template inline ObjectPtr make_object(Args&&... args) { using Handler = typename Derived::template Handler; - static_assert(std::is_base_of::value, - "make can only be used to create Object"); - T* ptr = Handler::New(static_cast(this), - std::forward(args)...); + static_assert(std::is_base_of::value, "make can only be used to create Object"); + T* ptr = Handler::New(static_cast(this), std::forward(args)...); ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); @@ -83,14 +82,13 @@ class ObjAllocatorBase { * \param num_elems The number of array elements. * \param args The arguments. */ - template + template inline ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { using Handler = typename Derived::template ArrayHandler; static_assert(std::is_base_of::value, "make_inplace_array can only be used to create Object"); - ArrayType* ptr = Handler::New(static_cast(this), - num_elems, - std::forward(args)...); + ArrayType* ptr = + Handler::New(static_cast(this), num_elems, std::forward(args)...); ptr->type_index_ = ArrayType::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); @@ -98,15 +96,14 @@ class ObjAllocatorBase { }; // Simple allocator that uses new/delete. -class SimpleObjAllocator : - public ObjAllocatorBase { +class SimpleObjAllocator : public ObjAllocatorBase { public: - template + template class Handler { public: using StorageType = typename std::aligned_storage::type; - template + template static T* New(SimpleObjAllocator*, Args&&... args) { // NOTE: the first argument is not needed for SimpleObjAllocator // It is reserved for special allocators that needs to recycle @@ -126,9 +123,7 @@ class SimpleObjAllocator : return reinterpret_cast(data); } - static Object::FDeleter Deleter() { - return Deleter_; - } + static Object::FDeleter Deleter() { return Deleter_; } private: static void Deleter_(Object* objptr) { @@ -146,16 +141,16 @@ class SimpleObjAllocator : }; // Array handler that uses new/delete. - template + template class ArrayHandler { public: using StorageType = typename std::aligned_storage::type; // for now only support elements that aligns with array header. static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && - sizeof(ArrayType) % alignof(ElemType) == 0, + sizeof(ArrayType) % alignof(ElemType) == 0, "element alignment constraint"); - template + template static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { // NOTE: the first argument is not needed for ArrayObjAllocator // It is reserved for special allocators that needs to recycle @@ -177,9 +172,7 @@ class SimpleObjAllocator : return reinterpret_cast(data); } - static Object::FDeleter Deleter() { - return Deleter_; - } + static Object::FDeleter Deleter() { return Deleter_; } private: static void Deleter_(Object* objptr) { @@ -193,20 +186,20 @@ class SimpleObjAllocator : // call a virtual destructor(which may not be available and is not required). tptr->ArrayType::~ArrayType(); StorageType* p = reinterpret_cast(tptr); - delete []p; + delete[] p; } }; }; -template +template inline ObjectPtr make_object(Args&&... args) { return SimpleObjAllocator().make_object(std::forward(args)...); } -template +template inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&... args) { - return SimpleObjAllocator().make_inplace_array( - num_elems, std::forward(args)...); + return SimpleObjAllocator().make_inplace_array(num_elems, + std::forward(args)...); } } // namespace runtime diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 3c43ae0..0e7cd2b 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -27,15 +27,14 @@ #define TVM_RUNTIME_MODULE_H_ #include - #include -#include #include +#include #include -#include #include #include +#include namespace tvm { namespace runtime { @@ -50,8 +49,7 @@ class Module : public ObjectRef { public: Module() {} // constructor from container. - explicit Module(ObjectPtr n) - : ObjectRef(n) {} + explicit Module(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Get packed function from current module by name. * @@ -82,8 +80,7 @@ class Module : public ObjectRef { * \note This function won't load the import relationship. * Re-create import relationship by calling Import. */ - TVM_DLL static Module LoadFromFile(const std::string& file_name, - const std::string& format = ""); + TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = ""); // refer to the corresponding container. using ContainerType = ModuleNode; friend class ModuleNode; @@ -137,16 +134,14 @@ class TVM_DLL ModuleNode : public Object { * If the function need resource from the module(e.g. late linking), * it should capture sptr_to_self. */ - virtual PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) = 0; + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) = 0; /*! * \brief Save the module to file. * \param file_name The file to be saved to. * \param format The format of the file. */ - virtual void SaveToFile(const std::string& file_name, - const std::string& format); + virtual void SaveToFile(const std::string& file_name, const std::string& format); /*! * \brief Save the module to binary stream. * \param stream The binary stream to save to. @@ -188,9 +183,7 @@ class TVM_DLL ModuleNode : public Object { */ const PackedFunc* GetFuncFromEnv(const std::string& name); /*! \return The module it imports from */ - const std::vector& imports() const { - return imports_; - } + const std::vector& imports() const { return imports_; } // integration with the existing components. static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule; @@ -207,8 +200,7 @@ class TVM_DLL ModuleNode : public Object { private: /*! \brief Cache used by GetImport */ - std::unordered_map > import_cache_; + std::unordered_map > import_cache_; }; /*! @@ -238,13 +230,9 @@ constexpr const char* tvm_module_main = "__tvm_main__"; // implementations of inline functions. -inline void Module::Import(Module other) { - return (*this)->Import(other); -} +inline void Module::Import(Module other) { return (*this)->Import(other); } -inline ModuleNode* Module::operator->() { - return static_cast(get_mutable()); -} +inline ModuleNode* Module::operator->() { return static_cast(get_mutable()); } inline const ModuleNode* Module::operator->() const { return static_cast(get()); @@ -254,4 +242,4 @@ inline const ModuleNode* Module::operator->() const { } // namespace tvm #include // NOLINT(*) -#endif // TVM_RUNTIME_MODULE_H_ +#endif // TVM_RUNTIME_MODULE_H_ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 33f27f4..8db93b4 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -29,8 +29,8 @@ #include #include -#include #include +#include namespace tvm { namespace runtime { @@ -53,8 +53,7 @@ class NDArray : public ObjectRef { * \brief constructor. * \param data ObjectPtr to the data container. */ - explicit NDArray(ObjectPtr data) - : ObjectRef(data) {} + explicit NDArray(ObjectPtr data) : ObjectRef(data) {} /*! \brief reset the content of NDArray to be nullptr */ inline void reset(); @@ -76,13 +75,13 @@ class NDArray : public ObjectRef { inline void CopyFrom(const DLTensor* other); inline void CopyFrom(const NDArray& other); /*! - * \brief Copy data content from a byte buffer. - * \param data The source bytes to be copied from. - * \param nbytes The size of the buffer in bytes - * Must be equal to the size of the NDArray. - * \note The copy may happen asynchronously if it involves a GPU context. - * TVMSynchronize is necessary. - */ + * \brief Copy data content from a byte buffer. + * \param data The source bytes to be copied from. + * \param nbytes The size of the buffer in bytes + * Must be equal to the size of the NDArray. + * \note The copy may happen asynchronously if it involves a GPU context. + * TVMSynchronize is necessary. + */ TVM_DLL void CopyFromBytes(const void* data, size_t nbytes); /*! * \brief Copy data content into another array. @@ -124,8 +123,7 @@ class NDArray : public ObjectRef { * \param dtype The data type of the new array. * \note The memory size of new array must be smaller than the current one. */ - TVM_DLL NDArray CreateView( - std::vector shape, DLDataType dtype); + TVM_DLL NDArray CreateView(std::vector shape, DLDataType dtype); /*! * \brief Create a reference view of NDArray that * represents as DLManagedTensor. @@ -139,9 +137,7 @@ class NDArray : public ObjectRef { * \param ctx The context of the Array. * \return The created Array */ - TVM_DLL static NDArray Empty(std::vector shape, - DLDataType dtype, - DLContext ctx); + TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); /*! * \brief Create a NDArray backed by a dlpack tensor. * @@ -160,8 +156,8 @@ class NDArray : public ObjectRef { * \param to The target array. * \param stream The stream used in copy. */ - TVM_DLL static void CopyFromTo( - const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); + TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to, + TVMStreamHandle stream = nullptr); TVM_DLL std::vector Shape() const; // internal namespace @@ -244,9 +240,7 @@ class NDArray::ContainerBase { * \brief Object container class that backs NDArray. * \note do not use this function directly, use NDArray. */ -class NDArray::Container : - public Object, - public NDArray::ContainerBase { +class NDArray::Container : public Object, public NDArray::ContainerBase { public: /*! \brief default constructor */ Container() { @@ -259,10 +253,7 @@ class NDArray::Container : dl_tensor.byte_offset = 0; } - Container(void* data, - std::vector shape, - DLDataType dtype, - DLContext ctx) { + Container(void* data, std::vector shape, DLDataType dtype, DLContext ctx) { // Initialize the type index. type_index_ = Container::RuntimeTypeIndex(); dl_tensor.data = data; @@ -278,9 +269,7 @@ class NDArray::Container : * \brief Set the deleter field. * \param deleter The deleter. */ - void SetDeleter(FDeleter deleter) { - deleter_ = deleter; - } + void SetDeleter(FDeleter deleter) { deleter_ = deleter; } // Expose DecRef and IncRef as public function // NOTE: they are only for developer purposes only. @@ -360,53 +349,44 @@ inline void NDArray::CopyTo(const NDArray& other) const { inline NDArray NDArray::CopyTo(const DLContext& ctx) const { CHECK(data_ != nullptr); const DLTensor* dptr = operator->(); - NDArray ret = Empty(std::vector(dptr->shape, dptr->shape + dptr->ndim), - dptr->dtype, ctx); + NDArray ret = + Empty(std::vector(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, ctx); this->CopyTo(ret); return ret; } -inline int NDArray::use_count() const { - return data_.use_count(); -} +inline int NDArray::use_count() const { return data_.use_count(); } -inline const DLTensor* NDArray::operator->() const { - return &(get_mutable()->dl_tensor); -} +inline const DLTensor* NDArray::operator->() const { return &(get_mutable()->dl_tensor); } inline NDArray::Container* NDArray::get_mutable() const { return static_cast(data_.get()); } inline ObjectPtr NDArray::FFIDataFromHandle(TVMArrayHandle handle) { - return GetObjectPtr(static_cast( - reinterpret_cast(handle))); + return GetObjectPtr( + static_cast(reinterpret_cast(handle))); } inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { // NOTE: it is necessary to cast to container then to base // so that the FFI handle uses the ContainerBase address. - return reinterpret_cast( - static_cast( - static_cast( - const_cast(nd.get())))); + return reinterpret_cast(static_cast( + static_cast(const_cast(nd.get())))); } inline void NDArray::FFIDecRef(TVMArrayHandle handle) { - static_cast( - reinterpret_cast(handle))->DecRef(); + static_cast(reinterpret_cast(handle))->DecRef(); } inline Object* TVMArrayHandleToObjectHandle(TVMArrayHandle handle) { - return static_cast( - reinterpret_cast(handle)); + return static_cast(reinterpret_cast(handle)); } /*! \brief Magic number for NDArray file */ constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; -inline bool SaveDLTensor(dmlc::Stream* strm, - const DLTensor* tensor) { +inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { uint64_t header = kTVMNDArrayMagic, reserved = 0; strm->Write(header); strm->Write(reserved); @@ -435,16 +415,15 @@ inline bool SaveDLTensor(dmlc::Stream* strm, int64_t data_byte_size = type_bytes * num_elems; strm->Write(data_byte_size); - if (DMLC_IO_NO_ENDIAN_SWAP && - tensor->ctx.device_type == kDLCPU && - tensor->strides == nullptr && + if (DMLC_IO_NO_ENDIAN_SWAP && tensor->ctx.device_type == kDLCPU && tensor->strides == nullptr && tensor->byte_offset == 0) { // quick path strm->Write(tensor->data, data_byte_size); } else { std::vector bytes(data_byte_size); - CHECK_EQ(TVMArrayCopyToBytes( - const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size), 0) + CHECK_EQ( + TVMArrayCopyToBytes(const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size), + 0) << TVMGetLastError(); if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); @@ -454,33 +433,23 @@ inline bool SaveDLTensor(dmlc::Stream* strm, return true; } -inline void NDArray::Save(dmlc::Stream* strm) const { - SaveDLTensor(strm, operator->()); -} +inline void NDArray::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->()); } inline bool NDArray::Load(dmlc::Stream* strm) { uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&reserved)) - << "Invalid DLTensor file format"; - CHECK(header == kTVMNDArrayMagic) - << "Invalid DLTensor file format"; + CHECK(strm->Read(&header)) << "Invalid DLTensor file format"; + CHECK(strm->Read(&reserved)) << "Invalid DLTensor file format"; + CHECK(header == kTVMNDArrayMagic) << "Invalid DLTensor file format"; DLContext ctx; int ndim; DLDataType dtype; - CHECK(strm->Read(&ctx)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&ndim)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&dtype)) - << "Invalid DLTensor file format"; - CHECK_EQ(ctx.device_type, kDLCPU) - << "Invalid DLTensor context: can only save as CPU tensor"; + CHECK(strm->Read(&ctx)) << "Invalid DLTensor file format"; + CHECK(strm->Read(&ndim)) << "Invalid DLTensor file format"; + CHECK(strm->Read(&dtype)) << "Invalid DLTensor file format"; + CHECK_EQ(ctx.device_type, kDLCPU) << "Invalid DLTensor context: can only save as CPU tensor"; std::vector shape(ndim); if (ndim != 0) { - CHECK(strm->ReadArray(&shape[0], ndim)) - << "Invalid DLTensor file format"; + CHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format"; } NDArray ret = NDArray::Empty(shape, dtype, ctx); int64_t num_elems = 1; @@ -489,12 +458,9 @@ inline bool NDArray::Load(dmlc::Stream* strm) { num_elems *= ret->shape[i]; } int64_t data_byte_size; - CHECK(strm->Read(&data_byte_size)) - << "Invalid DLTensor file format"; - CHECK(data_byte_size == num_elems * elem_bytes) - << "Invalid DLTensor file format"; - CHECK(strm->Read(ret->data, data_byte_size)) - << "Invalid DLTensor file format"; + CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format"; + CHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format"; + CHECK(strm->Read(ret->data, data_byte_size)) << "Invalid DLTensor file format"; if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(ret->data, elem_bytes, num_elems); } diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 764dcdf..7d912c5 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -25,8 +25,9 @@ #include #include -#include + #include +#include #include /*! @@ -100,8 +101,8 @@ struct TypeIndex { * Recommendation: set to estimate number of children needed. * - _type_child_slots_can_overflow: * Whether we can add additional child classes even if the number of child classes - * exceeds the _type_child_slots. A fallback mechanism to check global type table will be used. - * Recommendation: set to false for optimal runtime speed if we know exact number of children. + * exceeds the _type_child_slots. A fallback mechanism to check global type table will be + * used. Recommendation: set to false for optimal runtime speed if we know exact number of children. * * Two macros are used to declare helper functions in the object: * - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. @@ -163,28 +164,22 @@ class Object { */ typedef void (*FDeleter)(Object* self); /*! \return The internal runtime type index of the object. */ - uint32_t type_index() const { - return type_index_; - } + uint32_t type_index() const { return type_index_; } /*! * \return the type key of the object. * \note this operation is expensive, can be used for error reporting. */ - std::string GetTypeKey() const { - return TypeIndex2Key(type_index_); - } + std::string GetTypeKey() const { return TypeIndex2Key(type_index_); } /*! * \return A hash value of the return of GetTypeKey. */ - size_t GetTypeKeyHash() const { - return TypeIndex2KeyHash(type_index_); - } + size_t GetTypeKeyHash() const { return TypeIndex2KeyHash(type_index_); } /*! * Check if the object is an instance of TargetType. * \tparam TargetType The target type to be checked. * \return Whether the target type is true. */ - template + template inline bool IsInstance() const; /*! @@ -214,12 +209,8 @@ class Object { static constexpr const char* _type_key = "runtime.Object"; - static uint32_t _GetOrAllocRuntimeTypeIndex() { - return TypeIndex::kRoot; - } - static uint32_t RuntimeTypeIndex() { - return TypeIndex::kRoot; - } + static uint32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kRoot; } + static uint32_t RuntimeTypeIndex() { return TypeIndex::kRoot; } // Default object type properties for sub-classes static constexpr bool _type_final = false; @@ -234,7 +225,6 @@ class Object { // The type index of Object is TypeIndex::kRoot static constexpr uint32_t _type_index = TypeIndex::kDynamic; - // Default constructor and copy constructor Object() {} // Override the copy and assign constructors to do nothing. @@ -246,10 +236,10 @@ class Object { } Object(Object&& other) { // NOLINT(*) } - Object& operator=(const Object& other) { //NOLINT(*) + Object& operator=(const Object& other) { // NOLINT(*) return *this; } - Object& operator=(Object&& other) { //NOLINT(*) + Object& operator=(Object&& other) { // NOLINT(*) return *this; } @@ -267,7 +257,7 @@ class Object { FDeleter deleter_ = nullptr; // Invariant checks. static_assert(sizeof(int32_t) == sizeof(RefCounterType) && - alignof(int32_t) == sizeof(RefCounterType), + alignof(int32_t) == sizeof(RefCounterType), "RefCounter ABI check."); /*! @@ -287,12 +277,10 @@ class Object { * \param type_child_slots_can_overflow Whether to allow child to overflow the slots. * \return The allocated type index. */ - TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex( - const std::string& key, - uint32_t static_tindex, - uint32_t parent_tindex, - uint32_t type_child_slots, - bool type_child_slots_can_overflow); + TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, + uint32_t parent_tindex, + uint32_t type_child_slots, + bool type_child_slots_can_overflow); // reference counter related operations /*! \brief developer function, increases reference counter. */ @@ -316,9 +304,9 @@ class Object { */ TVM_DLL bool DerivedFrom(uint32_t parent_tindex) const; // friend classes - template + template friend class ObjAllocatorBase; - template + template friend class ObjectPtr; friend class TVMRetValue; friend class ObjectInternal; @@ -398,9 +386,7 @@ class ObjectPtr { other.data_ = nullptr; } /*! \brief destructor */ - ~ObjectPtr() { - this->reset(); - } + ~ObjectPtr() { this->reset(); } /*! * \brief Swap this array with another Object * \param other The other Object @@ -411,15 +397,11 @@ class ObjectPtr { /*! * \return Get the content of the pointer */ - T* get() const { - return static_cast(data_); - } + T* get() const { return static_cast(data_); } /*! * \return The pointer */ - T* operator->() const { - return get(); - } + T* operator->() const { return get(); } /*! * \return The reference */ @@ -455,29 +437,17 @@ class ObjectPtr { } } /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_ != nullptr ? data_->use_count() : 0; - } + int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } /*! \return whether the reference is unique */ - bool unique() const { - return data_ != nullptr && data_->use_count() == 1; - } + bool unique() const { return data_ != nullptr && data_->use_count() == 1; } /*! \return Whether two ObjectPtr do not equal each other */ - bool operator==(const ObjectPtr& other) const { - return data_ == other.data_; - } + bool operator==(const ObjectPtr& other) const { return data_ == other.data_; } /*! \return Whether two ObjectPtr equals each other */ - bool operator!=(const ObjectPtr& other) const { - return data_ != other.data_; - } + bool operator!=(const ObjectPtr& other) const { return data_ != other.data_; } /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { - return data_ == nullptr; - } + bool operator==(std::nullptr_t null) const { return data_ == nullptr; } /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return data_ != nullptr; - } + bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } private: /*! \brief internal pointer field */ @@ -506,9 +476,9 @@ class ObjectPtr { friend class Object; friend class ObjectRef; friend struct ObjectHash; - template + template friend class ObjectPtr; - template + template friend class ObjAllocatorBase; friend class TVMPODValue_; friend class TVMArgsSetter; @@ -533,55 +503,37 @@ class ObjectRef { * \param other Another object ref. * \return the compare result. */ - bool same_as(const ObjectRef& other) const { - return data_ == other.data_; - } + bool same_as(const ObjectRef& other) const { return data_ == other.data_; } /*! * \brief Comparator * \param other Another object ref. * \return the compare result. */ - bool operator==(const ObjectRef& other) const { - return data_ == other.data_; - } + bool operator==(const ObjectRef& other) const { return data_ == other.data_; } /*! * \brief Comparator * \param other Another object ref. * \return the compare result. */ - bool operator!=(const ObjectRef& other) const { - return data_ != other.data_; - } + bool operator!=(const ObjectRef& other) const { return data_ != other.data_; } /*! * \brief Comparator * \param other Another object ref by address. * \return the compare result. */ - bool operator<(const ObjectRef& other) const { - return data_.get() < other.data_.get(); - } + bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } /*! * \return whether the object is defined(not null). */ - bool defined() const { - return data_ != nullptr; - } + bool defined() const { return data_ != nullptr; } /*! \return the internal object pointer */ - const Object* get() const { - return data_.get(); - } + const Object* get() const { return data_.get(); } /*! \return the internal object pointer */ - const Object* operator->() const { - return get(); - } + const Object* operator->() const { return get(); } /*! \return whether the reference is unique */ - bool unique() const { - return data_.unique(); - } + bool unique() const { return data_.unique(); } /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_.use_count(); - } + int use_count() const { return data_.use_count(); } /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. @@ -605,16 +557,14 @@ class ObjectRef { /*! \brief Internal pointer that backs the reference. */ ObjectPtr data_; /*! \return return a mutable internal ptr, can be used by sub-classes. */ - Object* get_mutable() const { - return data_.get(); - } + Object* get_mutable() const { return data_.get(); } /*! * \brief Internal helper function downcast a ref without check. * \note Only used for internal dev purposes. * \tparam T The target reference type. * \return The casted result. */ - template + template static T DowncastNoCheck(ObjectRef ref) { return T(std::move(ref.data_)); } @@ -623,16 +573,14 @@ class ObjectRef { * after we successfully moved the field. * \param ref The reference data. */ - static void FFIClearAfterMove(ObjectRef* ref) { - ref->data_.data_ = nullptr; - } + static void FFIClearAfterMove(ObjectRef* ref) { ref->data_.data_ = nullptr; } /*! * \brief Internal helper function get data_ as ObjectPtr of ObjectType. * \note only used for internal dev purpose. * \tparam ObjectType The corresponding object type. * \return the corresponding type. */ - template + template static ObjectPtr GetDataPtr(const ObjectRef& ref) { return ObjectPtr(ref.data_.data_); } @@ -657,68 +605,56 @@ inline ObjectPtr GetObjectPtr(ObjectType* ptr); /*! \brief ObjectRef hash functor */ struct ObjectHash { - size_t operator()(const ObjectRef& a) const { - return operator()(a.data_); - } + size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } - template + template size_t operator()(const ObjectPtr& a) const { return std::hash()(a.get()); } }; - /*! \brief ObjectRef equal functor */ struct ObjectEqual { - bool operator()(const ObjectRef& a, const ObjectRef& b) const { - return a.same_as(b); - } + bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); } - template + template size_t operator()(const ObjectPtr& a, const ObjectPtr& b) const { return a == b; } }; - /*! * \brief helper macro to declare a base object type that can be inheritated. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ -#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static_assert(!ParentType::_type_final, "ParentObj maked as final"); \ - static uint32_t RuntimeTypeIndex() { \ - static_assert(TypeName::_type_child_slots == 0 || \ - ParentType::_type_child_slots == 0 || \ - TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ - return TypeName::_type_index; \ - } \ - return _GetOrAllocRuntimeTypeIndex(); \ - } \ - static uint32_t _GetOrAllocRuntimeTypeIndex() { \ - static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \ - TypeName::_type_key, \ - TypeName::_type_index, \ - ParentType::_GetOrAllocRuntimeTypeIndex(), \ - TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow); \ - return tidx; \ - } \ - +#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + static_assert(!ParentType::_type_final, "ParentObj maked as final"); \ + static uint32_t RuntimeTypeIndex() { \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ + TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ + if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ + return TypeName::_type_index; \ + } \ + return _GetOrAllocRuntimeTypeIndex(); \ + } \ + static uint32_t _GetOrAllocRuntimeTypeIndex() { \ + static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \ + TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \ + TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \ + return tidx; \ + } /*! * \brief helper macro to declare type information in a final class. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ - static const constexpr bool _type_final = true; \ - static const constexpr int _type_child_slots = 0; \ - TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ + static const constexpr bool _type_final = true; \ + static const constexpr int _type_child_slots = 0; \ + TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) /*! \brief helper macro to supress unused warning */ #if defined(__GNUC__) @@ -730,8 +666,7 @@ struct ObjectEqual { #define TVM_STR_CONCAT_(__x, __y) __x##__y #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) -#define TVM_OBJECT_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid +#define TVM_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid /*! * \brief Helper macro to register the object type to runtime. @@ -739,20 +674,18 @@ struct ObjectEqual { * * Use this macro in the cc file for each terminal class. */ -#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ - TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ - TypeName::_GetOrAllocRuntimeTypeIndex() - +#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ + TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex() /* * \brief Define the default copy/move constructor and assign opeator * \param TypeName The class typename. */ -#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - TypeName(const TypeName& other) = default; \ - TypeName(TypeName&& other) = default; \ - TypeName& operator=(const TypeName& other) = default; \ - TypeName& operator=(TypeName&& other) = default; \ +#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; /* * \brief Define object reference methods. @@ -760,15 +693,11 @@ struct ObjectEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { \ - return static_cast(data_.get()); \ - } \ +#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName; /* @@ -778,15 +707,11 @@ struct ObjectEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { \ - return static_cast(data_.get()); \ - } \ - static constexpr bool _type_is_nullable = false; \ +#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { return static_cast(data_.get()); } \ + static constexpr bool _type_is_nullable = false; \ using ContainerType = ObjectName; /* @@ -797,15 +722,11 @@ struct ObjectEqual { * \note We recommend making objects immutable when possible. * This macro is only reserved for objects that stores runtime states. */ -#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - ObjectName* operator->() const { \ - return static_cast(data_.get()); \ - } \ +#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + ObjectName* operator->() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName; /*! @@ -827,23 +748,21 @@ struct ObjectEqual { * * \endcode */ -#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ - ObjectName* CopyOnWrite() { \ - CHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - auto n = make_object(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ - } +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + ObjectName* CopyOnWrite() { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = make_object(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ + } // Implementations details below // Object reference counting. #if TVM_OBJECT_ATOMIC_REF_COUNTER -inline void Object::IncRef() { - ref_counter_.fetch_add(1, std::memory_order_relaxed); -} +inline void Object::IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); } inline void Object::DecRef() { if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { @@ -854,15 +773,11 @@ inline void Object::DecRef() { } } -inline int Object::use_count() const { - return ref_counter_.load(std::memory_order_relaxed); -} +inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); } #else -inline void Object::IncRef() { - ++ref_counter_; -} +inline void Object::IncRef() { ++ref_counter_; } inline void Object::DecRef() { if (--ref_counter_ == 0) { @@ -872,13 +787,11 @@ inline void Object::DecRef() { } } -inline int Object::use_count() const { - return ref_counter_; -} +inline int Object::use_count() const { return ref_counter_; } #endif // TVM_OBJECT_ATOMIC_REF_COUNTER -template +template inline bool Object::IsInstance() const { const Object* self = this; // NOTE: the following code can be optimized by @@ -912,11 +825,9 @@ inline bool Object::IsInstance() const { } } - template inline const ObjectType* ObjectRef::as() const { - if (data_ != nullptr && - data_->IsInstance()) { + if (data_ != nullptr && data_->IsInstance()) { return static_cast(data_.get()); } else { return nullptr; @@ -944,12 +855,11 @@ template inline SubRef Downcast(BaseRef ref) { if (ref.defined()) { CHECK(ref->template IsInstance()) - << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; + << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key + << " failed."; } else { - CHECK(SubRef::_type_is_nullable) - << "Downcast from nullptr to not nullable reference of " - << SubRef::ContainerType::_type_key; + CHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of " + << SubRef::ContainerType::_type_key; } return SubRef(std::move(ref.data_)); } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index dfc21fc..01f8e99 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,19 +26,19 @@ #include #include +#include #include #include -#include #include + #include -#include -#include -#include #include #include -#include +#include +#include #include - +#include +#include // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY @@ -91,7 +91,7 @@ class PackedFunc { * } * \endcode */ - using FType = std::function; + using FType = std::function; /*! \brief default constructor */ PackedFunc() {} /*! \brief constructor from null */ @@ -115,8 +115,8 @@ class PackedFunc { * } * \endcode */ - template - inline TVMRetValue operator()(Args&& ...args) const; + template + inline TVMRetValue operator()(Args&&... args) const; /*! * \brief Call the function in packed format. * \param args The arguments @@ -126,13 +126,9 @@ class PackedFunc { /*! \return the internal body function */ inline FType body() const; /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { - return body_ == nullptr; - } + bool operator==(std::nullptr_t null) const { return body_ == nullptr; } /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return body_ != nullptr; - } + bool operator!=(std::nullptr_t null) const { return body_ != nullptr; } private: /*! \brief internal container of packed function */ @@ -142,7 +138,7 @@ class PackedFunc { /*! * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc" */ -template +template class TypedPackedFunc; /*! @@ -177,7 +173,7 @@ class TypedPackedFunc; * \tparam R The return value of the function. * \tparam Args The argument signature of the function. */ -template +template class TypedPackedFunc { public: /*! \brief short hand for this function type */ @@ -234,11 +230,9 @@ class TypedPackedFunc { * \param typed_lambda typed lambda function. * \tparam FLambda the type of the lambda function. */ - template - >::value>::type> + template >::value>::type> TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*) this->AssignTypedLambda(typed_lambda); } @@ -258,11 +252,9 @@ class TypedPackedFunc { * \tparam FLambda the type of the lambda function. * \returns reference to self. */ - template - >::value>::type> + template >::value>::type> TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) this->AssignTypedLambda(typed_lambda); return *this; @@ -281,28 +273,20 @@ class TypedPackedFunc { * \param args The arguments * \returns The return value. */ - TVM_ALWAYS_INLINE R operator()(Args ...args) const; + TVM_ALWAYS_INLINE R operator()(Args... args) const; /*! * \brief convert to PackedFunc * \return the internal PackedFunc */ - operator PackedFunc() const { - return packed(); - } + operator PackedFunc() const { return packed(); } /*! * \return reference the internal PackedFunc */ - const PackedFunc& packed() const { - return packed_; - } + const PackedFunc& packed() const { return packed_; } /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { - return packed_ == nullptr; - } + bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return packed_ != nullptr; - } + bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } private: friend class TVMRetValue; @@ -315,7 +299,7 @@ class TypedPackedFunc { * \tparam FLambda The lambda function type. * \note We capture the lambda when possible for maximum efficiency. */ - template + template inline void AssignTypedLambda(FLambda flambda); }; @@ -331,12 +315,8 @@ class TVMArgs { * \param type_codes The argument type codes * \param num_args number of arguments. */ - TVMArgs(const TVMValue* values, - const int* type_codes, - int num_args) - : values(values), - type_codes(type_codes), - num_args(num_args) { } + TVMArgs(const TVMValue* values, const int* type_codes, int num_args) + : values(values), type_codes(type_codes), num_args(num_args) {} /*! \return size of the arguments */ inline int size() const; /*! @@ -348,15 +328,14 @@ class TVMArgs { }; // macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) \ - CHECK_EQ(CODE, T) << " expected " \ - << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ +#define TVM_CHECK_TYPE_CODE(CODE, T) \ + CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) /*! * \brief Type traits for runtime type check during FFI conversion. * \tparam T the type to be checked. */ -template +template struct ObjectTypeChecker { static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; @@ -410,61 +389,53 @@ class TVMPODValue_ { return value_.v_handle; } operator DLTensor*() const { - if (type_code_ == kTVMDLTensorHandle || - type_code_ == kTVMNDArrayHandle) { + if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) { return static_cast(value_.v_handle); } else { if (type_code_ == kTVMNullptr) return nullptr; LOG(FATAL) << "Expect " - << "DLTensor* or NDArray but get " - << TypeCode2Str(type_code_); + << "DLTensor* or NDArray but get " << TypeCode2Str(type_code_); return nullptr; } } operator NDArray() const { if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr(nullptr)); TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); - return NDArray(NDArray::FFIDataFromHandle( - static_cast(value_.v_handle))); + return NDArray(NDArray::FFIDataFromHandle(static_cast(value_.v_handle))); } operator Module() const { if (type_code_ == kTVMNullptr) { return Module(ObjectPtr(nullptr)); } TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); - return Module( - ObjectPtr(static_cast(value_.v_handle))); + return Module(ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); return value_.v_ctx; } - int type_code() const { - return type_code_; - } + int type_code() const { return type_code_; } /*! * \brief return handle as specific pointer type. * \tparam T the data type. * \return The pointer type. */ - template + template T* ptr() const { return static_cast(value_.v_handle); } // ObjectRef handling - template::value>::type> + template ::value>::type> inline bool IsObjectRef() const; - template + template inline TObjectRef AsObjectRef() const; protected: friend class TVMArgsSetter; friend class TVMRetValue; TVMPODValue_() : type_code_(kTVMNullptr) {} - TVMPODValue_(TVMValue value, int type_code) - : value_(value), type_code_(type_code) {} + TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} /*! \brief The value */ TVMValue value_; @@ -487,9 +458,7 @@ class TVMArgValue : public TVMPODValue_ { * \param value of the function * \param type_code The type code. */ - TVMArgValue(TVMValue value, int type_code) - : TVMPODValue_(value, type_code) { - } + TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; @@ -501,8 +470,8 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator Module; - using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; + using TVMPODValue_::IsObjectRef; // conversion operator. operator std::string() const { @@ -523,31 +492,27 @@ class TVMArgValue : public TVMPODValue_ { // None type if (type_code_ == kTVMNullptr) { DLDataType t; - t.code = kTVMOpaqueHandle; t.bits = 0; t.lanes = 0; + t.code = kTVMOpaqueHandle; + t.bits = 0; + t.lanes = 0; return t; } TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); return value_.v_type; } - operator DataType() const { - return DataType(operator DLDataType()); - } + operator DataType() const { return DataType(operator DLDataType()); } operator PackedFunc() const { if (type_code_ == kTVMNullptr) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); return *ptr(); } - template + template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } - const TVMValue& value() const { - return value_; - } + const TVMValue& value() const { return value_; } - template::value>::type> + template ::value>::type> inline operator T() const; }; @@ -563,9 +528,7 @@ class TVMArgValue : public TVMPODValue_ { */ class TVMMovableArgValue_ : public TVMArgValue { public: - TVMMovableArgValue_(TVMValue value, int type_code) - : TVMArgValue(value, type_code) { - } + TVMMovableArgValue_(TVMValue value, int type_code) : TVMArgValue(value, type_code) {} // reuse converter from parent using TVMArgValue::operator double; using TVMArgValue::operator int64_t; @@ -584,9 +547,8 @@ class TVMMovableArgValue_ : public TVMArgValue { * Try to move out an argument if possible, * fall back to normal argument conversion rule otherwise. */ - template::value>::type> + template ::value>::type> inline operator T() const; }; @@ -606,15 +568,12 @@ class TVMRetValue : public TVMPODValue_ { * \brief move constructor from anoter return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) - : TVMPODValue_(other.value_, other.type_code_) { + TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { other.value_.v_handle = nullptr; other.type_code_ = kTVMNullptr; } /*! \brief destructor */ - ~TVMRetValue() { - this->Clear(); - } + ~TVMRetValue() { this->Clear(); } // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; @@ -626,12 +585,10 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; - using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; + using TVMPODValue_::IsObjectRef; - TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { - this->Assign(other); - } + TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } // conversion operators operator std::string() const { if (type_code_ == kTVMDataType) { @@ -649,15 +606,13 @@ class TVMRetValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); return value_.v_type; } - operator DataType() const { - return DataType(operator DLDataType()); - } + operator DataType() const { return DataType(operator DLDataType()); } operator PackedFunc() const { if (type_code_ == kTVMNullptr) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); return *ptr(); } - template + template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } @@ -704,9 +659,7 @@ class TVMRetValue : public TVMPODValue_ { value_.v_type = t; return *this; } - TVMRetValue& operator=(const DataType& other) { - return operator=(other.operator DLDataType()); - } + TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { this->SwitchToPOD(kDLInt); value_.v_int64 = value; @@ -743,7 +696,7 @@ class TVMRetValue : public TVMPODValue_ { } return *this; } - template + template TVMRetValue& operator=(const TypedPackedFunc& f) { return operator=(f.packed()); } @@ -768,8 +721,7 @@ class TVMRetValue : public TVMPODValue_ { * \param ret_value The return value. * \param ret_type_code The return type code. */ - void MoveToCHost(TVMValue* ret_value, - int* ret_type_code) { + void MoveToCHost(TVMValue* ret_value, int* ret_type_code) { // cannot move str; need specially handle. CHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes); *ret_value = value_; @@ -783,11 +735,9 @@ class TVMRetValue : public TVMPODValue_ { * \param type_code The type code. * \return The created TVMRetValue. */ - static TVMRetValue MoveFromCHost(TVMValue value, - int type_code) { + static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. - CHECK(type_code <= kTVMPackedFuncHandle || - type_code == kTVMNDArrayHandle); + CHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; @@ -795,24 +745,20 @@ class TVMRetValue : public TVMPODValue_ { } /*! \return The value field, if the data is POD */ const TVMValue& value() const { - CHECK(type_code_ != kTVMObjectHandle && - type_code_ != kTVMPackedFuncHandle && - type_code_ != kTVMModuleHandle && - type_code_ != kTVMStr) << "TVMRetValue.value can only be used for POD data"; + CHECK(type_code_ != kTVMObjectHandle && type_code_ != kTVMPackedFuncHandle && + type_code_ != kTVMModuleHandle && type_code_ != kTVMStr) + << "TVMRetValue.value can only be used for POD data"; return value_; } // ObjectRef handling - template::value>::type> + template ::value>::type> inline TVMRetValue& operator=(TObjectRef other); - template::value>::type> + template ::value>::type> inline operator T() const; private: - template + template void Assign(const T& other) { switch (other.type_code()) { case kTVMStr: { @@ -837,9 +783,8 @@ class TVMRetValue : public TVMPODValue_ { } case kTVMObjectHandle: { // Avoid operator ObjectRef as we already know it is not NDArray/Module - SwitchToObject( - kTVMObjectHandle, GetObjectPtr( - static_cast(other.value_.v_handle))); + SwitchToObject(kTVMObjectHandle, + GetObjectPtr(static_cast(other.value_.v_handle))); break; } case kTVMObjectRValueRefArg: { @@ -860,7 +805,7 @@ class TVMRetValue : public TVMPODValue_ { type_code_ = type_code; } } - template + template void SwitchToClass(int type_code, T v) { if (type_code_ != type_code) { this->Clear(); @@ -884,8 +829,13 @@ class TVMRetValue : public TVMPODValue_ { void Clear() { if (type_code_ == kTVMNullptr) return; switch (type_code_) { - case kTVMStr: case kTVMBytes: delete ptr(); break; - case kTVMPackedFuncHandle: delete ptr(); break; + case kTVMStr: + case kTVMBytes: + delete ptr(); + break; + case kTVMPackedFuncHandle: + delete ptr(); + break; case kTVMNDArrayHandle: { NDArray::FFIDecRef(static_cast(value_.v_handle)); break; @@ -912,24 +862,20 @@ class TVMRetValue : public TVMPODValue_ { * * \tparam TObjectRef the specific ObjectRefType. */ -template +template struct PackedFuncValueConverter { /*! * \brief Convert a TObjectRef from an argument value. * \param val The argument value. * \return the converted result. */ - static TObjectRef From(const TVMArgValue& val) { - return val.AsObjectRef(); - } + static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef(); } /*! * \brief Convert a TObjectRef from a return value. * \param val The argument value. * \return the converted result. */ - static TObjectRef From(const TVMRetValue& val) { - return val.AsObjectRef(); - } + static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef(); } }; /*! @@ -951,29 +897,22 @@ struct PackedFuncValueConverter { * * \endcode */ -#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(TVMValue* args, \ - int* type_code, \ - int num_args, \ - TVMValue* out_value, \ - int* out_type_code); \ - int ExportName(TVMValue* args, \ - int* type_code, \ - int num_args, \ - TVMValue* out_value, \ - int* out_type_code) { \ - try { \ - ::tvm::runtime::TVMRetValue rv; \ - Function(::tvm::runtime::TVMArgs( \ - args, type_code, num_args), &rv); \ - rv.MoveToCHost(out_value, out_type_code); \ - return 0; \ - } catch (const ::std::runtime_error& _except_) { \ - TVMAPISetLastError(_except_.what()); \ - return -1; \ - } \ - } \ +#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ + int* out_type_code); \ + int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ + int* out_type_code) { \ + try { \ + ::tvm::runtime::TVMRetValue rv; \ + Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ + rv.MoveToCHost(out_value, out_type_code); \ + return 0; \ + } catch (const ::std::runtime_error& _except_) { \ + TVMAPISetLastError(_except_.what()); \ + return -1; \ + } \ + } \ } /*! @@ -1011,137 +950,113 @@ struct PackedFuncValueConverter { * * \endcode */ -#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(TVMValue* args, \ - int* type_code, \ - int num_args, \ - TVMValue* out_value, \ - int* out_type_code) { \ - try { \ - auto f = Function; \ - using FType = ::tvm::runtime::detail:: \ - function_signature::FType; \ - ::tvm::runtime::TVMRetValue rv; \ - ::tvm::runtime::detail::unpack_call_by_signature::run( \ - f, \ - ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ - rv.MoveToCHost(out_value, out_type_code); \ - return 0; \ - } catch (const ::std::runtime_error& _except_) { \ - TVMAPISetLastError(_except_.what()); \ - return -1; \ - } \ - } \ +#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ + int* out_type_code) { \ + try { \ + auto f = Function; \ + using FType = ::tvm::runtime::detail::function_signature::FType; \ + ::tvm::runtime::TVMRetValue rv; \ + ::tvm::runtime::detail::unpack_call_by_signature::run( \ + f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ + rv.MoveToCHost(out_value, out_type_code); \ + return 0; \ + } catch (const ::std::runtime_error& _except_) { \ + TVMAPISetLastError(_except_.what()); \ + return -1; \ + } \ + } \ } - inline TVMArgValue TVMArgs::operator[](int i) const { - CHECK_LT(i, num_args) - << "not enough argument passed, " - << num_args << " passed" - << " but request arg[" << i << "]."; + CHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed" + << " but request arg[" << i << "]."; return TVMArgValue(values[i], type_codes[i]); } -inline int TVMArgs::size() const { - return num_args; -} +inline int TVMArgs::size() const { return num_args; } -inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { - body_(args, rv); -} +inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); } -inline PackedFunc::FType PackedFunc::body() const { - return body_; -} +inline PackedFunc::FType PackedFunc::body() const { return body_; } // internal namespace namespace detail { -template +template struct for_each_dispatcher { - template + template static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) f(I, std::forward(value)); - for_each_dispatcher - ::run(f, std::forward(args)...); + for_each_dispatcher::run(f, std::forward(args)...); } }; -template -struct for_each_dispatcher { +template +struct for_each_dispatcher { static void run(const F& f) {} // NOLINT(*) }; -template +template inline void for_each(const F& f, Args&&... args) { // NOLINT(*) - for_each_dispatcher - ::run(f, std::forward(args)...); + for_each_dispatcher::run(f, std::forward(args)...); } -template +template struct func_signature_helper { using FType = void; }; -template +template struct func_signature_helper { using FType = R(Args...); - static_assert(!std::is_reference::value, - "TypedPackedFunc return reference"); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; -template +template struct func_signature_helper { using FType = R(Args...); - static_assert(!std::is_reference::value, - "TypedPackedFunc return reference"); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; /*! * \brief template class to get function signature of a function or functor. * \tparam T The funtion/functor type. */ -template +template struct function_signature { using FType = typename func_signature_helper::FType; }; // handle case of function. -template +template struct function_signature { using FType = R(Args...); - static_assert(!std::is_reference::value, - "TypedPackedFunc return reference"); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; // handle case of function ptr. -template +template struct function_signature { using FType = R(Args...); - static_assert(!std::is_reference::value, - "TypedPackedFunc return reference"); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; } // namespace detail /* \brief argument settter to PackedFunc */ class TVMArgsSetter { public: - TVMArgsSetter(TVMValue* values, int* type_codes) - : values_(values), type_codes_(type_codes) {} + TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {} // setters for POD types - template::value>::type> + template ::value>::type> TVM_ALWAYS_INLINE void operator()(size_t i, T value) const { values_[i].v_int64 = static_cast(value); type_codes_[i] = kDLInt; } TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); - CHECK_LE(value, - static_cast(std::numeric_limits::max())); + CHECK_LE(value, static_cast(std::numeric_limits::max())); type_codes_[i] = kDLInt; } TVM_ALWAYS_INLINE void operator()(size_t i, double value) const { @@ -1197,7 +1112,7 @@ class TVMArgsSetter { type_codes_[i] = kTVMNullptr; } } - template + template TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc& value) const { operator()(i, value.packed()); } @@ -1212,25 +1127,21 @@ class TVMArgsSetter { } } // ObjectRef handling - template::value> - ::type> + template ::value>::type> TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const { this->SetObject(i, value); } - template::type>::value> - ::type> + template ::type>::value>::type> TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const { this->SetObject(i, std::forward(value)); } private: - template + template inline void SetObject(size_t i, TObjectRef&& value) const; /*! \brief The values fields */ TVMValue* values_; @@ -1238,43 +1149,36 @@ class TVMArgsSetter { int* type_codes_; }; -template -inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { +template +inline TVMRetValue PackedFunc::operator()(Args&&... args) const { const int kNumArgs = sizeof...(Args); const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; - detail::for_each(TVMArgsSetter(values, type_codes), - std::forward(args)...); + detail::for_each(TVMArgsSetter(values, type_codes), std::forward(args)...); TVMRetValue rv; body_(TVMArgs(values, type_codes, kNumArgs), &rv); return rv; } namespace detail { -template +template struct unpack_call_dispatcher { - template - TVM_ALWAYS_INLINE static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, Args&&... unpacked_args) { // construct a movable argument value // which allows potential move of argument to the input of F. - unpack_call_dispatcher - ::run(f, args_pack, rv, - std::forward(unpacked_args)..., - TVMMovableArgValue_(args_pack.values[index], - args_pack.type_codes[index])); + unpack_call_dispatcher::run( + f, args_pack, rv, std::forward(unpacked_args)..., + TVMMovableArgValue_(args_pack.values[index], args_pack.type_codes[index])); } }; -template +template struct unpack_call_dispatcher { - template - TVM_ALWAYS_INLINE static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, Args&&... unpacked_args) { using RetType = decltype(f(std::forward(unpacked_args)...)); if (std::is_same::value) { @@ -1285,90 +1189,80 @@ struct unpack_call_dispatcher { } }; -template +template struct unpack_call_dispatcher { - template - TVM_ALWAYS_INLINE static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, Args&&... unpacked_args) { f(std::forward(unpacked_args)...); } }; -template -TVM_ALWAYS_INLINE void unpack_call( - const F& f, const TVMArgs& args, TVMRetValue* rv) { - CHECK_EQ(nargs, args.size()) - << "Expect " << nargs << " arguments but get " << args.size(); +template +TVM_ALWAYS_INLINE void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { + CHECK_EQ(nargs, args.size()) << "Expect " << nargs << " arguments but get " << args.size(); unpack_call_dispatcher::run(f, args, rv); } -template -struct unpack_call_by_signature { -}; +template +struct unpack_call_by_signature {}; -template +template struct unpack_call_by_signature { - template - TVM_ALWAYS_INLINE static void run( - const F& f, - const TVMArgs& args, - TVMRetValue* rv) { + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) { unpack_call(f, args, rv); } }; -template -TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&& ...args) { +template +TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) { return R(pf(std::forward(args)...)); } -template +template struct typed_packed_call_dispatcher { - template - TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&& ...args) { + template + TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) { return pf(std::forward(args)...); } }; -template<> +template <> struct typed_packed_call_dispatcher { - template - TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&& ...args) { + template + TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) { pf(std::forward(args)...); } }; } // namespace detail -template -TypedPackedFunc::TypedPackedFunc(PackedFunc packed) - : packed_(packed) {} +template +TypedPackedFunc::TypedPackedFunc(PackedFunc packed) : packed_(packed) {} -template +template TypedPackedFunc::TypedPackedFunc(const TVMRetValue& value) : packed_(value.operator PackedFunc()) {} -template +template TypedPackedFunc::TypedPackedFunc(const TVMArgValue& value) : packed_(value.operator PackedFunc()) {} -template +template TypedPackedFunc::TypedPackedFunc(TVMMovableArgValue_&& value) : packed_(value.operator PackedFunc()) {} -template -template +template +template inline void TypedPackedFunc::AssignTypedLambda(FType flambda) { packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { - detail::unpack_call(flambda, args, rv); - }); + detail::unpack_call(flambda, args, rv); + }); } -template +template TVM_ALWAYS_INLINE R TypedPackedFunc::operator()(Args... args) const { - return detail::typed_packed_call_dispatcher - ::run(packed_, std::forward(args)...); + return detail::typed_packed_call_dispatcher::run(packed_, std::forward(args)...); } // ObjectRef related conversion handling @@ -1376,7 +1270,7 @@ TVM_ALWAYS_INLINE R TypedPackedFunc::operator()(Args... args) const // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle // // We use type traits to eliminate un-necessary checks. -template +template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { using ContainerType = typename std::remove_reference::type::ContainerType; if (value.defined()) { @@ -1403,38 +1297,35 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { } } -template +template inline bool TVMPODValue_::IsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { return type_code_ == kTVMNDArrayHandle && - TVMArrayHandleToObjectHandle( - static_cast(value_.v_handle))->IsInstance(); + TVMArrayHandleToObjectHandle(static_cast(value_.v_handle)) + ->IsInstance(); } if (std::is_base_of::value) { return type_code_ == kTVMModuleHandle && - static_cast(value_.v_handle)->IsInstance(); + static_cast(value_.v_handle)->IsInstance(); } // NOTE: we don't pass NDArray and runtime::Module as RValue ref. if (type_code_ == kTVMObjectRValueRefArg) { - return ObjectTypeChecker::Check( - *static_cast(value_.v_handle)); - } - return - (std::is_base_of::value && - type_code_ == kTVMNDArrayHandle) || - (std::is_base_of::value && - type_code_ == kTVMModuleHandle) || - (type_code_ == kTVMObjectHandle && - ObjectTypeChecker::Check(static_cast(value_.v_handle))); + return ObjectTypeChecker::Check(*static_cast(value_.v_handle)); + } + return (std::is_base_of::value && + type_code_ == kTVMNDArrayHandle) || + (std::is_base_of::value && + type_code_ == kTVMModuleHandle) || + (type_code_ == kTVMObjectHandle && + ObjectTypeChecker::Check(static_cast(value_.v_handle))); } -template +template inline TObjectRef TVMPODValue_::AsObjectRef() const { - static_assert( - std::is_base_of::value, - "Conversion only works for ObjectRef"); + static_assert(std::is_base_of::value, + "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; if (type_code_ == kTVMNullptr) { @@ -1446,8 +1337,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { if (std::is_base_of::value) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); - ObjectPtr data = NDArray::FFIDataFromHandle( - static_cast(value_.v_handle)); + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); CHECK(data->IsInstance()) << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); return TObjectRef(data); @@ -1464,20 +1355,20 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { // normal object type check. Object* ptr = static_cast(value_.v_handle); CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); + << "Expect " << ObjectTypeChecker::TypeName() << " but get " + << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); } else if (type_code_ == kTVMObjectRValueRefArg) { Object* ptr = *static_cast(value_.v_handle); CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); + << "Expect " << ObjectTypeChecker::TypeName() << " but get " + << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); } else if (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) { // Casting to a base class that NDArray can sub-class - ObjectPtr data = NDArray::FFIDataFromHandle( - static_cast(value_.v_handle)); + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); return TObjectRef(data); } else if (std::is_base_of::value && type_code_ == kTVMModuleHandle) { @@ -1489,7 +1380,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { } } -template +template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); @@ -1511,13 +1402,12 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { return *this; } - -template +template inline TVMArgValue::operator T() const { return PackedFuncValueConverter::From(*this); } -template +template inline TVMMovableArgValue_::operator T() const { if (type_code_ == kTVMObjectRValueRefArg) { auto** ref = static_cast(value_.v_handle); @@ -1529,7 +1419,7 @@ inline TVMMovableArgValue_::operator T() const { return PackedFuncValueConverter::From(*this); } -template +template inline TVMRetValue::operator T() const { return PackedFuncValueConverter::From(*this); } diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 6faa7b7..4a5a210 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -44,9 +44,10 @@ #define TVM_RUNTIME_REGISTRY_H_ #include + #include -#include #include +#include namespace tvm { namespace runtime { @@ -68,7 +69,8 @@ class Registry { } /*! * \brief set the body of the function to the given function. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -88,14 +90,15 @@ class Registry { * \param f The function to forward to. * \tparam FLambda The signature of the function. */ - template + template Registry& set_body_typed(FLambda f) { using FType = typename detail::function_signature::FType; return set_body(TypedPackedFunc(std::move(f)).packed()); } /*! * \brief set the body of the function to be the passed method pointer. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -113,9 +116,9 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template + template Registry& set_body_method(R (T::*f)(Args...)) { - auto fwrap =[f](T target, Args... params) -> R { + auto fwrap = [f](T target, Args... params) -> R { // call method pointer return (target.*f)(params...); }; @@ -124,7 +127,8 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -142,7 +146,7 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template + template Registry& set_body_method(R (T::*f)(Args...) const) { auto fwrap = [f](const T target, Args... params) -> R { // call method pointer @@ -154,7 +158,8 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. * Used when calling a method on a Node subclass through a ObjectRef subclass. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -181,8 +186,8 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template ::value>::type> Registry& set_body_method(R (TNode::*f)(Args...)) { auto fwrap = [f](TObjectRef ref, Args... params) { TNode* target = ref.operator->(); @@ -195,7 +200,8 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. * Used when calling a method on a Node subclass through a ObjectRef subclass. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -222,8 +228,8 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template ::value>::type> Registry& set_body_method(R (TNode::*f)(Args...) const) { auto fwrap = [f](TObjectRef ref, Args... params) { const TNode* target = ref.operator->(); @@ -270,8 +276,7 @@ class Registry { friend struct Manager; }; -#define TVM_FUNC_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM +#define TVM_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_##TVM /*! * \brief Register a function globally. @@ -281,9 +286,8 @@ class Registry { * }); * \endcode */ -#define TVM_REGISTER_GLOBAL(OpName) \ - TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::runtime::Registry::Register(OpName) +#define TVM_REGISTER_GLOBAL(OpName) \ + TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName) } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/serializer.h b/include/tvm/runtime/serializer.h index 37bb95f..f40c87e 100644 --- a/include/tvm/runtime/serializer.h +++ b/include/tvm/runtime/serializer.h @@ -33,14 +33,14 @@ namespace dmlc { namespace serializer { -template<> +template <> struct Handler { - inline static void Write(Stream *strm, const DLDataType& dtype) { + inline static void Write(Stream* strm, const DLDataType& dtype) { Handler::Write(strm, dtype.code); Handler::Write(strm, dtype.bits); Handler::Write(strm, dtype.lanes); } - inline static bool Read(Stream *strm, DLDataType* dtype) { + inline static bool Read(Stream* strm, DLDataType* dtype) { if (!Handler::Read(strm, &(dtype->code))) return false; if (!Handler::Read(strm, &(dtype->bits))) return false; if (!Handler::Read(strm, &(dtype->lanes))) return false; @@ -48,14 +48,14 @@ struct Handler { } }; -template<> +template <> struct Handler { - inline static void Write(Stream *strm, const DLContext& ctx) { + inline static void Write(Stream* strm, const DLContext& ctx) { int32_t device_type = static_cast(ctx.device_type); Handler::Write(strm, device_type); Handler::Write(strm, ctx.device_id); } - inline static bool Read(Stream *strm, DLContext* ctx) { + inline static bool Read(Stream* strm, DLContext* ctx) { int32_t device_type = 0; if (!Handler::Read(strm, &(device_type))) return false; ctx->device_type = static_cast(device_type); diff --git a/include/tvm/runtime/threading_backend.h b/include/tvm/runtime/threading_backend.h index f198401..95a6404 100644 --- a/include/tvm/runtime/threading_backend.h +++ b/include/tvm/runtime/threading_backend.h @@ -40,26 +40,25 @@ class ThreadGroup { public: class Impl; - /*! - * \brief Creates a collection of threads which run a provided function. - * - * \param num_workers The total number of worker threads in this group. - Includes main thread if `exclude_worker0 = true` - * \param worker_callback A callback which is run in its own thread. - Receives the worker_id as an argument. - * \param exclude_worker0 Whether to use the main thread as a worker. - * If `true`, worker0 will not be launched in a new thread and - * `worker_callback` will only be called for values >= 1. This - * allows use of the main thread as a worker. - */ - ThreadGroup(int num_workers, - std::function worker_callback, + /*! + * \brief Creates a collection of threads which run a provided function. + * + * \param num_workers The total number of worker threads in this group. + Includes main thread if `exclude_worker0 = true` + * \param worker_callback A callback which is run in its own thread. + Receives the worker_id as an argument. + * \param exclude_worker0 Whether to use the main thread as a worker. + * If `true`, worker0 will not be launched in a new thread and + * `worker_callback` will only be called for values >= 1. This + * allows use of the main thread as a worker. + */ + ThreadGroup(int num_workers, std::function worker_callback, bool exclude_worker0 = false); ~ThreadGroup(); - /*! - * \brief Blocks until all non-main threads in the pool finish. - */ + /*! + * \brief Blocks until all non-main threads in the pool finish. + */ void Join(); enum AffinityMode : int { @@ -95,7 +94,6 @@ void Yield(); */ int MaxConcurrency(); - } // namespace threading } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 44d5898..62534d9 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -24,10 +24,11 @@ #ifndef TVM_RUNTIME_VM_H_ #define TVM_RUNTIME_VM_H_ -#include #include +#include #include #include + #include #include #include @@ -271,8 +272,8 @@ struct Instruction { * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensor(RegName storage, - const std::vector& shape, DLDataType dtype, RegName dst); + static Instruction AllocTensor(RegName storage, const std::vector& shape, + DLDataType dtype, RegName dst); /*! * \brief Construct an allocate tensor instruction with register. * \param storage The storage to allocate out of. @@ -281,8 +282,8 @@ struct Instruction { * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensorReg(RegName storage, - RegName shape_register, DLDataType dtype, RegName dst); + static Instruction AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, + RegName dst); /*! * \brief Construct an allocate datatype instruction. * \param tag The datatype tag. @@ -379,8 +380,8 @@ struct Instruction { * \param dst The destination to place the storage. * \return The alloc storage instruction. */ - static Instruction AllocStorage(RegName size, RegName alignment, - DLDataType dtype_hint, RegName dst); + static Instruction AllocStorage(RegName size, RegName alignment, DLDataType dtype_hint, + RegName dst); Instruction(); Instruction(const Instruction& instr); @@ -407,8 +408,7 @@ struct VMFunction { Index register_file_size; VMFunction(const std::string& name, std::vector params, - const std::vector& instructions, - Index register_file_size) + const std::vector& instructions, Index register_file_size) : name(name), params(params), instructions(instructions), @@ -473,8 +473,7 @@ class Executable : public ModuleNode { * * \return PackedFunc or nullptr when it is not available. */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; /*! * \brief Serialize the executable into global section, constant section, and @@ -559,9 +558,7 @@ class Executable : public ModuleNode { virtual ~Executable() {} - const char* type_key() const final { - return "VMExecutable"; - } + const char* type_key() const final { return "VMExecutable"; } /*! \brief The runtime module/library that contains both the host and also the device * code when executing on non-CPU devices. */ @@ -668,14 +665,11 @@ class VirtualMachine : public runtime::ModuleNode { * If the function needs resource from the module(e.g. late linking), * it should capture sptr_to_self. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); virtual ~VirtualMachine() {} - const char* type_key() const final { - return "VirtualMachine"; - } + const char* type_key() const final { return "VirtualMachine"; } VirtualMachine() : frames_(), func_index_(0), code_(nullptr), pc_(0), exec_(nullptr) {} @@ -763,11 +757,8 @@ class VirtualMachine : public runtime::ModuleNode { * * \note The return value will be stored in the last output_size slots of args. */ - virtual void InvokePacked(Index packed_index, - const PackedFunc& func, - Index arg_count, - Index output_size, - const std::vector& args); + virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, + Index output_size, const std::vector& args); /*! * \brief Initialize the virtual machine for a set of contexts. diff --git a/include/tvm/support/logging.h b/include/tvm/support/logging.h index 44b990e..c318b89 100644 --- a/include/tvm/support/logging.h +++ b/include/tvm/support/logging.h @@ -59,8 +59,8 @@ * a = ... * b = ... * // if quit_on_assertion is true, if a==b, continue, otherwise quit. - * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default behaviour) - * COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" + * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default + * behaviour) COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" * ... * for (int i = 0; i < N; i++) { * a = ... @@ -84,29 +84,24 @@ // Not supposed to be used by users directly. #define COND_CHECK_OP(quit_on_assert, x, y, what, op) \ - if (!quit_on_assert) { \ - if (!((x) op (y))) \ - what; \ - } \ - else /* NOLINT(*) */ \ + if (!quit_on_assert) { \ + if (!((x)op(y))) what; \ + } else /* NOLINT(*) */ \ CHECK_##op(x, y) #define COND_CHECK_EQ_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, ==) #define COND_CHECK_GE_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, >=) #define COND_CHECK_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - if (!(x)) \ - what; \ - } \ - else /* NOLINT(*) */ \ + if (!quit_on_assert) { \ + if (!(x)) what; \ + } else /* NOLINT(*) */ \ CHECK(x) #define COND_LOG_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - what; \ - } \ - else /* NOLINT(*) */ \ + if (!quit_on_assert) { \ + what; \ + } else /* NOLINT(*) */ \ LOG(x) #define COND_CHECK_EQ_3(quit_on_assert, x, y) COND_CHECK_EQ_4(quit_on_assert, x, y, return false) @@ -114,4 +109,4 @@ #define COND_CHECK_2(quit_on_assert, x) COND_CHECK_3(quit_on_assert, x, return false) #define COND_LOG_2(quit_on_assert, x) COND_LOG_3(quit_on_assert, x, return false) -#endif // TVM_SUPPORT_LOGGING_H_ +#endif // TVM_SUPPORT_LOGGING_H_ diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h index 46b091a..90c82c4 100644 --- a/include/tvm/support/with.h +++ b/include/tvm/support/with.h @@ -26,6 +26,7 @@ #define TVM_SUPPORT_WITH_H_ #include + #include namespace tvm { @@ -52,22 +53,19 @@ namespace tvm { * * \tparam ContextType Type of the context object. */ -template +template class With { public: /*! * \brief constructor. * Enter the scope of the context. */ - template - explicit With(Args&& ...args) - : ctx_(std::forward(args)...) { + template + explicit With(Args&&... args) : ctx_(std::forward(args)...) { ctx_.EnterWithScope(); } /*! \brief destructor, leaves the scope of the context. */ - ~With() DMLC_THROW_EXCEPTION { - ctx_.ExitWithScope(); - } + ~With() DMLC_THROW_EXCEPTION { ctx_.ExitWithScope(); } private: /*! \brief internal context type. */ diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 4b7ea56..e89d44d 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -24,14 +24,13 @@ #ifndef TVM_TARGET_CODEGEN_H_ #define TVM_TARGET_CODEGEN_H_ -#include #include -#include +#include #include +#include #include - namespace tvm { /*! \brief namespace for target translation and codegen. */ namespace codegen { @@ -71,8 +70,7 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib); * \param target_triple LLVM target triple * \return runtime::Module The generated LLVM module. */ -runtime::Module PackImportsToLLVM(const runtime::Module& m, - bool system_lib, +runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib, const std::string& target_triple); } // namespace codegen } // namespace tvm diff --git a/include/tvm/target/generic_func.h b/include/tvm/target/generic_func.h index f2a361b3..a310173 100644 --- a/include/tvm/target/generic_func.h +++ b/include/tvm/target/generic_func.h @@ -24,14 +24,14 @@ #ifndef TVM_TARGET_GENERIC_FUNC_H_ #define TVM_TARGET_GENERIC_FUNC_H_ -#include #include +#include #include -#include #include -#include #include +#include +#include namespace tvm { @@ -52,8 +52,7 @@ class GenericFunc : public ObjectRef { * false, an error will be logged if the call would override a previously registered function. * \return reference to self. */ - TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, - bool allow_override = false); + TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false); /*! * \brief Register a specialized function * \param tags The tags for this specialization @@ -63,8 +62,7 @@ class GenericFunc : public ObjectRef { * \return reference to self. */ TVM_DLL GenericFunc& register_func(const std::vector& tags, - const runtime::PackedFunc value, - bool allow_override = false); + const runtime::PackedFunc value, bool allow_override = false); /*! * \brief Call generic function by directly passing in unpacked format. * \param args Arguments to be passed. @@ -79,16 +77,15 @@ class GenericFunc : public ObjectRef { * } * \endcode */ - template - inline runtime::TVMRetValue operator()(Args&& ...args) const; + template + inline runtime::TVMRetValue operator()(Args&&... args) const; /*! * \brief Invoke the relevant function for the current target context, set by set_target_context. * Arguments are passed in packed format. * \param args The arguments to pass to the function. * \param ret The return value */ - TVM_DLL void CallPacked(runtime::TVMArgs args, - runtime::TVMRetValue* ret) const; + TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const; /*! * \brief Find or register the GenericFunc instance corresponding to the give name @@ -120,14 +117,14 @@ class GenericFunc : public ObjectRef { friend struct Manager; }; -template -inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const { +template +inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const { const int kNumArgs = sizeof...(Args); const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes), - std::forward(args)...); + std::forward(args)...); runtime::TVMRetValue rv; CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv); return rv; @@ -155,8 +152,7 @@ inline GenericFuncNode* GenericFunc::operator->() { return static_cast(get_mutable()); } -#define TVM_GENERIC_FUNC_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM +#define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM /*! * \def TVM_REGISTER_GENERIC_FUNC @@ -165,9 +161,8 @@ inline GenericFuncNode* GenericFunc::operator->() { * * \param name The name of the function */ -#define TVM_REGISTER_GENERIC_FUNC(name) \ - TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::GenericFunc::Get(#name) +#define TVM_REGISTER_GENERIC_FUNC(name) \ + TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name) } // namespace tvm #endif // TVM_TARGET_GENERIC_FUNC_H_ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 829de73..c28b051 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -24,15 +24,15 @@ #ifndef TVM_TARGET_TARGET_H_ #define TVM_TARGET_TARGET_H_ -#include -#include #include #include +#include +#include #include -#include #include #include +#include namespace tvm { /*! @@ -99,9 +99,9 @@ class Target : public ObjectRef { Target() {} explicit Target(ObjectPtr n) : ObjectRef(n) {} /*! - * \brief Create a Target given a string - * \param target_str the string to parse - */ + * \brief Create a Target given a string + * \param target_str the string to parse + */ TVM_DLL static Target Create(const std::string& target_str); /*! * \brief Get the current target context from thread local storage. @@ -113,12 +113,11 @@ class Target : public ObjectRef { */ TVM_DLL static tvm::Target Current(bool allow_not_defined = true); - const TargetNode* operator->() const { - return static_cast(get()); - } + const TargetNode* operator->() const { return static_cast(get()); } using ContainerType = TargetNode; class Internal; + private: // enable with syntax. friend class Internal; @@ -140,48 +139,37 @@ class Target : public ObjectRef { namespace target { /*! \return A target for LLVM */ -TVM_DLL Target llvm(const std::vector& options = - std::vector()); +TVM_DLL Target llvm(const std::vector& options = std::vector()); /*! \return A target for CUDA */ -TVM_DLL Target cuda(const std::vector& options = - std::vector()); +TVM_DLL Target cuda(const std::vector& options = std::vector()); /*! \return A target for ROCm */ -TVM_DLL Target rocm(const std::vector& options = - std::vector()); +TVM_DLL Target rocm(const std::vector& options = std::vector()); /*! \return A target for OpenCL */ -TVM_DLL Target opencl(const std::vector& options = - std::vector()); +TVM_DLL Target opencl(const std::vector& options = std::vector()); /*! \return A target for Metal */ -TVM_DLL Target metal(const std::vector& options = - std::vector()); +TVM_DLL Target metal(const std::vector& options = std::vector()); /*! \return A target for rasp */ -TVM_DLL Target rasp(const std::vector& options = - std::vector()); +TVM_DLL Target rasp(const std::vector& options = std::vector()); /*! \return A target for Mali */ -TVM_DLL Target mali(const std::vector& options = - std::vector()); +TVM_DLL Target mali(const std::vector& options = std::vector()); /*! \return A target for Intel Graphics */ -TVM_DLL Target intel_graphics(const std::vector& options = - std::vector()); +TVM_DLL Target intel_graphics(const std::vector& options = std::vector()); /*! \return A target for stackvm */ -TVM_DLL Target stackvm(const std::vector& options = - std::vector()); +TVM_DLL Target stackvm(const std::vector& options = std::vector()); /*! \return A target for external device */ -TVM_DLL Target ext_dev(const std::vector& options = - std::vector()); +TVM_DLL Target ext_dev(const std::vector& options = std::vector()); /*! \return A target for hexagon */ -TVM_DLL Target hexagon(const std::vector& options = - std::vector()); +TVM_DLL Target hexagon(const std::vector& options = std::vector()); } // namespace target /*! @@ -273,12 +261,8 @@ class BuildConfig : public ::tvm::ObjectRef { public: BuildConfig() {} explicit BuildConfig(ObjectPtr n) : ObjectRef(n) {} - const BuildConfigNode* operator->() const { - return static_cast(get()); - } - BuildConfigNode* operator->() { - return static_cast(get_mutable()); - } + const BuildConfigNode* operator->() const { return static_cast(get()); } + BuildConfigNode* operator->() { return static_cast(get_mutable()); } /*! * \brief Construct a BuildConfig containing a empty build config node. * \return The new BuildConfig diff --git a/include/tvm/target/target_info.h b/include/tvm/target/target_info.h index 4466476..1de15a5 100644 --- a/include/tvm/target/target_info.h +++ b/include/tvm/target/target_info.h @@ -25,6 +25,7 @@ #define TVM_TARGET_TARGET_INFO_H_ #include + #include namespace tvm { diff --git a/include/tvm/te/autodiff.h b/include/tvm/te/autodiff.h index 180ec0b..e2d3799 100644 --- a/include/tvm/te/autodiff.h +++ b/include/tvm/te/autodiff.h @@ -27,6 +27,7 @@ #include #include + #include "tensor.h" namespace tvm { @@ -59,8 +60,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input); * * Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor * dot product. \p input must be an immediate dependency of \p output (must be called from within - * the body of \p output). That is, the function will compute one summand of the adjoint for \p input - * given the adjoint for \p output (which is called \p head here). + * the body of \p output). That is, the function will compute one summand of the adjoint for \p + * input given the adjoint for \p output (which is called \p head here). * * \param output The tensor to differentiate. * \param input The input tensor, which \p output should directly use. @@ -68,7 +69,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input); * \return The tensor of shape `prefix + input.shape` * representing the partial adjoint of \p input wrt one of its consumers (output) */ -Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head); +Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Tensor& head); /*! * \brief Perform reverse mode automatic differentiation. @@ -82,14 +83,12 @@ Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Te * wrt all tensors the output depends on. * \param head The adjoint of the output, in other words, some tensor, by which the Jacobians * will be multiplied (using tensordot axes=`output.shape`). - * Its shape must be of the form `prefix + output.shape`. If the null pointer is provided, - * the identity tensor of shape `output.shape + output.shape` will be used. - * \return An array of adjoints corresponding to \p inputs. + * Its shape must be of the form `prefix + output.shape`. If the null pointer is + * provided, the identity tensor of shape `output.shape + output.shape` will be used. \return An + * array of adjoints corresponding to \p inputs. */ -TVM_DLL Array Gradient( - const Tensor& output, - const Array& inputs, - const Tensor& head = Tensor()); +TVM_DLL Array Gradient(const Tensor& output, const Array& inputs, + const Tensor& head = Tensor()); } // namespace te } // namespace tvm diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 2055899..739ea85 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -25,16 +25,15 @@ #define TVM_TE_OPERATION_H_ #include -#include #include - +#include +#include #include #include -#include #include -#include #include +#include namespace tvm { /*! \brief Tensor expression language DSL. */ @@ -46,8 +45,7 @@ namespace te { */ struct TensorDom { // constructor - explicit TensorDom(int ndim) - : data(ndim) {} + explicit TensorDom(int ndim) : data(ndim) {} /*! \brief The domain data */ std::vector > data; }; @@ -64,9 +62,7 @@ class OperationNode : public tir::FunctionBaseNode { /*! \brief additional attributes of the operation*/ Map attrs; /*! \return name of the operation */ - const std::string& func_name() const final { - return name; - } + const std::string& func_name() const final { return name; } /*! * \return The list of iteration variable at root * \note root_iter_vars decides the shape of the outputs. @@ -96,9 +92,8 @@ class OperationNode : public tir::FunctionBaseNode { * \param rmap The replacement map. * \return self if nothing is replaced, otherwise return replaced op. */ - virtual Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const = 0; + virtual Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const = 0; /*! * \brief Propagate the bounds to inputs * \param self The reference to self. @@ -108,11 +103,9 @@ class OperationNode : public tir::FunctionBaseNode { * The function is only asked to fill the bounds for Tensors that * is already in the out_dom_map */ - virtual void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const = 0; + virtual void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const = 0; /*! * \brief Gather the bound from output tensor. * Set the range of each root_iter_vars in the op to out_dom_map @@ -121,10 +114,9 @@ class OperationNode : public tir::FunctionBaseNode { * \param tensor_dom Domain map of Tensor->access set of each dimension. * \param out_dom_map The output domain map of each IterVar to be setted. */ - virtual void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const = 0; + virtual void GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const = 0; /*! * \brief Build the Realize statement that realizes * the op's output tensors. @@ -133,10 +125,9 @@ class OperationNode : public tir::FunctionBaseNode { * \param body The body that is going to get * \return A realization statement that wraps body. */ - virtual Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const = 0; + virtual Stmt BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const = 0; /*! * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. @@ -144,10 +135,8 @@ class OperationNode : public tir::FunctionBaseNode { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return A statement that add production and wraps consumer. */ - virtual Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const = 0; + virtual Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const = 0; static constexpr const char* _type_key = "Operation"; @@ -169,26 +158,17 @@ class PlaceholderOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -197,9 +177,7 @@ class PlaceholderOpNode : public OperationNode { v->Visit("shape", &shape); v->Visit("dtype", &dtype); } - static Operation make(std::string name, - Array shape, - DataType dtype); + static Operation make(std::string name, Array shape, DataType dtype); static constexpr const char* _type_key = "PlaceholderOp"; TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); @@ -219,21 +197,16 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { // override functions Array root_iter_vars() const final; Array output_shape(size_t idx) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); }; - /*! * \brief A Compute op that compute a tensor on certain domain. */ @@ -247,18 +220,13 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { int num_outputs() const final; DataType output_dtype(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; void VisitAttrs(AttrVisitor* v) { @@ -269,11 +237,8 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { v->Visit("reduce_axis", &reduce_axis); v->Visit("body", &body); } - static Operation make(std::string name, - std::string tag, - Map attrs, - Array axis, - Array body); + static Operation make(std::string name, std::string tag, Map attrs, + Array axis, Array body); static constexpr const char* _type_key = "ComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); @@ -300,18 +265,13 @@ class TensorComputeOpNode : public BaseComputeOpNode { int num_outputs() const final; DataType output_dtype(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; void VisitAttrs(AttrVisitor* v) { @@ -325,14 +285,9 @@ class TensorComputeOpNode : public BaseComputeOpNode { v->Visit("input_regions", &input_regions); v->Visit("scalar_inputs", &scalar_inputs); } - static Operation make(std::string name, - std::string tag, - Array axis, - Array reduce_axis, - int schedulable_ndim, - TensorIntrin intrin, - Array tensors, - Array regions, + static Operation make(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, + Array tensors, Array regions, Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; @@ -375,26 +330,17 @@ class ScanOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -407,14 +353,9 @@ class ScanOpNode : public OperationNode { v->Visit("inputs", &inputs); v->Visit("spatial_axis_", &spatial_axis_); } - static Operation make(std::string name, - std::string tag, - Map attrs, - IterVar axis, - Array init, - Array update, - Array state_placeholder, - Array input); + static Operation make(std::string name, std::string tag, Map attrs, + IterVar axis, Array init, Array update, + Array state_placeholder, Array input); static constexpr const char* _type_key = "ScanOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); @@ -442,26 +383,17 @@ class ExternOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -472,13 +404,10 @@ class ExternOpNode : public OperationNode { v->Visit("output_placeholders", &output_placeholders); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array input_placeholders, - Array output_placeholders, - Stmt body); + TVM_DLL static Operation make(std::string name, std::string tag, + Map attrs, Array inputs, + Array input_placeholders, Array output_placeholders, + Stmt body); static constexpr const char* _type_key = "ExternOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); @@ -510,26 +439,17 @@ class HybridOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -540,12 +460,9 @@ class HybridOpNode : public OperationNode { v->Visit("axis", &axis); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array outputs, - Stmt body); + TVM_DLL static Operation make(std::string name, std::string tag, + Map attrs, Array inputs, + Array outputs, Stmt body); static constexpr const char* _type_key = "HybridOp"; TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); @@ -575,10 +492,10 @@ TVM_DLL IterVar thread_axis(Range dom, std::string tag); TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); /*! \brief The compute function to specify the input source of a Tensor */ -using FCompute = std::function& i)>; +using FCompute = std::function& i)>; /*! \brief The compute function to specify the inputs source of Tensors */ -using FBatchCompute = std::function (const Array& i)>; +using FBatchCompute = std::function(const Array& i)>; /*! * \brief create a place holder tensor. @@ -586,8 +503,7 @@ using FBatchCompute = std::function (const Array& i)>; * \param dtype the data type of the tensor. * \param name The name of the Tensor. */ -TVM_DLL Tensor placeholder(Array shape, - DataType dtype = DataType::Float(32), +TVM_DLL Tensor placeholder(Array shape, DataType dtype = DataType::Float(32), std::string name = "placeholder"); /*! @@ -599,11 +515,8 @@ TVM_DLL Tensor placeholder(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Tensor compute(Array shape, - FCompute fcompute, - std::string name = "tensor", - std::string tag = "", - Map attrs = {}); +TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor", + std::string tag = "", Map attrs = {}); /*! * \brief Construct a new tensor by computing over shape, @@ -614,10 +527,8 @@ TVM_DLL Tensor compute(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array compute(Array shape, - FBatchCompute fcompute, - std::string name = "tensor", - std::string tag = "", +TVM_DLL Array compute(Array shape, FBatchCompute fcompute, + std::string name = "tensor", std::string tag = "", Map attrs = {}); /*! @@ -632,45 +543,34 @@ TVM_DLL Array compute(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array scan(Array init, - Array update, - Array state_placeholder, - Array inputs = Array(), - std::string name = "scan", - std::string tag = "", +TVM_DLL Array scan(Array init, Array update, + Array state_placeholder, Array inputs = Array(), + std::string name = "scan", std::string tag = "", Map attrs = {}); // same as compute, specialized for different fcompute function -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0]); }; + FCompute fc = [f](const Array& i) { return f(i[0]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0], i[1]); }; + FCompute fc = [f](const Array& i) { return f(i[0], i[1]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2]); }; - return compute(shape, fc, name, tag, attrs); + FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2]); }; + return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2], i[3]); }; + FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2], i[3]); }; return compute(shape, fc, name, tag, attrs); } diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index a8a0236..3667e1e 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -25,10 +25,10 @@ #ifndef TVM_TE_SCHEDULE_H_ #define TVM_TE_SCHEDULE_H_ -#include +#include #include #include -#include +#include #include #include @@ -84,12 +84,12 @@ class Stage : public ObjectRef { * \param scope The iteration point to carry the schedule. * \return reference to self. */ - TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) + TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) /*! * \brief Compute the function inline. * \return reference to self. */ - TVM_DLL Stage& compute_inline(); // NOLINT(*) + TVM_DLL Stage& compute_inline(); // NOLINT(*) /*! * \brief Compute the function at group root. * \return reference to self. @@ -131,7 +131,8 @@ class Stage : public ObjectRef { * \param p_inner The result inner domain. * \return reference to self. */ - TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, + IterVar* p_inner); // NOLINT(*) /*! * \brief Split the iteration with given number of parts. * @@ -141,7 +142,8 @@ class Stage : public ObjectRef { * \param p_inner The result inner domain. * \return reference to self. */ - TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, + IterVar* p_inner); // NOLINT(*) /*! * \brief Fuse the inner outer domain to the target * \param outer The outer domain to be fused. @@ -169,7 +171,7 @@ class Stage : public ObjectRef { * \param order The order of iteration variable. * \return reference to self. */ - TVM_DLL Stage& reorder(const Array& order); // NOLINT(*) + TVM_DLL Stage& reorder(const Array& order); // NOLINT(*) /*! * \brief Perform tiling on two dimensions * The final loop order from outmost to inner most are @@ -185,16 +187,15 @@ class Stage : public ObjectRef { * \param p_y_inner Inner axis of y dimension * \return reference to self. */ - TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) - PrimExpr x_factor, PrimExpr y_factor, - IterVar* p_x_outer, IterVar* p_y_outer, - IterVar* p_x_inner, IterVar* p_y_inner); + TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) + PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer, + IterVar* p_x_inner, IterVar* p_y_inner); /*! * \brief Vectorize iteration. * \param var The axis to be vectorized. * \return reference to self. */ - TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*) + TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*) /*! * \brief Replace computation of the current stage by tensor intrinsic f. * \param var The axis marks beginning of tensorization. @@ -202,19 +203,19 @@ class Stage : public ObjectRef { * \param f The Tensor compute intrinsics. * \return reference to self. */ - TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*) + TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*) /*! * \brief Unroll iteration. * \param var The axis to be unrolled. * \return reference to self. */ - TVM_DLL Stage& unroll(IterVar var); // NOLINT(*) + TVM_DLL Stage& unroll(IterVar var); // NOLINT(*) /*! * \brief Parallelize iteration. * \param var The axis to be parallelized. * \return reference to self. */ - TVM_DLL Stage& parallel(IterVar var); // NOLINT(*) + TVM_DLL Stage& parallel(IterVar var); // NOLINT(*) /*! * \brief Annotate the iteration with pragma * @@ -224,9 +225,8 @@ class Stage : public ObjectRef { * * \return reference to self. */ - TVM_DLL Stage& pragma(IterVar var, - const std::string& pragma_type, - const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*) + TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type, + const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*) /*! * \brief Fetch data in advance. * \param domain the tensor to be prefetched @@ -234,7 +234,7 @@ class Stage : public ObjectRef { * \param offset the number of iterations be to fetched in advance * \return reference to self */ - TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*) + TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset); // NOLINT(*) /*! * \brief Set alignment requirement for specific dimension. * @@ -245,17 +245,17 @@ class Stage : public ObjectRef { * \param offset The required offset factor. * \return reference to self */ - TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*) + TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); // NOLINT(*) /*! * \brief Compute current stage with double buffering. * \return reference to self. */ - TVM_DLL Stage& double_buffer(); // NOLINT(*) + TVM_DLL Stage& double_buffer(); // NOLINT(*) /*! * \brief Schedule for OpenGL fragment shader. * \return reference to self. */ - Stage& opengl(); // NOLINT(*) + Stage& opengl(); // NOLINT(*) /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. @@ -297,9 +297,7 @@ class Schedule : public ObjectRef { * \param tensor The tensor * \return The stage corresponding to the tensor's op */ - TVM_DLL Stage operator[](const Tensor& tensor) { - return this->operator[](tensor->op); - } + TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); } /*! * \brief Create a new stage group for all intermediate * operations between inputs and outputs. @@ -309,9 +307,8 @@ class Schedule : public ObjectRef { * \param include_inputs Whether include inputs if they are reachable from outputs. * \return The new grouped stage. */ - TVM_DLL Stage create_group(const Array& outputs, - const Array& inputs, - bool include_inputs = false); + TVM_DLL Stage create_group(const Array& outputs, const Array& inputs, + bool include_inputs = false); /*! * \brief create a cache read of original tensor for readers. * This will mutate the body of the readers. @@ -321,9 +318,8 @@ class Schedule : public ObjectRef { * \param readers The readers to redirect to the tensor. * \return The created tensor. */ - TVM_DLL Tensor cache_read(const Tensor& tensor, - const std::string& scope, - const Array& readers); + TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope, + const Array& readers); /*! * \brief Create a cache write tensor for producing tensor. * The the tensor will take over body of original tensor op. @@ -371,9 +367,7 @@ class Schedule : public ObjectRef { * \param factor_axis The position where the new axis is placed. * \return The created factored tensors. */ - TVM_DLL Array rfactor(const Tensor& tensor, - const IterVar& axis, - int factor_axis = 0); + TVM_DLL Array rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0); /*! * \brief Normalize the schedule. * This is needed before bound inference. @@ -565,9 +559,7 @@ class ScheduleNode : public Object { * \param tensor The candidate tensor. * \return true if the schedule has the tensor. Otherwise, false. */ - TVM_DLL bool Contain(const Tensor& tensor) const { - return Contain(tensor->op); - } + TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); } /*! * \brief Create a schedule for array of ops(and their dependencies). @@ -585,9 +577,7 @@ class ScheduleNode : public Object { * \param ops The ops to be scheduled. * \return sch The created Schedule. */ -inline Schedule create_schedule(Array ops) { - return ScheduleNode::make(ops); -} +inline Schedule create_schedule(Array ops) { return ScheduleNode::make(ops); } /*! \brief node container for IterVar attr */ class IterVarAttrNode : public Object { @@ -666,10 +656,7 @@ class SplitNode : public IterVarRelationNode { v->Visit("nparts", &nparts); } - static IterVarRelation make(IterVar parent, - IterVar outer, - IterVar inner, - PrimExpr factor, + static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts); static constexpr const char* _type_key = "Split"; @@ -694,8 +681,7 @@ class FuseNode : public IterVarRelationNode { v->Visit("fused", &fused); } - static IterVarRelation make( - IterVar outer, IterVar inner, IterVar fused); + static IterVarRelation make(IterVar outer, IterVar inner, IterVar fused); static constexpr const char* _type_key = "Fuse"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode); @@ -724,7 +710,6 @@ class RebaseNode : public IterVarRelationNode { TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode); }; - /*! * \brief Singleton iterator [0, 1) */ @@ -733,9 +718,7 @@ class SingletonNode : public IterVarRelationNode { /*! \brief The singleton iterator */ IterVar iter; - void VisitAttrs(AttrVisitor* v) { - v->Visit("iter", &iter); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); } static IterVarRelation make(IterVar iter); @@ -753,9 +736,7 @@ class SpecializedConditionNode : public Object { */ Array clauses; - void VisitAttrs(AttrVisitor* v) { - v->Visit("clauses", &clauses); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("clauses", &clauses); } static constexpr const char* _type_key = "SpecializedCondition"; TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object); @@ -792,19 +773,13 @@ class SpecializedCondition : public ObjectRef { }; // implementations -inline const StageNode* Stage::operator->() const { - return static_cast(get()); -} -inline StageNode* Stage::operator->() { - return static_cast(get_mutable()); -} +inline const StageNode* Stage::operator->() const { return static_cast(get()); } +inline StageNode* Stage::operator->() { return static_cast(get_mutable()); } inline const ScheduleNode* Schedule::operator->() const { return static_cast(get()); } -inline ScheduleNode* Schedule::operator->() { - return static_cast(get_mutable()); -} +inline ScheduleNode* Schedule::operator->() { return static_cast(get_mutable()); } inline const IterVarRelationNode* IterVarRelation::operator->() const { return static_cast(get()); diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index 618fc22..a4efa7a 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -90,10 +90,8 @@ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivia * buffer assignment of input and outputs. * \return Transformed stmt. */ -Stmt SchedulePostProcRewriteForTensorCore( - Stmt stmt, - Schedule schedule, - Map extern_buffer); +Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, + Map extern_buffer); /*! * \brief Postprocessing the Stmt generated by ScheduleOps to create @@ -111,8 +109,7 @@ Stmt SchedulePostProcRewriteForTensorCore( * \param body The body of the function. * \param bindings potential Tensor to Buffer bindings for the Tensors in the body. */ -PrimFunc SchedulePostProcToPrimFunc(Array arg_list, - Stmt body, +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, Optional> bindings); } // namespace te diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index c247dca..f82df8c 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -24,15 +24,15 @@ #ifndef TVM_TE_TENSOR_H_ #define TVM_TE_TENSOR_H_ -#include #include +#include #include #include #include -#include -#include #include +#include +#include namespace tvm { namespace te { @@ -78,8 +78,8 @@ class Tensor : public ObjectRef { * \param args The indices * \return the result expression representing tensor read. */ - template - inline PrimExpr operator()(Args&& ...args) const { + template + inline PrimExpr operator()(Args&&... args) const { Array indices{std::forward(args)...}; return operator()(indices); } @@ -119,9 +119,7 @@ class Tensor : public ObjectRef { * This is only valid when all the coordinates are fully specified. * \return the corresponding expression of this slice. */ - inline operator PrimExpr() const { - return tensor_(indices_); - } + inline operator PrimExpr() const { return tensor_(indices_); } private: const Tensor& tensor_; @@ -132,9 +130,7 @@ class Tensor : public ObjectRef { * \param i the index of the coordinate * \return the subsequent slice. */ - inline Slice operator[](PrimExpr i) const { - return Slice(*this, {i}); - } + inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); } /*! \brief specify container node */ using ContainerType = TensorNode; }; @@ -180,57 +176,46 @@ class TensorNode : public Object { v->Visit("op", &op); v->Visit("value_index", &value_index); } - TVM_DLL static Tensor make(Array shape, - DataType dtype, - Operation op, - int value_index); + TVM_DLL static Tensor make(Array shape, DataType dtype, Operation op, int value_index); static constexpr const char* _type_key = "Tensor"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object); }; - // Implementations of inline functions inline const TensorNode* Tensor::operator->() const { return static_cast(get()); } -inline size_t Tensor::ndim() const { - return (*this)->shape.size(); -} +inline size_t Tensor::ndim() const { return (*this)->shape.size(); } inline bool Tensor::operator==(const Tensor& other) const { if (get() == other.get()) return true; if (get() == nullptr || other.get() == nullptr) return false; if ((*this)->op.defined() || other->op.defined()) { - return (*this)->op == other->op && - (*this)->value_index == other->value_index; + return (*this)->op == other->op && (*this)->value_index == other->value_index; } else { return false; } } -inline bool Tensor::operator!=(const Tensor& other) const { - return !(*this == other); -} +inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); } // macro to turn every operation of slice to expression -#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ - inline PrimExpr operator Op (const Tensor::Slice& a) { \ - return Op a.operator PrimExpr() ; \ - } \ +#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ + inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); } -#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ - template \ - inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \ - return a.operator PrimExpr() Op b; \ - } \ - template \ - inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \ - return a Op b.operator PrimExpr(); \ - } \ - inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \ - return a.operator PrimExpr() Op b.operator PrimExpr(); \ +#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ + template \ + inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \ + return a.operator PrimExpr() Op b; \ + } \ + template \ + inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \ + return a Op b.operator PrimExpr(); \ + } \ + inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \ + return a.operator PrimExpr() Op b.operator PrimExpr(); \ } DEFINE_OVERLOAD_SLICE_UNARY_OP(!); @@ -254,8 +239,7 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) namespace std { template <> -struct hash<::tvm::te::Operation> : public ::tvm::ObjectHash { -}; +struct hash<::tvm::te::Operation> : public ::tvm::ObjectHash {}; template <> struct hash<::tvm::te::Tensor> { @@ -263,7 +247,7 @@ struct hash<::tvm::te::Tensor> { ::tvm::ObjectHash hasher; if (k.defined() && k->op.defined()) { return hasher(k->op); - } else{ + } else { return hasher(k); } } diff --git a/include/tvm/te/tensor_intrin.h b/include/tvm/te/tensor_intrin.h index c964d3e..252c5f5 100644 --- a/include/tvm/te/tensor_intrin.h +++ b/include/tvm/te/tensor_intrin.h @@ -100,14 +100,9 @@ class TensorIntrinNode : public Object { v->Visit("reduce_update", &reduce_update); } - TVM_DLL static TensorIntrin make(std::string name, - Operation op, - Array inputs, - Array buffers, - Array scalar_params, - Stmt body, - Stmt reduce_init, - Stmt reduce_update); + TVM_DLL static TensorIntrin make(std::string name, Operation op, Array inputs, + Array buffers, Array scalar_params, Stmt body, + Stmt reduce_init, Stmt reduce_update); static constexpr const char* _type_key = "TensorIntrin"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); @@ -144,7 +139,6 @@ class TensorIntrinCallNode : public Object { /*! \brief regions of input tensors */ Array regions; - /*! * \brief IterVar on each reduction axis, if the * intrin will use the reduce axis @@ -161,11 +155,8 @@ class TensorIntrinCallNode : public Object { v->Visit("reduce_axis", &reduce_axis); v->Visit("scalar_inputs", &scalar_inputs); } - static TensorIntrinCall make(TensorIntrin intrin, - Array tensors, - Array regions, - Array reduce_axis, - Array scalar_inputs); + static TensorIntrinCall make(TensorIntrin intrin, Array tensors, Array regions, + Array reduce_axis, Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object); diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index f7a89f5..b0f409c 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -29,6 +29,7 @@ #include #include #include + #include namespace tvm { @@ -75,8 +76,7 @@ TVM_DLL bool HasSideEffect(const PrimExpr& expr); * \param vset_contains The check function to see if var is in the vset. * \return Whether e uses vset. */ -TVM_DLL bool ExprUseVar(const PrimExpr& expr, - std::function vset_contains); +TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function vset_contains); /*! * \brief Whether e expression used var. @@ -85,12 +85,9 @@ TVM_DLL bool ExprUseVar(const PrimExpr& expr, * \return Whether e uses v. */ inline bool ExprUseVar(const PrimExpr& expr, const Var& var) { - return ExprUseVar(expr, [&](const VarNode* node) { - return var.get() == node; - }); + return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; }); } - /*! * \brief Verifies whether the IR stmt or Expr is in SSA form. * That is: each Var is defined and assigned once(in Let/For) @@ -133,8 +130,7 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); * \return valid Whether it is a valid GPU code * */ -TVM_DLL bool VerifyGPUCode(const PrimFunc& func, - Map constraints); +TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); // Pass variants of verification analysis // directly throws RuntimeError when verification fails. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 08a8e69..5d4e860 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -24,11 +24,11 @@ #ifndef TVM_TIR_BUFFER_H_ #define TVM_TIR_BUFFER_H_ -#include #include +#include #include -#include +#include namespace tvm { namespace tir { @@ -76,8 +76,7 @@ class Buffer : public ObjectRef { * \param content_lanes The number of lanes for the (data) type. * \param offset The offset of ptr. */ - TVM_DLL PrimExpr access_ptr(int access_mask, - DataType ptr_type = DataType::Handle(), + TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0)) const; /*! @@ -155,15 +154,10 @@ class BufferNode : public Object { bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { // Use DefEqual as buffer can define variables // in its semantics, skip name as name is not important. - return - equal.DefEqual(data, other->data) && - equal(dtype, other->dtype) && - equal.DefEqual(shape, other->shape) && - equal.DefEqual(strides, other->strides) && - equal.DefEqual(elem_offset, other->elem_offset) && - equal(scope, other->scope) && - equal(data_alignment, other->data_alignment) && - equal(buffer_type, other->buffer_type); + return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && + equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && + equal.DefEqual(elem_offset, other->elem_offset) && equal(scope, other->scope) && + equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); } void SHashReduce(SHashReducer hash_reduce) const { @@ -184,15 +178,9 @@ class BufferNode : public Object { // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. - TVM_DLL static Buffer make(Var ptr, - DataType dtype, - Array shape, - Array strides, - PrimExpr elem_offset, - std::string name, - std::string scope, - int data_alignment, - int offset_factor, + TVM_DLL static Buffer make(Var ptr, DataType dtype, Array shape, + Array strides, PrimExpr elem_offset, std::string name, + std::string scope, int data_alignment, int offset_factor, BufferType buffer_type); static constexpr const char* _type_key = "Buffer"; @@ -213,8 +201,7 @@ inline const BufferNode* Buffer::operator->() const { * \return The created buffer. * \sa BufferNode::make for complete constructor. */ -TVM_DLL Buffer decl_buffer(Array shape, - DataType dtype = DataType::Float(32), +TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), std::string name = "buffer"); } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 4343370..0a20db6 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -25,16 +25,14 @@ #ifndef TVM_TIR_DATA_LAYOUT_H_ #define TVM_TIR_DATA_LAYOUT_H_ - #include #include -#include +#include #include -#include +#include #include -#include - +#include namespace tvm { namespace tir { @@ -63,18 +61,12 @@ class LayoutAxis { } // return the primal axis. If it is already primal, return itself. - const LayoutAxis& ToPrimal() const { - return IsPrimal() ? *this : ToDual(); - } + const LayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); } // return the subordinate axis. If it is already subordinate, return itself. - const LayoutAxis& ToSubordinate() const { - return IsPrimal() ? ToDual() : *this; - } + const LayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; } - inline bool operator==(const LayoutAxis& rhs) const { - return name_ == rhs.name_; - } + inline bool operator==(const LayoutAxis& rhs) const { return name_ == rhs.name_; } friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) { os << l.name(); @@ -136,7 +128,7 @@ class Layout : public ObjectRef { explicit Layout(const Array& axes); /*! \brief construct from a string */ - Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) + Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) /*! * \brief construct from a string. @@ -146,23 +138,19 @@ class Layout : public ObjectRef { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - Layout(const std::string& name); // NOLINT(*) + Layout(const std::string& name); // NOLINT(*) /*! * \brief access the internal node container * \return the pointer to the internal node container */ - const LayoutNode* operator->() const { - return static_cast(get()); - } + const LayoutNode* operator->() const { return static_cast(get()); } /*! * \brief access the internal node container * \return the pointer to the internal node container */ - LayoutNode* operator->() { - return static_cast(get_mutable()); - } + LayoutNode* operator->() { return static_cast(get_mutable()); } /*! * \brief Return an undefined layout. @@ -190,8 +178,7 @@ class Layout : public ObjectRef { * \param factor size of the sub-dimension. * \return A newly constructed Layout object. */ - Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const; - + Layout Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const; /*! \return number of dimensions */ inline size_t ndim() const { @@ -292,9 +279,7 @@ class Layout : public ObjectRef { * \param rhs Another layout. * \return whether the two layouts are equal. */ - inline bool Equals(const Layout &rhs) const { - return name() == rhs.name(); - } + inline bool Equals(const Layout& rhs) const { return name() == rhs.name(); } /*! * \brief allow output string of layout to ostream diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index afa9414..a9f34d2 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -25,20 +25,20 @@ #ifndef TVM_TIR_EXPR_H_ #define TVM_TIR_EXPR_H_ -#include +#include #include #include +#include #include #include -#include -#include #include +#include -#include #include -#include #include #include +#include +#include #include namespace tvm { @@ -62,9 +62,7 @@ class StringImmNode : public PrimExprNode { return equal(value, other->value); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } TVM_DLL PrimExpr static make(std::string value); @@ -110,7 +108,7 @@ class CastNode : public PrimExprNode { * \brief Base template to implement binary ops. * \tparam T The type of the child class. */ -template +template class BinaryOpNode : public PrimExprNode { public: /*! \brief The left operand. */ @@ -125,10 +123,7 @@ class BinaryOpNode : public PrimExprNode { } bool SEqualReduce(const T* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -215,7 +210,7 @@ class MaxNode : public BinaryOpNode { * \brief Base template to implement comparison ops. * \tparam T The type of the child class. */ -template +template class CmpOpNode : public PrimExprNode { public: /*! \brief The left operand. */ @@ -230,10 +225,7 @@ class CmpOpNode : public PrimExprNode { } bool SEqualReduce(const T* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -307,10 +299,7 @@ class AndNode : public PrimExprNode { } bool SEqualReduce(const AndNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -340,10 +329,7 @@ class OrNode : public PrimExprNode { } bool SEqualReduce(const OrNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -408,11 +394,8 @@ class SelectNode : public PrimExprNode { } bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(condition, other->condition) && - equal(true_value, other->true_value) && - equal(false_value, other->false_value); + return equal(dtype, other->dtype) && equal(condition, other->condition) && + equal(true_value, other->true_value) && equal(false_value, other->false_value); } void SHashReduce(SHashReducer hash_reduce) const { @@ -452,10 +435,8 @@ class BufferLoadNode : public PrimExprNode { } bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(buffer, other->buffer) && - equal(indices, other->indices); + return equal(dtype, other->dtype) && equal(buffer, other->buffer) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { @@ -470,8 +451,7 @@ class BufferLoadNode : public PrimExprNode { class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, - Array indices); + TVM_DLL explicit BufferLoad(Buffer buffer, Array indices); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); }; @@ -507,11 +487,8 @@ class LoadNode : public PrimExprNode { } bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(buffer_var, other->buffer_var) && - equal(index, other->index) && - equal(predicate, other->predicate); + return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) && + equal(index, other->index) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { @@ -553,11 +530,8 @@ class RampNode : public PrimExprNode { } bool SEqualReduce(const RampNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(base, other->base) && - equal(stride, other->stride) && - equal(lanes, other->lanes); + return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) && + equal(lanes, other->lanes); } void SHashReduce(SHashReducer hash_reduce) const { @@ -588,10 +562,7 @@ class BroadcastNode : public PrimExprNode { } bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(value, other->value) && - equal(lanes, other->lanes); + return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes); } void SHashReduce(SHashReducer hash_reduce) const { @@ -626,11 +597,8 @@ class LetNode : public PrimExprNode { } bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal.DefEqual(var, other->var) && - equal(value, other->value) && - equal(body, other->body); + return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) && + equal(value, other->value) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -668,8 +636,7 @@ class FunctionBaseNode : public Object { return this == other; } - void SHashReduce(SHashReducer hash_reduce) const { - } + void SHashReduce(SHashReducer hash_reduce) const {} static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -731,13 +698,9 @@ class CallNode : public PrimExprNode { } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(name, other->name) && - equal(args, other->args) && - equal(call_type, other->call_type) && - equal(func, other->func) && - equal(value_index, other->value_index); + return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) && + equal(call_type, other->call_type) && equal(func, other->func) && + equal(value_index, other->value_index); } void SHashReduce(SHashReducer hash_reduce) const { @@ -749,18 +712,13 @@ class CallNode : public PrimExprNode { hash_reduce(value_index); } - TVM_DLL static PrimExpr make(DataType dtype, - std::string name, - Array args, - CallType call_type, - FunctionRef func = FunctionRef(), + TVM_DLL static PrimExpr make(DataType dtype, std::string name, Array args, + CallType call_type, FunctionRef func = FunctionRef(), int value_index = 0); /*! \return Whether call node is pure. */ bool is_pure() const { - return (call_type == PureExtern || - call_type == PureIntrinsic || - call_type == Halide); + return (call_type == PureExtern || call_type == PureIntrinsic || call_type == Halide); } /*! @@ -768,10 +726,7 @@ class CallNode : public PrimExprNode { * \param intrin_name The name of the intrinsic. */ bool is_intrinsic(const char* intrin_name) const { - return - ((call_type == Intrinsic || - call_type == PureIntrinsic) && - name == intrin_name); + return ((call_type == Intrinsic || call_type == PureIntrinsic) && name == intrin_name); } /*! \return Whether call node can be vectorized. */ @@ -818,10 +773,8 @@ class ShuffleNode : public PrimExprNode { } bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(vectors, other->vectors) && - equal(indices, other->indices); + return equal(dtype, other->dtype) && equal(vectors, other->vectors) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { @@ -880,9 +833,7 @@ class CommReducerNode : public Object { /*! \brief Function call operator to combine a and b */ Array operator()(Array a, Array b) const; /*! \brief construct CommReducer from args, result and identity_element */ - TVM_DLL static CommReducer make(Array lhs, - Array rhs, - Array result, + TVM_DLL static CommReducer make(Array lhs, Array rhs, Array result, Array identity_element); void VisitAttrs(AttrVisitor* v) { @@ -893,11 +844,8 @@ class CommReducerNode : public Object { } bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const { - return - equal.DefEqual(lhs, other->lhs) && - equal.DefEqual(rhs, other->rhs) && - equal(result, other->result) && - equal(identity_element, other->identity_element); + return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) && + equal(result, other->result) && equal(identity_element, other->identity_element); } void SHashReduce(SHashReducer hash_reduce) const { @@ -916,9 +864,7 @@ class CommReducerNode : public Object { inline const CommReducerNode* CommReducer::get() const { return static_cast(data_.get()); } -inline const CommReducerNode* CommReducer::operator->() const { - return get(); -} +inline const CommReducerNode* CommReducer::operator->() const { return get(); } /*! \brief Reduction operator operator */ class ReduceNode : public PrimExprNode { @@ -938,11 +884,8 @@ class ReduceNode : public PrimExprNode { int value_index; /*! \brief construct expr from op and rdom */ - TVM_DLL static PrimExpr make(CommReducer combiner, - Array src, - Array rdom, - PrimExpr condition, - int value_index); + TVM_DLL static PrimExpr make(CommReducer combiner, Array src, Array rdom, + PrimExpr condition, int value_index); void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -955,13 +898,9 @@ class ReduceNode : public PrimExprNode { bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const { // check axis first so IterVars can define the necessary variables. - return - equal(dtype, other->dtype) && - equal(axis, other->axis) && - equal(combiner, other->combiner) && - equal(source, other->source) && - equal(condition, other->condition) && - equal(value_index, other->value_index); + return equal(dtype, other->dtype) && equal(axis, other->axis) && + equal(combiner, other->combiner) && equal(source, other->source) && + equal(condition, other->condition) && equal(value_index, other->value_index); } void SHashReduce(SHashReducer hash_reduce) const { @@ -982,17 +921,12 @@ class AnyNode : public PrimExprNode { public: void VisitAttrs(AttrVisitor* v) {} - bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { - return true; - } + bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { return true; } - void SHashReduce(SHashReducer hash_reduce) const { - } + void SHashReduce(SHashReducer hash_reduce) const {} /*! \brief Convert to var. */ - Var ToVar() const { - return Var("any_dim", DataType::Int(32)); - } + Var ToVar() const { return Var("any_dim", DataType::Int(32)); } TVM_DLL static PrimExpr make(); @@ -1000,7 +934,6 @@ class AnyNode : public PrimExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); }; - /* * \brief Template function to convert Map to unordered_map * Sometimes useful for API gluing when internal uses unordered_map @@ -1009,7 +942,7 @@ class AnyNode : public PrimExprNode { * \tparam K the key of the Map. * \tparam V the value of the Map. */ -template +template inline std::unordered_map as_unordered_map(const Map& dmap) { std::unordered_map ret; for (auto kv : dmap) { @@ -1176,7 +1109,7 @@ constexpr const char* tvm_call_packed = "tvm_call_packed"; * return 0; * } */ -constexpr const char *tvm_call_trace_packed = "tvm_call_trace_packed"; +constexpr const char* tvm_call_trace_packed = "tvm_call_trace_packed"; /*! * \brief See pesudo code * Mark the content as thread local context, can get optimized @@ -1223,8 +1156,7 @@ constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered"; * TVMRetValue(value_stack + end, tcode_stack + end)); * } */ -constexpr const char *tvm_call_trace_packed_lowered = - "tvm_call_trace_packed_lowered"; +constexpr const char* tvm_call_trace_packed_lowered = "tvm_call_trace_packed_lowered"; /*! * \brief See pseudo code * @@ -1368,7 +1300,7 @@ enum TVMStructFieldKind : int { kTVMValueContent, kTVMValueKindBound_ }; -} // namespace intrinsic +} // namespace intrinsic } // namespace tir } // namespace tvm @@ -1377,7 +1309,7 @@ namespace tvm { namespace runtime { // Additional implementattion overloads for PackedFunc. -template<> +template <> struct PackedFuncValueConverter { // common rule for RetValue and ArgValue static tvm::Integer From(const TVMPODValue_& val) { @@ -1395,7 +1327,6 @@ struct PackedFuncValueConverter { namespace std { template <> -struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectHash { -}; -} +struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectHash {}; +} // namespace std #endif // TVM_TIR_EXPR_H_ diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index dcf04c3..15ec3d2 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -71,22 +71,19 @@ namespace tir { * \tparam FType function signiture * This type if only defined for FType with function signiture R(const Expr&, Args...) */ -template +template class ExprFunctor; // functions to be overriden. -#define EXPR_FUNCTOR_DEFAULT { \ - return VisitExprDefault_(op, std::forward(args)...); \ - } +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } -#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ +#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); -template +template class ExprFunctor { private: using TSelf = ExprFunctor; @@ -152,7 +149,7 @@ class ExprFunctor { virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExprDefault_(const Object* op, Args ...) { + virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -205,8 +202,7 @@ class ExprFunctor { /*! * \brief ExprVisitor */ -class TVM_DLL ExprVisitor : - public ExprFunctor { +class TVM_DLL ExprVisitor : public ExprFunctor { public: using ExprFunctor::operator(); @@ -251,8 +247,7 @@ class TVM_DLL ExprVisitor : /*! * \brief ExprMutator that mutates expressions. */ -class TVM_DLL ExprMutator : - protected ExprFunctor { +class TVM_DLL ExprMutator : protected ExprFunctor { public: using ExprFunctor::operator(); diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1866f2f..919391e 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -25,11 +25,11 @@ #define TVM_TIR_FUNCTION_H_ #include -#include #include +#include #include -#include +#include namespace tvm { namespace tir { @@ -104,12 +104,9 @@ class PrimFuncNode : public BaseFuncNode { bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { // visit params and buffer_map first as they contains defs. - return - equal.DefEqual(params, other->params) && - equal(buffer_map, other->buffer_map) && - equal(ret_type, other->ret_type) && - equal(body, other->body) && - equal(attrs, other->attrs); + return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && + equal(ret_type, other->ret_type) && equal(body, other->body) && + equal(attrs, other->attrs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -146,9 +143,7 @@ class PrimFunc : public BaseFunc { * \param buffer_map The buffer map for parameter buffer unpacking. * \param attrs Additional function attributes. */ - TVM_DLL PrimFunc(Array params, - Stmt body, - Type ret_type = VoidType(), + TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), Map buffer_map = NullValue>(), DictAttrs attrs = NullValue()); diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 3fbdca5..5884942 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -33,9 +33,8 @@ #include #include -#include #include - +#include namespace tvm { @@ -551,7 +550,7 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x) { \ return tir::CallNode::make(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ - } \ + } TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(exp2); @@ -577,7 +576,6 @@ TVM_DECLARE_INTRIN_UNARY(acosh); TVM_DECLARE_INTRIN_UNARY(asinh); TVM_DECLARE_INTRIN_UNARY(atanh); - namespace tir { /*! * \brief Make a const value with certain data type. @@ -586,8 +584,8 @@ namespace tir { * \return the result expression. * \tparam ValueType The constant value type */ -template::value>::type> +template ::value>::type> inline PrimExpr make_const(DataType t, ValueType value); /*! * \brief Make a const zero expr. @@ -600,17 +598,13 @@ inline PrimExpr make_zero(DataType t); * \param lanes The number of lanes in the bool * \return The result expression. */ -inline PrimExpr const_true(int lanes = 1) { - return make_const(DataType::UInt(1, lanes), 1); -} +inline PrimExpr const_true(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 1); } /*! * \brief Make a constant false expression. * \param lanes The number of lanes in the bool * \return The result expression. */ -inline PrimExpr const_false(int lanes = 1) { - return make_const(DataType::UInt(1, lanes), 0); -} +inline PrimExpr const_false(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 0); } /*! * \brief Get x as constant int expression. * \param x The expression @@ -647,9 +641,7 @@ inline bool is_no_op(const tir::Stmt& stmt); * \note This only return true for integer types. * \return whether x is constant 1 */ -inline bool is_one(const PrimExpr& x) { - return is_const_int(x, 1); -} +inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); } /*! * \brief Check whether x is a constant integer 0 @@ -657,9 +649,7 @@ inline bool is_one(const PrimExpr& x) { * \return whether x is constant 0 * \note This only return true for integer types. */ -inline bool is_zero(const PrimExpr& x) { - return is_const_int(x, 0); -} +inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); } /*! * \brief Check whether x is a constant. @@ -730,7 +720,7 @@ inline bool is_no_op(const tir::Stmt& stmt) { return false; } -template +template inline PrimExpr MakeConstScalar(DataType t, ValueType value) { if (t.is_int()) return IntImm(t, static_cast(value)); if (t.is_uint()) { @@ -757,13 +747,12 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { return PrimExpr(); } -template +template inline PrimExpr make_const(DataType t, ValueType value) { if (t.lanes() == 1) { return MakeConstScalar(t, value); } else { - return tir::BroadcastNode::make( - MakeConstScalar(t.element_of(), value), t.lanes()); + return tir::BroadcastNode::make(MakeConstScalar(t.element_of(), value), t.lanes()); } } @@ -776,44 +765,34 @@ inline PrimExpr make_zero(DataType t) { } // namespace tir // additional const expression overloading -#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ - inline PrimExpr Name(PrimExpr& a, PrimExpr b) {\ - a = OpFunc(a, b); \ - return a; \ +#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ + inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \ + a = OpFunc(a, b); \ + return a; \ } -#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, float b) { \ - return Name(a, PrimExpr(b)); \ - } \ - inline PrimExpr Name(float a, const PrimExpr& b) { \ - return Name(PrimExpr(a), b); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tir::make_const(b.dtype(), a), b); \ - } \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tir::make_const(a.dtype(), b)); \ - } \ - inline PrimExpr Name(const PrimExpr& a, double b) { \ - return Name(a, tir::make_const(DataType::Float(64), b)); \ +#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ + inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ + inline PrimExpr Name(int a, const PrimExpr& b) { \ + return Name(tir::make_const(b.dtype(), a), b); \ + } \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tir::make_const(a.dtype(), b)); \ + } \ + inline PrimExpr Name(const PrimExpr& a, double b) { \ + return Name(a, tir::make_const(DataType::Float(64), b)); \ } -#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, bool b) { \ - return Name(a, PrimExpr(b)); \ - } \ - inline PrimExpr Name(bool a, const PrimExpr& b) { \ - return Name(PrimExpr(a), b); \ - } +#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \ + inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); } -#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tir::make_const(a.dtype(), b)); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tir::make_const(b.dtype(), a), b); \ - } +#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tir::make_const(a.dtype(), b)); \ + } \ + inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); } TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); @@ -835,8 +814,8 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod); -TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) -TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); @@ -849,7 +828,7 @@ TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); * \note The call to this function will always results in a compiler error. * \tparam TA Any class type. */ -template +template inline void DivAmbiguityError(const TA& a) { constexpr bool div_ambiguity = !std::is_class::value; static_assert(div_ambiguity, @@ -865,19 +844,19 @@ inline void DivAmbiguityError(const TA& a) { // to use the specific division function. // The second template argument is necessary to make sure the // code compiles lazily by the compiler during invocation. -template +template inline PrimExpr operator/(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } -template +template inline PrimExpr operator/=(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } -template +template inline PrimExpr operator%(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0d3cf42..115d05c 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -26,10 +26,10 @@ #include -#include #include -#include +#include #include +#include namespace tvm { namespace tir { @@ -69,10 +69,8 @@ class LetStmtNode : public StmtNode { } bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const { - return - equal.DefEqual(var, other->var) && - equal(value, other->value) && - equal(body, other->body); + return equal.DefEqual(var, other->var) && equal(value, other->value) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -116,11 +114,8 @@ class AttrStmtNode : public StmtNode { } bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const { - return - equal(node, other->node) && - equal(attr_key, other->attr_key) && - equal(value, other->value) && - equal(body, other->body); + return equal(node, other->node) && equal(attr_key, other->attr_key) && + equal(value, other->value) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -130,10 +125,7 @@ class AttrStmtNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(ObjectRef node, - std::string type_key, - PrimExpr value, - Stmt body); + TVM_DLL static Stmt make(ObjectRef node, std::string type_key, PrimExpr value, Stmt body); static constexpr const char* _type_key = "AttrStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); @@ -161,10 +153,8 @@ class AssertStmtNode : public StmtNode { } bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const { - return - equal(condition, other->condition) && - equal(message, other->message) && - equal(body, other->body); + return equal(condition, other->condition) && equal(message, other->message) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -216,11 +206,8 @@ class StoreNode : public StmtNode { } bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const { - return - equal(buffer_var, other->buffer_var) && - equal(value, other->value) && - equal(index, other->index) && - equal(predicate, other->predicate); + return equal(buffer_var, other->buffer_var) && equal(value, other->value) && + equal(index, other->index) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { @@ -230,10 +217,7 @@ class StoreNode : public StmtNode { hash_reduce(predicate); } - TVM_DLL static Stmt make(Var buffer_var, - PrimExpr value, - PrimExpr index, - PrimExpr predicate); + TVM_DLL static Stmt make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate); static constexpr const char* _type_key = "Store"; TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); @@ -265,10 +249,8 @@ class BufferStoreNode : public StmtNode { } bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const { - return - equal(buffer, other->buffer) && - equal(value, other->value) && - equal(indices, other->indices); + return equal(buffer, other->buffer) && equal(value, other->value) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { @@ -287,9 +269,7 @@ class BufferStoreNode : public StmtNode { */ class BufferStore : public Stmt { public: - TVM_DLL explicit BufferStore(Buffer buffer, - PrimExpr value, - Array indices); + TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); }; @@ -323,11 +303,8 @@ class BufferRealizeNode : public StmtNode { } bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const { - return - equal(buffer, other->buffer) && - equal(bounds, other->bounds) && - equal(condition, other->condition) && - equal(body, other->body); + return equal(buffer, other->buffer) && equal(bounds, other->bounds) && + equal(condition, other->condition) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -338,12 +315,8 @@ class BufferRealizeNode : public StmtNode { } BufferRealizeNode() = default; - BufferRealizeNode(Buffer buffer, - Array bounds, - PrimExpr condition, - Stmt body) - : buffer(buffer), bounds(bounds), - condition(condition), body(body) {} + BufferRealizeNode(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) + : buffer(buffer), bounds(bounds), condition(condition), body(body) {} static constexpr const char* _type_key = "BufferRealize"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode); @@ -355,10 +328,7 @@ class BufferRealizeNode : public StmtNode { */ class BufferRealize : public Stmt { public: - TVM_DLL explicit BufferRealize(Buffer buffer, - Array bounds, - PrimExpr condition, - Stmt body); + TVM_DLL explicit BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); }; @@ -387,11 +357,8 @@ class ProvideNode : public StmtNode { } bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(value_index, other->value_index) && - equal(value, other->value) && - equal(args, other->args); + return equal(func, other->func) && equal(value_index, other->value_index) && + equal(value, other->value) && equal(args, other->args); } void SHashReduce(SHashReducer hash_reduce) const { @@ -401,10 +368,7 @@ class ProvideNode : public StmtNode { hash_reduce(args); } - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - PrimExpr value, - Array args); + TVM_DLL static Stmt make(FunctionRef func, int value_index, PrimExpr value, Array args); static constexpr const char* _type_key = "Provide"; TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode); @@ -435,12 +399,9 @@ class AllocateNode : public StmtNode { } bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { - return - equal.DefEqual(buffer_var, other->buffer_var) && - equal(dtype, other->dtype) && - equal(extents, other->extents) && - equal(condition, other->condition) && - equal(body, other->body); + return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && + equal(extents, other->extents) && equal(condition, other->condition) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -451,28 +412,22 @@ class AllocateNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(Var buffer_var, - DataType dtype, - Array extents, - PrimExpr condition, - Stmt body); + TVM_DLL static Stmt make(Var buffer_var, DataType dtype, Array extents, + PrimExpr condition, Stmt body); /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. * \return The result. */ - int32_t constant_allocation_size() const { - return constant_allocation_size(extents); - } + int32_t constant_allocation_size() const { return constant_allocation_size(extents); } /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. * \param extents The extents of the buffer. * \return The result. */ - TVM_DLL static int32_t constant_allocation_size( - const Array& extents); + TVM_DLL static int32_t constant_allocation_size(const Array& extents); static constexpr const char* _type_key = "Allocate"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); @@ -484,18 +439,13 @@ class FreeNode : public StmtNode { /*! \brief The buffer variable. */ Var buffer_var; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer_var", &buffer_var); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); } bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const { - return - equal(buffer_var, other->buffer_var); + return equal(buffer_var, other->buffer_var); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer_var); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); } TVM_DLL static Stmt make(Var buffer_var); @@ -533,21 +483,13 @@ class RealizeNode : public StmtNode { v->Visit("body", &body); } - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds, - PrimExpr condition, - Stmt body); + TVM_DLL static Stmt make(FunctionRef func, int value_index, DataType dtype, Region bounds, + PrimExpr condition, Stmt body); bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(value_index, other->value_index) && - equal(dtype, other->dtype) && - equal(bounds, other->bounds) && - equal(condition, other->condition) && - equal(body, other->body); + return equal(func, other->func) && equal(value_index, other->value_index) && + equal(dtype, other->dtype) && equal(bounds, other->bounds) && + equal(condition, other->condition) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -573,27 +515,19 @@ class SeqStmtNode : public StmtNode { Array seq; /*! \return get the size of the sequence */ - size_t size() const { - return seq.size(); - } + size_t size() const { return seq.size(); } /*! * \brief Get the index-th element in the sequence. */ - Stmt operator[](size_t index) const { - return seq[index]; - } + Stmt operator[](size_t index) const { return seq[index]; } - void VisitAttrs(AttrVisitor* v) { - v->Visit("seq", &seq); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("seq", &seq); } bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const { return equal(seq, other->seq); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(seq); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); } static constexpr const char* _type_key = "SeqStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); @@ -609,15 +543,11 @@ class SeqStmt : public Stmt { TVM_DLL explicit SeqStmt(Array seq); /*! \return get the size of the sequence */ - size_t size() const { - return operator->()->size(); - } + size_t size() const { return operator->()->size(); } /*! * \brief Get the index-th element in the sequence. */ - Stmt operator[](size_t index) const { - return (*(operator->()))[index]; - } + Stmt operator[](size_t index) const { return (*(operator->()))[index]; } /*! * \brief Construct a sequence statement by flattening * all the arrays and sequences in the arguments @@ -634,19 +564,17 @@ class SeqStmt : public Stmt { * \tparam Args arguments * \return The constructed statement */ - template + template static Stmt Flatten(Args&&... seq_args) { Array seq; - runtime::detail::for_each( - Flattener(&seq), std::forward(seq_args)...); + runtime::detail::for_each(Flattener(&seq), std::forward(seq_args)...); if (seq.size() == 1) return seq[0]; return SeqStmt(seq); } /*! \brief Helper class to flatten sequence of arguments into Array. */ class Flattener { public: - explicit Flattener(Array* seq) - : seq_(seq) {} + explicit Flattener(Array* seq) : seq_(seq) {} void operator()(size_t i, const Stmt& stmt) const { if (!stmt.defined()) return; @@ -657,7 +585,7 @@ class SeqStmt : public Stmt { } } - template + template void operator()(size_t i, const T& seq) const { for (auto v : seq) { this->operator()(0, v); @@ -690,10 +618,8 @@ class IfThenElseNode : public StmtNode { } bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const { - return - equal(condition, other->condition) && - equal(then_case, other->then_case) && - equal(else_case, other->else_case); + return equal(condition, other->condition) && equal(then_case, other->then_case) && + equal(else_case, other->else_case); } void SHashReduce(SHashReducer hash_reduce) const { @@ -719,17 +645,13 @@ class EvaluateNode : public StmtNode { /*! \brief The expression to be evaluated. */ PrimExpr value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("value", &value); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); } bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const { return equal(value, other->value); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } TVM_DLL static Stmt make(PrimExpr v); @@ -752,9 +674,7 @@ enum class ForType : int { // Kevice api of for loop // kept for backward compatibility // consider refactor and remove later. -enum class DeviceAPI: int { - None = 0 -}; +enum class DeviceAPI : int { None = 0 }; /*! * \brief A for loop, with poissible type annotations. @@ -784,12 +704,8 @@ class ForNode : public StmtNode { /*! \brief The body of the for loop. */ Stmt body; - TVM_DLL static Stmt make(Var loop_var, - PrimExpr min, - PrimExpr extent, - ForType for_type, - DeviceAPI device_api, - Stmt body); + TVM_DLL static Stmt make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, + DeviceAPI device_api, Stmt body); void VisitAttrs(AttrVisitor* v) { v->Visit("loop_var", &loop_var); @@ -801,13 +717,9 @@ class ForNode : public StmtNode { } bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { - return - equal.DefEqual(loop_var, other->loop_var) && - equal(min, other->min) && - equal(extent, other->extent) && - equal(for_type, other->for_type) && - equal(device_api, other->device_api) && - equal(body, other->body); + return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && + equal(extent, other->extent) && equal(for_type, other->for_type) && + equal(device_api, other->device_api) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -819,7 +731,6 @@ class ForNode : public StmtNode { hash_reduce(body); } - static constexpr const char* _type_key = "For"; TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); }; @@ -840,9 +751,7 @@ class PrefetchNode : public StmtNode { } bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { - return - equal(buffer, other->buffer) && - equal(bounds, other->bounds); + return equal(buffer, other->buffer) && equal(bounds, other->bounds); } void SHashReduce(SHashReducer hash_reduce) const { @@ -851,8 +760,7 @@ class PrefetchNode : public StmtNode { } PrefetchNode() = default; - PrefetchNode(Buffer buffer, Array bounds) - : buffer(buffer), bounds(bounds) {} + PrefetchNode(Buffer buffer, Array bounds) : buffer(buffer), bounds(bounds) {} static constexpr const char* _type_key = "Prefetch"; TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); @@ -1024,9 +932,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { * \return Expr a expression with dtype. */ inline PrimExpr TypeAnnotation(DataType dtype) { - return tir::CallNode::make(dtype, - "type_annotation", {}, - tir::CallNode::PureIntrinsic); + return tir::CallNode::make(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic); } // overload printing of for type. diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 0f8038e..052ea92 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -26,14 +26,14 @@ #ifndef TVM_TIR_STMT_FUNCTOR_H_ #define TVM_TIR_STMT_FUNCTOR_H_ -#include #include +#include #include -#include #include +#include -#include #include +#include namespace tvm { namespace tir { @@ -42,22 +42,18 @@ namespace tir { * \tparam FType The function signature. * \sa ExprFunctor */ -template +template class StmtFunctor; -#define STMT_FUNCTOR_DEFAULT { \ - return VisitStmtDefault_(op, std::forward(args)...); \ - } - -#define IR_STMT_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitStmt_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ +#define STMT_FUNCTOR_DEFAULT \ + { return VisitStmtDefault_(op, std::forward(args)...); } +#define IR_STMT_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStmt_(static_cast(n.get()), std::forward(args)...); \ + }); -template +template class StmtFunctor { private: using TSelf = StmtFunctor; @@ -74,9 +70,7 @@ class StmtFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Stmt& n, Args... args) { - return VisitStmt(n, std::forward(args)...); - } + R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward(args)...); } /*! * \brief The functor call. * \param n The stmt node. @@ -103,7 +97,7 @@ class StmtFunctor { virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmtDefault_(const Object* op, Args ...) { + virtual R VisitStmtDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -137,8 +131,7 @@ class StmtFunctor { /*! * \brief StmtVisitor. */ -class TVM_DLL StmtVisitor : - protected StmtFunctor { +class TVM_DLL StmtVisitor : protected StmtFunctor { public: using StmtFunctor::operator(); @@ -173,8 +166,7 @@ class TVM_DLL StmtVisitor : /*! * \brief StmtMutator that mutates the statements. */ -class TVM_DLL StmtMutator : - protected StmtFunctor { +class TVM_DLL StmtMutator : protected StmtFunctor { public: /*! * \brief Mutate stmt. @@ -210,7 +202,7 @@ class TVM_DLL StmtMutator : * * \return The result object pointer. */ - template + template ObjectPtr CopyOnWrite(const TNode* node) { if (allow_copy_on_write_) { // return the old node. @@ -244,9 +236,7 @@ class TVM_DLL StmtMutator : * or have a class sub-class both StmtMutator and ExprMutator * and redirect Mutate to ExprMutator::Mutate(Expr) */ - virtual PrimExpr VisitExpr(const PrimExpr& e) { - return e; - } + virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; } // statement visitor Stmt VisitStmt_(const AttrStmtNode* op) override; Stmt VisitStmt_(const IfThenElseNode* op) override; @@ -275,8 +265,7 @@ class TVM_DLL StmtMutator : * \param fmutate The mutate function, can be nullptr, which defaults to Visit. * \return The mutated result. */ - Stmt VisitSeqStmt_(const SeqStmtNode* op, - bool flatten_before_visit, + Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, std::function fmutate = nullptr); // internal helper. class Internal; @@ -285,39 +274,31 @@ class TVM_DLL StmtMutator : /*! * \brief Visitor that recursively visit stmts and exprs on them. */ -class StmtExprVisitor : - public StmtVisitor, - public ExprVisitor { +class StmtExprVisitor : public StmtVisitor, public ExprVisitor { public: using StmtVisitor::operator(); using ExprVisitor::operator(); protected: - using StmtVisitor::VisitStmt; using ExprVisitor::VisitExpr; + using StmtVisitor::VisitStmt; - void VisitExpr(const PrimExpr& e) override { - return ExprVisitor::VisitExpr(e); - } + void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); } }; /*! * \brief Mutator that recursively mutates stmts and exprs on them. */ -class StmtExprMutator : - public StmtMutator, - public ExprMutator { +class StmtExprMutator : public StmtMutator, public ExprMutator { public: using StmtMutator::operator(); using ExprMutator::operator(); protected: - using StmtMutator::VisitExpr; using ExprMutator::VisitExpr; + using StmtMutator::VisitExpr; - PrimExpr VisitExpr(const PrimExpr& e) override { - return ExprMutator::VisitExpr(e); - } + PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); } }; /*! @@ -335,8 +316,7 @@ class StmtExprMutator : * If it is not null, preorder/postorder will only be called * when the IRNode's type key is in the list. */ -TVM_DLL Stmt IRTransform(Stmt stmt, - const runtime::PackedFunc& preorder, +TVM_DLL Stmt IRTransform(Stmt stmt, const runtime::PackedFunc& preorder, const runtime::PackedFunc& postorder, Optional> only_enable = NullOpt); @@ -354,8 +334,7 @@ TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function(const Var& var)> vmap); +TVM_DLL Stmt Substitute(Stmt stmt, std::function(const Var& var)> vmap); /*! * \brief Substitute the var specified by vmap. @@ -363,8 +342,7 @@ TVM_DLL Stmt Substitute(Stmt stmt, * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. * \return The result. */ -TVM_DLL PrimExpr Substitute(PrimExpr expr, - std::function(const Var& var)> vmap); +TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function(const Var& var)> vmap); /*! * \brief Sugar for substitute via a given map. @@ -373,7 +351,7 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, * \return The result. * \tparam T the input type, can be PrimExpr or Stmt. */ -template +template inline T Substitute(T input, const Map& value_map) { auto vmap = [&](const Var& var) -> Optional { auto it = value_map.find(var); @@ -390,9 +368,8 @@ inline T Substitute(T input, const Map& value_map) { * \return The result. * \tparam T the input type, can be PrimExpr or Stmt. */ -template -inline T Substitute(T input, - const std::unordered_map& value_map) { +template +inline T Substitute(T input, const std::unordered_map& value_map) { auto vmap = [&](const Var& var) -> Optional { auto it = value_map.find(var.get()); if (it != value_map.end()) return (*it).second; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index abf8b1c..13e1e25 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -35,11 +35,11 @@ namespace tir { namespace transform { using tvm::transform::Pass; -using tvm::transform::PassNode; -using tvm::transform::PassInfo; -using tvm::transform::PassInfoNode; using tvm::transform::PassContext; using tvm::transform::PassContextNode; +using tvm::transform::PassInfo; +using tvm::transform::PassInfoNode; +using tvm::transform::PassNode; using tvm::transform::Sequential; /* @@ -52,12 +52,9 @@ using tvm::transform::Sequential; * * \return The created function pass. */ -TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< - PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); - +TVM_DLL Pass CreatePrimFuncPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, const std::string& name, const tvm::Array& required); /*! * \brief Inject prefetch instructions into stmt. @@ -76,8 +73,7 @@ TVM_DLL Pass InjectPrefetch(); * * \return The Pass */ -TVM_DLL Pass StorageFlatten(int cache_line_size, - bool create_bound_attribute = false); +TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false); /*! * \brief Inject copy intrinsics with optional pad. @@ -92,8 +88,7 @@ TVM_DLL Pass StorageFlatten(int cache_line_size, * Expr pad_value) * \return The pass. */ -TVM_DLL Pass InjectCopyIntrin(std::string pragma_key, - runtime::PackedFunc fintrin); +TVM_DLL Pass InjectCopyIntrin(std::string pragma_key, runtime::PackedFunc fintrin); /*! * \brief Detect and insert sync points to co-processor. @@ -164,9 +159,7 @@ TVM_DLL Pass StorageRewrite(); * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen. * \return The pass. */ -TVM_DLL Pass UnrollLoop(int auto_max_step, - int auto_max_depth, - int auto_max_extent, +TVM_DLL Pass UnrollLoop(int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll); /*! @@ -184,17 +177,17 @@ TVM_DLL Pass RemoveNoOp(); TVM_DLL Pass RewriteUnsafeSelect(); /*! -* \brief Run arithmetic simplifications on the statements and expressions. -* -* \return The pass. -*/ + * \brief Run arithmetic simplifications on the statements and expressions. + * + * \return The pass. + */ TVM_DLL Pass Simplify(); /*! -* \brief Instruments bound checkers. -* -* \return The pass. -*/ + * \brief Instruments bound checkers. + * + * \return The pass. + */ TVM_DLL Pass InstrumentBoundCheckers(); /*! @@ -278,7 +271,6 @@ TVM_DLL Pass SkipAssert(); */ TVM_DLL Pass ThreadSync(std::string storage_scope); - /*! * \brief Lower cross thread alleduce. * @@ -328,7 +320,6 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo(); */ TVM_DLL Pass CombineContextCall(); - /*! * \brief Narrow down PrimExpr datatype in stmt to target_bits. * diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index bb73bf0..a89c665 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -24,9 +24,10 @@ #ifndef TVM_TIR_VAR_H_ #define TVM_TIR_VAR_H_ +#include #include #include -#include + #include namespace tvm { @@ -91,8 +92,7 @@ class Var : public PrimExpr { * \param name_hint variable name * \param dtype data type */ - TVM_DLL explicit Var(std::string name_hint = "v", - DataType dtype = DataType::Int(32)); + TVM_DLL explicit Var(std::string name_hint = "v", DataType dtype = DataType::Int(32)); /*! * \brief Constructor which provides a more detailed type annotation. * \param name_hint variable name. @@ -109,16 +109,12 @@ class Var : public PrimExpr { * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const VarNode* operator->() const { - return get(); - } + const VarNode* operator->() const { return get(); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const VarNode* get() const { - return static_cast(data_.get()); - } + const VarNode* get() const { return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = VarNode; }; @@ -142,27 +138,21 @@ class SizeVar : public Var { * \param name_hint variable name * \param t data type */ - TVM_DLL explicit SizeVar(std::string name_hint = "s", - DataType t = DataType::Int(32)); + TVM_DLL explicit SizeVar(std::string name_hint = "s", DataType t = DataType::Int(32)); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const SizeVarNode* operator->() const { - return get(); - } + const SizeVarNode* operator->() const { return get(); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const SizeVarNode* get() const { - return static_cast(data_.get()); - } + const SizeVarNode* get() const { return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = SizeVarNode; }; - /*! \brief container class of iteration variable. */ class IterVarNode; @@ -292,11 +282,8 @@ class IterVarNode : public Object { } bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const { - return - equal(dom, other->dom) && - equal.DefEqual(var, other->var) && - equal(iter_type, other->iter_type) && - equal(thread_tag, other->thread_tag); + return equal(dom, other->dom) && equal.DefEqual(var, other->var) && + equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag); } void SHashReduce(SHashReducer hash_reduce) const { @@ -306,8 +293,7 @@ class IterVarNode : public Object { hash_reduce(thread_tag); } - TVM_DLL static IterVar make(Range dom, Var var, - IterVarType iter_type, + TVM_DLL static IterVar make(Range dom, Var var, IterVarType iter_type, std::string thread_tag = ""); static constexpr const char* _type_key = "IterVar"; @@ -321,21 +307,28 @@ inline const IterVarNode* IterVar::operator->() const { return static_cast(data_.get()); } -inline IterVar::operator PrimExpr() const { - return (*this)->var; -} +inline IterVar::operator PrimExpr() const { return (*this)->var; } inline const char* IterVarType2String(IterVarType t) { switch (t) { - case kDataPar: return "DataPar"; - case kThreadIndex: return "ThreadIndex"; - case kCommReduce: return "CommReduce"; - case kOrdered: return "Ordered"; - case kOpaque: return "Opaque"; - case kUnrolled: return "Unrolled"; - case kVectorized: return "Vectorized"; - case kParallelized: return "Parallelized"; - case kTensorized: return "Tensorized"; + case kDataPar: + return "DataPar"; + case kThreadIndex: + return "ThreadIndex"; + case kCommReduce: + return "CommReduce"; + case kOrdered: + return "Ordered"; + case kOpaque: + return "Opaque"; + case kUnrolled: + return "Unrolled"; + case kVectorized: + return "Vectorized"; + case kParallelized: + return "Parallelized"; + case kTensorized: + return "Tensorized"; } return "Unknown"; } diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index ce1979c..0f20200 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -26,7 +26,7 @@ #define TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_ // Helper functions for RefXXX getter & setter -jlong getLongField(JNIEnv *env, jobject obj) { +jlong getLongField(JNIEnv* env, jobject obj) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); jlong ret = env->GetLongField(obj, refFid); @@ -34,7 +34,7 @@ jlong getLongField(JNIEnv *env, jobject obj) { return ret; } -jint getIntField(JNIEnv *env, jobject obj) { +jint getIntField(JNIEnv* env, jobject obj) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt"); jfieldID refFid = env->GetFieldID(refClass, "value", "I"); jint ret = env->GetIntField(obj, refFid); @@ -42,21 +42,21 @@ jint getIntField(JNIEnv *env, jobject obj) { return ret; } -void setIntField(JNIEnv *env, jobject obj, jint value) { +void setIntField(JNIEnv* env, jobject obj, jint value) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt"); jfieldID refFid = env->GetFieldID(refClass, "value", "I"); env->SetIntField(obj, refFid, value); env->DeleteLocalRef(refClass); } -void setLongField(JNIEnv *env, jobject obj, jlong value) { +void setLongField(JNIEnv* env, jobject obj, jlong value) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); env->SetLongField(obj, refFid, value); env->DeleteLocalRef(refClass); } -void setStringField(JNIEnv *env, jobject obj, const char *value) { +void setStringField(JNIEnv* env, jobject obj, const char* value) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefString"); jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;"); env->SetObjectField(obj, refFid, env->NewStringUTF(value)); @@ -64,8 +64,8 @@ void setStringField(JNIEnv *env, jobject obj, const char *value) { } // Helper functions for TVMValue -jlong getTVMValueLongField(JNIEnv *env, jobject obj, - const char *clsname = "org/apache/tvm/TVMValueLong") { +jlong getTVMValueLongField(JNIEnv* env, jobject obj, + const char* clsname = "org/apache/tvm/TVMValueLong") { jclass cls = env->FindClass(clsname); jfieldID fid = env->GetFieldID(cls, "value", "J"); jlong ret = env->GetLongField(obj, fid); @@ -73,7 +73,7 @@ jlong getTVMValueLongField(JNIEnv *env, jobject obj, return ret; } -jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) { +jdouble getTVMValueDoubleField(JNIEnv* env, jobject obj) { jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble"); jfieldID fid = env->GetFieldID(cls, "value", "D"); jdouble ret = env->GetDoubleField(obj, fid); @@ -81,7 +81,7 @@ jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) { return ret; } -jstring getTVMValueStringField(JNIEnv *env, jobject obj) { +jstring getTVMValueStringField(JNIEnv* env, jobject obj) { jclass cls = env->FindClass("org/apache/tvm/TVMValueString"); jfieldID fid = env->GetFieldID(cls, "value", "Ljava/lang/String;"); jstring ret = static_cast(env->GetObjectField(obj, fid)); @@ -89,7 +89,7 @@ jstring getTVMValueStringField(JNIEnv *env, jobject obj) { return ret; } -jobject newTVMValueHandle(JNIEnv *env, jlong value) { +jobject newTVMValueHandle(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/TVMValueHandle"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -97,7 +97,7 @@ jobject newTVMValueHandle(JNIEnv *env, jlong value) { return object; } -jobject newTVMValueLong(JNIEnv *env, jlong value) { +jobject newTVMValueLong(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/TVMValueLong"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -105,7 +105,7 @@ jobject newTVMValueLong(JNIEnv *env, jlong value) { return object; } -jobject newTVMValueDouble(JNIEnv *env, jdouble value) { +jobject newTVMValueDouble(JNIEnv* env, jdouble value) { jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble"); jmethodID constructor = env->GetMethodID(cls, "", "(D)V"); jobject object = env->NewObject(cls, constructor, value); @@ -113,7 +113,7 @@ jobject newTVMValueDouble(JNIEnv *env, jdouble value) { return object; } -jobject newTVMValueString(JNIEnv *env, const char *value) { +jobject newTVMValueString(JNIEnv* env, const char* value) { jstring jvalue = env->NewStringUTF(value); jclass cls = env->FindClass("org/apache/tvm/TVMValueString"); jmethodID constructor = env->GetMethodID(cls, "", "(Ljava/lang/String;)V"); @@ -123,10 +123,10 @@ jobject newTVMValueString(JNIEnv *env, const char *value) { return object; } -jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) { +jobject newTVMValueBytes(JNIEnv* env, const TVMByteArray* arr) { jbyteArray jarr = env->NewByteArray(arr->size); env->SetByteArrayRegion(jarr, 0, arr->size, - reinterpret_cast(const_cast(arr->data))); + reinterpret_cast(const_cast(arr->data))); jclass cls = env->FindClass("org/apache/tvm/TVMValueBytes"); jmethodID constructor = env->GetMethodID(cls, "", "([B)V"); jobject object = env->NewObject(cls, constructor, jarr); @@ -135,7 +135,7 @@ jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) { return object; } -jobject newModule(JNIEnv *env, jlong value) { +jobject newModule(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/Module"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -143,7 +143,7 @@ jobject newModule(JNIEnv *env, jlong value) { return object; } -jobject newFunction(JNIEnv *env, jlong value) { +jobject newFunction(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/Function"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -151,7 +151,7 @@ jobject newFunction(JNIEnv *env, jlong value) { return object; } -jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) { +jobject newNDArray(JNIEnv* env, jlong handle, jboolean isview) { jclass cls = env->FindClass("org/apache/tvm/NDArrayBase"); jmethodID constructor = env->GetMethodID(cls, "", "(JZ)V"); jobject object = env->NewObject(cls, constructor, handle, isview); @@ -159,7 +159,7 @@ jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) { return object; } -jobject newObject(JNIEnv *env, const char *clsname) { +jobject newObject(JNIEnv* env, const char* clsname) { jclass cls = env->FindClass(clsname); jmethodID constructor = env->GetMethodID(cls, "", "()V"); jobject object = env->NewObject(cls, constructor); @@ -167,7 +167,7 @@ jobject newObject(JNIEnv *env, const char *clsname) { return object; } -void fromJavaDType(JNIEnv *env, jobject jdtype, DLDataType *dtype) { +void fromJavaDType(JNIEnv* env, jobject jdtype, DLDataType* dtype) { jclass tvmTypeClass = env->FindClass("org/apache/tvm/DLDataType"); dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I"))); dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I"))); @@ -175,16 +175,16 @@ void fromJavaDType(JNIEnv *env, jobject jdtype, DLDataType *dtype) { env->DeleteLocalRef(tvmTypeClass); } -void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) { +void fromJavaContext(JNIEnv* env, jobject jctx, TVMContext* ctx) { jclass tvmContextClass = env->FindClass("org/apache/tvm/TVMContext"); - ctx->device_type = static_cast(env->GetIntField(jctx, - env->GetFieldID(tvmContextClass, "deviceType", "I"))); - ctx->device_id = static_cast(env->GetIntField(jctx, - env->GetFieldID(tvmContextClass, "deviceId", "I"))); + ctx->device_type = static_cast( + env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceType", "I"))); + ctx->device_id = + static_cast(env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceId", "I"))); env->DeleteLocalRef(tvmContextClass); } -jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { +jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) { switch (tcode) { case kDLUInt: case kDLInt: @@ -204,7 +204,7 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { case kTVMStr: return newTVMValueString(env, value.v_str); case kTVMBytes: - return newTVMValueBytes(env, reinterpret_cast(value.v_handle)); + return newTVMValueBytes(env, reinterpret_cast(value.v_handle)); case kTVMNullptr: return newObject(env, "org/apache/tvm/TVMValueNull"); default: diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index b599568..6fc316c 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -29,28 +29,28 @@ #include #include #endif -#include #include -#include +#include #include +#include #include "jni_helper_func.h" -JavaVM *_jvm; -void *_tvmHandle = nullptr; +JavaVM* _jvm; +void* _tvmHandle = nullptr; struct TVMFuncArgsThreadLocalEntry { std::vector tvmFuncArgValues; std::vector tvmFuncArgTypes; // for later release - std::vector > tvmFuncArgPushedStrs; - std::vector > tvmFuncArgPushedBytes; + std::vector > tvmFuncArgPushedStrs; + std::vector > tvmFuncArgPushedBytes; }; typedef dmlc::ThreadLocalStore TVMFuncArgsThreadLocalStore; -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit - (JNIEnv *env, jobject obj, jstring jtvmLibFile) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit(JNIEnv* env, jobject obj, + jstring jtvmLibFile) { if (_tvmHandle == NULL && !env->IsSameObject(jtvmLibFile, NULL)) { - const char *tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0); + const char* tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0); _tvmHandle = dlopen(tvmLibFile, RTLD_LAZY | RTLD_GLOBAL); env->ReleaseStringUTFChars(jtvmLibFile, tvmLibFile); if (!_tvmHandle) { @@ -61,70 +61,70 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit return env->GetJavaVM(&_jvm); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv *env, jobject obj) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv* env, jobject obj) { if (_tvmHandle) { dlclose(_tvmHandle); } return 0; } -JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv * env, jobject obj) { +JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv* env, jobject obj) { return env->NewStringUTF(TVMGetLastError()); } // Function -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong( - JNIEnv *env, jobject obj, jlong arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong(JNIEnv* env, jobject obj, + jlong arg) { TVMValue value; value.v_int64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kDLInt); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble( - JNIEnv *env, jobject obj, jdouble arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble(JNIEnv* env, jobject obj, + jdouble arg) { TVMValue value; value.v_float64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kDLFloat); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString( - JNIEnv *env, jobject obj, jstring arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString(JNIEnv* env, jobject obj, + jstring arg) { TVMValue value; jstring garg = reinterpret_cast(env->NewGlobalRef(arg)); value.v_str = env->GetStringUTFChars(garg, 0); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kTVMStr); // release string args later e->tvmFuncArgPushedStrs.push_back(std::make_pair(garg, value.v_str)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle( - JNIEnv *env, jobject obj, jlong arg, jint argType) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv* env, jobject obj, + jlong arg, jint argType) { TVMValue value; - value.v_handle = reinterpret_cast(arg); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + value.v_handle = reinterpret_cast(arg); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(static_cast(argType)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes( - JNIEnv *env, jobject obj, jbyteArray arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj, + jbyteArray arg) { jbyteArray garg = reinterpret_cast(env->NewGlobalRef(arg)); - jbyte *data = env->GetByteArrayElements(garg, 0); + jbyte* data = env->GetByteArrayElements(garg, 0); - TVMByteArray *byteArray = new TVMByteArray(); + TVMByteArray* byteArray = new TVMByteArray(); byteArray->size = static_cast(env->GetArrayLength(garg)); - byteArray->data = reinterpret_cast(data); + byteArray->data = reinterpret_cast(data); TVMValue value; - value.v_handle = reinterpret_cast(byteArray); + value.v_handle = reinterpret_cast(byteArray); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kTVMBytes); @@ -132,10 +132,10 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes( // release (garg, data), byteArray later } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames( - JNIEnv *env, jobject obj, jobject jfuncNames) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames(JNIEnv* env, jobject obj, + jobject jfuncNames) { int outSize; - const char **outArray; + const char** outArray; int ret = TVMFuncListGlobalNames(&outSize, &outArray); if (ret) { @@ -157,24 +157,25 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames( return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree( - JNIEnv *env, jobject obj, jlong jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree(JNIEnv* env, jobject obj, + jlong jhandle) { return TVMFuncFree(reinterpret_cast(jhandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal( - JNIEnv *env, jobject obj, jstring jname, jobject jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal(JNIEnv* env, jobject obj, + jstring jname, + jobject jhandle) { TVMFunctionHandle handle; - const char *name = env->GetStringUTFChars(jname, 0); + const char* name = env->GetStringUTFChars(jname, 0); int ret = TVMFuncGetGlobal(name, &handle); env->ReleaseStringUTFChars(jname, name); setLongField(env, jhandle, reinterpret_cast(handle)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( - JNIEnv *env, jobject obj, jlong jhandle, jobject jretVal) { - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall(JNIEnv* env, jobject obj, + jlong jhandle, jobject jretVal) { + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); int numArgs = e->tvmFuncArgValues.size(); TVMValue retVal; @@ -192,8 +193,8 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( e->tvmFuncArgTypes.clear(); e->tvmFuncArgValues.clear(); - int ret = TVMFuncCall(reinterpret_cast(jhandle), - &argValues[0], &argTypes[0], numArgs, &retVal, &retTypeCode); + int ret = TVMFuncCall(reinterpret_cast(jhandle), &argValues[0], &argTypes[0], + numArgs, &retVal, &retTypeCode); if (ret != 0) { return ret; @@ -204,16 +205,15 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( env->DeleteGlobalRef(iter->first); } for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) { - env->ReleaseByteArrayElements(iter->first, - reinterpret_cast(const_cast(iter->second->data)), 0); + env->ReleaseByteArrayElements( + iter->first, reinterpret_cast(const_cast(iter->second->data)), 0); env->DeleteGlobalRef(iter->first); delete iter->second; } // return TVMValue object to Java jclass refTVMValueCls = env->FindClass("org/apache/tvm/Base$RefTVMValue"); - jfieldID refTVMValueFid - = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;"); + jfieldID refTVMValueFid = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;"); env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, retVal, retTypeCode)); @@ -223,16 +223,16 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( } // Callback function -extern "C" int funcInvokeCallback(TVMValue *args, - int *typeCodes, int numArgs, TVMRetValueHandle ret, void *resourceHandle) { - JNIEnv *env; - int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); +extern "C" int funcInvokeCallback(TVMValue* args, int* typeCodes, int numArgs, + TVMRetValueHandle ret, void* resourceHandle) { + JNIEnv* env; + int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (jniStatus == JNI_EDETACHED) { - #ifdef TVM4J_ANDROID +#ifdef TVM4J_ANDROID _jvm->AttachCurrentThread(&env, nullptr); - #else - _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - #endif +#else + _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); +#endif } else { CHECK(jniStatus == JNI_OK); } @@ -242,10 +242,8 @@ extern "C" int funcInvokeCallback(TVMValue *args, for (int i = 0; i < numArgs; ++i) { TVMValue arg = args[i]; int tcode = typeCodes[i]; - if (tcode == kTVMObjectHandle || - tcode == kTVMPackedFuncHandle || - tcode == kTVMObjectRValueRefArg || - tcode == kTVMModuleHandle) { + if (tcode == kTVMObjectHandle || tcode == kTVMPackedFuncHandle || + tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle) { TVMCbArgToReturn(&arg, &tcode); } jobject jarg = tvmRetValueToJava(env, arg, tcode); @@ -253,15 +251,16 @@ extern "C" int funcInvokeCallback(TVMValue *args, } jclass clsFunc = env->FindClass("org/apache/tvm/Function"); - jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID(clsFunc, "invokeRegisteredCbFunc", + jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID( + clsFunc, "invokeRegisteredCbFunc", "(Lorg/apache/tvm/Function$Callback;[Lorg/apache/tvm/TVMValue;)Ljava/lang/Object;"); - jmethodID pushArgToStack = env->GetStaticMethodID(clsFunc, "pushArgToStack", - "(Ljava/lang/Object;)V"); + jmethodID pushArgToStack = + env->GetStaticMethodID(clsFunc, "pushArgToStack", "(Ljava/lang/Object;)V"); jobject jretValue = env->CallStaticObjectMethod(clsFunc, invokeRegisteredCbFunc, - reinterpret_cast(resourceHandle), jargs); + reinterpret_cast(resourceHandle), jargs); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); const size_t prevNumStrArg = e->tvmFuncArgPushedStrs.size(); const size_t prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); @@ -279,16 +278,16 @@ extern "C" int funcInvokeCallback(TVMValue *args, // release allocated strings. if (e->tvmFuncArgPushedStrs.size() > prevNumStrArg) { - const auto &pairArg = e->tvmFuncArgPushedStrs.back(); + const auto& pairArg = e->tvmFuncArgPushedStrs.back(); env->ReleaseStringUTFChars(pairArg.first, pairArg.second); env->DeleteGlobalRef(pairArg.first); e->tvmFuncArgPushedStrs.pop_back(); } // release allocated bytes. if (e->tvmFuncArgPushedBytes.size() > prevNumBytesArg) { - const auto &pairArg = e->tvmFuncArgPushedBytes.back(); - env->ReleaseByteArrayElements(pairArg.first, - reinterpret_cast(const_cast(pairArg.second->data)), 0); + const auto& pairArg = e->tvmFuncArgPushedBytes.back(); + env->ReleaseByteArrayElements( + pairArg.first, reinterpret_cast(const_cast(pairArg.second->data)), 0); env->DeleteGlobalRef(pairArg.first); delete pairArg.second; e->tvmFuncArgPushedBytes.pop_back(); @@ -301,62 +300,64 @@ extern "C" int funcInvokeCallback(TVMValue *args, } // Free callback function -extern "C" void funcFreeCallback(void *resourceHandle) { - JNIEnv *env; - int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); +extern "C" void funcFreeCallback(void* resourceHandle) { + JNIEnv* env; + int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (jniStatus == JNI_EDETACHED) { - #ifdef TVM4J_ANDROID +#ifdef TVM4J_ANDROID _jvm->AttachCurrentThread(&env, nullptr); - #else - _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - #endif +#else + _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); +#endif } else { CHECK(jniStatus == JNI_OK); } env->DeleteGlobalRef(reinterpret_cast(resourceHandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc( - JNIEnv *env, jobject obj, jobject jfunction, jobject jretHandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc(JNIEnv* env, jobject obj, + jobject jfunction, + jobject jretHandle) { TVMFunctionHandle out; - int ret = TVMFuncCreateFromCFunc(reinterpret_cast(&funcInvokeCallback), - reinterpret_cast(env->NewGlobalRef(jfunction)), - reinterpret_cast(&funcFreeCallback), - &out); + int ret = + TVMFuncCreateFromCFunc(reinterpret_cast(&funcInvokeCallback), + reinterpret_cast(env->NewGlobalRef(jfunction)), + reinterpret_cast(&funcFreeCallback), &out); setLongField(env, jretHandle, reinterpret_cast(out)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal( - JNIEnv *env, jobject obj, jstring jname, jlong jhandle, jint joverride) { - const char *name = env->GetStringUTFChars(jname, 0); - int ret = TVMFuncRegisterGlobal( - name, reinterpret_cast(jhandle), reinterpret_cast(joverride)); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal(JNIEnv* env, jobject obj, + jstring jname, + jlong jhandle, + jint joverride) { + const char* name = env->GetStringUTFChars(jname, 0); + int ret = TVMFuncRegisterGlobal(name, reinterpret_cast(jhandle), + reinterpret_cast(joverride)); env->ReleaseStringUTFChars(jname, name); return ret; } // Module -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree( - JNIEnv *env, jobject obj, jlong jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree(JNIEnv* env, jobject obj, + jlong jhandle) { return TVMModFree(reinterpret_cast(jhandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport( - JNIEnv *env, jobject obj, jlong jmod, jlong jdep) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport(JNIEnv* env, jobject obj, + jlong jmod, jlong jdep) { return TVMModImport(reinterpret_cast(jmod), reinterpret_cast(jdep)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction( - JNIEnv *env, jobject obj, jlong jhandle, jstring jname, jint jimport, jobject jret) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction(JNIEnv* env, jobject obj, + jlong jhandle, jstring jname, + jint jimport, jobject jret) { TVMFunctionHandle retFunc; - const char *name = env->GetStringUTFChars(jname, 0); - int ret = TVMModGetFunction(reinterpret_cast(jhandle), - name, - reinterpret_cast(jimport), - &retFunc); + const char* name = env->GetStringUTFChars(jname, 0); + int ret = TVMModGetFunction(reinterpret_cast(jhandle), name, + reinterpret_cast(jimport), &retFunc); env->ReleaseStringUTFChars(jname, name); setLongField(env, jret, reinterpret_cast(retFunc)); @@ -365,28 +366,25 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction( } // NDArray -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree( - JNIEnv *env, jobject obj, jlong jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree(JNIEnv* env, jobject obj, + jlong jhandle) { return TVMArrayFree(reinterpret_cast(jhandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc( - JNIEnv *env, jobject obj, jlongArray jshape, jint jdtypeCode, - jint jdtypeBits, jint jdtypeLanes, jint jdeviceType, jint jdeviceId, jobject jret) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc(JNIEnv* env, jobject obj, + jlongArray jshape, jint jdtypeCode, + jint jdtypeBits, jint jdtypeLanes, + jint jdeviceType, jint jdeviceId, + jobject jret) { int ndim = static_cast(env->GetArrayLength(jshape)); TVMArrayHandle out; - jlong *shapeArray = env->GetLongArrayElements(jshape, NULL); - int ret = TVMArrayAlloc( - reinterpret_cast(shapeArray), - ndim, - static_cast(jdtypeCode), - static_cast(jdtypeBits), - static_cast(jdtypeLanes), - static_cast(jdeviceType), - static_cast(jdeviceId), - &out); + jlong* shapeArray = env->GetLongArrayElements(jshape, NULL); + int ret = TVMArrayAlloc(reinterpret_cast(shapeArray), ndim, + static_cast(jdtypeCode), static_cast(jdtypeBits), + static_cast(jdtypeLanes), static_cast(jdeviceType), + static_cast(jdeviceId), &out); env->ReleaseLongArrayElements(jshape, shapeArray, 0); setLongField(env, jret, reinterpret_cast(out)); @@ -394,10 +392,10 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc( return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape( - JNIEnv *env, jobject obj, jlong jhandle, jobject jshape) { - DLTensor *array = reinterpret_cast(jhandle); - int64_t *shape = array->shape; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(JNIEnv* env, jobject obj, + jlong jhandle, jobject jshape) { + DLTensor* array = reinterpret_cast(jhandle); + int64_t* shape = array->shape; int ndim = array->ndim; // fill shape buffer @@ -417,18 +415,19 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape( return 0; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo( - JNIEnv *env, jobject obj, jlong jfrom, jlong jto) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo(JNIEnv* env, jobject obj, + jlong jfrom, jlong jto) { return TVMArrayCopyFromTo(reinterpret_cast(jfrom), reinterpret_cast(jto), NULL); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray( - JNIEnv *env, jobject obj, jbyteArray jarr, jlong jfrom, jlong jto) { - jbyte *data = env->GetByteArrayElements(jarr, NULL); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray(JNIEnv* env, jobject obj, + jbyteArray jarr, + jlong jfrom, jlong jto) { + jbyte* data = env->GetByteArrayElements(jarr, NULL); - DLTensor *from = reinterpret_cast(jfrom); - from->data = static_cast(data); + DLTensor* from = reinterpret_cast(jfrom); + from->data = static_cast(data); int ret = TVMArrayCopyFromTo(static_cast(from), reinterpret_cast(jto), NULL); @@ -439,13 +438,14 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray( return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray( - JNIEnv *env, jobject obj, jlong jfrom, jbyteArray jarr) { - DLTensor *from = reinterpret_cast(jfrom); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray(JNIEnv* env, jobject obj, + jlong jfrom, + jbyteArray jarr) { + DLTensor* from = reinterpret_cast(jfrom); int size = static_cast(env->GetArrayLength(jarr)); - jbyte *pdata = env->GetByteArrayElements(jarr, NULL); + jbyte* pdata = env->GetByteArrayElements(jarr, NULL); int ret = 0; - if (memcpy(static_cast(pdata), from->data, size) == NULL) { + if (memcpy(static_cast(pdata), from->data, size) == NULL) { ret = 1; } env->ReleaseByteArrayElements(jarr, pdata, 0); // copy back to java array automatically @@ -453,7 +453,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray( } // Context -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize( - JNIEnv *env, jint deviceType, jint deviceId) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, jint deviceType, + jint deviceId) { return TVMSynchronize(static_cast(deviceType), static_cast(deviceId), NULL); } diff --git a/nnvm/include/nnvm/base.h b/nnvm/include/nnvm/base.h index 678ed4d..b8c5c6c 100644 --- a/nnvm/include/nnvm/base.h +++ b/nnvm/include/nnvm/base.h @@ -24,13 +24,13 @@ #ifndef NNVM_BASE_H_ #define NNVM_BASE_H_ +#include +#include #include #include -#include -#include #include +#include #include -#include namespace nnvm { @@ -52,7 +52,7 @@ enum TypeFlag { kFloat16 = 2, kUint8 = 3, kInt32 = 4, - kInt8 = 5, + kInt8 = 5, kInt64 = 6, // kBool = 7, // 7 is reserved for kBool, in order to keep consistency with MXNet TypeFlag defined in diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index b35e4da..e6efb79 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -41,11 +41,11 @@ typedef unsigned int nn_uint; /*! \brief handle to a function that takes param and creates symbol */ -typedef void *OpHandle; +typedef void* OpHandle; /*! \brief handle to a symbol that can be bind as operator */ -typedef void *SymbolHandle; +typedef void* SymbolHandle; /*! \brief handle to Graph */ -typedef void *GraphHandle; +typedef void* GraphHandle; #ifdef __cplusplus extern "C" { @@ -65,7 +65,7 @@ NNVM_DLL void NNAPISetLastError(const char* msg); * this function is threadsafe and can be called by different thread * \return error info */ -NNVM_DLL const char *NNGetLastError(void); +NNVM_DLL const char* NNGetLastError(void); /*! * \brief list all the available operator names, include entries. @@ -73,16 +73,14 @@ NNVM_DLL const char *NNGetLastError(void); * \param out_array the output operator name array. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNListAllOpNames(nn_uint *out_size, - const char*** out_array); +NNVM_DLL int NNListAllOpNames(nn_uint* out_size, const char*** out_array); /*! * \brief Get operator handle given name. * \param op_name The name of the operator. * \param op_out The returnning op handle. */ -NNVM_DLL int NNGetOpHandle(const char* op_name, - OpHandle* op_out); +NNVM_DLL int NNGetOpHandle(const char* op_name, OpHandle* op_out); /*! * \brief list all the available operators. @@ -93,8 +91,7 @@ NNVM_DLL int NNGetOpHandle(const char* op_name, * \param out_array the output AtomicSymbolCreator array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNListUniqueOps(nn_uint *out_size, - OpHandle **out_array); +NNVM_DLL int NNListUniqueOps(nn_uint* out_size, OpHandle** out_array); /*! * \brief Get the detailed information about atomic symbol. @@ -109,14 +106,10 @@ NNVM_DLL int NNListUniqueOps(nn_uint *out_size, * \param return_type Return type of the function, if any. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGetOpInfo(OpHandle op, - const char **real_name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); +NNVM_DLL int NNGetOpInfo(OpHandle op, const char** real_name, const char** description, + nn_uint* num_doc_args, const char*** arg_names, + const char*** arg_type_infos, const char*** arg_descriptions, + const char** return_type); /*! * \brief Create an AtomicSymbol functor. * \param op The operator handle @@ -126,18 +119,15 @@ NNVM_DLL int NNGetOpInfo(OpHandle op, * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, - nn_uint num_param, - const char **keys, - const char **vals, - SymbolHandle *out); +NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, nn_uint num_param, const char** keys, + const char** vals, SymbolHandle* out); /*! * \brief Create a Variable Symbol. * \param name name of the variable * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); +NNVM_DLL int NNSymbolCreateVariable(const char* name, SymbolHandle* out); /*! * \brief Create a Symbol by grouping list of symbols together * \param num_symbols number of symbols to be grouped @@ -145,16 +135,13 @@ NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, - SymbolHandle *symbols, - SymbolHandle *out); +NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, SymbolHandle* symbols, SymbolHandle* out); /*! * \brief Add src_dep to the handle as control dep. * \param handle The symbol to add dependency edges on. * \param src_dep the source handles. */ -NNVM_DLL int NNAddControlDeps(SymbolHandle handle, - SymbolHandle src_dep); +NNVM_DLL int NNAddControlDeps(SymbolHandle handle, SymbolHandle src_dep); /*! * \brief Free the symbol handle. * \param symbol the symbol @@ -167,14 +154,14 @@ NNVM_DLL int NNSymbolFree(SymbolHandle symbol); * \param out used to hold the result of copy * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out); +NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle* out); /*! * \brief Print the content of symbol, used for debug. * \param symbol the symbol * \param out_str pointer to hold the output string of the printing. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); +NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char** out_str); /*! * \brief Get string attribute from symbol * \param symbol the source symbol @@ -183,13 +170,11 @@ NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); * \param success Whether the result is contained in out. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, - const char* key, - const char** out, - int *success); +NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success); /*! * \brief Set string attribute from symbol. - * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. + * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic + * graph. * * Safe recommendaton: use immutable graph * - Only allow set attributes during creation of new symbol as optional parameter @@ -204,9 +189,7 @@ NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, * \param values The value to be set * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, - nn_uint num_param, - const char** keys, +NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, const char** keys, const char** values); /*! * \brief Get all attributes from symbol, including all descendents. @@ -216,9 +199,7 @@ NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, * \param out 2*out_size strings representing key value pairs. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, - int recursive_option, - nn_uint *out_size, +NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, int recursive_option, nn_uint* out_size, const char*** out); /*! @@ -232,9 +213,7 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, * \param out_sym_array the output array. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, - int option, - nn_uint *out_size, +NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, int option, nn_uint* out_size, SymbolHandle** out_sym_array); /*! @@ -248,10 +227,8 @@ NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char ***out_str_array); +NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, int option, nn_uint* out_size, + const char*** out_str_array); /*! * \brief List returns names in the symbol. * \param symbol the symbol @@ -259,10 +236,8 @@ NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array); - +NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, nn_uint* out_size, + const char*** out_str_array); /*! * \brief Supply number of outputs of the symbol. @@ -270,8 +245,7 @@ NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, * \param output_count number of outputs * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, - nn_uint *output_count); +NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, nn_uint* output_count); /*! * \brief Get a symbol that contains all the internals. @@ -279,16 +253,14 @@ NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, * \param out The output symbol whose outputs are all the internals. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, - SymbolHandle *out); +NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, SymbolHandle* out); /*! * \brief Get a symbol that contains only direct children. * \param symbol The symbol * \param out The output symbol whose outputs are the direct children. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, - SymbolHandle *out); +NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, SymbolHandle* out); /*! * \brief Get index-th outputs of the symbol. * \param symbol The symbol @@ -296,9 +268,7 @@ NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, * \param out The output symbol whose outputs are the index-th symbol. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, - nn_uint index, - SymbolHandle *out); +NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, nn_uint index, SymbolHandle* out); /*! * \brief Compose the symbol on other symbols. @@ -314,11 +284,8 @@ NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, * \param args arguments to sym * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCompose(SymbolHandle sym, - const char* name, - nn_uint num_args, - const char** keys, - SymbolHandle* args); +NNVM_DLL int NNSymbolCompose(SymbolHandle sym, const char* name, nn_uint num_args, + const char** keys, SymbolHandle* args); // Graph IR API /*! @@ -327,7 +294,7 @@ NNVM_DLL int NNSymbolCompose(SymbolHandle sym, * \param graph The graph handle created. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph); +NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle* graph); /*! * \brief free the graph handle * \param handle The handle to be freed. @@ -339,7 +306,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle); * \param symbol The corresponding symbol * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); +NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle* symbol); /*! * \brief Get Set a attribute in json format. @@ -351,9 +318,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); * Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, - const char* key, - const char* json_value); +NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value); /*! * \brief Get a serialized attrirbute from graph. @@ -367,10 +332,8 @@ NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, * \param success Whether the result is contained in out. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, - const char* key, - const char** json_out, - int *success); +NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, const char* key, const char** json_out, + int* success); /*! * \brief Set a attribute whose type is std::vector in c++ @@ -383,9 +346,7 @@ NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, * \param list The symbol whose outputs represents the list of NodeEntry to be passed. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, - const char* key, - SymbolHandle list); +NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, const char* key, SymbolHandle list); /*! * \brief Apply passes on the src graph. * \param src The source graph handle. @@ -394,10 +355,8 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, * \param dst The result graph. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphApplyPasses(GraphHandle src, - nn_uint num_pass, - const char** pass_names, - GraphHandle *dst); +NNVM_DLL int NNGraphApplyPasses(GraphHandle src, nn_uint num_pass, const char** pass_names, + GraphHandle* dst); #ifdef __cplusplus } /* end extern "C" */ diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 1911a03..475494e 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -24,13 +24,14 @@ #ifndef NNVM_GRAPH_H_ #define NNVM_GRAPH_H_ -#include -#include -#include #include #include +#include #include #include +#include +#include + #include "base.h" #include "node.h" #include "symbolic.h" @@ -64,7 +65,7 @@ class Graph { * \return the reference to corresponding attribute * \tparam T the type of the attribute. */ - template + template inline const T& GetAttr(const std::string& attr_name) const; /*! * \brief Check whether has a specific attribute. @@ -81,7 +82,7 @@ class Graph { * \return a new copy of the corresponding attribute. * \tparam T the type of the attribute. */ - template + template inline T MoveCopyAttr(const std::string& attr_name); /*! * \brief get a indexed graph of current graph, if not exist, create it on demand @@ -127,13 +128,9 @@ class IndexedGraph { std::weak_ptr weak_ref; }; /*! \return number of nodes in the graph */ - inline size_t num_nodes() const { - return nodes_.size(); - } + inline size_t num_nodes() const { return nodes_.size(); } /*! \return total number of NodeEntry in the graph */ - inline size_t num_node_entries() const { - return entry_rptr_.back(); - } + inline size_t num_node_entries() const { return entry_rptr_.back(); } /*! * \brief Get a unique entry id between 0 to num_node_entries() * for a given IndexedGraph::NodeEntry @@ -150,9 +147,7 @@ class IndexedGraph { * \param e The entry to query for index. * \return the unique index. */ - inline uint32_t entry_id(const NodeEntry& e) const { - return entry_rptr_[e.node_id] + e.index; - } + inline uint32_t entry_id(const NodeEntry& e) const { return entry_rptr_[e.node_id] + e.index; } /*! * \brief Get a unique entry id between 0 to num_node_entries() * for a given NodeEntry. @@ -167,42 +162,30 @@ class IndexedGraph { * \param node The Node to query for index. * \return the node index. */ - inline uint32_t node_id(const nnvm::Node* node) const { - return node2index_.at(node); - } + inline uint32_t node_id(const nnvm::Node* node) const { return node2index_.at(node); } /*! * \brief Get the corresponding Node structure for a given node_id. * \param node_id The node id * \return const reference to the corresponding IndexedGraph::Node */ - inline const Node& operator[](uint32_t node_id) const { - return nodes_[node_id]; - } + inline const Node& operator[](uint32_t node_id) const { return nodes_[node_id]; } /*! * \brief Get the corresponding Node structure * \param node The pointer to the Node structure * \return const reference to the corresponding IndexedGraph::Node */ - inline const Node& operator[](const nnvm::Node* node) const { - return nodes_[node_id(node)]; - } + inline const Node& operator[](const nnvm::Node* node) const { return nodes_[node_id(node)]; } /*! \return list of argument nodes */ - inline const std::vector& input_nodes() const { - return input_nodes_; - } + inline const std::vector& input_nodes() const { return input_nodes_; } /*! \return list of mutable nodes */ inline const std::unordered_set& mutable_input_nodes() const { return mutable_input_nodes_; } /*! \return list of output entries */ - inline const std::vector& outputs() const { - return outputs_; - } + inline const std::vector& outputs() const { return outputs_; } /*! \return whether a node is existed in the indexed graph */ - inline bool exist(const nnvm::Node* node) const { - return node2index_.count(node); - } + inline bool exist(const nnvm::Node* node) const { return node2index_.count(node); } // disalllow copy assign IndexedGraph(const IndexedGraph&) = delete; @@ -239,15 +222,14 @@ class IndexedGraph { * \param fvisit a function of type std::function&)> * \tparam FVisit The function type to perform the visit. */ -template +template inline void DFSVisit(const std::vector& heads, FVisit fvisit); // inline function implementations -template +template inline const T& Graph::GetAttr(const std::string& attr_name) const { auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) - << "Cannot find attribute " << attr_name << " in the graph"; + CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph"; return nnvm::unsafe_get(*it->second); } @@ -256,11 +238,10 @@ inline bool Graph::HasAttr(const std::string& attr_name) const { return it != attrs.end(); } -template +template inline T Graph::MoveCopyAttr(const std::string& attr_name) { auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) - << "Cannot find attribute " << attr_name << " in the graph"; + CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph"; std::shared_ptr sptr = it->second; attrs.erase(it); if (sptr.unique()) { @@ -270,14 +251,10 @@ inline T Graph::MoveCopyAttr(const std::string& attr_name) { } } -template -void PostOrderDFSVisit(const std::vector& heads, - FVisit fvisit, - HashFunc hash, - InDegree indegree, - GetInput getinput) { +template +void PostOrderDFSVisit(const std::vector& heads, FVisit fvisit, HashFunc hash, + InDegree indegree, GetInput getinput) { std::vector > stack; std::unordered_set visited; for (auto& head : heads) { @@ -303,28 +280,20 @@ void PostOrderDFSVisit(const std::vector& heads, } } -template -inline void DFSVisit(const std::vector& heads, - FVisit fvisit) { +template +inline void DFSVisit(const std::vector& heads, FVisit fvisit) { typedef const ObjectPtr* GNode; std::vector head_nodes(heads.size()); std::transform(heads.begin(), heads.end(), head_nodes.begin(), - [](const NodeEntry& e)->GNode { - return &e.node; - }); + [](const NodeEntry& e) -> GNode { return &e.node; }); PostOrderDFSVisit( - head_nodes, - [fvisit](GNode n) { - fvisit(*n); - }, // FVisit - [](GNode n)->Node* { - return n->get(); - }, // HashFunc - [](GNode n)->uint32_t { // InDegree + head_nodes, [fvisit](GNode n) { fvisit(*n); }, // FVisit + [](GNode n) -> Node* { return n->get(); }, // HashFunc + [](GNode n) -> uint32_t { // InDegree if (!(*n)) return 0; return (*n)->inputs.size() + (*n)->control_deps.size(); - }, - [](GNode n, uint32_t index)->GNode { // GetInput + }, + [](GNode n, uint32_t index) -> GNode { // GetInput if (index < (*n)->inputs.size()) { return &(*n)->inputs.at(index).node; } else { diff --git a/nnvm/include/nnvm/graph_attr_types.h b/nnvm/include/nnvm/graph_attr_types.h index acc52a2..9e01855 100644 --- a/nnvm/include/nnvm/graph_attr_types.h +++ b/nnvm/include/nnvm/graph_attr_types.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,11 +24,12 @@ #ifndef NNVM_GRAPH_ATTR_TYPES_H_ #define NNVM_GRAPH_ATTR_TYPES_H_ -#include #include #include -#include "tuple.h" +#include + #include "layout.h" +#include "tuple.h" namespace nnvm { diff --git a/nnvm/include/nnvm/layout.h b/nnvm/include/nnvm/layout.h index 3a81b84..e2e9978 100644 --- a/nnvm/include/nnvm/layout.h +++ b/nnvm/include/nnvm/layout.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,11 +31,12 @@ #define NNVM_LAYOUT_H_ #include -#include + +#include #include -#include +#include #include -#include +#include namespace nnvm { @@ -44,7 +45,7 @@ class Layout { using LayoutDim = char; /*! \brief default constructor */ - Layout() : name_("__undef__") {} // NOLINT(*) + Layout() : name_("__undef__") {} // NOLINT(*) /*! * \brief construct from a string. @@ -54,21 +55,21 @@ class Layout { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - inline Layout(const std::string& layout) { // NOLINT(*) + inline Layout(const std::string& layout) { // NOLINT(*) parse(layout); } /*! * \brief copy constructor from another layout * \param s the source layout */ - inline Layout(const Layout& s) { // NOLINT(*) + inline Layout(const Layout& s) { // NOLINT(*) this->parse(s.name_); } /*! * \brief move constructor from Layout * \param src the source layout */ - inline Layout(Layout&& src) { // NOLINT(*) + inline Layout(Layout&& src) { // NOLINT(*) this->swap(src); } /*! @@ -86,7 +87,7 @@ class Layout { * \return reference of self */ inline Layout& operator=(Layout&& src) { - Layout(std::move(src)).swap(*this); // NOLINT(*) + Layout(std::move(src)).swap(*this); // NOLINT(*) return *this; } /*! @@ -102,16 +103,12 @@ class Layout { * \return whether two layout equals * \param s the layout to compare against */ - inline bool operator==(const Layout& s) const { - return name_ == s.name_; - } + inline bool operator==(const Layout& s) const { return name_ == s.name_; } /*! * \return whether two layout not equal * \param s the layout to compare against */ - inline bool operator!=(const Layout& s) const { - return !(*this == s); - } + inline bool operator!=(const Layout& s) const { return !(*this == s); } /*! * \brief Append the current layout by another. @@ -134,18 +131,14 @@ class Layout { * \param dim input dimension * \return Whether a given dimension is a super-dimension. */ - static inline bool is_superdim(LayoutDim dim) { - return dim >= 'A' && dim <= 'Z'; - } + static inline bool is_superdim(LayoutDim dim) { return dim >= 'A' && dim <= 'Z'; } /*! * \brief Check whether a given dimension is a sub-dimension. * \param dim input dimension * \return Whether a given dimension is a sub-dimension. */ - static inline bool is_subdim(LayoutDim dim) { - return dim >= 'a' && dim <= 'z'; - } + static inline bool is_subdim(LayoutDim dim) { return dim >= 'a' && dim <= 'z'; } /*! * \brief Convert a given dimension to super-dimension. @@ -200,7 +193,7 @@ class Layout { * \param dst the target layout * \return Whether can be converted to dst layout. */ - inline bool convertible(const Layout &dst) const { + inline bool convertible(const Layout& dst) const { if (!this->defined() || !dst.defined()) return false; for (size_t i = 0; i < kUniqueDim; ++i) { if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) || @@ -258,13 +251,12 @@ class Layout { * \return A newly constructed Layout object. */ inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const { - CHECK(target_pos <= this->ndim()) << "Invalid split position " - << target_pos << " for layout " << name_; + CHECK(target_pos <= this->ndim()) + << "Invalid split position " << target_pos << " for layout " << name_; CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim; CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_; - CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim - << " has already been split in " - << name_; + CHECK(!this->contains(to_subdim(dim))) + << "Dimension " << dim << " has already been split in " << name_; CHECK(size > 0) << "Invalid split size " << size; std::ostringstream new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { @@ -282,26 +274,16 @@ class Layout { using reverse_iterator = std::vector::const_reverse_iterator; /*! \return begin iterator */ - inline iterator begin() const { - return layout_simplified_.begin(); - } + inline iterator begin() const { return layout_simplified_.begin(); } /*! \return end iterator */ - inline iterator end() const { - return layout_simplified_.end(); - } + inline iterator end() const { return layout_simplified_.end(); } /*! \return rbegin iterator */ - inline reverse_iterator rbegin() const { - return layout_simplified_.rbegin(); - } + inline reverse_iterator rbegin() const { return layout_simplified_.rbegin(); } /*! \return rend iterator */ - inline reverse_iterator rend() const { - return layout_simplified_.rend(); - } + inline reverse_iterator rend() const { return layout_simplified_.rend(); } /*! \return number of dimensions */ - inline size_t ndim() const { - return layout_simplified_.size(); - } + inline size_t ndim() const { return layout_simplified_.size(); } /*! * \brief The description of the \p i-th dimension. @@ -311,8 +293,7 @@ class Layout { * \return the description of the dimension. */ inline std::string at(size_t i) const { - CHECK_LT(i, this->ndim()) << "position " << i - << " exceeds ndim=" << this->ndim(); + CHECK_LT(i, this->ndim()) << "position " << i << " exceeds ndim=" << this->ndim(); std::ostringstream repr; if (is_subdim(layout_simplified_[i])) { auto factor = subsizeof(layout_simplified_[i]); @@ -331,9 +312,12 @@ class Layout { * \return the index or -1 if not found. */ inline int32_t indexof(LayoutDim dim) const { - if (!this->defined()) return -1; - else if (is_superdim(dim)) return superdim_pos_[dim - 'A']; - else if (is_subdim(dim)) return subdim_pos_[dim - 'a']; + if (!this->defined()) + return -1; + else if (is_superdim(dim)) + return superdim_pos_[dim - 'A']; + else if (is_subdim(dim)) + return subdim_pos_[dim - 'a']; return -1; } @@ -359,34 +343,26 @@ class Layout { */ inline bool contains(LayoutDim dim) const { if (is_superdim(dim)) { - return superdim_pos_[dim-'A'] >= 0; + return superdim_pos_[dim - 'A'] >= 0; } else if (is_subdim(dim)) { - return subdim_pos_[dim-'a'] >= 0; + return subdim_pos_[dim - 'a'] >= 0; } return false; } - inline LayoutDim operator[](size_t i) const { - return layout_simplified_[i]; - } + inline LayoutDim operator[](size_t i) const { return layout_simplified_[i]; } /*! \return whether the layout is defined */ - inline bool defined() const { - return name_ != "__undef__"; - } + inline bool defined() const { return name_ != "__undef__"; } /*! \return the string description of the layout */ - inline const std::string& name() const { - return name_; - } + inline const std::string& name() const { return name_; } /*! * \brief Write layout in JSON format. * \param writer JSONWriter */ - inline void Save(dmlc::JSONWriter* writer) const { - writer->Write(name_); - } + inline void Save(dmlc::JSONWriter* writer) const { writer->Write(name_); } /*! * \brief Load layout from JSON. @@ -433,21 +409,20 @@ class Layout { const LayoutDim c = layout.at(i); if (is_superdim(c)) { int pos = c - 'A'; - CHECK_EQ(factor, 0) << "Invalid layout " << layout - << ": invalid factor size " << factor + CHECK_EQ(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor << " before dimension " << c; - CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; + CHECK_EQ(superdim_pos_[pos], -1) + << "Invalid layout " << layout << ": duplicate dimension " << c; superdim_pos_[pos] = curr++; layout_simplified_.push_back(c); } else if (is_subdim(c)) { int pos = c - 'a'; - CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " - << factor << " for dimension " << c; - CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; - CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; + CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor + << " for dimension " << c; + CHECK_EQ(subdim_pos_[pos], -1) + << "Invalid layout " << layout << ": duplicate dimension " << c; + CHECK_EQ(subdim_size_[pos], -1) + << "Invalid layout " << layout << ": duplicate dimension " << c; subdim_pos_[pos] = curr++; subdim_size_[pos] = factor; layout_simplified_.push_back(c); @@ -461,9 +436,8 @@ class Layout { } CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout; for (LayoutDim dim : layout_simplified_) { - CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0) - << "Invalid layout " << layout << ": missing axis " - << static_cast(dim - 'a' + 'A'); + CHECK(is_superdim(dim) || superdim_pos_[dim - 'a'] >= 0) + << "Invalid layout " << layout << ": missing axis " << static_cast(dim - 'a' + 'A'); } } }; diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 2155481..91d13e5 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -26,12 +26,13 @@ #include #include -#include -#include #include +#include +#include + #include "base.h" -#include "op.h" #include "c_api.h" +#include "op.h" namespace nnvm { @@ -49,27 +50,16 @@ using ObjectPtr = std::shared_ptr; /*! \brief an entry that represents output data from a node */ struct NodeEntry { - NodeEntry(ObjectPtr node, uint32_t index, uint32_t version): - node(std::move(node)), - index(index), - version(version) - {} - - explicit NodeEntry(ObjectPtr node): - node(std::move(node)), - index(), - version() - {} + NodeEntry(ObjectPtr node, uint32_t index, uint32_t version) + : node(std::move(node)), index(index), version(version) {} + + explicit NodeEntry(ObjectPtr node) : node(std::move(node)), index(), version() {} /** * MXNet assumes that a node with a null ptr doesn't have a gradient attached. Don't change this * constructor. */ - NodeEntry(): - node(nullptr), - index(), - version() - {} + NodeEntry() : node(nullptr), index(), version() {} /*! \brief the source node of this data */ ObjectPtr node; @@ -79,7 +69,8 @@ struct NodeEntry { * \brief version of input Variable. * This field can only be nonzero when this->node is a Variable node. * version is increased by one each time a Variable get composed to a mutation Op. - * This information can be helpful to decide order of operations when sequence of mutation happens. + * This information can be helpful to decide order of operations when sequence of mutation + * happens. */ uint32_t version; }; @@ -90,9 +81,8 @@ struct NodeEntry { */ struct NodeEntryHash { size_t operator()(const NodeEntry& e) const { - return std::hash()(e.node.get()) ^ - (std::hash()(e.index) << 1 >> 1) ^ - (std::hash()(e.version) << 1); + return std::hash()(e.node.get()) ^ (std::hash()(e.index) << 1 >> 1) ^ + (std::hash()(e.version) << 1); } }; @@ -102,14 +92,12 @@ struct NodeEntryHash { */ struct NodeEntryEqual { size_t operator()(const NodeEntry& a, const NodeEntry& b) const { - return (a.node.get() == b.node.get()) && - (a.index == b.index) && - (a.version == b.version); + return (a.node.get() == b.node.get()) && (a.index == b.index) && (a.version == b.version); } }; /*! use NodeEntry as key in unordered_map */ -template +template using NodeEntryMap = std::unordered_map; /*! @@ -121,7 +109,7 @@ struct NodeAttrs { * \brief The operator this node uses. * For place holder variable, op == nullptr. */ - const Op *op{nullptr}; + const Op* op{nullptr}; /*! \brief name of the node */ std::string name; /*! \brief The dictionary representation of attributes */ @@ -188,7 +176,7 @@ class NNVM_DLL Node { * \brief create a new empty shared_ptr of Node. * \return a created empty node. */ - template + template static ObjectPtr Create(Args&&... args) { return std::make_shared(std::forward(args)...); } @@ -202,12 +190,9 @@ class NNVM_DLL Node { * \param attrs The attributes * \return The created node entry. */ -inline NodeEntry MakeNode( - const char* op_name, - std::string node_name, - std::vector inputs, - std::unordered_map attrs = - std::unordered_map()) { +inline NodeEntry MakeNode(const char* op_name, std::string node_name, std::vector inputs, + std::unordered_map attrs = + std::unordered_map()) { ObjectPtr p = Node::Create(); p->attrs.op = nnvm::Op::Get(op_name); p->attrs.name = std::move(node_name); @@ -220,13 +205,9 @@ inline NodeEntry MakeNode( } // implementation of functions. -inline const Op* Node::op() const { - return this->attrs.op; -} +inline const Op* Node::op() const { return this->attrs.op; } -inline bool Node::is_variable() const { - return this->op() == nullptr; -} +inline bool Node::is_variable() const { return this->op() == nullptr; } inline uint32_t Node::num_outputs() const { if (is_variable()) return 1; diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index 645804c..f53e0f2 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,12 +25,14 @@ #define NNVM_OP_H_ #include + +#include +#include #include -#include -#include #include -#include -#include +#include +#include + #include "base.h" #include "c_api.h" @@ -39,7 +41,7 @@ namespace nnvm { // forward declarations class Node; struct NodeAttrs; -template +template class OpMap; class OpGroup; class OpRegistryEntry; @@ -193,15 +195,14 @@ class NNVM_DLL Op { * \param description Description of the argument. * \return reference to self. */ - inline Op& add_argument(const std::string &name, - const std::string &type, - const std::string &description); + inline Op& add_argument(const std::string& name, const std::string& type, + const std::string& description); /*! * \brief Append list if arguments to the end. * \param args Additional list of arguments. * \return reference to self. */ - inline Op& add_arguments(const std::vector &args); + inline Op& add_arguments(const std::vector& args); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -219,7 +220,7 @@ class NNVM_DLL Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_inputs(std::function fn); // NOLINT(*) + inline Op& set_num_inputs(std::function fn); // NOLINT(*) /*! * \brief Set the num_outputs * \param n The number of outputs to be set. @@ -231,13 +232,13 @@ class NNVM_DLL Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_outputs(std::function fn); // NOLINT(*) + inline Op& set_num_outputs(std::function fn); // NOLINT(*) /*! * \brief Set the attr_parser function. * \param fn The number of outputs to be set. * \return reference to self. */ - inline Op& set_attr_parser(std::function fn); // NOLINT(*) + inline Op& set_attr_parser(std::function fn); // NOLINT(*) /*! * \brief Register additional attributes to operator. * \param attr_name The name of the attribute. @@ -251,10 +252,9 @@ class NNVM_DLL Op { * * \tparam ValueType The type of the value to be set. */ - template + template inline Op& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 10); + const ValueType& value, int plevel = 10); /*! * \brief Add another alias to this operator. * The same Op can be queried with Op::Get(alias) @@ -284,11 +284,11 @@ class NNVM_DLL Op { * \return An OpMap of specified attr_name. * \tparam ValueType The type of the attribute. */ - template + template static const OpMap& GetAttr(const std::string& attr_name); private: - template + template friend class OpMap; friend class OpGroup; friend class dmlc::Registry; @@ -300,15 +300,13 @@ class NNVM_DLL Op { // get const reference to certain attribute static const any* GetAttrMap(const std::string& key); // update the attribute OpMap - static void UpdateAttrMap(const std::string& key, - std::function updater); + static void UpdateAttrMap(const std::string& key, std::function updater); // add a trigger based on tag matching on certain tag attribute // This will apply trigger on all the op such that // include the corresponding group. // The trigger will also be applied to all future registrations // that calls include - static void AddGroupTrigger(const std::string& group_name, - std::function trigger); + static void AddGroupTrigger(const std::string& group_name, std::function trigger); }; /*! @@ -316,7 +314,7 @@ class NNVM_DLL Op { * and returns ValueType * \tparam ValueType The type of the value stored in map. */ -template +template class OpMap { public: /*! @@ -351,7 +349,7 @@ class OpMap { // internal attribute name std::string attr_name_; // internal data - std::vector > data_; + std::vector> data_; OpMap() = default; }; @@ -376,18 +374,17 @@ class OpGroup { * * \tparam ValueType The type of the value to be set. */ - template + template inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 1); + const ValueType& value, int plevel = 1); }; // internal macros to make -#define NNVM_REGISTER_VAR_DEF(OpName) \ - static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName +#define NNVM_REGISTER_VAR_DEF(OpName) \ + static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op& __make_##NnvmOp##_##OpName -#define NNVM_REGISTER_GVAR_DEF(TagName) \ - static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName +#define NNVM_REGISTER_GVAR_DEF(TagName) \ + static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_##NnvmOpGroup##_##TagName /*! * \def NNVM_REGISTER_OP @@ -404,8 +401,8 @@ class OpGroup { * * \endcode */ -#define NNVM_REGISTER_OP(OpName) \ - DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ +#define NNVM_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) /*! @@ -429,85 +426,72 @@ class OpGroup { * * \endcode */ -#define NNVM_REGISTER_OP_GROUP(GroupName) \ - DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \ - ::nnvm::OpGroup {#GroupName} +#define NNVM_REGISTER_OP_GROUP(GroupName) \ + DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = ::nnvm::OpGroup { #GroupName } // implementations of template functions after this. // member function of Op -template +template inline const OpMap& Op::GetAttr(const std::string& key) { const any* ref = GetAttrMap(key); if (ref == nullptr) { // update the attribute map of the key by creating new empty OpMap UpdateAttrMap(key, [key](any* pmap) { - // use callback so it is in lockscope - if (pmap->empty()) { - OpMap pm; - pm.attr_name_ = key; - *pmap = std::move(pm); - } - }); + // use callback so it is in lockscope + if (pmap->empty()) { + OpMap pm; + pm.attr_name_ = key; + *pmap = std::move(pm); + } + }); ref = GetAttrMap(key); } - return nnvm::get >(*ref); + return nnvm::get>(*ref); } -template +template inline Op& Op::set_attr( // NOLINT(*) - const std::string& attr_name, - const ValueType& value, - int plevel) { - CHECK_GT(plevel, 0) - << "plevel in set_attr must be greater than 0"; + const std::string& attr_name, const ValueType& value, int plevel) { + CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; // update the attribute map of the key by creating new empty if needed. - UpdateAttrMap(attr_name, - [this, attr_name, value, plevel](any* pmap) { - // the callback is in lockscope so is threadsafe. - if (pmap->empty()) { - OpMap pm; - pm.attr_name_ = attr_name; - *pmap = std::move(pm); - } - CHECK(pmap->type() == typeid(OpMap)) - << "Attribute " << attr_name - << " of operator " << this->name - << " is registered as inconsistent types" - << " previously " << pmap->type().name() - << " current " << typeid(OpMap).name(); - std::vector >& vec = - nnvm::get >(*pmap).data_; - // resize the value type. - if (vec.size() <= index_) { - vec.resize(index_ + 1, - std::make_pair(ValueType(), 0)); - } - std::pair& p = vec[index_]; - CHECK(p.second != plevel) - << "Attribute " << attr_name - << " of operator " << this->name - << " is already registered with same plevel=" << plevel; - if (p.second < plevel) { - vec[index_] = std::make_pair(value, plevel); - } - }); + UpdateAttrMap(attr_name, [this, attr_name, value, plevel](any* pmap) { + // the callback is in lockscope so is threadsafe. + if (pmap->empty()) { + OpMap pm; + pm.attr_name_ = attr_name; + *pmap = std::move(pm); + } + CHECK(pmap->type() == typeid(OpMap)) + << "Attribute " << attr_name << " of operator " << this->name + << " is registered as inconsistent types" + << " previously " << pmap->type().name() << " current " << typeid(OpMap).name(); + std::vector>& vec = nnvm::get>(*pmap).data_; + // resize the value type. + if (vec.size() <= index_) { + vec.resize(index_ + 1, std::make_pair(ValueType(), 0)); + } + std::pair& p = vec[index_]; + CHECK(p.second != plevel) << "Attribute " << attr_name << " of operator " << this->name + << " is already registered with same plevel=" << plevel; + if (p.second < plevel) { + vec[index_] = std::make_pair(value, plevel); + } + }); return *this; } - inline Op& Op::describe(const std::string& descr) { // NOLINT(*) this->description = descr; return *this; } -inline Op& Op::add_argument(const std::string &name, - const std::string &type, - const std::string &description) { +inline Op& Op::add_argument(const std::string& name, const std::string& type, + const std::string& description) { arguments.push_back({name, type, type, description}); return *this; } -inline Op& Op::add_arguments(const std::vector &args) { +inline Op& Op::add_arguments(const std::vector& args) { this->arguments.insert(arguments.end(), args.begin(), args.end()); return *this; } @@ -522,7 +506,7 @@ inline Op& Op::set_support_level(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) +inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) this->get_num_inputs = fn; return *this; } @@ -532,18 +516,18 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) +inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) this->get_num_outputs = fn; return *this; } -inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) +inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) this->attr_parser = fn; return *this; } // member functions of OpMap -template +template inline int OpMap::count(const Op* op) const { if (contains(op)) { return 1; @@ -552,7 +536,7 @@ inline int OpMap::count(const Op* op) const { } } -template +template inline bool OpMap::contains(const Op* op) const { if (op == nullptr) { return false; @@ -561,17 +545,16 @@ inline bool OpMap::contains(const Op* op) const { return idx < data_.size() ? (data_[idx].second != 0) : false; } -template +template inline const ValueType& OpMap::operator[](const Op* op) const { CHECK(op != nullptr); const uint32_t idx = op->index_; CHECK(idx < data_.size() && data_[idx].second) - << "Attribute " << attr_name_ - << " has not been registered for Operator " << op->name; + << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name; return data_[idx].first; } -template +template inline const ValueType& OpMap::get(const Op* op, const ValueType& def_value) const { if (op == nullptr) return def_value; const uint32_t idx = op->index_; @@ -582,9 +565,8 @@ inline const ValueType& OpMap::get(const Op* op, const ValueType& def } } -template -inline OpGroup& OpGroup::set_attr(const std::string& attr_name, - const ValueType& value, +template +inline OpGroup& OpGroup::set_attr(const std::string& attr_name, const ValueType& value, int plevel) { auto trigger = [attr_name, value, plevel](Op* op) { op->set_attr(attr_name, value, plevel); diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index c2af989..8409536 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -24,15 +24,16 @@ #ifndef NNVM_OP_ATTR_TYPES_H_ #define NNVM_OP_ATTR_TYPES_H_ -#include -#include -#include #include +#include #include +#include +#include + #include "base.h" +#include "layout.h" #include "node.h" #include "tuple.h" -#include "layout.h" namespace nnvm { @@ -48,7 +49,7 @@ namespace nnvm { * * FListInputNames enables automatic variable creation for missing arguments. */ -using FListInputNames = std::function (const NodeAttrs& attrs)>; +using FListInputNames = std::function(const NodeAttrs& attrs)>; /*! * \brief Return number of visible outputs by the user. @@ -60,7 +61,7 @@ using FListInputNames = std::function (const NodeAttrs& * but the additional outputs can be used to pass information from * forward to gradient pass. */ -using FNumVisibleOutputs = std::function; +using FNumVisibleOutputs = std::function; /*! * \brief Return list of output arguments names of each operator. @@ -71,7 +72,7 @@ using FNumVisibleOutputs = std::function; * * FListOutputNames customized naming for operator outputs. */ -using FListOutputNames = std::function (const NodeAttrs& attrs)>; +using FListOutputNames = std::function(const NodeAttrs& attrs)>; /*! * \brief Check whether operator will mutate k-th input. @@ -81,17 +82,16 @@ using FListOutputNames = std::function (const NodeAttrs * \note Register under "FMutateInputs", default return false * FMutateInputs enables mutation order handling correctly. */ -using FMutateInputs = std::function (const NodeAttrs& attrs)>; +using FMutateInputs = std::function(const NodeAttrs& attrs)>; /*! * \brief Inference function of certain type. * \tparam AttrType The type of the attribute to be infered. * \return whether all attributes are inferred. */ -template -using FInferNodeEntryAttr = std::function *in_attrs, - std::vector *out_attrs)>; +template +using FInferNodeEntryAttr = std::function* in_attrs, std::vector* out_attrs)>; /*! * \brief Get attribute dictionary from node. @@ -100,9 +100,8 @@ using FInferNodeEntryAttr = std::function - (const NodeAttrs& attrs)>; +using FGetAttrDict = + std::function(const NodeAttrs& attrs)>; /*! * \brief Shape inference function. @@ -155,8 +154,7 @@ using TIsGhost = bool; * * \note Register under "FInplaceOption", by default no inplace can happen. */ -using FInplaceOption = std::function< - std::vector > (const NodeAttrs& attrs)>; +using FInplaceOption = std::function >(const NodeAttrs& attrs)>; /*! * \brief Get if the inplace option is an identity @@ -168,7 +166,7 @@ using FInplaceOption = std::function< * * \note Register under "FInplaceIdentity", by default no identities. */ -using FInplaceIdentity = std::function (const NodeAttrs& attrs)>; +using FInplaceIdentity = std::function(const NodeAttrs& attrs)>; /*! * \brief Get list of inputs in the op whose content are actually not used by the operator @@ -179,8 +177,7 @@ using FInplaceIdentity = std::function (const NodeAttrs& attrs * * \note Register under "FIgnoreInputs". */ -using FIgnoreInputs = std::function< - std::vector (const NodeAttrs& attrs)>; +using FIgnoreInputs = std::function(const NodeAttrs& attrs)>; /*! * \brief Get the gradient node of the op node @@ -191,9 +188,8 @@ using FIgnoreInputs = std::function< * * \note Register under "FGradient" */ -using FGradient = std::function( - const ObjectPtr& nodeptr, - const std::vector& out_grads)>; +using FGradient = std::function(const ObjectPtr& nodeptr, + const std::vector& out_grads)>; /*! * \brief Set the attributes of input variable. @@ -202,10 +198,8 @@ using FGradient = std::function( * \param var the input variable * \param index index of var in all inputs */ -using FSetInputVarAttrOnCompose = std::function; +using FSetInputVarAttrOnCompose = + std::function; /*! * \brief Infer & correct function of node layout. See \p Layout for layout convention @@ -226,11 +220,9 @@ using FSetInputVarAttrOnCompose = std::function *ilayouts, - const std::vector *last_ilayouts, - std::vector *olayouts)>; +using FCorrectLayout = + std::function* ilayouts, + const std::vector* last_ilayouts, std::vector* olayouts)>; /*! * \brief Get a list of inputs that represent graphs instead of data. diff --git a/nnvm/include/nnvm/pass.h b/nnvm/include/nnvm/pass.h index a6158df..0bccdcc 100644 --- a/nnvm/include/nnvm/pass.h +++ b/nnvm/include/nnvm/pass.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,8 +24,9 @@ #ifndef NNVM_PASS_H_ #define NNVM_PASS_H_ -#include #include +#include + #include "base.h" #include "graph.h" @@ -42,7 +43,7 @@ namespace nnvm { * \param src The graph to be transformed. * \return The generated graph. */ -typedef std::function PassFunction; +typedef std::function PassFunction; /*! * \brief Apply a series of pass transformations on the input graph. @@ -50,8 +51,7 @@ typedef std::function PassFunction; * \param passes A list of pass names to be applied. * \return The transformed graph */ -Graph ApplyPasses(Graph src, - const std::vector& passes); +Graph ApplyPasses(Graph src, const std::vector& passes); /*! * \brief Apply one pass to the graph. @@ -59,17 +59,12 @@ Graph ApplyPasses(Graph src, * \param pass The name of pass to be applied. * \return The transformed graph. */ -inline Graph ApplyPass(Graph src, const std::string& pass) { - return ApplyPasses(src, {pass}); -} - +inline Graph ApplyPass(Graph src, const std::string& pass) { return ApplyPasses(src, {pass}); } /*! * \brief Registry entry for pass functions. */ -struct PassFunctionReg - : public dmlc::FunctionRegEntryBase { +struct PassFunctionReg : public dmlc::FunctionRegEntryBase { /*! * \brief Whether the pass will change graph structure * If this is false, the pass will only change attributes. @@ -138,7 +133,7 @@ struct PassFunctionReg * }); * \endcode */ -#define NNVM_REGISTER_PASS(name) \ +#define NNVM_REGISTER_PASS(name) \ DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name) } // namespace nnvm diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h index a7893c6..3097e20 100644 --- a/nnvm/include/nnvm/pass_functions.h +++ b/nnvm/include/nnvm/pass_functions.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,13 +28,14 @@ #ifndef NNVM_PASS_FUNCTIONS_H_ #define NNVM_PASS_FUNCTIONS_H_ -#include #include -#include +#include #include +#include + #include "base.h" -#include "pass.h" #include "graph_attr_types.h" +#include "pass.h" namespace nnvm { namespace pass { @@ -60,7 +61,6 @@ inline std::string SaveJSON(Graph graph) { return ret.GetAttr("json"); } - /*! * \brief Print graph ir * \param graph The graph to be printed @@ -81,9 +81,7 @@ inline std::string PrintGraphIR(Graph graph) { * \param src The input graph. * \return A graph with proper control flow dependencies added. */ -inline Graph OrderMutation(Graph src) { - return ApplyPass(std::move(src), "OrderMutation"); -} +inline Graph OrderMutation(Graph src) { return ApplyPass(std::move(src), "OrderMutation"); } /*! * \brief Infer shapes in the graph given the information. @@ -94,9 +92,7 @@ inline Graph OrderMutation(Graph src) { * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. * The index of ShapeVector is given by graph.indexed_graph().entry_id. */ -inline Graph InferShape(Graph graph, - ShapeVector shape_inputs, - std::string shape_attr_key = "") { +inline Graph InferShape(Graph graph, ShapeVector shape_inputs, std::string shape_attr_key = "") { if (shape_inputs.size() != 0) { graph.attrs["shape_inputs"] = std::make_shared(std::move(shape_inputs)); } @@ -115,9 +111,7 @@ inline Graph InferShape(Graph graph, * \return A graph with new attribute "dtype" containing inferred type of each NodeEntry. * The index of ShapeVector is given by graph.indexed_graph().entry_id. */ -inline Graph InferType(Graph graph, - DTypeVector dtype_inputs, - std::string dtype_attr_key = "") { +inline Graph InferType(Graph graph, DTypeVector dtype_inputs, std::string dtype_attr_key = "") { if (dtype_inputs.size() != 0) { graph.attrs["dtype_inputs"] = std::make_shared(std::move(dtype_inputs)); } @@ -141,10 +135,8 @@ inline Graph InferType(Graph graph, * \param device_copy_op The name of copy op to be inserted when cross device copy happened. * \return A graph with new attribute "device", cotaining device information of each node. */ -inline Graph PlaceDevice(Graph graph, - std::string device_group_attr_key, - DeviceAssignMap device_assign_map, - std::string device_copy_op) { +inline Graph PlaceDevice(Graph graph, std::string device_group_attr_key, + DeviceAssignMap device_assign_map, std::string device_copy_op) { graph.attrs["device_group_attr_key"] = std::make_shared(std::move(device_group_attr_key)); graph.attrs["device_assign_map"] = std::make_shared(std::move(device_assign_map)); graph.attrs["device_copy_op"] = std::make_shared(std::move(device_copy_op)); @@ -159,22 +151,18 @@ inline Graph PlaceDevice(Graph graph, * \param ys_out_grad The symbol for additional gradient to be propagate back to y. * \param aggregate_fun Aggregation function applied to aggregate the inputs. * \param mirror_fun Optional mirror function to do mirror optimization and save memory. - * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like. - * \param zero_ops Optional, list of operators that outputs a single zero array. The first one - * must be zeros_like. - * \param copy_op_str Optional, name of the copy operation required to handle duplicates - * on the edge of the graph - * \return A new graph, whose outputs correspond to inputs of xs. + * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same + * as like. \param zero_ops Optional, list of operators that outputs a single zero array. The first + * one must be zeros_like. \param copy_op_str Optional, name of the copy operation required to + * handle duplicates on the edge of the graph \return A new graph, whose outputs correspond to + * inputs of xs. */ inline Graph Gradient( - Graph graph, - std::vector ys, - std::vector xs, + Graph graph, std::vector ys, std::vector xs, std::vector ys_out_grad, std::function&& inputs)> aggregate_fun = nullptr, std::function mirror_fun = nullptr, - std::function - attr_hint_fun = nullptr, + std::function attr_hint_fun = nullptr, std::vector zero_ops = std::vector(), std::string copy_op_str = std::string()) { graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); @@ -198,7 +186,7 @@ inline Graph Gradient( } if (copy_op_str != std::string()) { - graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); + graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); } return ApplyPass(std::move(graph), "Gradient"); diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h index d3555ec..77d3855 100644 --- a/nnvm/include/nnvm/symbolic.h +++ b/nnvm/include/nnvm/symbolic.h @@ -29,10 +29,10 @@ #define NNVM_SYMBOLIC_H_ #include -#include #include -#include #include +#include +#include #include "base.h" #include "node.h" @@ -81,13 +81,13 @@ class NNVM_DLL Symbol { * \brief Print the symbol info to output stream. * \param os The output stream to print to. */ - void Print(std::ostream &os) const; // NOLINT(*) + void Print(std::ostream& os) const; // NOLINT(*) /*! * \brief Get the index-th element from the returned tuple. * \param index Index of multi output. * \return The symbol corresponds to the indexed element. */ - Symbol operator[] (size_t index) const; + Symbol operator[](size_t index) const; /*! * \brief List the input variable nodes. * @@ -139,9 +139,9 @@ class NNVM_DLL Symbol { * \param name Name of returned symbol. * \return A new Symbol which is the composition of current symbol with its arguments. */ - Symbol operator () (const array_view& args, - const std::unordered_map& kwargs, - const std::string& name) const; + Symbol operator()(const array_view& args, + const std::unordered_map& kwargs, + const std::string& name) const; /*! * \brief Add control flow dependencies to the operators in symbols. * @@ -201,16 +201,14 @@ class NNVM_DLL Symbol { * * \return The created attribute in format . */ - std::vector > - ListAttrsRecursive() const; + std::vector > ListAttrsRecursive() const; /*! * \brief Create symbolic functor(AtomicSymbol) by given operator and attributes. * \param op The operator. * \param attrs The additional attributes. * \return Symbol that can be used to call compose further. */ - static Symbol CreateFunctor(const Op* op, - std::unordered_map attrs); + static Symbol CreateFunctor(const Op* op, std::unordered_map attrs); /*! * \brief Create symbolic functor(AtomicSymbol) by given node attributes. * \param attrs pre-initialized Node attributes. diff --git a/nnvm/include/nnvm/tuple.h b/nnvm/include/nnvm/tuple.h index a7f2d26..c6d6125 100644 --- a/nnvm/include/nnvm/tuple.h +++ b/nnvm/include/nnvm/tuple.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,12 +24,13 @@ #ifndef NNVM_TUPLE_H_ #define NNVM_TUPLE_H_ -#include -#include #include -#include #include #include +#include +#include +#include + #include "base.h" namespace nnvm { @@ -47,29 +48,23 @@ typedef int64_t dim_t; * \tparam ValueType The type of data stored inside tuple. * \sa TShape */ -template +template class Tuple { public: /*! \brief default constructor */ Tuple() = default; /*! \brief destructor */ - inline ~Tuple() { - delete [] data_heap_; - } + inline ~Tuple() { delete[] data_heap_; } /*! * \brief copy constructor from another tuple * \param s the source tuple */ - inline Tuple(const Tuple& s) { - this->assign(s.begin(), s.end()); - } + inline Tuple(const Tuple& s) { this->assign(s.begin(), s.end()); } /*! * \brief constructor from initializer list * \param init the initializer_list */ - inline Tuple(std::initializer_list init) { - this->assign(init.begin(), init.end()); - } + inline Tuple(std::initializer_list init) { this->assign(init.begin(), init.end()); } /*! * \brief constructor from vector * \param init the vector @@ -82,7 +77,7 @@ class Tuple { * \param src the source shape */ - inline Tuple(Tuple&& src) { // NOLINT(runtime/explicit) + inline Tuple(Tuple&& src) { // NOLINT(runtime/explicit) this->swap(src); } /*! @@ -91,9 +86,8 @@ class Tuple { * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template - inline Tuple(RandomAccessIterator begin, - RandomAccessIterator end) { + template + inline Tuple(RandomAccessIterator begin, RandomAccessIterator end) { this->assign(begin, end); } /*! @@ -102,9 +96,8 @@ class Tuple { * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template - inline void assign(RandomAccessIterator begin, - RandomAccessIterator end) { + template + inline void assign(RandomAccessIterator begin, RandomAccessIterator end) { this->SetDim(end - begin); std::copy(begin, end, this->begin()); } @@ -141,7 +134,7 @@ class Tuple { * \param init the source initializer list * \return reference of self */ - inline Tuple &operator=(std::initializer_list init) { + inline Tuple& operator=(std::initializer_list init) { this->assign(init.begin(), init.end()); return *this; } @@ -149,7 +142,7 @@ class Tuple { * \return whether two tuple equals * \param s the tuple to compare against */ - inline bool operator==(const Tuple &s) const { + inline bool operator==(const Tuple& s) const { if (ndim_ != s.ndim_) return false; return std::equal(begin(), end(), s.begin()); } @@ -157,45 +150,33 @@ class Tuple { * \return whether two tuple not equal * \param s the tuple to compare against */ - inline bool operator!=(const Tuple &s) const { - return !(*this == s); - } + inline bool operator!=(const Tuple& s) const { return !(*this == s); } /*! \return the begin data pointer to content of the tuple */ - inline const ValueType *begin() const { - return ndim_ <= kStackCache ? data_stack_ : data_heap_; - } + inline const ValueType* begin() const { return ndim_ <= kStackCache ? data_stack_ : data_heap_; } /*! \return the begin data pointer to content of the tuple */ - inline ValueType *begin() { - return ndim_ <= kStackCache ? data_stack_ : data_heap_; - } + inline ValueType* begin() { return ndim_ <= kStackCache ? data_stack_ : data_heap_; } /*! \return the data pointer to end of the tuple */ inline const ValueType* end() const { - return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); + return ndim_ <= kStackCache ? (data_stack_ + ndim_) : (data_heap_ + ndim_); } /*! \return the data pointer to end the tuple */ inline ValueType* end() { - return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); + return ndim_ <= kStackCache ? (data_stack_ + ndim_) : (data_heap_ + ndim_); } /*! \return number of dimension of the tuple */ - inline uint32_t ndim() const { - return ndim_; - } + inline uint32_t ndim() const { return ndim_; } /*! * \brief get corresponding index * \param i dimension index * \return the corresponding dimension size */ - inline ValueType& operator[](size_t i) { - return begin()[i]; - } + inline ValueType& operator[](size_t i) { return begin()[i]; } /*! * \brief get corresponding index * \param i dimension index * \return the corresponding dimension size */ - inline const ValueType& operator[](size_t i) const { - return begin()[i]; - } + inline const ValueType& operator[](size_t i) const { return begin()[i]; } /*! * \brief Save Tuple to JSON. * \param writer JSONWriter @@ -219,7 +200,7 @@ class Tuple { * \param t the tuple * \return the ostream */ - friend std::ostream &operator<<(std::ostream &os, const Tuple &t) { + friend std::ostream& operator<<(std::ostream& os, const Tuple& t) { os << '['; const ValueType* begin = t.begin(); const ValueType* end = t.end(); @@ -236,7 +217,7 @@ class Tuple { * \param t The tuple * \return the istream */ - friend std::istream &operator>>(std::istream &is, Tuple &t) { + friend std::istream& operator>>(std::istream& is, Tuple& t) { // get ( while (true) { char ch = is.peek(); @@ -252,7 +233,7 @@ class Tuple { if (!isspace(ch)) { is.setstate(std::ios::failbit); return is; - } + } } // Handle empty tuple while (isspace(is.peek())) { @@ -278,10 +259,12 @@ class Tuple { while (true) { ch = is.peek(); if (isspace(ch)) { - is.get(); continue; + is.get(); + continue; } if (ch == ')' || ch == ']') { - is.get(); break; + is.get(); + break; } break; } @@ -302,8 +285,8 @@ class Tuple { * \tparam DType data type that save to * \tparam TStream any stream type that have write */ - template - inline void Save(TStream *strm) const; + template + inline void Save(TStream* strm) const; /*! * \brief load the content from binary stream * \param strm the output stream @@ -311,8 +294,8 @@ class Tuple { * \tparam TStream any stream type that have write * \return whether the load is successful */ - template - inline bool Load(TStream *strm); + template + inline bool Load(TStream* strm); protected: // stack cache size @@ -327,9 +310,8 @@ class Tuple { ValueType* data_heap_{nullptr}; // internal function to change the dimension inline void SetDim(uint32_t ndim) { - if (ndim > kStackCache && - ndim > num_heap_allocated_) { - delete [] data_heap_; + if (ndim > kStackCache && ndim > num_heap_allocated_) { + delete[] data_heap_; data_heap_ = new ValueType[ndim]; num_heap_allocated_ = ndim; } @@ -356,16 +338,14 @@ class TShape : public Tuple { * \brief copy constructor of TShape * \param s source shape. */ - inline TShape(const Tuple& s) { // NOLINT(*) + inline TShape(const Tuple& s) { // NOLINT(*) this->assign(s.begin(), s.end()); } /*! * \brief constructor from initializer list * \param init the initializer_list */ - inline TShape(std::initializer_list init) { - this->assign(init.begin(), init.end()); - } + inline TShape(std::initializer_list init) { this->assign(init.begin(), init.end()); } /*! * \brief move constructor. * \param s source shape. @@ -379,9 +359,8 @@ class TShape : public Tuple { * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template - inline TShape(RandomAccessIterator begin, - RandomAccessIterator end) { + template + inline TShape(RandomAccessIterator begin, RandomAccessIterator end) { this->assign(begin, end); } /*! @@ -399,13 +378,13 @@ class TShape : public Tuple { * \return self. */ inline TShape& operator=(Tuple&& src) { // NOLINT(*) - TShape(std::move(src)).swap(*this); // NOLINT(*) + TShape(std::move(src)).swap(*this); // NOLINT(*) return *this; } /*! \return total number of elements in the shape */ inline size_t Size() const { dim_t size = 1; - const dim_t* start = begin(), *fin = end(); + const dim_t *start = begin(), *fin = end(); for (const dim_t* it = start; it != fin; ++it) { size *= *it; } @@ -418,28 +397,24 @@ class TShape : public Tuple { */ inline size_t ProdShape(int dimstart, int dimend) const { dim_t num = 1; - const dim_t *d = this->data(); + const dim_t* d = this->data(); for (int i = dimstart; i < dimend; ++i) { num *= d[i]; } return num; } /*! \return the begin data pointer to content of the tuple */ - inline const dim_t *data() const { - return begin(); - } + inline const dim_t* data() const { return begin(); } /*! \return the begin data pointer to content of the tuple */ - inline dim_t *data() { - return begin(); - } + inline dim_t* data() { return begin(); } #ifdef MSHADOW_XINLINE - template - inline TShape(const mshadow::Shape &s) {// NOLINT(*) + template + inline TShape(const mshadow::Shape& s) { // NOLINT(*) this->assign(s.shape_, s.shape_ + dim); } - template - inline TShape(mshadow::Shape &&s) {// NOLINT(*) + template + inline TShape(mshadow::Shape&& s) { // NOLINT(*) this->assign(s.shape_, s.shape_ + dim); } /*! @@ -448,8 +423,8 @@ class TShape : public Tuple { * \tparam dim shape dimension * \return reference of self */ - template - inline TShape &operator=(const mshadow::Shape &shape) { + template + inline TShape& operator=(const mshadow::Shape& shape) { this->assign(shape.shape_, shape.shape_ + dim); return *this; } @@ -458,11 +433,11 @@ class TShape : public Tuple { * \return the shape requested * \tparam dim dimension of the tensor */ - template + template inline mshadow::Shape get() const { CHECK_EQ(dim, static_cast(ndim())) << "dimension do not match target dimension " << dim << " vs " << ndim(); - const dim_t *d = this->data(); + const dim_t* d = this->data(); mshadow::Shape s; for (int i = 0; i < dim; ++i) { s[i] = d[i]; @@ -476,7 +451,7 @@ class TShape : public Tuple { inline mshadow::Shape<2> FlatTo2D(void) const { mshadow::Shape<2> s; if (ndim() == 0) return mshadow::Shape2(0, 0); - const dim_t *d = this->data(); + const dim_t* d = this->data(); s.shape_[1] = d[ndim() - 1]; dim_t ymax = 1; for (size_t i = 1; i < ndim(); ++i) { @@ -495,7 +470,7 @@ class TShape : public Tuple { CHECK(axis_end >= axis_begin); mshadow::Shape<3> s; if (ndim() == 0) return mshadow::Shape3(0, 0, 0); - const dim_t *d = this->data(); + const dim_t* d = this->data(); s.shape_[0] = 1; s.shape_[1] = 1; s.shape_[2] = 1; @@ -516,25 +491,21 @@ class TShape : public Tuple { * \param axis The axis specified. * \return the flat 3d shape */ - inline mshadow::Shape<3> FlatTo3D(size_t axis) const { - return FlatTo3D(axis, axis); - } - inline bool operator==(const TShape &s) const { + inline mshadow::Shape<3> FlatTo3D(size_t axis) const { return FlatTo3D(axis, axis); } + inline bool operator==(const TShape& s) const { if (ndim() != s.ndim()) return false; return std::equal(begin(), end(), s.begin()); } - inline bool operator!=(const TShape &s) const { - return !(*this == s); - } + inline bool operator!=(const TShape& s) const { return !(*this == s); } /*! * \return whether two shape equals * \param s the shape to compare against * \tparam dim dimension of the shape */ - template - inline bool operator==(const mshadow::Shape &s) const { + template + inline bool operator==(const mshadow::Shape& s) const { if (ndim_ != dim) return false; - const dim_t *d = dim <= kStackCache ? data_stack_ : data_heap_; + const dim_t* d = dim <= kStackCache ? data_stack_ : data_heap_; for (size_t i = 0; i < dim; ++i) { if (d[i] != s.shape_[i]) return false; } @@ -545,18 +516,16 @@ class TShape : public Tuple { * \param s the shape to compare against * \tparam dim dimension of the shape */ - template - inline bool operator!=(const mshadow::Shape &s) const { + template + inline bool operator!=(const mshadow::Shape& s) const { return !(*this == s); } #endif }; /*! \brief helper function to cast type of container elements */ -template -inline DstIter ShapeTypeCast(const SrcIter begin, - const SrcIter end, - DstIter dst_begin) { +template +inline DstIter ShapeTypeCast(const SrcIter begin, const SrcIter end, DstIter dst_begin) { typedef typename std::iterator_traits::value_type SrcDType; typedef typename std::iterator_traits::value_type DstDType; auto cast = [](const SrcDType& dim) { return static_cast(dim); }; @@ -564,7 +533,7 @@ inline DstIter ShapeTypeCast(const SrcIter begin, } /*! \brief helper function to transform a container to TShape with type cast */ -template +template inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) { size_t ndim = std::distance(begin, end); TShape res(ndim); @@ -573,9 +542,9 @@ inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) { } /*! \tparam ValueType The type of data stored inside tuple. */ -template -template -inline void Tuple::Save(TStream *strm) const { +template +template +inline void Tuple::Save(TStream* strm) const { strm->Write(&ndim_, sizeof(ndim_)); if (typeid(DType) == typeid(ValueType)) { strm->Write(begin(), sizeof(ValueType) * ndim_); @@ -587,9 +556,9 @@ inline void Tuple::Save(TStream *strm) const { } /*! \tparam ValueType The type of data stored inside tuple. */ -template -template -inline bool Tuple::Load(TStream *strm) { +template +template +inline bool Tuple::Load(TStream* strm) { if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false; this->SetDim(ndim_); size_t nread = sizeof(DType) * ndim_; @@ -607,7 +576,7 @@ inline bool Tuple::Load(TStream *strm) { namespace std { /*! \brief hash function for Tuple. */ -template +template struct hash > { /*! \brief hash a Tuple into unsigned int */ size_t operator()(const nnvm::Tuple& val) const { @@ -621,7 +590,7 @@ struct hash > { }; /*! \brief hash function for TShape. */ -template<> +template <> struct hash { /*! \brief hash a TShape into unsigned int */ size_t operator()(const nnvm::TShape& val) const { @@ -640,11 +609,9 @@ namespace dmlc { DMLC_DECLARE_TYPE_NAME(optional, "Shape or None"); // avoid low version of MSVC #if !defined(_MSC_VER) -template +template struct type_name_helper > { - static inline std::string value() { - return "tuple of <" + type_name() + ">"; - } + static inline std::string value() { return "tuple of <" + type_name() + ">"; } }; #endif } // namespace dmlc diff --git a/nnvm/src/c_api/c_api_common.h b/nnvm/src/c_api/c_api_common.h index b3ff36a..1291947 100644 --- a/nnvm/src/c_api/c_api_common.h +++ b/nnvm/src/c_api/c_api_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,23 +29,34 @@ #include #include #include -#include + #include -#include #include +#include +#include /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { /*! \brief every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR */ -#define API_END() } catch(dmlc::Error &_except_) { return NNAPIHandleException(_except_); } return 0; // NOLINT(*) +#define API_END() \ + } \ + catch (dmlc::Error & _except_) { \ + return NNAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! * \brief every function starts with API_BEGIN(); * and finishes with API_END() or API_END_HANDLE_ERROR * The finally clause contains procedure to cleanup states when an error happens. */ -#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return NNAPIHandleException(_except_); } return 0; // NOLINT(*) - +#define API_END_HANDLE_ERROR(Finalize) \ + } \ + catch (dmlc::Error & _except_) { \ + Finalize; \ + return NNAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! \brief entry to to easily hold returning information */ struct NNAPIThreadLocalEntry { @@ -54,9 +65,9 @@ struct NNAPIThreadLocalEntry { /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; + std::vector ret_vec_charp; /*! \brief result holder for returning handles */ - std::vector ret_handles; + std::vector ret_handles; /*! \brief argument holder to hold symbol */ std::unordered_map kwarg_symbol; }; @@ -69,7 +80,7 @@ typedef dmlc::ThreadLocalStore NNAPIThreadLocalStore; * \param e the exception * \return the return value of API after exception is handled */ -inline int NNAPIHandleException(const dmlc::Error &e) { +inline int NNAPIHandleException(const dmlc::Error& e) { NNAPISetLastError(e.what()); return -1; } diff --git a/nnvm/src/c_api/c_api_error.cc b/nnvm/src/c_api/c_api_error.cc index ba6e1cd..c2f90b1 100644 --- a/nnvm/src/c_api/c_api_error.cc +++ b/nnvm/src/c_api/c_api_error.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,7 @@ * \brief C error handling */ #include + #include "c_api_common.h" struct ErrorEntry { @@ -30,10 +31,6 @@ struct ErrorEntry { typedef dmlc::ThreadLocalStore NNAPIErrorStore; -const char *NNGetLastError() { - return NNAPIErrorStore::Get()->last_error.c_str(); -} +const char* NNGetLastError() { return NNAPIErrorStore::Get()->last_error.c_str(); } -void NNAPISetLastError(const char* msg) { - NNAPIErrorStore::Get()->last_error = msg; -} +void NNAPISetLastError(const char* msg) { NNAPIErrorStore::Get()->last_error = msg; } diff --git a/nnvm/src/c_api/c_api_graph.cc b/nnvm/src/c_api/c_api_graph.cc index cc5449b..a547476 100644 --- a/nnvm/src/c_api/c_api_graph.cc +++ b/nnvm/src/c_api/c_api_graph.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,17 +21,18 @@ * \file c_api_graph.cc * \brief C API related to Graph IR. */ +#include #include -#include -#include #include +#include #include -#include +#include + #include "c_api_common.h" using namespace nnvm; -int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) { +int NNGraphCreate(SymbolHandle symbol, GraphHandle* graph) { Graph* g = new Graph(); API_BEGIN(); g->outputs = static_cast(symbol)->outputs; @@ -45,7 +46,7 @@ int NNGraphFree(GraphHandle handle) { API_END(); } -int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { +int NNGraphGetSymbol(GraphHandle graph, SymbolHandle* symbol) { Symbol* s = new Symbol(); API_BEGIN(); s->outputs = static_cast(graph)->outputs; @@ -53,20 +54,15 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { API_END_HANDLE_ERROR(delete s); } -int NNGraphSetNodeEntryListAttr_(GraphHandle handle, - const char* key, - SymbolHandle list) { +int NNGraphSetNodeEntryListAttr_(GraphHandle handle, const char* key, SymbolHandle list) { API_BEGIN(); Symbol* s = static_cast(list); Graph* g = static_cast(handle); - g->attrs[std::string(key)] - = std::make_shared(s->outputs); + g->attrs[std::string(key)] = std::make_shared(s->outputs); API_END(); } -int NNGraphSetJSONAttr(GraphHandle handle, - const char* key, - const char* json_value) { +int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value) { API_BEGIN(); Graph* g = static_cast(handle); std::string temp(json_value); @@ -78,11 +74,8 @@ int NNGraphSetJSONAttr(GraphHandle handle, API_END(); } -int NNGraphGetJSONAttr(GraphHandle handle, - const char* key, - const char** json_out, - int *success) { - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNGraphGetJSONAttr(GraphHandle handle, const char* key, const char** json_out, int* success) { + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); Graph* g = static_cast(handle); std::string skey(key); @@ -100,10 +93,8 @@ int NNGraphGetJSONAttr(GraphHandle handle, API_END(); } -int NNGraphApplyPasses(GraphHandle src, - nn_uint num_pass, - const char** pass_names, - GraphHandle *dst) { +int NNGraphApplyPasses(GraphHandle src, nn_uint num_pass, const char** pass_names, + GraphHandle* dst) { Graph* g = new Graph(); API_BEGIN(); std::vector vpass; diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index 7ca5603..2127997 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -24,14 +24,14 @@ #include #include #include + #include "c_api_common.h" using namespace nnvm; -int NNListAllOpNames(nn_uint *out_size, - const char*** out_array) { +int NNListAllOpNames(nn_uint* out_size, const char*** out_array) { API_BEGIN(); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); ret->ret_vec_str = dmlc::Registry::ListAllNames(); ret->ret_vec_charp.resize(0); ret->ret_vec_charp.reserve(ret->ret_vec_str.size()); @@ -43,40 +43,31 @@ int NNListAllOpNames(nn_uint *out_size, API_END(); } -int NNGetOpHandle(const char* op_name, - OpHandle* op_out) { +int NNGetOpHandle(const char* op_name, OpHandle* op_out) { API_BEGIN(); *op_out = (OpHandle)Op::Get(op_name); // NOLINT(*) API_END(); } -int NNListUniqueOps(nn_uint *out_size, - OpHandle **out_array) { +int NNListUniqueOps(nn_uint* out_size, OpHandle** out_array) { API_BEGIN(); - auto &vec = dmlc::Registry::List(); + auto& vec = dmlc::Registry::List(); *out_size = static_cast(vec.size()); *out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } -int NNAddControlDeps(SymbolHandle handle, - SymbolHandle src_dep) { +int NNAddControlDeps(SymbolHandle handle, SymbolHandle src_dep) { API_BEGIN(); - static_cast(handle)->AddControlDeps( - *static_cast(src_dep)); + static_cast(handle)->AddControlDeps(*static_cast(src_dep)); API_END(); } -int NNGetOpInfo(OpHandle handle, - const char **name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type) { - const Op *op = static_cast(handle); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNGetOpInfo(OpHandle handle, const char** name, const char** description, nn_uint* num_doc_args, + const char*** arg_names, const char*** arg_type_infos, + const char*** arg_descriptions, const char** return_type) { + const Op* op = static_cast(handle); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); *name = op->name.c_str(); @@ -100,12 +91,9 @@ int NNGetOpInfo(OpHandle handle, API_END(); } -int NNSymbolCreateAtomicSymbol(OpHandle creator, - nn_uint num_param, - const char **keys, - const char **vals, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolCreateAtomicSymbol(OpHandle creator, nn_uint num_param, const char** keys, + const char** vals, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); const Op* op = static_cast(creator); std::unordered_map kwargs; @@ -117,19 +105,17 @@ int NNSymbolCreateAtomicSymbol(OpHandle creator, API_END_HANDLE_ERROR(delete s;); } -int NNSymbolCreateVariable(const char *name, SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolCreateVariable(const char* name, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = Symbol::CreateVariable(name); *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolCreateGroup(nn_uint num_symbols, - SymbolHandle *symbols, - SymbolHandle *out) { - Symbol *s = new Symbol(); - Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*) +int NNSymbolCreateGroup(nn_uint num_symbols, SymbolHandle* symbols, SymbolHandle* out) { + Symbol* s = new Symbol(); + Symbol** sym_arr = (Symbol**)symbols; // NOLINT(*) API_BEGIN(); std::vector syms; for (nn_uint i = 0; i < num_symbols; ++i) { @@ -140,28 +126,24 @@ int NNSymbolCreateGroup(nn_uint num_symbols, API_END_HANDLE_ERROR(delete s); } -int NNSymbolGetOutput(SymbolHandle symbol, - nn_uint index, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolGetOutput(SymbolHandle symbol, nn_uint index, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = (*static_cast(symbol))[index]; *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolGetInternals(SymbolHandle symbol, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolGetInternals(SymbolHandle symbol, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->GetInternals(); *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolGetChildren(SymbolHandle symbol, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolGetChildren(SymbolHandle symbol, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->GetChildren(); *out = s; @@ -174,17 +156,17 @@ int NNSymbolFree(SymbolHandle symbol) { API_END(); } -int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolCopy(SymbolHandle symbol, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->Copy(); *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolPrint(SymbolHandle symbol, const char **out_str) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolPrint(SymbolHandle symbol, const char** out_str) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::ostringstream os; s->Print(os); @@ -193,12 +175,9 @@ int NNSymbolPrint(SymbolHandle symbol, const char **out_str) { API_END(); } -int NNSymbolGetAttr(SymbolHandle symbol, - const char* key, - const char** out, - int* success) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); if (s->GetAttr(key, &(ret->ret_str))) { *out = (ret->ret_str).c_str(); @@ -210,27 +189,20 @@ int NNSymbolGetAttr(SymbolHandle symbol, API_END(); } -int NNSymbolSetAttrs(SymbolHandle symbol, - nn_uint num_param, - const char** keys, - const char** vals) { - Symbol *s = static_cast(symbol); +int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, const char** keys, const char** vals) { + Symbol* s = static_cast(symbol); API_BEGIN(); std::vector > kwargs; for (nn_uint i = 0; i < num_param; ++i) { - kwargs.emplace_back( - std::make_pair(std::string(keys[i]), std::string(vals[i]))); + kwargs.emplace_back(std::make_pair(std::string(keys[i]), std::string(vals[i]))); } s->SetAttrs(kwargs); API_END(); } -int NNSymbolListAttrs(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char*** out) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolListAttrs(SymbolHandle symbol, int option, nn_uint* out_size, const char*** out) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::unordered_map attr = s->ListAttrs(static_cast(option)); // NOLINT(*) @@ -252,12 +224,10 @@ int NNSymbolListAttrs(SymbolHandle symbol, API_END(); } -int NNSymbolListInputVariables(SymbolHandle symbol, - int option, - nn_uint *out_size, +int NNSymbolListInputVariables(SymbolHandle symbol, int option, nn_uint* out_size, SymbolHandle** out_sym_array) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::vector vs = s->ListInputs(Symbol::ListInputOption(option)); ret->ret_handles.resize(0); @@ -272,15 +242,12 @@ int NNSymbolListInputVariables(SymbolHandle symbol, API_END(); } -int NNSymbolListInputNames(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char ***out_str_array) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolListInputNames(SymbolHandle symbol, int option, nn_uint* out_size, + const char*** out_str_array) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); - ret->ret_vec_str = - s->ListInputNames(Symbol::ListInputOption(option)); + ret->ret_vec_str = s->ListInputNames(Symbol::ListInputOption(option)); ret->ret_vec_charp.resize(0); ret->ret_vec_charp.reserve(ret->ret_vec_str.size()); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { @@ -291,11 +258,9 @@ int NNSymbolListInputNames(SymbolHandle symbol, API_END(); } -int NNSymbolListOutputNames(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolListOutputNames(SymbolHandle symbol, nn_uint* out_size, const char*** out_str_array) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_vec_str = s->ListOutputNames(); ret->ret_vec_charp.resize(0); @@ -308,24 +273,19 @@ int NNSymbolListOutputNames(SymbolHandle symbol, API_END(); } -int NNSymbolGetNumOutputs(SymbolHandle symbol, - nn_uint *output_count) { - Symbol *s = static_cast(symbol); +int NNSymbolGetNumOutputs(SymbolHandle symbol, nn_uint* output_count) { + Symbol* s = static_cast(symbol); API_BEGIN(); *output_count = static_cast(s->outputs.size()); API_END(); } -int NNSymbolCompose(SymbolHandle sym, - const char *name, - nn_uint num_args, - const char** keys, +int NNSymbolCompose(SymbolHandle sym, const char* name, nn_uint num_args, const char** keys, SymbolHandle* args) { API_BEGIN(); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); std::string& s_name = ret->ret_str; - std::unordered_map& kwargs - = ret->kwarg_symbol; + std::unordered_map& kwargs = ret->kwarg_symbol; kwargs.clear(); if (name != nullptr) { s_name = name; @@ -335,8 +295,7 @@ int NNSymbolCompose(SymbolHandle sym, Symbol* s = static_cast(sym); if (keys == nullptr && num_args != 0) { kwargs.clear(); - array_view parg( - (Symbol**)args, (Symbol**)args + num_args); // NOLINT(*) + array_view parg((Symbol**)args, (Symbol**)args + num_args); // NOLINT(*) s->Compose(parg, kwargs, s_name); } else { for (nn_uint i = 0; i < num_args; ++i) { diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index c3ae60e..fd5b64f 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -23,6 +23,7 @@ */ #include #include + #include namespace nnvm { @@ -39,23 +40,22 @@ const IndexedGraph& Graph::indexed_graph() const { // e.g. the main graph is level 0 // subgraphs of the main graph is level 1 // subgraphs of the subgraphs of the main graph is level 2 -static void SubgraphSanityCheck(const std::vector> &subgraphs) { +static void SubgraphSanityCheck(const std::vector>& subgraphs) { std::vector*> curr_level; std::vector*> next_level; std::unordered_map node2level; - for (auto &subgraph : subgraphs) - next_level.push_back(&subgraph->outputs); + for (auto& subgraph : subgraphs) next_level.push_back(&subgraph->outputs); for (uint32_t level = 0; !next_level.empty(); ++level) { curr_level.swap(next_level); next_level.clear(); - for (const std::vector *graph_ptr : curr_level) { - const std::vector &graph = *graph_ptr; + for (const std::vector* graph_ptr : curr_level) { + const std::vector& graph = *graph_ptr; DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) { - nnvm::Node *node = n.get(); + nnvm::Node* node = n.get(); // if the node is visited, but on a different level, then check failed // if check failed here or before, we stop doing anything, but raise an error CHECK(!node2level.count(node) || node2level[node] == level) - << "A subgraph should not depend on the outputs of nodes on higher levels"; + << "A subgraph should not depend on the outputs of nodes on higher levels"; // otherwise, this node belongs to the current level node2level[node] = level; // subgraphs of current node belongs to next level @@ -68,55 +68,51 @@ static void SubgraphSanityCheck(const std::vector> &subg } // implement constructor from graph -IndexedGraph::IndexedGraph(const Graph &g) { +IndexedGraph::IndexedGraph(const Graph& g) { entry_rptr_.push_back(0); std::vector inputs_rptr{0}, control_rptr{0}; std::vector> subgraphs; - DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] - (const ObjectPtr& n) { - const auto& is_ghost = Op::GetAttr("TIsGhost"); - if (!n->is_variable() && is_ghost.get(n->op(), false)) return; - CHECK_LT(nodes_.size(), std::numeric_limits::max()); - uint32_t nid = static_cast(nodes_.size()); - CHECK(n); - for (const auto &subgraph : n->attrs.subgraphs) - subgraphs.push_back(subgraph); - // nodes_ - IndexedGraph::Node new_node; - new_node.source = n.get(); - new_node.weak_ref = n; - nodes_.emplace_back(std::move(new_node)); - // arg_nodes_ - if (n->is_variable()) { - input_nodes_.push_back(nid); - } - // node2index_ - node2index_[n.get()] = nid; - // entry rptr - entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs()); - // input entries - for (const auto& e : n->inputs) { - auto it = node2index_.find(e.node.get()); - if (it == node2index_.end() || it->first != e.node.get()) continue; - input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version}); - } - inputs_rptr.push_back(input_entries_.size()); - // control deps - for (const auto& nptr : n->control_deps) { - if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; - auto it = node2index_.find(nptr.get()); - CHECK(it != node2index_.end()) << "control dep not found in graph"; - control_deps_.push_back(it->second); - } - control_rptr.push_back(control_deps_.size()); + DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs](const ObjectPtr& n) { + const auto& is_ghost = Op::GetAttr("TIsGhost"); + if (!n->is_variable() && is_ghost.get(n->op(), false)) return; + CHECK_LT(nodes_.size(), std::numeric_limits::max()); + uint32_t nid = static_cast(nodes_.size()); + CHECK(n); + for (const auto& subgraph : n->attrs.subgraphs) subgraphs.push_back(subgraph); + // nodes_ + IndexedGraph::Node new_node; + new_node.source = n.get(); + new_node.weak_ref = n; + nodes_.emplace_back(std::move(new_node)); + // arg_nodes_ + if (n->is_variable()) { + input_nodes_.push_back(nid); + } + // node2index_ + node2index_[n.get()] = nid; + // entry rptr + entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs()); + // input entries + for (const auto& e : n->inputs) { + auto it = node2index_.find(e.node.get()); + if (it == node2index_.end() || it->first != e.node.get()) continue; + input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version}); + } + inputs_rptr.push_back(input_entries_.size()); + // control deps + for (const auto& nptr : n->control_deps) { + if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; + auto it = node2index_.find(nptr.get()); + CHECK(it != node2index_.end()) << "control dep not found in graph"; + control_deps_.push_back(it->second); + } + control_rptr.push_back(control_deps_.size()); }); - if (!subgraphs.empty()) - SubgraphSanityCheck(subgraphs); + if (!subgraphs.empty()) SubgraphSanityCheck(subgraphs); for (const auto& e : g.outputs) { - outputs_.emplace_back(NodeEntry{ - node2index_.at(e.node.get()), e.index, e.version}); + outputs_.emplace_back(NodeEntry{node2index_.at(e.node.get()), e.index, e.version}); } static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); @@ -124,10 +120,9 @@ IndexedGraph::IndexedGraph(const Graph &g) { // input_entries_ and control_rptr must not change after this step. const NodeEntry* iptr = dmlc::BeginPtr(input_entries_); for (size_t nid = 0; nid < nodes_.size(); ++nid) { - nodes_[nid].inputs = array_view( - iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); - if (nodes_[nid].source->op() != nullptr && - fmutate_inputs.count(nodes_[nid].source->op())) { + nodes_[nid].inputs = + array_view(iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); + if (nodes_[nid].source->op() != nullptr && fmutate_inputs.count(nodes_[nid].source->op())) { for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) { mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id); } @@ -135,8 +130,8 @@ IndexedGraph::IndexedGraph(const Graph &g) { } const uint32_t* cptr = dmlc::BeginPtr(control_deps_); for (size_t nid = 0; nid < nodes_.size(); ++nid) { - nodes_[nid].control_deps = array_view( - cptr + control_rptr[nid], cptr + control_rptr[nid + 1]); + nodes_[nid].control_deps = + array_view(cptr + control_rptr[nid], cptr + control_rptr[nid + 1]); } } diff --git a/nnvm/src/core/op.cc b/nnvm/src/core/op.cc index eb51d4b..08a11df 100644 --- a/nnvm/src/core/op.cc +++ b/nnvm/src/core/op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,8 +24,8 @@ #include #include -#include #include +#include #include #include @@ -46,7 +46,7 @@ struct OpManager { // storage of additional attribute table. std::unordered_map > attr; // storage of existing triggers - std::unordered_map > > tmap; + std::unordered_map > > tmap; // group of each operator. std::vector > op_group; // get singleton of the @@ -70,14 +70,13 @@ Op& Op::add_alias(const std::string& alias) { // NOLINT(*) // find operator by name const Op* Op::Get(const std::string& name) { const Op* op = dmlc::Registry::Find(name); - CHECK(op != nullptr) - << "Operator " << name << " is not registered"; + CHECK(op != nullptr) << "Operator " << name << " is not registered"; return op; } // Get attribute map by key const any* Op::GetAttrMap(const std::string& key) { - auto& dict = OpManager::Global()->attr; + auto& dict = OpManager::Global()->attr; auto it = dict.find(key); if (it != dict.end()) { return it->second.get(); @@ -87,8 +86,7 @@ const any* Op::GetAttrMap(const std::string& key) { } // update attribute map -void Op::UpdateAttrMap(const std::string& key, - std::function updater) { +void Op::UpdateAttrMap(const std::string& key, std::function updater) { OpManager* mgr = OpManager::Global(); std::lock_guard(mgr->mutex); std::unique_ptr& value = mgr->attr[key]; @@ -96,16 +94,14 @@ void Op::UpdateAttrMap(const std::string& key, if (updater != nullptr) updater(value.get()); } -void Op::AddGroupTrigger(const std::string& group_name, - std::function trigger) { +void Op::AddGroupTrigger(const std::string& group_name, std::function trigger) { OpManager* mgr = OpManager::Global(); std::lock_guard(mgr->mutex); auto& tvec = mgr->tmap[group_name]; tvec.push_back(trigger); auto& op_group = mgr->op_group; for (const Op* op : dmlc::Registry::List()) { - if (op->index_ < op_group.size() && - op_group[op->index_].count(group_name) != 0) { + if (op->index_ < op_group.size() && op_group[op->index_].count(group_name) != 0) { trigger((Op*)op); // NOLINT(*) } } diff --git a/nnvm/src/core/pass.cc b/nnvm/src/core/pass.cc index b43d470..974cd2b 100644 --- a/nnvm/src/core/pass.cc +++ b/nnvm/src/core/pass.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,7 @@ * \brief Support for pass registry. */ #include + #include namespace dmlc { @@ -31,7 +32,7 @@ DMLC_REGISTRY_ENABLE(nnvm::PassFunctionReg); namespace nnvm { -const PassFunctionReg* FindPassDep(const std::string&attr_name) { +const PassFunctionReg* FindPassDep(const std::string& attr_name) { for (auto* r : dmlc::Registry::List()) { for (auto& s : r->graph_attr_targets) { if (s == attr_name) return r; @@ -40,13 +41,11 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) { return nullptr; } -Graph ApplyPasses(Graph g, - const std::vector& pass) { +Graph ApplyPasses(Graph g, const std::vector& pass) { std::vector fpass; for (auto& name : pass) { auto* reg = dmlc::Registry::Find(name); - CHECK(reg != nullptr) - << "Cannot find pass " << name << " in the registry"; + CHECK(reg != nullptr) << "Cannot find pass " << name << " in the registry"; fpass.push_back(reg); } @@ -58,10 +57,8 @@ Graph ApplyPasses(Graph g, if (pass_dep != nullptr) { msg = " The attribute is provided by pass " + pass_dep->name; } - LOG(FATAL) << "Graph attr dependency " << dep - << " is required by pass " << r->name - << " but is not available " - << msg; + LOG(FATAL) << "Graph attr dependency " << dep << " is required by pass " << r->name + << " but is not available " << msg; } } g = r->body(std::move(g)); diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 86dc7e6..12b8675 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -22,13 +22,13 @@ * \brief Symbolic graph composition API. */ #include -#include #include +#include namespace nnvm { namespace symbol_constants { -const char *kNamespaceSeparator = "$"; +const char* kNamespaceSeparator = "$"; } // namespace symbol_constants // auxililary version attribute in variable. @@ -48,7 +48,7 @@ ObjectPtr CreateVariableNode(const std::string& name) { // If the node's op mutates a certain input variable, // The version of that varaible will increase // version is used to implicitly order the mutation sequences -inline void UpdateNodeVersion(Node *n) { +inline void UpdateNodeVersion(Node* n) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); for (NodeEntry& e : n->inputs) { if (e.node->is_variable()) { @@ -58,16 +58,14 @@ inline void UpdateNodeVersion(Node *n) { if (fmutate_inputs.count(n->op()) != 0) { for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) { NodeEntry& e = n->inputs[i]; - CHECK(e.node->is_variable()) - << "Mutation target can only be Variable"; + CHECK(e.node->is_variable()) << "Mutation target can only be Variable"; // increase the version of the variable. e.version = ++nnvm::get(e.node->attrs.parsed).version; } } } -inline std::string DefaultVarName(const std::string &op_name, - const std::string &arg_name) { +inline std::string DefaultVarName(const std::string& op_name, const std::string& arg_name) { if (op_name.length() == 0) { return arg_name; } else { @@ -75,8 +73,7 @@ inline std::string DefaultVarName(const std::string &op_name, } } -inline void KeywordArgumentMismatch(const char *source, - const std::vector& user_args, +inline void KeywordArgumentMismatch(const char* source, const std::vector& user_args, const array_view& args) { std::unordered_set keys(args.begin(), args.end()); std::ostringstream head, msg; @@ -87,16 +84,13 @@ inline void KeywordArgumentMismatch(const char *source, for (const auto& key : user_args) { if (keys.count(key) == 0) { - LOG(FATAL) << source - << "Keyword argument name " << key << " not found." - << msg.str(); + LOG(FATAL) << source << "Keyword argument name " << key << " not found." << msg.str(); } } } -template -inline std::vector GetKeys( - const std::unordered_map& kwargs) { +template +inline std::vector GetKeys(const std::unordered_map& kwargs) { std::vector keys(kwargs.size()); std::transform(kwargs.begin(), kwargs.end(), keys.begin(), [](decltype(*kwargs.begin())& kv) { return kv.first; }); @@ -117,14 +111,14 @@ Symbol Symbol::Copy() const { std::unordered_map old_new; // use DFSVisit to copy all the nodes DFSVisit(this->outputs, [&old_new](const ObjectPtr& node) { - ObjectPtr np = Node::Create(); - np->attrs = node->attrs; - old_new[node.get()] = std::move(np); - }); + ObjectPtr np = Node::Create(); + np->attrs = node->attrs; + old_new[node.get()] = std::move(np); + }); // connect nodes of new graph - for (const auto &kv : old_new) { + for (const auto& kv : old_new) { for (const NodeEntry& e : kv.first->inputs) { - Node *ptr = e.node.get(); + Node* ptr = e.node.get(); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); } for (const ObjectPtr& p : kv.first->control_deps) { @@ -133,66 +127,64 @@ Symbol Symbol::Copy() const { } // set the head Symbol ret; - for (const NodeEntry &e : outputs) { + for (const NodeEntry& e : outputs) { ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version}); } return ret; } -void Symbol::Print(std::ostream &os) const { - if (outputs.size() == 1 && - outputs[0].node->inputs.size() == 0 && +void Symbol::Print(std::ostream& os) const { + if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0 && outputs[0].node->control_deps.size() == 0) { if (outputs[0].node->is_variable()) { os << "Variable:" << outputs[0].node->attrs.name << '\n'; } else { - os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n'; + os << "AtomicFunctor " + << " Op:" << outputs[0].node->op()->name << '\n'; } } else { // use DFSVisit to copy all the nodes os << "Symbol Outputs:\n"; for (size_t i = 0; i < outputs.size(); ++i) { - os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name - << '(' << outputs[i].index << ")\n"; + os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name << '(' << outputs[i].index + << ")\n"; } DFSVisit(this->outputs, [&os](const ObjectPtr& node) { - if (node->is_variable()) { - os << "Variable:" << node->attrs.name << '\n'; - } else { - os << "--------------------\n"; - os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n' - << "Inputs:\n"; - for (size_t i = 0; i < node->inputs.size(); ++i) { - const NodeEntry& e = node->inputs[i]; - os << "\targ[" << i << "]=" << e.node->attrs.name - << '(' << e.index << ")"; - if (e.node->is_variable()) { - os << " version=" << e.version << '\n'; - } else { - os << '\n'; - } + if (node->is_variable()) { + os << "Variable:" << node->attrs.name << '\n'; + } else { + os << "--------------------\n"; + os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n' << "Inputs:\n"; + for (size_t i = 0; i < node->inputs.size(); ++i) { + const NodeEntry& e = node->inputs[i]; + os << "\targ[" << i << "]=" << e.node->attrs.name << '(' << e.index << ")"; + if (e.node->is_variable()) { + os << " version=" << e.version << '\n'; + } else { + os << '\n'; } - if (!node->attrs.dict.empty()) { - os << "Attrs:\n"; - // make an ordered copy because unordered_map doesn't guarantee order. - std::map sorted_dict( - node->attrs.dict.begin(), node->attrs.dict.end()); - for (auto &kv : sorted_dict) { - os << '\t' << kv.first << '=' << kv.second << '\n'; - } + } + if (!node->attrs.dict.empty()) { + os << "Attrs:\n"; + // make an ordered copy because unordered_map doesn't guarantee order. + std::map sorted_dict(node->attrs.dict.begin(), + node->attrs.dict.end()); + for (auto& kv : sorted_dict) { + os << '\t' << kv.first << '=' << kv.second << '\n'; } - if (node->control_deps.size() != 0) { - os << "Control deps:\n"; - for (size_t i = 0; i < node->control_deps.size(); ++i) { - os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n'; - } + } + if (node->control_deps.size() != 0) { + os << "Control deps:\n"; + for (size_t i = 0; i < node->control_deps.size(); ++i) { + os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n'; } } - }); + } + }); } } -Symbol Symbol::operator[] (size_t index) const { +Symbol Symbol::operator[](size_t index) const { size_t nreturn = outputs.size(); CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; if (nreturn == 1) { @@ -208,25 +200,25 @@ std::vector Symbol::ListInputs(ListInputOption option) const { std::vector ret; if (option == kAll) { ret.reserve(this->outputs.size()); - DFSVisit(this->outputs, [&ret](const ObjectPtr &node) { - if (node->is_variable()) { - ret.push_back(node); - } - }); + DFSVisit(this->outputs, [&ret](const ObjectPtr& node) { + if (node->is_variable()) { + ret.push_back(node); + } + }); } else { std::unordered_set mutable_set; std::vector vlist; vlist.reserve(this->outputs.size()); static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); - DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr &node) { - if (node->is_variable()) { - vlist.push_back(node); - } else if (fmutate_inputs.count(node->op())) { - for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){ - mutable_set.insert(node->inputs[i].node.get()); - } + DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr& node) { + if (node->is_variable()) { + vlist.push_back(node); + } else if (fmutate_inputs.count(node->op())) { + for (uint32_t i : fmutate_inputs[node->op()](node->attrs)) { + mutable_set.insert(node->inputs[i].node.get()); } - }); + } + }); ret.reserve(vlist.size()); for (const ObjectPtr& node : vlist) { if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) || @@ -252,7 +244,7 @@ std::vector Symbol::ListOutputNames() const { std::vector ret; ret.reserve(outputs.size()); - for (auto &head : outputs) { + for (auto& head : outputs) { if (head.node->is_variable()) { ret.push_back(head.node->attrs.name); } else { @@ -291,8 +283,7 @@ void Symbol::Compose(const array_view& args, Node* n = outputs[0].node.get(); FInputGraph fng = fgraph.get(n->op(), nullptr); std::vector garg_idx; - if (fng != nullptr) - garg_idx = fng(n->attrs); + if (fng != nullptr) garg_idx = fng(n->attrs); // The names of the arguments that contain graphs. FListInputNames name_fn = flist_inputs.get(n->op(), nullptr); @@ -300,8 +291,7 @@ void Symbol::Compose(const array_view& args, std::vector garg_names(garg_idx.size()); for (size_t i = 0; i < garg_idx.size(); i++) { size_t idx = garg_idx[i]; - if (idx < arg_names.size()) - garg_names[i] = arg_names[idx]; + if (idx < arg_names.size()) garg_names[i] = arg_names[idx]; } // parameter check. @@ -309,13 +299,13 @@ void Symbol::Compose(const array_view& args, // If the argument isn't a graph, it should have only one output. if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end()) CHECK_EQ(args[i]->outputs.size(), 1U) - << "Argument " << i << " is a tuple, single value is required"; + << "Argument " << i << " is a tuple, single value is required"; } for (const auto& kv : kwargs) { - if (garg_names.empty() - || std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end()) + if (garg_names.empty() || + std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end()) CHECK_EQ(kv.second->outputs.size(), 1U) - << "Keyword Argument " << kv.first << " is a tuple, single value is required"; + << "Keyword Argument " << kv.first << " is a tuple, single value is required"; } // assign new name if (!name.empty()) outputs[0].node->attrs.name = name; @@ -323,14 +313,14 @@ void Symbol::Compose(const array_view& args, // Atomic functor composition. if (IsAtomic(outputs)) { uint32_t n_req = n->num_inputs(); - std::vector arg_vec(args.begin(), args.end()); + std::vector arg_vec(args.begin(), args.end()); std::unordered_map kwarg_map(kwargs.begin(), kwargs.end()); // If one of the input arguments is a graph, we need to remove it from the // list. if (fng != nullptr) { std::vector idxes = fng(n->attrs); for (auto idx : idxes) { - const Symbol *sym; + const Symbol* sym; if (idx < arg_vec.size()) { sym = arg_vec[idx]; } else { @@ -339,8 +329,7 @@ void Symbol::Compose(const array_view& args, sym = it->second; kwarg_map.erase(it); } - if (n_req != kVarg) - n_req--; + if (n_req != kVarg) n_req--; n->attrs.subgraphs.push_back(std::make_shared(*sym)); } // Because idxes does not contain duplicates, the loop below functions well. @@ -358,8 +347,7 @@ void Symbol::Compose(const array_view& args, if (n_req != kVarg) { n->inputs.resize(n_req); CHECK_LE(arg_vec.size(), n_req) - << "Incorrect number of arguments, requires " << n_req - << ", provided " << arg_vec.size(); + << "Incorrect number of arguments, requires " << n_req << ", provided " << arg_vec.size(); for (size_t i = 0; i < arg_vec.size(); ++i) { n->inputs[i] = arg_vec[i]->outputs[0]; } @@ -375,8 +363,7 @@ void Symbol::Compose(const array_view& args, n->inputs[i] = it->second->outputs[0]; ++nmatched; } else { - n->inputs[i] = NodeEntry{ - CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0}; + n->inputs[i] = NodeEntry{CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0}; // copy attribute of parent over automatically created variables n->inputs[i].node->attrs.dict = n->attrs.dict; } @@ -409,20 +396,19 @@ void Symbol::Compose(const array_view& args, } } else { // general composition - CHECK_EQ(args.size(), 0U) - << "General composition only support kwargs for now"; + CHECK_EQ(args.size(), 0U) << "General composition only support kwargs for now"; size_t nmatched = 0; size_t arg_counter = 0; - std::unordered_map replace_map; + std::unordered_map replace_map; // replace map stores the existing replacement plan for arguments node - auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map] - (const ObjectPtr &node) { + auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, + &replace_map](const ObjectPtr& node) { if (node->is_variable()) { if (arg_counter < args.size()) { replace_map[node.get()] = &(args[arg_counter]->outputs[0]); ++arg_counter; } else { - // match kwargs + // match kwargs auto kit = kwargs.find(node->attrs.name); if (kit != kwargs.end()) { replace_map[node.get()] = &(kit->second->outputs[0]); @@ -436,12 +422,11 @@ void Symbol::Compose(const array_view& args, if (nmatched == kwargs.size() && arg_counter <= args.size()) { std::vector update_nodes; std::vector > replace_plan; - auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes] - (const ObjectPtr &node) { + auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes](const ObjectPtr& node) { // visit all the childs, find possible replacement bool repl = false; for (size_t i = 0; i < node->inputs.size(); ++i) { - NodeEntry *e = &(node->inputs[i]); + NodeEntry* e = &(node->inputs[i]); if (e->node->is_variable()) { auto iter = replace_map.find(e->node.get()); if (iter != replace_map.end()) { @@ -479,17 +464,16 @@ void Symbol::Compose(const array_view& args, } } -Symbol Symbol::operator () (const array_view& args, - const std::unordered_map& kwargs, - const std::string& name) const { +Symbol Symbol::operator()(const array_view& args, + const std::unordered_map& kwargs, + const std::string& name) const { Symbol s = this->Copy(); s.Compose(args, kwargs, name); return s; } void Symbol::AddControlDeps(const Symbol& src) { - CHECK_EQ(outputs.size(), 1U) - << "AddControlDeps only works for nongrouped symbol"; + CHECK_EQ(outputs.size(), 1U) << "AddControlDeps only works for nongrouped symbol"; Node* n = outputs[0].node.get(); for (const NodeEntry& sp : src.outputs) { n->control_deps.push_back(sp.node); @@ -500,21 +484,21 @@ Symbol Symbol::GetInternals() const { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol ret; DFSVisit(this->outputs, [&ret](const ObjectPtr& node) { - Node* n = node.get(); - if (n->is_variable()) { - // grab version from variable. - VariableParam& param = nnvm::get(n->attrs.parsed); - ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); - } else { - uint32_t nout = n->num_outputs(); - if (fnum_vis_output.count(n->op())) { - nout = fnum_vis_output[n->op()](n->attrs); - } - for (uint32_t i = 0; i < nout; ++i) { - ret.outputs.emplace_back(NodeEntry{node, i, 0}); - } + Node* n = node.get(); + if (n->is_variable()) { + // grab version from variable. + VariableParam& param = nnvm::get(n->attrs.parsed); + ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); + } else { + uint32_t nout = n->num_outputs(); + if (fnum_vis_output.count(n->op())) { + nout = fnum_vis_output[n->op()](n->attrs); } - }); + for (uint32_t i = 0; i < nout; ++i) { + ret.outputs.emplace_back(NodeEntry{node, i, 0}); + } + } + }); return ret; } @@ -533,8 +517,7 @@ Symbol Symbol::GetChildren() const { void Symbol::SetAttrs(const std::vector >& attrs) { Node* node = outputs[0].node.get(); for (const NodeEntry& e : outputs) { - CHECK(node == e.node.get()) - << "Symbol.SetAttrs only works for non-grouped symbol"; + CHECK(node == e.node.get()) << "Symbol.SetAttrs only works for non-grouped symbol"; } for (const auto& kv : attrs) { if (kv.first == "name") { @@ -583,29 +566,27 @@ std::unordered_map Symbol::ListAttrs(ListAttrOption op if (option == kRecursive) { std::unordered_map ret; DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { - for (const auto& it : n->attrs.dict) { - ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; - } - }); + for (const auto& it : n->attrs.dict) { + ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; + } + }); return ret; } else { return outputs[0].node->attrs.dict; } } -std::vector > - Symbol::ListAttrsRecursive() const { +std::vector > Symbol::ListAttrsRecursive() const { std::vector > ret; DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { - for (const auto& it : n->attrs.dict) { - ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); - } - }); + for (const auto& it : n->attrs.dict) { + ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); + } + }); return ret; } -Symbol Symbol::CreateFunctor(const Op* op, - std::unordered_map attrs) { +Symbol Symbol::CreateFunctor(const Op* op, std::unordered_map attrs) { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol s; ObjectPtr n = Node::Create(); @@ -641,9 +622,9 @@ Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) { return s; } -Symbol Symbol::CreateGroup(const std::vector &symbols) { +Symbol Symbol::CreateGroup(const std::vector& symbols) { Symbol ret; - for (const auto &s : symbols) { + for (const auto& s : symbols) { ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end()); } return ret; diff --git a/nnvm/src/pass/correct_layout.cc b/nnvm/src/pass/correct_layout.cc index bdb7dba..b9024a5 100644 --- a/nnvm/src/pass/correct_layout.cc +++ b/nnvm/src/pass/correct_layout.cc @@ -22,16 +22,15 @@ * \brief Infer and correct layout. */ #include -#include #include -#include #include +#include +#include namespace nnvm { namespace pass { -nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src, - const Layout& dst) { +nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src, const Layout& dst) { static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__"); static int count = 0; nnvm::ObjectPtr n = nnvm::Node::Create(); @@ -50,8 +49,7 @@ using LayoutAttrDict = std::unordered_map >; * insert layout transform nodes automatically. */ nnvm::Graph CorrectLayout(nnvm::Graph src) { - static auto& op_correct_layout = - nnvm::Op::GetAttr("FCorrectLayout"); + static auto& op_correct_layout = nnvm::Op::GetAttr("FCorrectLayout"); const IndexedGraph& idx = src.indexed_graph(); std::vector mirror_vec(idx.num_nodes(), nullptr); @@ -65,13 +63,12 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { *new_node = *(inode.source); if (new_node->is_variable()) { // Variable node. No operator. Only one output entry. - auto input_iter = std::find( - idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid); + auto input_iter = std::find(idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid); CHECK(input_iter != idx.input_nodes().cend()); int64_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter); if (src.HasAttr("layout_inputs")) { - new_layouts[new_node.get()] = - {src.GetAttr >("layout_inputs")[input_id]}; + new_layouts[new_node.get()] = { + src.GetAttr >("layout_inputs")[input_id]}; } else { new_layouts[new_node.get()] = {Layout::Undef()}; } @@ -110,9 +107,9 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { } if (op_correct_layout.count(new_node->op())) { - const auto &flayout = op_correct_layout[new_node->op()]; + const auto& flayout = op_correct_layout[new_node->op()]; CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts)) - << "Layout infer fail"; + << "Layout infer fail"; CHECK_EQ(request_ilayouts.size(), num_inputs); CHECK_EQ(produce_olayouts.size(), num_outputs); } @@ -175,10 +172,10 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { // register pass NNVM_REGISTER_PASS(CorrectLayout) -.describe("Return a layout-transformed graph of src.") -.set_body(CorrectLayout) -.provide_graph_attr("layout") -.set_change_graph(true); + .describe("Return a layout-transformed graph of src.") + .set_body(CorrectLayout) + .provide_graph_attr("layout") + .set_change_graph(true); DMLC_JSON_ENABLE_ANY(LayoutVector, list_layout); diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 9c30a78..1df3af7 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -22,8 +22,9 @@ * \brief Passes that takes gradient of the graph * This code code was modified based on mxnet codebase by Min Lin */ -#include #include +#include + #include #include @@ -53,8 +54,7 @@ NodeEntry DefaultAggregateGradient(std::vector&& v) { } } -bool CheckGradAllZero(const std::vector& grads, - const std::vector& zero_ops) { +bool CheckGradAllZero(const std::vector& grads, const std::vector& zero_ops) { if (!grads.size() || !zero_ops.size()) return false; for (const auto& g : grads) { bool found = false; @@ -82,22 +82,18 @@ struct GradEntry { Graph Gradient(Graph src) { using nnvm::FGradient; - using MirrorFun = std::function; - using AttrHintFun = std::function; + using MirrorFun = std::function; + using AttrHintFun = std::function; - CHECK_NE(src.attrs.count("grad_ys"), 0U) - << "Gradient require grad_ys to be presented."; + CHECK_NE(src.attrs.count("grad_ys"), 0U) << "Gradient require grad_ys to be presented."; CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U) << "Gradient require grad_ys_out_grad to be presented."; - CHECK_NE(src.attrs.count("grad_xs"), 0U) - << "Gradient require grad_xs to be presented."; - const std::vector& ys = - src.GetAttr >("grad_ys"); + CHECK_NE(src.attrs.count("grad_xs"), 0U) << "Gradient require grad_xs to be presented."; + const std::vector& ys = src.GetAttr >("grad_ys"); const std::vector& ys_out_grad = src.GetAttr >("grad_ys_out_grad"); - const std::vector& xs = - src.GetAttr >("grad_xs"); - using AggFun = std::function&& inputs)>; + const std::vector& xs = src.GetAttr >("grad_xs"); + using AggFun = std::function && inputs)>; AggFun agg_fun = DefaultAggregateGradient; if (src.attrs.count("grad_aggregate_fun") != 0) { agg_fun = src.GetAttr("grad_aggregate_fun"); @@ -114,31 +110,30 @@ Graph Gradient(Graph src) { if (src.attrs.count("zero_ops") != 0) { zero_ops = src.GetAttr >("zero_ops"); } - const Op* copy_op = (src.attrs.count("copy_op") != 0) ? - Op::Get(src.GetAttr("copy_op")) : - nullptr; + const Op* copy_op = + (src.attrs.count("copy_op") != 0) ? Op::Get(src.GetAttr("copy_op")) : nullptr; // topo sort std::vector topo_order; std::unordered_map > output_grads; DFSVisit(ys, [&](const ObjectPtr& node) { - if (output_grads.count(node.get()) == 0) { - output_grads[node.get()].resize(node->num_outputs()); - } - topo_order.push_back(node); - }); + if (output_grads.count(node.get()) == 0) { + output_grads[node.get()].resize(node->num_outputs()); + } + topo_order.push_back(node); + }); CHECK_EQ(ys.size(), ys_out_grad.size()); for (size_t i = 0; i < ys.size(); ++i) { NodeEntry ograd = ys_out_grad[i]; - output_grads[ys[i].node.get()][ys[i].index].grads = { ograd }; + output_grads[ys[i].node.get()][ys[i].index].grads = {ograd}; } // Check that all xs are reachable from ys for (size_t i = 0; i < xs.size(); ++i) { CHECK(output_grads.find(xs[i].node.get()) != output_grads.end()) - << "Cannot differentiate with respect to the " << i+1 << "-th variable " + << "Cannot differentiate with respect to the " << i + 1 << "-th variable " << "because it is unreachable from the outputs."; } @@ -211,8 +206,7 @@ Graph Gradient(Graph src) { LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable " << "because it didn't register FGradient attribute."; } - for (const auto& nodeEntry : input_grads) - CHECK(nodeEntry.node); + for (const auto& nodeEntry : input_grads) CHECK(nodeEntry.node); auto git = input_grads.begin(); CHECK((*rit)->inputs.size() <= input_grads.size()); for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { @@ -252,12 +246,12 @@ Graph Gradient(Graph src) { copy_node->attrs.name = os.str(); copy_node->inputs.emplace_back(entry.sum); if (copy_node->attrs.op->attr_parser != nullptr) { - copy_node->attrs.op->attr_parser(&(copy_node->attrs)); + copy_node->attrs.op->attr_parser(&(copy_node->attrs)); } unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter)); } } else { - ret.outputs[counter] = entry.sum; + ret.outputs[counter] = entry.sum; } ++counter; } @@ -271,12 +265,12 @@ Graph Gradient(Graph src) { // register pass NNVM_REGISTER_PASS(Gradient) -.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") -.set_body(Gradient) -.set_change_graph(true) -.depend_graph_attr("grad_ys") -.depend_graph_attr("grad_xs") -.depend_graph_attr("grad_ys_out_grad"); + .describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") + .set_body(Gradient) + .set_change_graph(true) + .depend_graph_attr("grad_ys") + .depend_graph_attr("grad_xs") + .depend_graph_attr("grad_ys_out_grad"); } // namespace } // namespace pass diff --git a/nnvm/src/pass/graph_algorithm.h b/nnvm/src/pass/graph_algorithm.h index 1d274ff..b305c08 100644 --- a/nnvm/src/pass/graph_algorithm.h +++ b/nnvm/src/pass/graph_algorithm.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,11 +22,12 @@ * \brief This header contains graph algorithms on StaticGraph. * It is used compute informations such as whether two * operations can run in parallel, and helps allocation. -*/ + */ #ifndef NNVM_PASS_GRAPH_ALGORITHM_H_ #define NNVM_PASS_GRAPH_ALGORITHM_H_ #include + #include namespace nnvm { @@ -41,10 +42,8 @@ namespace pass { * \param path the output path of nodes. * \return the total reward of best path. */ -inline uint32_t FindBestPath( - const IndexedGraph& graph, - const std::vector& node_reward, - std::vector* path) { +inline uint32_t FindBestPath(const IndexedGraph& graph, const std::vector& node_reward, + std::vector* path) { const uint32_t num_nodes = static_cast(graph.num_nodes()); CHECK_EQ(num_nodes, node_reward.size()); @@ -71,7 +70,8 @@ inline uint32_t FindBestPath( path->clear(); uint32_t reward = 0; for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) { - path->push_back(nid); reward += node_reward[nid]; + path->push_back(nid); + reward += node_reward[nid]; } CHECK_EQ(reward, best_solution); return best_solution; @@ -88,11 +88,8 @@ inline uint32_t FindBestPath( * \param color the color index of each of the node. * \return the total number of colors. */ -inline uint32_t ColorNodeGroup( - const IndexedGraph &graph, - std::vector node_importance, - uint32_t max_ncolor, - std::vector *color) { +inline uint32_t ColorNodeGroup(const IndexedGraph& graph, std::vector node_importance, + uint32_t max_ncolor, std::vector* color) { CHECK_NE(max_ncolor, 0U); CHECK_EQ(graph.num_nodes(), node_importance.size()); diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 876dce1..fde1691 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -21,33 +21,24 @@ * \file infer_shape.cc * \brief Inference the shapes given existin information. */ -#include -#include #include +#include +#include namespace nnvm { namespace pass { namespace { -template -Graph InferAttr(Graph &&ret, - const AttrType empty_val, - const char* infer_name, - const char* input_name, - const char* attr_key_name, - const char* attr_name, - const char* unknown_name, - IsNone fis_none, - FDefault fdefault) { +template +Graph InferAttr(Graph&& ret, const AttrType empty_val, const char* infer_name, + const char* input_name, const char* attr_key_name, const char* attr_name, + const char* unknown_name, IsNone fis_none, FDefault fdefault) { using AttrVector = std::vector; const IndexedGraph& idx = ret.indexed_graph(); - static auto& finfer_shape = - Op::GetAttr >(infer_name); - static auto& is_backward = - Op::GetAttr("TIsBackward"); + static auto& finfer_shape = Op::GetAttr>(infer_name); + static auto& is_backward = Op::GetAttr("TIsBackward"); // gradient function, used to get node correspondence. - static auto& fgrad = - Op::GetAttr("FGradient"); + static auto& fgrad = Op::GetAttr("FGradient"); // reshape shape vector AttrVector rshape; if (ret.attrs.count(attr_name) != 0) { @@ -70,8 +61,7 @@ Graph InferAttr(Graph &&ret, // get the shape hints std::string shape_hints_key = std::string(attr_name) + "_hints"; if (ret.attrs.count(shape_hints_key)) { - NodeEntryMap shape_hints = - ret.GetAttr>(shape_hints_key); + NodeEntryMap shape_hints = ret.GetAttr>(shape_hints_key); for (const auto& kv : shape_hints) { NodeEntry e = kv.first; if (idx.exist(e.node.get())) { @@ -110,7 +100,7 @@ Graph InferAttr(Graph &&ret, } } else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) { CHECK_GE(inode.control_deps.size(), 1U) - << "BackwardOp need to have control_deps to its forward op"; + << "BackwardOp need to have control_deps to its forward op"; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; ObjectPtr fwd_ptr = inode.source->control_deps[0]; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; @@ -141,7 +131,7 @@ Graph InferAttr(Graph &&ret, } // out grad entries CHECK(igrad_node != nullptr) - << "Cannot find matching backward op for " << inode.source->attrs.name; + << "Cannot find matching backward op for " << inode.source->attrs.name; for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { const NodeEntry& e = igrad_node->inputs[i]; if (e.node == nullptr) { @@ -174,10 +164,9 @@ Graph InferAttr(Graph &&ret, throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); } } else { - CHECK(!last_iter) - << "Attribute " << infer_name - << " is not registered by op " << inode.source->op()->name - << " we are not able to complete the inference because of this"; + CHECK(!last_iter) << "Attribute " << infer_name << " is not registered by op " + << inode.source->op()->name + << " we are not able to complete the inference because of this"; } } // Save to the result map. @@ -221,32 +210,30 @@ Graph InferAttr(Graph &&ret, } NNVM_REGISTER_PASS(InferShape) -.describe("Infer the shape of each node entries.") -.set_body([](Graph ret) { - return InferAttr( - std::move(ret), TShape(), - "FInferShape", "shape_inputs", "shape_attr_key", - "shape", "shape_num_unknown_nodes", - [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, - nullptr); - }) -.set_change_graph(false) -.provide_graph_attr("shape"); + .describe("Infer the shape of each node entries.") + .set_body([](Graph ret) { + return InferAttr( + std::move(ret), TShape(), "FInferShape", "shape_inputs", "shape_attr_key", "shape", + "shape_num_unknown_nodes", [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, + nullptr); + }) + .set_change_graph(false) + .provide_graph_attr("shape"); // inference function for same type -inline bool SameType(const NodeAttrs& attrs, - std::vector *iattr, - std::vector *oattr) { +inline bool SameType(const NodeAttrs& attrs, std::vector* iattr, std::vector* oattr) { int def_v = -1; for (int v : *oattr) { if (v != -1) { - def_v = v; break; + def_v = v; + break; } } if (def_v == -1) { for (int v : *iattr) { if (v != -1) { - def_v = v; break; + def_v = v; + break; } } } @@ -261,17 +248,14 @@ inline bool SameType(const NodeAttrs& attrs, } NNVM_REGISTER_PASS(InferType) -.describe("Infer the dtype of each node entries.") -.set_body([](Graph ret) { - return InferAttr( - std::move(ret), -1, - "FInferType", "dtype_inputs", "dtype_attr_key", - "dtype", "dtype_num_unknown_nodes", - [](const int t) { return t == -1; }, - SameType); - }) -.set_change_graph(false) -.provide_graph_attr("dtype"); + .describe("Infer the dtype of each node entries.") + .set_body([](Graph ret) { + return InferAttr( + std::move(ret), -1, "FInferType", "dtype_inputs", "dtype_attr_key", "dtype", + "dtype_num_unknown_nodes", [](const int t) { return t == -1; }, SameType); + }) + .set_change_graph(false) + .provide_graph_attr("dtype"); DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape); DMLC_JSON_ENABLE_ANY(DTypeVector, list_int); diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index b2fa2ca..2575a03 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -23,17 +23,15 @@ * To correctly order mutation and read to resolve * write after read problem and read after write problems. */ -#include #include +#include namespace nnvm { namespace pass { namespace { -template -inline T get_with_default(const std::unordered_map &map, - Node* key, - const T& def) { +template +inline T get_with_default(const std::unordered_map& map, Node* key, const T& def) { auto it = map.find(key); if (it != map.end()) return it->second; return def; @@ -46,19 +44,19 @@ inline bool IsMutate(const std::vector& mutate_inputs, uint32_t i) { Graph OrderMutation(const Graph& src) { std::unordered_map > version_hist; DFSVisit(src.outputs, [&version_hist](const ObjectPtr& n) { - for (const NodeEntry& e : n->inputs) { - if (e.node->is_variable()) { - if (e.version != 0 && version_hist.count(e.node.get()) == 0) { - version_hist[e.node.get()] = std::vector{}; - } + for (const NodeEntry& e : n->inputs) { + if (e.node->is_variable()) { + if (e.version != 0 && version_hist.count(e.node.get()) == 0) { + version_hist[e.node.get()] = std::vector{}; } } - }); + } + }); // no mutation happens, everything if fine. if (version_hist.size() == 0) return src; // start preparing for remapping the nodes. std::unordered_map old_new; - auto prepare = [&version_hist, &old_new] (const ObjectPtr& n) { + auto prepare = [&version_hist, &old_new](const ObjectPtr& n) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); std::vector mutate_inputs; if (!n->is_variable() && fmutate_inputs.count(n->op())) { @@ -91,17 +89,17 @@ Graph OrderMutation(const Graph& src) { }; DFSVisit(src.outputs, prepare); // comparator of history entry - auto comparator = [](const NodeEntry& a, const NodeEntry &b) { + auto comparator = [](const NodeEntry& a, const NodeEntry& b) { if (a.version < b.version) return true; if (a.version > b.version) return false; return a.index > b.index; }; - for (auto &kv : version_hist) { + for (auto& kv : version_hist) { std::sort(kv.second.begin(), kv.second.end(), comparator); } // copy the nodes, as well as add control deps - for (auto &kv : old_new) { + for (auto& kv : old_new) { // copy the nodes for (const NodeEntry& e : kv.first->inputs) { auto it = old_new.find(e.node.get()); @@ -112,8 +110,7 @@ Graph OrderMutation(const Graph& src) { } } for (const ObjectPtr& p : kv.first->control_deps) { - kv.second->control_deps.emplace_back( - get_with_default(old_new, p.get(), p)); + kv.second->control_deps.emplace_back(get_with_default(old_new, p.get(), p)); } // add control deps static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); @@ -127,9 +124,8 @@ Graph OrderMutation(const Graph& src) { const NodeEntry& e = kv.first->inputs[i]; if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) { std::vector& vec = version_hist.at(e.node.get()); - auto it = std::lower_bound(vec.begin(), vec.end(), - NodeEntry{nullptr, 1, e.version}, - comparator); + auto it = + std::lower_bound(vec.begin(), vec.end(), NodeEntry{nullptr, 1, e.version}, comparator); if (IsMutate(mutate_inputs, i)) { int read_dep = 0; while (it != vec.begin()) { @@ -137,37 +133,35 @@ Graph OrderMutation(const Graph& src) { if (it->index != 0) break; ++read_dep; // depend on previous read - kv.second->control_deps.push_back( - get_with_default(old_new, it->node.get(), it->node)); + kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node)); } if (read_dep == 0 && it->index != 0) { // depend on last write - kv.second->control_deps.push_back( - get_with_default(old_new, it->node.get(), it->node)); + kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node)); } } else { // depend on last write if (it->index != 0) { - kv.second->control_deps.push_back( - get_with_default(old_new, it->node.get(), it->node)); + kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node)); } } } } } Graph ret; - for (const NodeEntry &e : src.outputs) { - ret.outputs.emplace_back(NodeEntry{ - get_with_default(old_new, e.node.get(), e.node), e.index, e.version}); + for (const NodeEntry& e : src.outputs) { + ret.outputs.emplace_back( + NodeEntry{get_with_default(old_new, e.node.get(), e.node), e.index, e.version}); } return ret; } NNVM_REGISTER_PASS(OrderMutation) -.describe("Return a new graph that adds control dependencies, "\ - "to order the mutation and reads if mutation exists.") -.set_body(OrderMutation) -.set_change_graph(true); + .describe( + "Return a new graph that adds control dependencies, " + "to order the mutation and reads if mutation exists.") + .set_body(OrderMutation) + .set_change_graph(true); } // namespace } // namespace pass diff --git a/nnvm/src/pass/place_device.cc b/nnvm/src/pass/place_device.cc index 6d6866e..d45658a 100644 --- a/nnvm/src/pass/place_device.cc +++ b/nnvm/src/pass/place_device.cc @@ -22,9 +22,9 @@ * \brief Inference the device of each operator given known information. * Insert a copy node automatically when there is a cross device. */ -#include -#include #include +#include +#include namespace nnvm { namespace pass { @@ -43,8 +43,7 @@ Graph PlaceDevice(Graph src) { const Op* copy_op = Op::Get(src.GetAttr("device_copy_op")); auto& device_assign_map = src.GetAttr("device_assign_map"); const IndexedGraph& idx = src.indexed_graph(); - static auto& is_backward = - Op::GetAttr("TIsBackward"); + static auto& is_backward = Op::GetAttr("TIsBackward"); DeviceVector device; // copy on write semanatics if (src.attrs.count("device") != 0) { @@ -65,15 +64,15 @@ Graph PlaceDevice(Graph src) { << "The device assignment not found for group " << device_group; device[nid] = dit->second; } else { - if (!inode.source->is_variable() && - is_backward.get(inode.source->op(), false)) { + if (!inode.source->is_variable() && is_backward.get(inode.source->op(), false)) { if (device[inode.control_deps[0]] != -1) { device[nid] = device[inode.control_deps[0]]; } } else { for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (device[e.node_id] != -1) { - device[nid] = device[e.node_id]; break; + device[nid] = device[e.node_id]; + break; } } } @@ -121,20 +120,21 @@ Graph PlaceDevice(Graph src) { auto e = inode.inputs[index]; if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { LOG(FATAL) << " mutable state cannot go across device" - << " op=" << inode.source->op()->name - << " input_state_index=" << index; + << " op=" << inode.source->op()->name << " input_state_index=" << index; } } } for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { - need_mutate = true; break; + need_mutate = true; + break; } } if (!need_mutate) { for (const uint32_t cid : inode.control_deps) { - if (new_node_map[cid] != nullptr) { - need_mutate = true; break; + if (new_node_map[cid] != nullptr) { + need_mutate = true; + break; } } } @@ -151,17 +151,15 @@ Graph PlaceDevice(Graph src) { auto copy_key = std::make_tuple(e.node_id, e.index, dev_id); auto it = copy_map.find(copy_key); if (it != copy_map.end() && it->first == copy_key) { - new_node->inputs.emplace_back( - NodeEntry{it->second, 0, 0}); + new_node->inputs.emplace_back(NodeEntry{it->second, 0, 0}); } else { ObjectPtr copy_node = Node::Create(); std::ostringstream os; - os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; + os << inode.source->inputs[i].node->attrs.name << "_" << e.index << "_copy"; copy_node->attrs.op = copy_op; copy_node->attrs.name = os.str(); if (new_node_map[e.node_id] != nullptr) { - copy_node->inputs.emplace_back( - NodeEntry{new_node_map[e.node_id], e.index, 0}); + copy_node->inputs.emplace_back(NodeEntry{new_node_map[e.node_id], e.index, 0}); } else { copy_node->inputs.push_back(inode.source->inputs[i]); } @@ -170,13 +168,11 @@ Graph PlaceDevice(Graph src) { } copy_map[copy_key] = copy_node; new_device_map[copy_node.get()] = dev_id; - new_node->inputs.emplace_back( - NodeEntry{std::move(copy_node), 0, 0}); + new_node->inputs.emplace_back(NodeEntry{std::move(copy_node), 0, 0}); } } else { if (new_node_map[e.node_id] != nullptr) { - new_node->inputs.emplace_back( - NodeEntry{new_node_map[e.node_id], e.index, 0}); + new_node->inputs.emplace_back(NodeEntry{new_node_map[e.node_id], e.index, 0}); } else { new_node->inputs.push_back(inode.source->inputs[i]); } @@ -220,14 +216,15 @@ Graph PlaceDevice(Graph src) { } NNVM_REGISTER_PASS(PlaceDevice) -.describe("Infer the device type of each operator."\ - "Insert a copy node when there is cross device copy") -.set_body(PlaceDevice) -.set_change_graph(true) -.provide_graph_attr("device") -.depend_graph_attr("device_group_attr_key") -.depend_graph_attr("device_assign_map") -.depend_graph_attr("device_copy_op"); + .describe( + "Infer the device type of each operator." + "Insert a copy node when there is cross device copy") + .set_body(PlaceDevice) + .set_change_graph(true) + .provide_graph_attr("device") + .depend_graph_attr("device_group_attr_key") + .depend_graph_attr("device_assign_map") + .depend_graph_attr("device_copy_op"); DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int); diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 83d8f87..2c36cd2 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -22,10 +22,12 @@ * \brief Assign memory tag to each of the data entries. */ #include -#include #include #include +#include + #include + #include "graph_algorithm.h" namespace nnvm { @@ -82,10 +84,10 @@ class GraphAllocator { auto end = free_.upper_bound(size * match_range_); // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->device_id != dev_id) continue; - if (node_color_.size() != 0 && - node_color_[e->released_by_node] != node_color_[node_id]) continue; + if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id]) + continue; // Use exect matching strategy e->max_bytes = std::max(size, e->max_bytes); // find a exact match, erase from map and return @@ -95,10 +97,10 @@ class GraphAllocator { // then search for memory blocks smaller than requested space for (auto it = mid; it != begin;) { --it; - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->device_id != dev_id) continue; - if (node_color_.size() != 0 && - node_color_[e->released_by_node] != node_color_[node_id]) continue; + if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id]) + continue; // Use exect matching strategy e->max_bytes = std::max(size, e->max_bytes); // erase from map and return @@ -112,7 +114,7 @@ class GraphAllocator { void Release(StorageID id, uint32_t node_id) { CHECK_NE(id, kBadStorageID); if (id == kExternalStorageID || id == kDynamicStorageID) return; - StorageEntry *e = data_[id].get(); + StorageEntry* e = data_[id].get(); e->released_by_node = node_id; free_.insert({e->max_bytes, e}); } @@ -120,7 +122,7 @@ class GraphAllocator { // totoal number of bytes allocated size_t TotalAllocBytes() const { size_t total = 0; - for (auto &p : data_) { + for (auto& p : data_) { total += p->max_bytes; } return total; @@ -142,8 +144,7 @@ class GraphAllocator { if ((*idx_)[nid].source->is_variable()) continue; importance[nid] = 1; } - num_match_color_ = pass::ColorNodeGroup( - *idx_, importance, num_match_color_, &node_color_); + num_match_color_ = pass::ColorNodeGroup(*idx_, importance, num_match_color_, &node_color_); } } @@ -187,18 +188,16 @@ class GraphAllocator { * Internal method to perform the memory allocation for a graph * */ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, - const std::pair& node_range, - StorageVector* storage_ptr, + const std::pair& node_range, StorageVector* storage_ptr, std::vector* storage_inplace_index_ptr, - const std::vector& entry_ref_count, - GraphAllocator* allocator) { + const std::vector& entry_ref_count, GraphAllocator* allocator) { static auto& finplace_option = Op::GetAttr("FInplaceOption"); static auto& finplace_identity = Op::GetAttr("FInplaceIdentity"); static auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); // Get reference - auto &storage = *storage_ptr; - auto &storage_inplace_index = *storage_inplace_index_ptr; + auto& storage = *storage_ptr; + auto& storage_inplace_index = *storage_inplace_index_ptr; // Get attributes from the graph const ShapeVector& shape_vec = ret.GetAttr("shape"); @@ -234,19 +233,16 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, auto sid_out = storage[eid_out]; auto sid_in = storage[eid_in]; bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 && - fignore_inputs[inode.source->op()]( - inode.source->attrs).size() == inode.source->num_inputs()); + fignore_inputs[inode.source->op()](inode.source->attrs).size() == + inode.source->num_inputs()); // Identity should only be true if shape.Size() and types match bool real_identity = identity[ipair] && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && dtype_vec[eid_out] == dtype_vec[eid_in]; - if (taken[kv.first] == false && - sid_out == GraphAllocator::kBadStorageID && - sid_in >= 0 && + if (taken[kv.first] == false && sid_out == GraphAllocator::kBadStorageID && sid_in >= 0 && ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || real_identity) && - entry_ref_count[eid_out] > 0 && - shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && - (dtype_vec[eid_out] == dtype_vec[eid_in] || + entry_ref_count[eid_out] > 0 && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && + (dtype_vec[eid_out] == dtype_vec[eid_in] || GetDTypeSize(dtype_vec[eid_out]) == GetDTypeSize(dtype_vec[eid_in]))) { // inplace optimization taken[kv.first] = true; @@ -267,19 +263,19 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, uint32_t eid = idx.entry_id(nid, index); // only request memory for kBadStorageID if (storage[eid] == GraphAllocator::kBadStorageID) { - auto &eshape = shape_vec[eid]; + auto& eshape = shape_vec[eid]; size_t esize = 0; if (eshape.ndim() != 0) esize = eshape.Size(); eids.insert(std::make_pair(esize, eid)); } } for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) { - uint32_t eid = rit->second; - auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); - if (sid >= 0) { - storage_ref_count[sid] = entry_ref_count[eid]; - } - storage[eid] = sid; + uint32_t eid = rit->second; + auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); + if (sid >= 0) { + storage_ref_count[sid] = entry_ref_count[eid]; + } + storage[eid] = sid; } // check if certain inputs is ignored. std::vector ignore_inputs; @@ -320,7 +316,6 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, return num_not_allocated; } - // function to plan memory Graph PlanMemory(Graph ret) { // setup ref counter @@ -368,7 +363,7 @@ Graph PlanMemory(Graph ret) { size_t min_allocated_bytes = -1; size_t max_match_range = dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16); size_t min_match_range = - dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; + dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; for (size_t match_range = min_match_range; match_range <= max_match_range; match_range *= 2) { // Make a copy of related fields StorageVector storage_vec(storage); @@ -378,9 +373,8 @@ Graph PlanMemory(Graph ret) { GraphAllocator allocator(&idx, match_range); // number of entries that are not statically allocated. - size_t storage_num_not_allocated = - AllocMemory(ret, idx, node_range, &storage_vec, &storage_inplace_index, - ref_count, &allocator); + size_t storage_num_not_allocated = AllocMemory(ret, idx, node_range, &storage_vec, + &storage_inplace_index, ref_count, &allocator); size_t storage_allocated_bytes = allocator.TotalAllocBytes(); // Choose the plan which leads to minimal memory usage @@ -400,13 +394,13 @@ Graph PlanMemory(Graph ret) { } NNVM_REGISTER_PASS(PlanMemory) -.describe("Plan the memory allocation of each node entries.") -.set_body(PlanMemory) -.set_change_graph(false) -.depend_graph_attr("dtype") -.depend_graph_attr("shape") -.provide_graph_attr("storage_id") -.provide_graph_attr("storage_inplace_index"); + .describe("Plan the memory allocation of each node entries.") + .set_body(PlanMemory) + .set_change_graph(false) + .depend_graph_attr("dtype") + .depend_graph_attr("shape") + .provide_graph_attr("storage_id") + .provide_graph_attr("storage_inplace_index"); } // namespace } // namespace pass diff --git a/nnvm/src/pass/print_graph_ir.cc b/nnvm/src/pass/print_graph_ir.cc index a0127ab..4fe92e6 100644 --- a/nnvm/src/pass/print_graph_ir.cc +++ b/nnvm/src/pass/print_graph_ir.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,6 +24,7 @@ #include #include #include + #include namespace nnvm { @@ -31,47 +32,39 @@ namespace pass { using AttrPrinter = std::function; // NOLINT(*) -template +template AttrPrinter GetVectorPrinter_(const T& vec) { return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*) os << vec[index]; }; } -AttrPrinter GetVectorPrinter(const Graph& graph, - const std::string& key) { +AttrPrinter GetVectorPrinter(const Graph& graph, const std::string& key) { auto it = graph.attrs.find(key); - CHECK(it != graph.attrs.end()) - << "Cannot find " << key << " in graph attr"; + CHECK(it != graph.attrs.end()) << "Cannot find " << key << " in graph attr"; const any& value = *(it->second); if (value.type() == typeid(std::vector)) { - return GetVectorPrinter_( - nnvm::get >(value)); + return GetVectorPrinter_(nnvm::get >(value)); } else if (value.type() == typeid(std::vector)) { - return GetVectorPrinter_( - nnvm::get >(value)); + return GetVectorPrinter_(nnvm::get >(value)); } else if (value.type() == typeid(std::vector)) { - return GetVectorPrinter_( - nnvm::get >(value)); + return GetVectorPrinter_(nnvm::get >(value)); } else { LOG(FATAL) << "Cannot handle type " << value.type().name(); return nullptr; } } - // print the graph ir in readable format -void PrintGraphIR_(Graph src, - const std::vector& join_entry_attrs, +void PrintGraphIR_(Graph src, const std::vector& join_entry_attrs, const std::vector& join_node_attrs, - std::ostream& os) { // NOLINT(*) + std::ostream& os) { // NOLINT(*) const IndexedGraph& idx = src.indexed_graph(); std::vector > trigger; // NOLINT(*) for (const std::string& key : join_entry_attrs) { AttrPrinter fp = GetVectorPrinter(src, key); - auto fprint = [&idx, key, fp]( - uint32_t nid, std::ostream& os) { // NOLINT(*) + auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*) const IndexedGraph::Node& inode = idx[nid]; os << ", " << key << "="; if (inode.source->num_outputs() != 1) { @@ -89,8 +82,7 @@ void PrintGraphIR_(Graph src, } for (const std::string& key : join_node_attrs) { AttrPrinter fp = GetVectorPrinter(src, key); - auto fprint = [&idx, key, fp]( - uint32_t nid, std::ostream& os) { // NOLINT(*) + auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*) os << ", " << key << "="; fp(idx.entry_id(nid, 0), os); }; @@ -101,7 +93,7 @@ void PrintGraphIR_(Graph src, if (idx.input_nodes().size() < 4) { for (size_t i = 0; i < idx.input_nodes().size(); ++i) { uint32_t nid = idx.input_nodes()[i]; - if (i != 0) { + if (i != 0) { os << ", "; } os << '%' << idx[nid].source->attrs.name; @@ -109,7 +101,7 @@ void PrintGraphIR_(Graph src, } else { for (size_t i = 0; i < idx.input_nodes().size(); ++i) { uint32_t nid = idx.input_nodes()[i]; - if (i != 0) { + if (i != 0) { os << ",\n "; } os << '%' << idx[nid].source->attrs.name; @@ -141,8 +133,8 @@ void PrintGraphIR_(Graph src, for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; - os << " " << "%" << nid << " = " - << inode.source->op()->name << "("; + os << " " + << "%" << nid << " = " << inode.source->op()->name << "("; bool first = true; for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (first) { @@ -213,12 +205,10 @@ Graph PrintGraphIRPass(Graph src) { std::ostringstream os; std::vector join_entry_attrs, join_node_attrs; if (src.attrs.count("join_entry_attrs") != 0) { - join_entry_attrs = src.MoveCopyAttr >( - "join_entry_attrs"); + join_entry_attrs = src.MoveCopyAttr >("join_entry_attrs"); } if (src.attrs.count("join_node_attrs") != 0) { - join_node_attrs = src.MoveCopyAttr >( - "join_node_attrs"); + join_node_attrs = src.MoveCopyAttr >("join_node_attrs"); } PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os); Graph ret; @@ -228,8 +218,8 @@ Graph PrintGraphIRPass(Graph src) { // register pass NNVM_REGISTER_PASS(PrintGraphIR) -.describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]") -.set_body(PrintGraphIRPass); + .describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]") + .set_body(PrintGraphIRPass); } // namespace pass } // namespace nnvm diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 9389995..3916da4 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -21,20 +21,21 @@ * \file saveload_json.cc * \brief Save and load graph to/from JSON file. */ +#include #include #include -#include + #include namespace dmlc { namespace json { // overload handler for shared ptr -template<> -struct Handler > { - inline static void Write(JSONWriter *writer, const std::shared_ptr &data) { +template <> +struct Handler> { + inline static void Write(JSONWriter* writer, const std::shared_ptr& data) { writer->Write(*data); } - inline static void Read(JSONReader *reader, std::shared_ptr *data) { + inline static void Read(JSONReader* reader, std::shared_ptr* data) { any v; reader->Read(&v); *data = std::make_shared(std::move(v)); @@ -60,17 +61,16 @@ struct JSONNode { uint32_t index; uint32_t version; Entry() = default; - Entry(uint32_t node_id, uint32_t index, uint32_t version): - node_id(node_id), index(index), version(version) { - } - void Save(dmlc::JSONWriter *writer) const { + Entry(uint32_t node_id, uint32_t index, uint32_t version) + : node_id(node_id), index(index), version(version) {} + void Save(dmlc::JSONWriter* writer) const { writer->BeginArray(false); writer->WriteArrayItem(node_id); writer->WriteArrayItem(index); writer->WriteArrayItem(version); writer->EndArray(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginArray(); CHECK(reader->NextArrayItem()) << "invalid json format"; reader->Read(&node_id); @@ -95,7 +95,7 @@ struct JSONNode { std::vector subgraphs; // function to save JSON node. - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); if (node->op() != nullptr) { writer->WriteObjectKeyValue("op", node->op()->name); @@ -106,8 +106,7 @@ struct JSONNode { writer->WriteObjectKeyValue("name", node->attrs.name); if (node->attrs.dict.size() != 0) { // write attributes in order; - std::map dict( - node->attrs.dict.begin(), node->attrs.dict.end()); + std::map dict(node->attrs.dict.begin(), node->attrs.dict.end()); writer->WriteObjectKeyValue("attrs", dict); } writer->WriteObjectKeyValue("inputs", inputs); @@ -120,7 +119,7 @@ struct JSONNode { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { node = Node::Create(); control_deps.clear(); dmlc::JSONObjectReadHelper helper; @@ -143,10 +142,10 @@ struct JSONNode { if (op_type_str != "null") { try { node->attrs.op = Op::Get(op_type_str); - } catch (const dmlc::Error &err) { + } catch (const dmlc::Error& err) { std::ostringstream os; - os << "Failed loading Op " << node->attrs.name - << " of type " << op_type_str << ": " << err.what(); + os << "Failed loading Op " << node->attrs.name << " of type " << op_type_str << ": " + << err.what(); throw dmlc::Error(os.str()); } } else { @@ -161,9 +160,9 @@ struct JSONGraph { std::vector arg_nodes; std::vector node_row_ptr; std::vector heads; - std::unordered_map > attrs; + std::unordered_map> attrs; - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("nodes", nodes); writer->WriteObjectKeyValue("arg_nodes", arg_nodes); @@ -175,7 +174,7 @@ struct JSONGraph { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { attrs.clear(); dmlc::JSONObjectReadHelper helper; helper.DeclareField("nodes", &nodes); @@ -187,7 +186,7 @@ struct JSONGraph { } }; -void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { +void Symbol2JSONGraph(std::shared_ptr src, JSONGraph* jgraph) { std::unordered_map node2index; jgraph->node_row_ptr.push_back(0); DFSVisit(src->outputs, [&node2index, jgraph](const ObjectPtr& n) { @@ -212,10 +211,10 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version); } // recursively construct subgraphs - for (JSONNode &jnode : jgraph->nodes) { + for (JSONNode& jnode : jgraph->nodes) { // construct jnode's subgraphs - const std::vector> &subgraphs = jnode.node->attrs.subgraphs; - std::vector &jsubgraphs = jnode.subgraphs; + const std::vector>& subgraphs = jnode.node->attrs.subgraphs; + std::vector& jsubgraphs = jnode.subgraphs; jsubgraphs.resize(subgraphs.size()); for (uint32_t i = 0; i < subgraphs.size(); ++i) { Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]); @@ -223,10 +222,10 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { } } -std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) { - for (const JSONNode &n : jgraph.nodes) { +std::shared_ptr JSONGraph2Symbol(const JSONGraph& jgraph, bool no_parse) { + for (const JSONNode& n : jgraph.nodes) { n.node->inputs.reserve(n.inputs.size()); - for (const JSONNode::Entry &e : n.inputs) { + for (const JSONNode::Entry& e : n.inputs) { CHECK(e.node_id < jgraph.nodes.size()); n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } @@ -235,7 +234,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) CHECK(nid < jgraph.nodes.size()); n.node->control_deps.push_back(jgraph.nodes[nid].node); } - for (const JSONGraph &subgraph : n.subgraphs) { + for (const JSONGraph& subgraph : n.subgraphs) { // The "no_parse" option here, is to be compatible with // commit cfd3075e85807dcd8f9534c37e053583dee87524 // (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524), @@ -248,7 +247,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) n.node->op()->attr_parser(&(n.node->attrs)); } else if (!no_parse && n.node->is_variable()) { n.node->attrs.parsed = - Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed; + Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed; } } // consistency check @@ -258,7 +257,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) } std::shared_ptr symbol = std::make_shared(); symbol->outputs.reserve(jgraph.heads.size()); - for (const JSONNode::Entry &e : jgraph.heads) { + for (const JSONNode::Entry& e : jgraph.heads) { CHECK(e.node_id < jgraph.nodes.size()); symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } @@ -267,10 +266,8 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) // Load a graph from JSON file. Graph LoadJSON(Graph src) { - CHECK_NE(src.attrs.count("json"), 0U) - << "Load JSON require json to be presented."; - const std::string &json_str = - nnvm::get(*src.attrs.at("json")); + CHECK_NE(src.attrs.count("json"), 0U) << "Load JSON require json to be presented."; + const std::string& json_str = nnvm::get(*src.attrs.at("json")); bool no_parse = false; if (src.attrs.count("load_json_no_parse")) { no_parse = nnvm::get(*src.attrs.at("load_json_no_parse")); @@ -305,17 +302,16 @@ Graph SaveJSON(Graph src) { // register pass NNVM_REGISTER_PASS(LoadJSON) -.describe("Return a new Graph, loaded from src.attrs[\"json\"]") -.set_body(LoadJSON) -.set_change_graph(true) -.depend_graph_attr("json"); + .describe("Return a new Graph, loaded from src.attrs[\"json\"]") + .set_body(LoadJSON) + .set_change_graph(true) + .depend_graph_attr("json"); NNVM_REGISTER_PASS(SaveJSON) -.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]") -.set_body(SaveJSON) -.set_change_graph(true) -.provide_graph_attr("json"); - + .describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]") + .set_body(SaveJSON) + .set_change_graph(true) + .provide_graph_attr("json"); DMLC_JSON_ENABLE_ANY(std::string, str); DMLC_JSON_ENABLE_ANY(std::vector, list_int); diff --git a/nnvm/tests/cpp/op_test.cc b/nnvm/tests/cpp/op_test.cc index 4c77165..2ebd146 100644 --- a/nnvm/tests/cpp/op_test.cc +++ b/nnvm/tests/cpp/op_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,16 +20,15 @@ #include #include #include -#include -NNVM_REGISTER_OP(add) -.describe("add two data together") -.set_num_inputs(2) -.set_attr("inplace_pair", std::make_pair(0, 0)); +#include NNVM_REGISTER_OP(add) -.set_attr("nick_name", "plus"); + .describe("add two data together") + .set_num_inputs(2) + .set_attr("inplace_pair", std::make_pair(0, 0)); +NNVM_REGISTER_OP(add).set_attr("nick_name", "plus"); TEST(Op, GetAttr) { using namespace nnvm; @@ -39,7 +38,7 @@ TEST(Op, GetAttr) { CHECK_EQ(nick[add], "plus"); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/nnvm/tests/cpp/tuple_test.cc b/nnvm/tests/cpp/tuple_test.cc index 7bf59b5..2c2c307 100644 --- a/nnvm/tests/cpp/tuple_test.cc +++ b/nnvm/tests/cpp/tuple_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,8 +22,8 @@ #include TEST(Tuple, Basic) { - using nnvm::Tuple; using nnvm::TShape; + using nnvm::Tuple; Tuple x{1, 2, 3}; Tuple y{1, 2, 3, 5, 6}; x = std::move(y); @@ -42,7 +42,7 @@ TEST(Tuple, Basic) { CHECK((s == TShape{1, 2, 3})); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 9199bac..037c766 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -20,9 +20,9 @@ /*! * \file tvm/arith/analyzer.cc */ +#include #include #include -#include #include namespace tvm { @@ -33,8 +33,7 @@ Analyzer::Analyzer() modular_set(this), rewrite_simplify(this), canonical_simplify(this), - int_set(this) { -} + int_set(this) {} void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) { PrimExpr new_expr = expr; @@ -124,63 +123,53 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr) { return res; } -TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") -.set_body([](TVMArgs args, TVMRetValue* ret) { - using runtime::PackedFunc; - using runtime::TypedPackedFunc; - auto self = std::make_shared(); - auto f = [self](std::string name) -> PackedFunc { - if (name == "const_int_bound") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->const_int_bound(args[0]); - }); - } else if (name == "modular_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->modular_set(args[0]); - }); - } else if (name == "const_int_bound_update") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - self->const_int_bound.Update(args[0], args[1], args[2]); - }); - } else if (name == "Simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->Simplify(args[0]); - }); - } else if (name == "rewrite_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->rewrite_simplify(args[0]); - }); - } else if (name == "canonical_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->canonical_simplify(args[0]); - }); - } else if (name == "int_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->int_set(args[0], args[1]); - }); - } else if (name == "bind") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - if (args[1].IsObjectRef()) { - self->Bind(args[0], args[1].operator Range()); - } else { - self->Bind(args[0], args[1].operator PrimExpr()); - } - }); - } else if (name == "enter_constraint_context") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - // can't use make_shared due to noexcept(false) decl in destructor, - // see https://stackoverflow.com/a/43907314 - auto ctx = std::shared_ptr >( - new With(self.get(), args[0])); - auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { - ctx.reset(); - }; - *ret = PackedFunc(fexit); - }); - } - return PackedFunc(); - }; - *ret = TypedPackedFunc(f); +TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) { + using runtime::PackedFunc; + using runtime::TypedPackedFunc; + auto self = std::make_shared(); + auto f = [self](std::string name) -> PackedFunc { + if (name == "const_int_bound") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->const_int_bound(args[0]); }); + } else if (name == "modular_set") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->modular_set(args[0]); }); + } else if (name == "const_int_bound_update") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + self->const_int_bound.Update(args[0], args[1], args[2]); + }); + } else if (name == "Simplify") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { *ret = self->Simplify(args[0]); }); + } else if (name == "rewrite_simplify") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); }); + } else if (name == "canonical_simplify") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); }); + } else if (name == "int_set") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->int_set(args[0], args[1]); }); + } else if (name == "bind") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + if (args[1].IsObjectRef()) { + self->Bind(args[0], args[1].operator Range()); + } else { + self->Bind(args[0], args[1].operator PrimExpr()); + } + }); + } else if (name == "enter_constraint_context") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + // can't use make_shared due to noexcept(false) decl in destructor, + // see https://stackoverflow.com/a/43907314 + auto ctx = std::shared_ptr >( + new With(self.get(), args[0])); + auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); }; + *ret = PackedFunc(fexit); + }); + } + return PackedFunc(); + }; + *ret = TypedPackedFunc(f); }); } // namespace arith diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index eeaaa8a..496eb20 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -21,13 +21,14 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ +#include #include #include #include -#include -#include #include +#include + #include "interval_set.h" namespace tvm { @@ -37,7 +38,7 @@ using namespace tir; // a visitor to find the path to the target variable // from a expression. -class VariablePathFinder: public ExprVisitor { +class VariablePathFinder : public ExprVisitor { public: explicit VariablePathFinder(PrimExpr target) : target_(target) {} @@ -67,17 +68,17 @@ std::vector GetPath(PrimExpr target, PrimExpr expr) { return v.path_; } -enum CompareOp {kGreater, kLess, kEqual}; +enum CompareOp { kGreater, kLess, kEqual }; // a visitor to deduce the bound of a variable from a expression -class BoundDeducer: public ExprVisitor { +class BoundDeducer : public ExprVisitor { public: friend class BoundDeduceInputChecker; friend class Converter; BoundDeducer(PrimExpr target, PrimExpr expr, const std::unordered_map& hint_map, const std::unordered_map& relax_map) - : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} + : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} void Deduce(); @@ -119,7 +120,7 @@ class BoundDeducer: public ExprVisitor { result_ += op->b; } else { result_ -= op->a; - result_ = - result_; + result_ = -result_; comp_op = ReverseOp(comp_op); } this->VisitExpr(left ? op->a : op->b); @@ -148,7 +149,7 @@ class BoundDeducer: public ExprVisitor { // always use relax bound bool divided = analyzer_.CanProve(floormod(result_, operand) == 0); - result_ = floordiv(result_, operand); // rounding down here + result_ = floordiv(result_, operand); // rounding down here if (!divided) { if (comp_op == kGreater) { @@ -193,7 +194,7 @@ class BoundDeducer: public ExprVisitor { Analyzer analyzer_; }; -class BoundDeduceInputChecker: public ExprVisitor { +class BoundDeduceInputChecker : public ExprVisitor { public: bool Check(BoundDeducer* deducer) { deducer_ = deducer; @@ -219,9 +220,12 @@ void BoundDeducer::Init() { CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) { switch (comp_op) { - case kEqual: return kEqual; // IntSet can not represent range for `NE - case kGreater: return kLess; - case kLess: return kGreater; + case kEqual: + return kEqual; // IntSet can not represent range for `NE + case kGreater: + return kLess; + case kLess: + return kGreater; default: LOG(FATAL) << "Not a valid compare op"; return kGreater; // return some default value @@ -318,18 +322,18 @@ void BoundDeducer::Relax() { // Both LHS and RHS of the EQ should behave as constants e.g. i == j, // can not be resolved when either `i` or `j` or both are variables with // some Range OR `i` and `j` both should be a single point in IntSet - if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max()) - || !analyzer_.CanProve(a.min() == a.max()))) { + if (comp_op == kEqual && + (!analyzer_.CanProve(b.min() == b.max()) || !analyzer_.CanProve(a.min() == a.max()))) { success_ = false; return; } - expr_ = (comp_op == kGreater) ? a.min() : a.max(); + expr_ = (comp_op == kGreater) ? a.min() : a.max(); result_ = (comp_op == kGreater) ? b.max() : b.min(); } IntSet DeduceBound(PrimExpr v, PrimExpr e, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) { + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success_) return IntSet::nothing(); @@ -347,8 +351,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. -IntSet DeduceBound(PrimExpr v, PrimExpr e, - const Map& hint_map, +IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, const Map& relax_map) { std::unordered_map hmap; for (auto kv : hint_map) { @@ -361,16 +364,11 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, return DeduceBound(v, e, hmap, rmap); } - TVM_REGISTER_GLOBAL("arith.DeduceBound") -.set_body_typed([]( - PrimExpr v, PrimExpr cond, - const Map hint_map, - const Map relax_map -) { - return DeduceBound(v, cond, hint_map, relax_map); -}); - + .set_body_typed([](PrimExpr v, PrimExpr cond, const Map hint_map, + const Map relax_map) { + return DeduceBound(v, cond, hint_map, relax_map); + }); } // namespace arith } // namespace tvm diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index a10db7a..2738707 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -22,8 +22,8 @@ * \brief Canonical form based simplification. */ #include -#include #include +#include #include "const_fold.h" #include "pattern_match.h" @@ -37,7 +37,6 @@ using namespace tir; class SumExpr; class SplitExpr; - /*! * \brief Base class of all temporary expression introduced * for canonicalization. @@ -53,8 +52,7 @@ class CanonicalExprNode : public PrimExprNode { virtual PrimExpr Normalize() const = 0; // overrides - void VisitAttrs(tvm::AttrVisitor* v) { - } + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "arith.CanonicalExpr"; static constexpr const uint32_t _type_child_slots = 2; @@ -111,9 +109,7 @@ class SplitExprNode : public CanonicalExprNode { DivMode div_mode{kTruncDiv}; /*! \brief verify that this is a valid entry. */ - void Verify() const { - CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); - } + void Verify() const { CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); } PrimExpr NormalizeWithScale(int64_t sscale) const { PrimExpr res = this->index; @@ -135,13 +131,9 @@ class SplitExprNode : public CanonicalExprNode { return res; } - PrimExpr Normalize() const final { - return NormalizeWithScale(1); - } + PrimExpr Normalize() const final { return NormalizeWithScale(1); } - void MulToSelf(int64_t scale) { - this->scale *= scale; - } + void MulToSelf(int64_t scale) { this->scale *= scale; } inline bool IndexEqual(const SplitExpr& other) const; inline bool DivModeCompatibleTo(DivMode mode) const; @@ -186,9 +178,7 @@ class SumExprNode : public CanonicalExprNode { /*! \brief Base value in the summation. */ int64_t base{0}; /*! \brief The expression equals zero. */ - bool IsZero() const { - return base == 0 && args.size() == 0; - } + bool IsZero() const { return base == 0 && args.size() == 0; } /*! * \brief Return the normal Expr that is equivalent to self. * \return The normal expression. @@ -198,9 +188,7 @@ class SumExprNode : public CanonicalExprNode { if (this->args.size() == 0) { return make_const(this->dtype, this->base); } - return Normalize_(this->dtype, - SimplifySplitExprs(args), - base); + return Normalize_(this->dtype, SimplifySplitExprs(args), base); } /*! * \brief Whether self is divisible by scale. @@ -239,9 +227,7 @@ class SumExprNode : public CanonicalExprNode { * \brief add constant value to self. * \param value to be added. */ - void AddToSelf(int64_t value) { - this->base += value; - } + void AddToSelf(int64_t value) { this->base += value; } /*! * \brief self += other * scale; * \param other The expression to be added. @@ -257,8 +243,7 @@ class SumExprNode : public CanonicalExprNode { if (args[start]->IndexEqual(other)) break; } for (size_t j = start; j < args.size(); ++j) { - if (!args[j]->IndexEqual(other) || - other->lower_factor > args[j]->lower_factor) { + if (!args[j]->IndexEqual(other) || other->lower_factor > args[j]->lower_factor) { other.CopyOnWrite()->scale *= scale; this->args.insert(this->args.begin() + j, other); return; @@ -286,8 +271,7 @@ class SumExprNode : public CanonicalExprNode { * \param args The original list of arguments. * \return simplified version. */ - static std::vector - SimplifySplitExprs(std::vector args) { + static std::vector SimplifySplitExprs(std::vector args) { // NOTE: This algorithm relies on the factor that args are divided into segments // and each segment is sorted in descending order of lower_factor. for (size_t i = 0; i < args.size(); ++i) { @@ -297,14 +281,12 @@ class SumExprNode : public CanonicalExprNode { SplitExpr& rhs = args[j]; if (!lhs->IndexEqual(rhs)) break; if (lhs->upper_factor < rhs->lower_factor) break; - if (lhs->upper_factor == rhs->upper_factor && - lhs->lower_factor == rhs->lower_factor && + if (lhs->upper_factor == rhs->upper_factor && lhs->lower_factor == rhs->lower_factor && lhs->DivModeCompatibleTo(rhs->div_mode)) { // folding same co-efficient. rhs.CopyOnWrite()->scale += lhs->scale; lhs.CopyOnWrite()->scale = 0; - } else if (lhs->lower_factor == rhs->upper_factor && - rhs->scale != 0 && + } else if (lhs->lower_factor == rhs->upper_factor && rhs->scale != 0 && lhs->scale % rhs->scale == 0 && lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor && lhs->DivModeCompatibleTo(rhs->div_mode)) { @@ -385,9 +367,7 @@ class SumExprNode : public CanonicalExprNode { std::stable_sort(args.begin(), args.end(), fcompare); return args; } - static PrimExpr Normalize_(DataType dtype, - const std::vector& args, - int64_t base) { + static PrimExpr Normalize_(DataType dtype, const std::vector& args, int64_t base) { // Positive scales first PrimExpr res = make_const(dtype, 0); for (size_t i = 0; i < args.size(); ++i) { @@ -432,9 +412,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { public: using Rewriter = RewriteSimplifier::Impl; - explicit Impl(Analyzer* parent) - : Rewriter(parent) {} - + explicit Impl(Analyzer* parent) : Rewriter(parent) {} PrimExpr CanonicalSimplify(PrimExpr expr) { expr = operator()(expr); @@ -448,9 +426,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { } // Normal mutation without normalization. - PrimExpr CanonicalMutate(PrimExpr expr) { - return Rewriter::VisitExpr(expr); - } + PrimExpr CanonicalMutate(PrimExpr expr) { return Rewriter::VisitExpr(expr); } using Rewriter::VisitExpr_; PrimExpr VisitExpr_(const AddNode* op) final; @@ -486,9 +462,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \param out_divisible The result divisible component. * \param out_non_divisible The non-divisible component. */ - void SeparateDivisibleParts(const SumExprNode* psum, - int64_t coeff, - SumExpr* out_divisible, + void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, SumExpr* out_non_divisible); /*! * \brief Normalize expr to normal expr. @@ -568,8 +542,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { PrimExpr SimplifyReduceCombiner(const ReduceNode* op); }; -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const AddNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -594,8 +567,7 @@ VisitExpr_(const AddNode* op) { return std::move(ret); } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const SubNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -620,9 +592,7 @@ VisitExpr_(const SubNode* op) { return std::move(ret); } - -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const MulNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -660,11 +630,9 @@ VisitExpr_(const MulNode* op) { } } -void CanonicalSimplifier::Impl:: -SeparateDivisibleParts(const SumExprNode* psum, - int64_t coeff, - SumExpr* out_divisible, - SumExpr* out_non_divisible) { +void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, + SumExpr* out_divisible, + SumExpr* out_non_divisible) { auto divisible = make_object(); auto non_divisible = make_object(); divisible->dtype = psum->dtype; @@ -686,8 +654,7 @@ SeparateDivisibleParts(const SumExprNode* psum, *out_non_divisible = SumExpr(non_divisible); } -SplitExpr CanonicalSimplifier::Impl:: -SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { +SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { CHECK_GT(cval, 0); lhs = ConvertDivMode(lhs, div_mode); @@ -728,8 +695,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { return lhs; } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const DivNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -764,8 +730,7 @@ VisitExpr_(const DivNode* op) { } else { // if 0 <= extra < cval, it means the extra can be eliminated. if (TryCompare(temp, cval) != kLT) { - lhs.CopyOnWrite()->AddToSelf( - SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1); + lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1); } } return std::move(lhs); @@ -789,8 +754,7 @@ VisitExpr_(const DivNode* op) { } } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorDivNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -821,8 +785,7 @@ VisitExpr_(const FloorDivNode* op) { } else { // if 0 <= extra < cval, it means the extra can be eliminated. if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) { - lhs.CopyOnWrite()->AddToSelf( - SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); + lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); } } return std::move(lhs); @@ -845,8 +808,7 @@ VisitExpr_(const FloorDivNode* op) { } } -SplitExpr CanonicalSimplifier::Impl:: -SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { +SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { CHECK_GT(cval, 0); lhs = ConvertDivMode(lhs, div_mode); @@ -860,16 +822,14 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { // (x / c1) % c2 => (x % (c1 * c2)) / c2 int64_t new_upper_factor = lhs->lower_factor * scaled_cval; // try to see if we can reduce the existing upper modular. - if (lhs->upper_factor == SplitExprNode::kPosInf || - lhs->upper_factor % new_upper_factor == 0) { + if (lhs->upper_factor == SplitExprNode::kPosInf || lhs->upper_factor % new_upper_factor == 0) { // we gained a new upper factor that is smaller // than the original one // Perhaps there are more chances in simplifying the index // Do a recursive call to simplify the mod with the new factor. - if (new_upper_factor < lhs->upper_factor && - lhs->upper_factor != SplitExprNode::kPosInf) { - auto updated = ToSplitExpr(this->VisitExpr(ModImpl( - lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); + if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { + auto updated = ToSplitExpr(this->VisitExpr( + ModImpl(lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); updated.CopyOnWrite()->scale = lhs->scale; // re-apply the lower_factor if (lhs->lower_factor != 1) { @@ -896,8 +856,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { return lhs; } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const ModNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -941,8 +900,7 @@ VisitExpr_(const ModNode* op) { // (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); int64_t new_base = psum->base % cval; - if (cbound->min_value >= 0 && - cbound->min_value - psum->base + new_base >= 0) { + if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) { SumExpr sum_expr = Downcast(a); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv); @@ -966,8 +924,7 @@ VisitExpr_(const ModNode* op) { } } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorModNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -991,8 +948,7 @@ VisitExpr_(const FloorModNode* op) { return floormod(temp, c1.Eval()); } else { // If temp < cval && temp >=0 then can remove the mod. - if (TryCompare(temp, cval) == kLT && - analyzer_->CanProveGreaterEqual(temp, 0)) { + if (TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0)) { return temp; } else { // contonue to use logic below. @@ -1027,8 +983,7 @@ VisitExpr_(const FloorModNode* op) { } // Simplify reduce expression. -PrimExpr CanonicalSimplifier::Impl:: -SimplifyReduceCombiner(const ReduceNode* op) { +PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) { // First simplify the results Array simplified_result; for (const auto& res : op->combiner->result) { @@ -1062,8 +1017,7 @@ SimplifyReduceCombiner(const ReduceNode* op) { // components which have side effects should also be preserved for (size_t i = 0; i < used.size(); ++i) { - if (HasSideEffect(op->source[i]) || - HasSideEffect(op->combiner->identity_element[i]) || + if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) || HasSideEffect(op->combiner->result[i])) { mark_used(i); } @@ -1091,14 +1045,11 @@ SimplifyReduceCombiner(const ReduceNode* op) { } } - CommReducer new_combiner = - CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); - return ReduceNode::make( - new_combiner, new_source, op->axis, op->condition, new_value_index); + CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); + return ReduceNode::make(new_combiner, new_source, op->axis, op->condition, new_value_index); } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const ReduceNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { // Recursively call simplification when necessary. PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op); op = ret.as(); @@ -1109,10 +1060,8 @@ VisitExpr_(const ReduceNode* op) { // assumption we would have to perform a single iteration of the loop, i.e. use // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. - return this->VisitExpr( - SelectNode::make(op->condition, - op->source[op->value_index], - op->combiner->identity_element[op->value_index])); + return this->VisitExpr(SelectNode::make(op->condition, op->source[op->value_index], + op->combiner->identity_element[op->value_index])); } // combiner simplification. ret = SimplifyReduceCombiner(op); @@ -1123,19 +1072,13 @@ PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { return impl_->CanonicalSimplify(expr); } -void CanonicalSimplifier::Update(const Var& var, - const PrimExpr& info, - bool override) { +void CanonicalSimplifier::Update(const Var& var, const PrimExpr& info, bool override) { impl_->Update(var, info, override); } -CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) - : impl_(new Impl(parent)) { -} +CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} -CanonicalSimplifier::~CanonicalSimplifier() { - delete impl_; -} +CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h index fddd34b..39530ff 100644 --- a/src/arith/compute_expr.h +++ b/src/arith/compute_expr.h @@ -26,8 +26,9 @@ #include #include -#include + #include +#include namespace tvm { namespace arith { @@ -39,7 +40,7 @@ namespace arith { * \tparam Op the computation operator * \return The result. */ -template +template inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) { return OP::make(lhs, rhs); } @@ -52,46 +53,45 @@ inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) { * \tparam Op The computation operator * \return The result. */ -template -inline PrimExpr ComputeReduce( - const Array& values, PrimExpr empty_value); +template +inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_value); -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a + b; } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a - b; } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a * b; } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return truncdiv(a, b); } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return truncmod(a, b); } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return max(a, b); } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return min(a, b); } -template +template inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_value) { if (values.size() == 0U) { CHECK(empty_value.defined()); @@ -106,4 +106,4 @@ inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_valu } // namespace arith } // namespace tvm -#endif // TVM_ARITH_COMPUTE_EXPR_H_ +#endif // TVM_ARITH_COMPUTE_EXPR_H_ diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index a440af9..ad6570e 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -26,8 +26,10 @@ #include #include + #include #include + #include "int_operator.h" namespace tvm { @@ -43,7 +45,7 @@ namespace arith { * \note a and b Must already matched data types with each other. * \return nullptr if constant fold fails, otherwise return folded result. */ -template +template inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { return PrimExpr(); } @@ -57,7 +59,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { * \note a and b Must already matched data types with each other. * \return nullptr if constant fold fails, otherwise return folded result. */ -template +template inline PrimExpr TryConstFold(PrimExpr a); /*! @@ -70,254 +72,250 @@ inline PrimExpr TryConstFold(PrimExpr a); * \return the checked result. */ inline bool IsIndexType(const DataType& type) { - return type.is_int() && type.lanes() == 1 && - (type.bits() == 32 || type.bits() == 64); + return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64); } - -#define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using tir::FloatImmNode; \ - const IntImmNode* pa = a.as(); \ - const IntImmNode* pb = b.as(); \ - const FloatImmNode* fa = a.as(); \ - const FloatImmNode* fb = b.as(); \ +#define TVM_ARITH_CONST_PROPAGATION(BODY) \ + using tir::FloatImmNode; \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ + const FloatImmNode* fa = a.as(); \ + const FloatImmNode* fb = b.as(); \ BODY; - -#define TVM_INDEX_CONST_PROPAGATION(BODY) \ - const IntImmNode* pa = a.as(); \ - const IntImmNode* pb = b.as(); \ - const DataType& ta = a.dtype(); \ - const DataType& tb = b.dtype(); \ - if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ - BODY; \ - } \ - +#define TVM_INDEX_CONST_PROPAGATION(BODY) \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ + const DataType& ta = a.dtype(); \ + const DataType& tb = b.dtype(); \ + if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ + BODY; \ + } // specialization of constant folders. -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value + pb->value); - if (pa && pa->value == 0) return b; - if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value + fb->value); - if (fa && fa->value == 0) return b; - if (fb && fb->value == 0) return a; - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, pa->value + pb->value); + if (pa && pa->value == 0) return b; + if (pb && pb->value == 0) return a; + if (fa && fb) return FloatImm(rtype, fa->value + fb->value); + if (fa && fa->value == 0) return b; + if (fb && fb->value == 0) return a; + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value - pb->value); - if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value - fb->value); - if (fb && fb->value == 0) return a; - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, pa->value - pb->value); + if (pb && pb->value == 0) return a; + if (fa && fb) return FloatImm(rtype, fa->value - fb->value); + if (fb && fb->value == 0) return a; + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value * pb->value); - if (pa) { - if (pa->value == 1) return b; - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return a; - if (pb->value == 0) return b; - } - if (fa && fb) return FloatImm(rtype, fa->value * fb->value); - if (fa) { - if (fa->value == 1) return b; - if (fa->value == 0) return a; - } - if (fb) { - if (fb->value == 1) return a; - if (fb->value == 0) return b; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, pa->value * pb->value); + if (pa) { + if (pa->value == 1) return b; + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + if (pb->value == 0) return b; + } + if (fa && fb) return FloatImm(rtype, fa->value * fb->value); + if (fa) { + if (fa->value == 1) return b; + if (fa->value == 0) return a; + } + if (fb) { + if (fb->value == 1) return a; + if (fb->value == 0) return b; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - // due to division and mod can have different modes - // NOTE: this will assumes truc div. - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value / pb->value); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return a; - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - if (fa && fb && fb->value != 0) { - return FloatImm(rtype, fa->value / fb->value); - } - if (fa && fa->value == 0) return a; - if (fb) { - if (fb->value == 1) return a; - CHECK_NE(fb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + // due to division and mod can have different modes + // NOTE: this will assumes truc div. + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, pa->value / pb->value); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + if (fa && fb && fb->value != 0) { + return FloatImm(rtype, fa->value / fb->value); + } + if (fa && fa->value == 0) return a; + if (fb) { + if (fb->value == 1) return a; + CHECK_NE(fb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value % pb->value); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return tir::make_zero(rtype); - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, pa->value % pb->value); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return tir::make_zero(rtype); + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, arith::floordiv(pa->value, pb->value)); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return a; - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - if (fa && fb && fb->value != 0) { - return FloatImm(rtype, std::floor(fa->value / fb->value)); - } - if (fa && fa->value == 0) return a; - if (fb) { - if (fb->value == 1) return a; - CHECK_NE(fb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, arith::floordiv(pa->value, pb->value)); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + if (fa && fb && fb->value != 0) { + return FloatImm(rtype, std::floor(fa->value / fb->value)); + } + if (fa && fa->value == 0) return a; + if (fb) { + if (fb->value == 1) return a; + CHECK_NE(fb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, floormod(pa->value, pb->value)); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return tir::make_zero(rtype); - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, floormod(pa->value, pb->value)); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return tir::make_zero(rtype); + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); + if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); + }); if (a.same_as(b)) return a; return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); + if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); + }); if (a.same_as(b)) return a; return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); @@ -328,7 +326,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); @@ -339,7 +337,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { @@ -364,9 +362,7 @@ struct SymbolicLimits { * * \return positive infinity. */ -inline PrimExpr pos_inf() { - return SymbolicLimits::pos_inf_; -} +inline PrimExpr pos_inf() { return SymbolicLimits::pos_inf_; } /*! * \brief Check if value is positive infinity. @@ -374,9 +370,7 @@ inline PrimExpr pos_inf() { * * \return The check result. */ -inline bool is_pos_inf(const PrimExpr& value) { - return value.same_as(SymbolicLimits::pos_inf_); -} +inline bool is_pos_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::pos_inf_); } /*! * \brief Opaque expression representing negative infinity. @@ -386,9 +380,7 @@ inline bool is_pos_inf(const PrimExpr& value) { * * \return negative infinity. */ -inline PrimExpr neg_inf() { - return SymbolicLimits::neg_inf_; -} +inline PrimExpr neg_inf() { return SymbolicLimits::neg_inf_; } /*! * \brief Check if value is negative infinity. @@ -396,9 +388,7 @@ inline PrimExpr neg_inf() { * * \return The check result. */ -inline bool is_neg_inf(const PrimExpr& value) { - return value.same_as(SymbolicLimits::neg_inf_); -} +inline bool is_neg_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::neg_inf_); } } // namespace arith } // namespace tvm diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 4437225..0f4d9c0 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -20,10 +20,12 @@ /*! * \file tvm/arith/const_int_bound.cc */ -#include #include +#include #include + #include + #include "int_operator.h" #include "pattern_match.h" @@ -34,8 +36,7 @@ using namespace tir; TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); -ConstIntBound::ConstIntBound( - int64_t min_value, int64_t max_value) { +ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) { auto node = make_object(); node->min_value = min_value; node->max_value = max_value; @@ -46,8 +47,7 @@ ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { return ConstIntBound(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.ConstIntBound") -.set_body_typed(MakeConstIntBound); +TVM_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound); inline void PrintBoundValue(std::ostream& os, int64_t val) { if (val == ConstIntBound::kPosInf) { @@ -60,31 +60,29 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "ConstIntBound["; - PrintBoundValue(p->stream, op->min_value); - p->stream << ','; - PrintBoundValue(p->stream, op->max_value); - p->stream << ']'; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "ConstIntBound["; + PrintBoundValue(p->stream, op->min_value); + p->stream << ','; + PrintBoundValue(p->stream, op->max_value); + p->stream << ']'; + }); // internal entry for const int bound struct ConstIntBoundAnalyzer::Entry { int64_t min_value; int64_t max_value; - bool is_const(int64_t value) const { - return min_value == max_value && min_value == value; - } + bool is_const(int64_t value) const { return min_value == max_value && min_value == value; } bool operator==(const Entry& other) const { return min_value == other.min_value && max_value == other.max_value; } }; -class ConstIntBoundAnalyzer::Impl : - public ExprFunctor { +class ConstIntBoundAnalyzer::Impl + : public ExprFunctor { public: /*! \brief additional bound info about expr \in bound */ struct BoundInfo { @@ -94,9 +92,7 @@ class ConstIntBoundAnalyzer::Impl : Entry bound; BoundInfo() {} - BoundInfo(PrimExpr expr, Entry bound) - : expr(expr), bound(bound) { - } + BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {} }; void Bind(const Var& var, const Range& range, bool override) { @@ -108,32 +104,27 @@ class ConstIntBoundAnalyzer::Impl : Update(var, ret, override); } - void Update(const Var& var, - const Entry& info, - bool override) { + void Update(const Var& var, const Entry& info, bool override) { if (!override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(it->second == info) - << "Trying to update var \'" << var << "\'" - << " with a different const bound: " - << "original=" << ConstIntBound(it->second.min_value, it->second.max_value) - << ", new=" << ConstIntBound(info.min_value, info.max_value); + CHECK(it->second == info) << "Trying to update var \'" << var << "\'" + << " with a different const bound: " + << "original=" + << ConstIntBound(it->second.min_value, it->second.max_value) + << ", new=" << ConstIntBound(info.min_value, info.max_value); } } var_map_[var] = info; } - void Update(const Var& var, - const ConstIntBound& info, - bool override) { + void Update(const Var& var, const ConstIntBound& info, bool override) { Update(var, MakeBound(info->min_value, info->max_value), override); } // Override visitor behaviors Entry VisitExprDefault_(const Object* op) final { - return Everything( - static_cast(op)->dtype); + return Everything(static_cast(op)->dtype); } Entry VisitExpr(const PrimExpr& expr) final { @@ -177,9 +168,7 @@ class ConstIntBoundAnalyzer::Impl : return Intersect(a, b); } - Entry VisitExpr_(const IntImmNode* op) final { - return MakeBound(op->value, op->value); - } + Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); } Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); @@ -224,8 +213,7 @@ class ConstIntBoundAnalyzer::Impl : // 0 <= [a_min, a_max] < b_min if (a.max_value < b.min_value) return a; // other case, we can get close to 0 - return MakeBound(0, - std::min(a.max_value, b_max_cap)); + return MakeBound(0, std::min(a.max_value, b_max_cap)); } else { return MakeBound(std::max(a.min_value, -b_max_cap), std::min(std::max(a.max_value, (int64_t)0), b_max_cap)); @@ -383,7 +371,7 @@ class ConstIntBoundAnalyzer::Impl : * \tparam F the operator function type. * \return The result. */ - template + template static Entry BinaryOpBoundry(Entry a, Entry b, const F& op) { Entry ret; // The boundary point must be shihft of the original boundary. @@ -561,17 +549,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { return ConstIntBound(ret.min_value, ret.max_value); } -ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, - BoundMapType* bound) { +ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapType* bound) { impl_->bound_ = bound; Entry ret = impl_->VisitExpr(expr); impl_->bound_ = nullptr; return ConstIntBound(ret.min_value, ret.max_value); } -void ConstIntBoundAnalyzer::Update(const Var& var, - const ConstIntBound& info, - bool override) { +void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool override) { impl_->Update(var, info, override); } @@ -583,13 +568,9 @@ std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& con return impl_->EnterConstraint(constraint); } -ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) - : impl_(new Impl()) { -} +ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {} -ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { - delete impl_; -} +ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index c7f90f5..2bc7209 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -21,13 +21,13 @@ * \file detect_linear_equation.cc * \brief Utility to detect patterns in the expression. */ +#include #include -#include #include -#include +#include #include +#include #include -#include namespace tvm { namespace arith { @@ -45,11 +45,9 @@ struct IntervalEntry { PrimExpr max_value; }; -class LinearEqDetector - : public ExprFunctor { +class LinearEqDetector : public ExprFunctor { public: - explicit LinearEqDetector(Var var) - : var_(var) {} + explicit LinearEqDetector(Var var) : var_(var) {} bool Detect(const PrimExpr& e, LinearEqEntry* ret) { *ret = VisitExpr(e, e); @@ -142,8 +140,7 @@ class LinearEqDetector } }; -Array DetectLinearEquation(const PrimExpr& e, - const Array& vars) { +Array DetectLinearEquation(const PrimExpr& e, const Array& vars) { PrimExpr base = e; Array coeff; @@ -157,9 +154,7 @@ Array DetectLinearEquation(const PrimExpr& e, } std::unordered_set vset; - auto vset_contains = [&](const VarNode* node) { - return vset.count(node) != 0; - }; + auto vset_contains = [&](const VarNode* node) { return vset.count(node) != 0; }; for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); @@ -173,9 +168,8 @@ Array DetectLinearEquation(const PrimExpr& e, } // Detect clip condition as min max value -bool DetectClipBound( - const PrimExpr& cond, - std::unordered_map* bmap) { +bool DetectClipBound(const PrimExpr& cond, + std::unordered_map* bmap) { int flag = 0; Var var; auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) { @@ -237,8 +231,7 @@ bool DetectClipBound( return false; } - -template +template void SplitCommExpr(const PrimExpr& e, std::vector* ret) { if (const OP* op = e.as()) { SplitCommExpr(op->a, ret); @@ -276,12 +269,11 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { return ret; } -TVM_REGISTER_GLOBAL("arith.DetectLinearEquation") -.set_body_typed(DetectLinearEquation); +TVM_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation); TVM_REGISTER_GLOBAL("arith.DetectClipBound") -.set_body_typed([](const PrimExpr& e, const Array& vars) { - return DetectClipBound(e, vars); -}); + .set_body_typed([](const PrimExpr& e, const Array& vars) { + return DetectClipBound(e, vars); + }); } // namespace arith } // namespace tvm diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 81443db..0ac4a89 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -21,13 +21,13 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ +#include +#include #include #include -#include -#include -#include #include +#include namespace tvm { namespace arith { @@ -37,12 +37,8 @@ using namespace tir; // Find Read region of the tensor in the stmt. class BufferTouchedDomain final : public StmtExprVisitor { public: - BufferTouchedDomain(const Buffer &buffer, - bool consider_loads, - bool consider_stores) - : buffer_(buffer), - consider_loads_(consider_loads), - consider_stores_(consider_stores) {} + BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores) + : buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {} Domain Find(const Stmt& stmt) { operator()(stmt); @@ -54,17 +50,15 @@ class BufferTouchedDomain final : public StmtExprVisitor { return ret; } - void VisitStmt_(const ForNode *op) final { + void VisitStmt_(const ForNode* op) final { const VarNode* var = op->loop_var.get(); - dom_map_[var] = IntSet::range( - Range::make_by_min_extent(op->min, op->extent)); + dom_map_[var] = IntSet::range(Range::make_by_min_extent(op->min, op->extent)); StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); } void VisitStmt_(const LetStmtNode* op) final { - dom_map_[op->var.get()] = - arith::EvalSet(op->value, dom_map_); + dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_); StmtExprVisitor::VisitStmt_(op); dom_map_.erase(op->var.get()); } @@ -107,21 +101,18 @@ class BufferTouchedDomain final : public StmtExprVisitor { } } - const Buffer &buffer_; + const Buffer& buffer_; bool consider_loads_, consider_stores_; std::vector > bounds_; std::unordered_map dom_map_; }; -Domain DomainTouched(const Stmt& stmt, - const Buffer& buffer, - bool consider_loads, +Domain DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, bool consider_stores) { return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt); } -TVM_REGISTER_GLOBAL("arith.DomainTouched") -.set_body_typed(DomainTouched); +TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched); } // namespace arith } // namespace tvm diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 34efa98..62858d2 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -22,19 +22,18 @@ * \brief The integer constraints data structures. */ #include +#include #include #include -#include -#include #include #include +#include namespace tvm { namespace arith { -IntConstraints::IntConstraints(Array variables, - Map ranges, +IntConstraints::IntConstraints(Array variables, Map ranges, Array relations) { ObjectPtr node = make_object(); if (!variables.defined()) { @@ -46,7 +45,7 @@ IntConstraints::IntConstraints(Array variables, CHECK(relations.defined()); for (const auto& var : variables) { CHECK(var.dtype().is_int() || var.dtype().is_uint()) - << "Variables in IntConstraints must be integers"; + << "Variables in IntConstraints must be integers"; } node->variables = std::move(variables); node->ranges = std::move(ranges); @@ -57,18 +56,13 @@ IntConstraints::IntConstraints(Array variables, TVM_REGISTER_NODE_TYPE(IntConstraintsNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntConstraints(" - << op->variables - << ", " << op->ranges - << ", " << op->relations - << ")"; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations + << ")"; + }); -IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, - IntConstraints dst, +IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst, Map src_to_dst, Map dst_to_src) { ObjectPtr node = make_object(); @@ -82,15 +76,12 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntConstraintsTransform(" - << "\n\t" << op->src - << "\n\t" << op->dst - << "\n\t" << op->src_to_dst - << "\n\t" << op->dst_to_src - << "\n)"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraintsTransform(" + << "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst << "\n\t" + << op->dst_to_src << "\n)"; + }); } // namespace arith } // namespace tvm diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index 3be34b6..8e4dda0 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -38,56 +38,41 @@ namespace arith { * \return Whether overflow can happen. * \tparam Op The integer operator. */ -template -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { if ((y > 0) && (x > max_value - y)) return true; if ((y < 0) && (x < min_value - y)) return true; return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { if ((y > 0) && (x < min_value + y)) return true; if ((y < 0) && (x > max_value + y)) return true; return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { if (y == 0) return false; if (y > 0) { - if (x < min_value / y) return true; - if (x > max_value / y) return true; + if (x < min_value / y) return true; + if (x > max_value / y) return true; } else { if (y == -1 && x == std::numeric_limits::min()) return true; - if (x > min_value / y) return true; - if (x < max_value / y) return true; + if (x > min_value / y) return true; + if (x < max_value / y) return true; } return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { return y == 0; } @@ -97,9 +82,7 @@ inline bool WillOverflow(int64_t x, * \param y The right operand. * \return the result. */ -inline int64_t truncdiv(int64_t x, int64_t y) { - return x / y; -} +inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; } /*! * \brief Compute the truncdiv remainder of two integers. @@ -107,9 +90,7 @@ inline int64_t truncdiv(int64_t x, int64_t y) { * \param y The right operand. * \return the result. */ -inline int64_t truncmod(int64_t x, int64_t y) { - return x % y; -} +inline int64_t truncmod(int64_t x, int64_t y) { return x % y; } /*! * \brief Peform floor division of two integers. @@ -120,13 +101,10 @@ inline int64_t truncmod(int64_t x, int64_t y) { inline int64_t floordiv(int64_t x, int64_t y) { int64_t rdiv = x / y; int64_t rmod = x % y; - bool is_floor_div = - (y >= 0 && rmod >= 0) || - (y < 0 && rmod <= 0); + bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0); return is_floor_div ? rdiv : (rdiv - 1); } - /*! * \brief Compute the floordiv remainder of two integers. * \param x The left operand. @@ -135,9 +113,7 @@ inline int64_t floordiv(int64_t x, int64_t y) { */ inline int64_t floormod(int64_t x, int64_t y) { int64_t rmod = x % y; - bool is_floor_div = - (y >= 0 && rmod >= 0) || - (y < 0 && rmod <= 0); + bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0); return is_floor_div ? rmod : rmod + y; } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index d2d43d6..7462808 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -22,23 +22,24 @@ * \brief The integer set functions */ #include +#include #include #include -#include -#include #include #include +#include + #include "interval_set.h" #include "pattern_match.h" namespace tvm { namespace arith { +using tir::is_one; +using tir::is_zero; using tir::make_const; using tir::make_zero; -using tir::is_zero; -using tir::is_one; PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); @@ -54,9 +55,7 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { return IntervalSet(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.IntervalSet") -.set_body_typed(MakeIntervalSet); - +TVM_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet); IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); @@ -77,15 +76,15 @@ IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } // type traits -template +template struct is_logical_op { static const bool value = false; }; -#define TVM_DECLARE_LOGICAL_OP(OP) \ - template<> \ - struct is_logical_op { \ - static const bool value = true; \ +#define TVM_DECLARE_LOGICAL_OP(OP) \ + template <> \ + struct is_logical_op { \ + static const bool value = true; \ }; TVM_DECLARE_LOGICAL_OP(AndNode); @@ -102,18 +101,15 @@ TVM_DECLARE_LOGICAL_OP(NotNode); * \brief Combine two interval set under arithmetic operations. * \note this can possibly relax the set. */ -template -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr res = TryConstFold(a->min_value, b->min_value); if (!res.defined()) res = Op::make(a->min_value, b->min_value); return IntervalSet::SinglePoint(res); } if (is_logical_op::value) { - return IntervalSet(make_const(a->min_value.dtype(), 0), - make_const(a->min_value.dtype(), 1)); + return IntervalSet(make_const(a->min_value.dtype(), 0), make_const(a->min_value.dtype(), 1)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; @@ -122,47 +118,36 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; PrimExpr min_value = - a->HasLowerBound() && b->HasLowerBound() ? - a->min_value + b->min_value : neg_inf(); + a->HasLowerBound() && b->HasLowerBound() ? a->min_value + b->min_value : neg_inf(); PrimExpr max_value = - a->HasUpperBound() && b->HasUpperBound() ? - a->max_value + b->max_value : pos_inf(); + a->HasUpperBound() && b->HasUpperBound() ? a->max_value + b->max_value : pos_inf(); return IntervalSet(min_value, max_value); } -template<> -inline IntervalSet Combine(Analyzer* analyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; PrimExpr min_value = - a->HasLowerBound() && b->HasUpperBound() ? - a->min_value - b->max_value : neg_inf(); + a->HasLowerBound() && b->HasUpperBound() ? a->min_value - b->max_value : neg_inf(); PrimExpr max_value = - a->HasUpperBound() && b->HasLowerBound() ? - a->max_value - b->min_value : pos_inf(); + a->HasUpperBound() && b->HasLowerBound() ? a->max_value - b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } - -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -194,10 +179,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -229,10 +212,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -259,11 +240,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } - -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -295,10 +273,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -331,30 +307,24 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analzyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { - return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); + return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; - return IntervalSet(max(a->min_value, b->min_value), - max(a->max_value, b->max_value)); + return IntervalSet(max(a->min_value, b->min_value), max(a->max_value, b->max_value)); } -template<> -inline IntervalSet Combine(Analyzer* analzyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; - return IntervalSet(min(a->min_value, b->min_value), - min(a->max_value, b->max_value)); + return IntervalSet(min(a->min_value, b->min_value), min(a->max_value, b->max_value)); } // internal helper function to get an interval set @@ -370,20 +340,12 @@ using namespace tir; // Simplified version of int set evaluator that operates on IntervalSet // We might use better set analysis in the future to replace the intervalset. -class IntervalSetEvaluator : - public ExprFunctor { +class IntervalSetEvaluator : public ExprFunctor { public: - IntervalSetEvaluator(Analyzer* analyzer, - const Map& dom_map, - bool eval_vec = false) - : analyzer_(analyzer), - dom_map_(dom_map), - eval_vec_(eval_vec) { - } + IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, bool eval_vec = false) + : analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {} - IntervalSet Eval(const PrimExpr& val) { - return this->VisitExpr(val); - } + IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); } // evaluate and relax the set IntervalSet Eval(IntervalSet val) { // avoid recursive indefinite recursive expansion. @@ -404,8 +366,7 @@ class IntervalSetEvaluator : auto it = dom_map_.find(var); if (it != dom_map_.end()) { IntervalSet res = ToIntervalSet((*it).second); - if (res->min_value.same_as(var) && - res->max_value.same_as(var)) { + if (res->min_value.same_as(var) && res->max_value.same_as(var)) { return res; } // recursively evaluate mapped result @@ -416,74 +377,39 @@ class IntervalSetEvaluator : } } + IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const AddNode* op) final { - return VisitBinaryExpr_(op); - } - - IntervalSet VisitExpr_(const SubNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MulNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const DivNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const ModNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorDivNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorModNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MinNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MaxNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const EQNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const NENode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LTNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LENode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GTNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GENode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const AndNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const OrNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_(op); } IntervalSet VisitExpr_(const RampNode* op) final { CHECK(eval_vec_); @@ -492,16 +418,12 @@ class IntervalSetEvaluator : if (stride.Match(op->stride)) { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; - if (vstride> 0) { - return Combine( - analyzer_, - base, - IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); + if (vstride > 0) { + return Combine(analyzer_, base, + IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); } else { - return Combine( - analyzer_, - base, - IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); + return Combine(analyzer_, base, + IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); } } DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); @@ -526,12 +448,11 @@ class IntervalSetEvaluator : private: // whether set is exactly single point that equals value. - bool MatchPoint(const IntervalSet& set, - const PrimExpr& value) const { + bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const { return set->min_value.same_as(value) && set->max_value.same_as(value); } - template + template inline IntervalSet VisitBinaryExpr_(const T* op) { IntervalSet a = this->Eval(op->a); IntervalSet b = this->Eval(op->b); @@ -551,9 +472,7 @@ class IntervalSetEvaluator : class IntSetAnalyzer::Impl { public: - explicit Impl(Analyzer* analyzer) - : analyzer_(analyzer) { - } + explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} IntSet Eval(const PrimExpr& expr, const Map& dom_map) const { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); @@ -563,16 +482,11 @@ class IntSetAnalyzer::Impl { Analyzer* analyzer_; }; -IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) - : impl_(new Impl(parent)) { -} +IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} -IntSetAnalyzer::~IntSetAnalyzer() { - delete impl_; -} +IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } -IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, - const Map& dom_map) { +IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map& dom_map) { return impl_->Eval(expr, dom_map); } @@ -584,8 +498,8 @@ Range IntSet::cover_range(Range max_range) const { const IntervalSetNode* s_int = (*this).as(); CHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { - return Range::make_by_min_extent( - s_int->min_value, analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); + return Range::make_by_min_extent(s_int->min_value, + analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); } return max_range; } @@ -664,17 +578,11 @@ PrimExpr IntSet::point_value() const { return s_int->min_value; } -IntSet IntSet::nothing() { - return IntervalSet::Empty(); -} +IntSet IntSet::nothing() { return IntervalSet::Empty(); } -IntSet IntSet::everything() { - return IntervalSet::Everything(); -} +IntSet IntSet::everything() { return IntervalSet::Everything(); } -IntSet IntSet::single_point(PrimExpr x) { - return IntervalSet::SinglePoint(x); -} +IntSet IntSet::single_point(PrimExpr x) { return IntervalSet::SinglePoint(x); } IntSet IntSet::interval(PrimExpr min, PrimExpr max) { if (min.same_as(max)) { @@ -702,7 +610,7 @@ bool IntSet::match_range(const Range& b) const { if (!a_int) return false; Analyzer ana; return ProveEqual(&ana, a_int->min_value, b->min) && - ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); + ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); } IntSet Union(const Array& sets) { @@ -713,8 +621,7 @@ IntSet Union(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Union(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(ana.Simplify(x->min_value), - ana.Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } IntSet Intersect(const Array& sets) { @@ -725,8 +632,7 @@ IntSet Intersect(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Intersect(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(ana.Simplify(x->min_value), - ana.Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } Map ConvertDomMap(const Map& dom_map) { @@ -737,8 +643,7 @@ Map ConvertDomMap(const Map& dom_map) { return dmap; } -Map ConvertDomMap( - const std::unordered_map& dom_map) { +Map ConvertDomMap(const std::unordered_map& dom_map) { Map dmap; for (auto kv : dom_map) { dmap.Set(GetRef(kv.first), kv.second); @@ -746,8 +651,7 @@ Map ConvertDomMap( return dmap; } -IntSet EvalSet(PrimExpr e, - const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const Map& dom_map) { Analyzer ana; return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); } @@ -758,49 +662,40 @@ IntSet IntSet::vector(PrimExpr x) { return IntervalSetEvaluator(&ana, dmap, true).Eval(x); } -IntSet EvalSet(PrimExpr e, - const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const Map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map) { +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, - const Map& dom_map) { +IntSet EvalSet(Range r, const Map& dom_map) { Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; - auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); + auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); return std::move(res); } -IntSet EvalSet(Range r, - const std::unordered_map& dom_map) { +IntSet EvalSet(Range r, const std::unordered_map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } -IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map) { +IntSet EvalSet(IntSet s, const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); IntervalSetEvaluator m(&ana, dmap); const IntervalSetNode* s_int = s.as(); - PrimExpr vmax = s_int->HasUpperBound() ? - m.Eval(s_int->max_value).max() : s_int->max_value; - PrimExpr vmin = s_int->HasLowerBound() ? - m.Eval(s_int->min_value).min() : s_int->min_value; + PrimExpr vmax = s_int->HasUpperBound() ? m.Eval(s_int->max_value).max() : s_int->max_value; + PrimExpr vmin = s_int->HasLowerBound() ? m.Eval(s_int->min_value).min() : s_int->min_value; return IntervalSet(vmin, vmax); } class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { public: - explicit SubExprIntervalSetEvaluator( - Analyzer* analyzer, - const Map& dom_map) + explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map) : IntervalSetEvaluator(analyzer, dom_map) {} IntervalSet VisitExpr(const PrimExpr& n) final { @@ -812,9 +707,8 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { ExprIntSetMap expr_map; }; -ExprIntSetMap EvalSetForEachSubExpr( - PrimExpr e, - const std::unordered_map& dom_map) { +ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, + const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); SubExprIntervalSetEvaluator m(&ana, dmap); @@ -822,42 +716,32 @@ ExprIntSetMap EvalSetForEachSubExpr( return m.expr_map; } -IntSet EvalSet(Range r, - const Map& dom_map) { +IntSet EvalSet(Range r, const Map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } TVM_REGISTER_NODE_TYPE(IntervalSetNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntervalSet" - << "[" << op->min_value << ", " - << op->max_value << ']'; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntervalSet" + << "[" << op->min_value << ", " << op->max_value << ']'; + }); -TVM_REGISTER_GLOBAL("arith.intset_single_point") -.set_body_typed(IntSet::single_point); +TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::single_point); -TVM_REGISTER_GLOBAL("arith.intset_vector") -.set_body_typed(IntSet::vector); +TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::vector); -TVM_REGISTER_GLOBAL("arith.intset_interval") -.set_body_typed(IntSet::interval); +TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::interval); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin") -.set_body_method(&IntSet::min); +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax") -.set_body_method(&IntSet::max); +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max); -TVM_REGISTER_GLOBAL("arith.IntSetIsNothing") -.set_body_method(&IntSet::is_nothing); +TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::is_nothing); -TVM_REGISTER_GLOBAL("arith.IntSetIsEverything") -.set_body_method(&IntSet::is_everything); +TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::is_everything); } // namespace arith } // namespace tvm diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index 51b500a..eb308dd 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -26,7 +26,9 @@ #include #include + #include + #include "const_fold.h" namespace tvm { @@ -53,26 +55,18 @@ class IntervalSetNode : public IntSetNode { } /*! \return Whether the interval has upper bound. */ - bool HasUpperBound() const { - return !is_pos_inf(max_value) && !IsEmpty(); - } + bool HasUpperBound() const { return !is_pos_inf(max_value) && !IsEmpty(); } /*! \return Whether the interval has lower bound. */ - bool HasLowerBound() const { - return !is_neg_inf(min_value) && !IsEmpty(); - } + bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); } /*! \return Whether the interval is a single point. */ - bool IsSinglePoint() const { - return min_value.same_as(max_value); - } + bool IsSinglePoint() const { return min_value.same_as(max_value); } /*! \return whether interval represent nothing */ bool IsEmpty() const { // during computations, either extreme could occur. return is_pos_inf(min_value) || is_neg_inf(max_value); } /*! \return whether interval represent everything */ - bool IsEverything() const { - return is_neg_inf(min_value) && is_pos_inf(max_value); - } + bool IsEverything() const { return is_neg_inf(min_value) && is_pos_inf(max_value); } static constexpr const char* _type_key = "arith.IntervalSet"; TVM_DECLARE_FINAL_OBJECT_INFO(IntervalSetNode, IntSetNode); @@ -97,24 +91,18 @@ class IntervalSet : public IntSet { * \param value The value to be represented. * \return The result set. */ - static IntervalSet SinglePoint(PrimExpr value) { - return IntervalSet(value, value); - } + static IntervalSet SinglePoint(PrimExpr value) { return IntervalSet(value, value); } /*! * \brief Create an IntervalSet that represents everything. * \param value The value to be represented. * \return The result set. */ - static IntervalSet Everything() { - return IntervalSet(neg_inf(), pos_inf()); - } + static IntervalSet Everything() { return IntervalSet(neg_inf(), pos_inf()); } /*! * \brief Create an empty eet. * \return The result set. */ - static IntervalSet Empty() { - return IntervalSet(pos_inf(), neg_inf()); - } + static IntervalSet Empty() { return IntervalSet(pos_inf(), neg_inf()); } TVM_DEFINE_OBJECT_REF_COW_METHOD(IntervalSetNode); TVM_DEFINE_OBJECT_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); @@ -136,7 +124,7 @@ TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b); * \param b The second set. * \return The result set. */ -TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b); +TVM_DLL IntervalSet Intersect(Analyzer* analzyer, IntervalSet a, IntervalSet b); } // namespace arith } // namespace tvm diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 0ae9841..e09ff1d 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -20,24 +20,22 @@ /*! * \file tvm/arith/ir_mutator_with_analyzer.cc */ +#include "ir_mutator_with_analyzer.h" + #include #include -#include "ir_mutator_with_analyzer.h" namespace tvm { namespace arith { using namespace tir; -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const ForNode* op) { - analyzer_->Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); +Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { + analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); return StmtExprMutator::VisitStmt_(op); } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const LetStmtNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); if (!tir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); @@ -45,8 +43,7 @@ VisitStmt_(const LetStmtNode* op) { // We keep the let-binding here // as sub-class may or maynot choose to replace it. Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -56,8 +53,7 @@ VisitStmt_(const LetStmtNode* op) { } } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const IfThenElseNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case, else_case; { @@ -65,9 +61,8 @@ VisitStmt_(const IfThenElseNode* op) { then_case = this->VisitStmt(op->then_case); } if (op->else_case.defined()) { - With ctx(analyzer_, - analyzer_->rewrite_simplify(NotNode::make(condition))); - else_case = this->VisitStmt(op->else_case); + With ctx(analyzer_, analyzer_->rewrite_simplify(NotNode::make(condition))); + else_case = this->VisitStmt(op->else_case); } if (is_one(condition)) return then_case; if (is_zero(condition)) { @@ -77,8 +72,7 @@ VisitStmt_(const IfThenElseNode* op) { return EvaluateNode::make(0); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -90,14 +84,11 @@ VisitStmt_(const IfThenElseNode* op) { } } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == tir::attr::thread_extent || - op->attr_key == tir::attr::virtual_thread) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); - analyzer_->Bind(iv->var, - Range::make_by_min_extent(0, op->value)); + analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); Stmt stmt = StmtExprMutator::VisitStmt_(op); return stmt; } else { @@ -105,16 +96,13 @@ VisitStmt_(const AttrStmtNode* op) { } } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const AssertStmtNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr message = this->VisitExpr(op->message); With ctx(analyzer_, condition); Stmt body = this->VisitStmt(op->body); - if (condition.same_as(op->condition) && - message.same_as(op->message) && - body.same_as(op->body)) { + if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -125,8 +113,7 @@ VisitStmt_(const AssertStmtNode* op) { } } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const CallNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) { PrimExpr cond = this->VisitExpr(op->args[0]); @@ -146,21 +133,17 @@ VisitExpr_(const CallNode* op) { if (is_one(cond)) { return true_value; } - if (cond.same_as(op->args[0]) && - true_value.same_as(op->args[1]) && + if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { return GetRef(op); } else { - return CallNode::make(op->dtype, op->name, - {cond, true_value, false_value}, - op->call_type); + return CallNode::make(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); } } return StmtExprMutator::VisitExpr_(op); } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const LetNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); if (!tir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); @@ -168,16 +151,14 @@ VisitExpr_(const LetNode* op) { // We keep the let-binding here // as sub-class may or maynot choose to replace it. PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); } } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const SelectNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) { PrimExpr cond = this->VisitExpr(op->condition); PrimExpr true_value, false_value; { @@ -185,8 +166,7 @@ VisitExpr_(const SelectNode* op) { true_value = VisitExpr(op->true_value); } { - With constraint(analyzer_, - analyzer_->rewrite_simplify(NotNode::make(cond))); + With constraint(analyzer_, analyzer_->rewrite_simplify(NotNode::make(cond))); false_value = VisitExpr(op->false_value); } if (is_zero(cond)) { @@ -196,8 +176,7 @@ VisitExpr_(const SelectNode* op) { return true_value; } // normal path - if (cond.same_as(op->condition) && - true_value.same_as(op->true_value) && + if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return GetRef(op); } else { @@ -205,8 +184,7 @@ VisitExpr_(const SelectNode* op) { } } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const ReduceNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const ReduceNode* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_->Bind(iv->var, iv->dom); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index f6004e2..004265b 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -24,8 +24,9 @@ #ifndef TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ #define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ -#include #include +#include + #include namespace tvm { @@ -42,11 +43,10 @@ namespace arith { */ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { public: - explicit IRMutatorWithAnalyzer(Analyzer* analyzer) - : analyzer_(analyzer) {} + explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {} - using StmtExprMutator::VisitStmt_; using StmtExprMutator::VisitExpr_; + using StmtExprMutator::VisitStmt_; // override functions that need to populate the context information. tir::Stmt VisitStmt_(const tir::ForNode* op) override; diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index b2dbe9d..810949b 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -34,23 +34,18 @@ namespace tir { class IRVisitorWithAnalyzer final : public StmtExprVisitor { public: - PrimExpr Simplify(const PrimExpr& expr) { - return analyzer_.Simplify(expr); - } + PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); } void VisitStmt_(const ForNode* op) { - analyzer_.Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); + analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); return StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, - Range::make_by_min_extent(0, op->value)); + analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 40cd7f8..7ddb8f5 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -21,13 +21,15 @@ * \file modular_set.cc * \brief Modular set analysis */ -#include #include -#include +#include #include +#include + #include -#include #include +#include + #include "pattern_match.h" namespace tvm { @@ -46,19 +48,15 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "ModularSet(" - << "coeff=" << op->coeff << ", base=" - << op->base << ')'; - }); - -ModularSet MakeModularSet(int64_t coeff, int64_t base) { - return ModularSet(coeff, base); -} + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "ModularSet(" + << "coeff=" << op->coeff << ", base=" << op->base << ')'; + }); + +ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } -TVM_REGISTER_GLOBAL("arith.ModularSet") -.set_body_typed(MakeModularSet); +TVM_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet); // internal entry for const int bound struct ModularSetAnalyzer::Entry { @@ -77,37 +75,27 @@ struct ModularSetAnalyzer::Entry { this->base = base; } - bool is_const() const { - return coeff == 0; - } + bool is_const() const { return coeff == 0; } - bool operator==(const Entry& other) const { - return coeff == other.coeff && base == other.base; - } + bool operator==(const Entry& other) const { return coeff == other.coeff && base == other.base; } bool operator==(const ModularSet& other) const { - return other.defined() && - coeff == other->coeff && base == other->base; + return other.defined() && coeff == other->coeff && base == other->base; } }; -class ModularSetAnalyzer::Impl : - public ExprFunctor { +class ModularSetAnalyzer::Impl : public ExprFunctor { public: - explicit Impl(Analyzer* parent) - : parent_(parent) {} + explicit Impl(Analyzer* parent) : parent_(parent) {} - void Update(const Var& var, - const ModularSet& info, - bool override) { + void Update(const Var& var, const ModularSet& info, bool override) { if (!override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(it->second == info) - << "Trying to update var \'" << var << "\'" - << " with a different const bound: " - << "original=" << ModularSet(it->second.coeff, it->second.base) - << ", new=" << info; + CHECK(it->second == info) << "Trying to update var \'" << var << "\'" + << " with a different const bound: " + << "original=" << ModularSet(it->second.coeff, it->second.base) + << ", new=" << info; } } var_map_[var] = Entry(info->coeff, info->base); @@ -127,17 +115,11 @@ class ModularSetAnalyzer::Impl : } // Override visitor behaviors - Entry VisitExprDefault_(const Object* op) final { - return Everything(); - } + Entry VisitExprDefault_(const Object* op) final { return Everything(); } - Entry VisitExpr_(const CastNode* op) final { - return VisitExpr(op->value); - } + Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } - Entry VisitExpr_(const IntImmNode* op) final { - return Entry(0, op->value); - } + Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); } Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); @@ -167,9 +149,7 @@ class ModularSetAnalyzer::Impl : return Entry(coeff, a.base * b.base); } - Entry DivByConst(const PrimExpr& lhs, - int64_t val, - bool round_down) { + Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) { Entry a = VisitExpr(lhs); CHECK_NE(val, 0); if (a.coeff % val == 0) { @@ -179,8 +159,7 @@ class ModularSetAnalyzer::Impl : } // positive division have a clear rounding mode. // Only handle case where we clearly know we need to round down. - if (a.base > 0 && val > 0 && - (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { + if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { return Entry(a.coeff / val, a.base / val); } } @@ -269,9 +248,7 @@ class ModularSetAnalyzer::Impl : } var_map_[var] = Intersect(old, entry); // reover function. - return [this, old, var]() { - var_map_[var] = old; - }; + return [this, old, var]() { var_map_[var] = old; }; } /*! * \brief Create union of two sets. @@ -385,16 +362,12 @@ class ModularSetAnalyzer::Impl : * \brief return everything dtype can represent. * \return Bound that represent everything dtype can represent. */ - static Entry Everything() { - return Entry(1, 0); - } + static Entry Everything() { return Entry(1, 0); } /*! * \brief return an empty set * \return Bound that represent everything dtype can represent. */ - static Entry Nothing() { - return Entry(0, 1); - } + static Entry Nothing() { return Entry(0, 1); } }; ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { @@ -402,9 +375,7 @@ ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { return ModularSet(ret.coeff, ret.base); } -void ModularSetAnalyzer::Update(const Var& var, - const ModularSet& info, - bool override) { +void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool override) { impl_->Update(var, info, override); } @@ -412,13 +383,9 @@ std::function ModularSetAnalyzer::EnterConstraint(const PrimExpr& constr return impl_->EnterConstraint(constraint); } -ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) - : impl_(new Impl(parent)) { -} +ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} -ModularSetAnalyzer::~ModularSetAnalyzer() { - delete impl_; -} +ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 14cfbd6..2a02303 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -67,7 +67,9 @@ #include #include + #include + #include "const_fold.h" namespace tvm { @@ -84,7 +86,7 @@ namespace arith { * * \tparam Derived The type of the derived class. */ -template +template class Pattern { public: /*! @@ -108,30 +110,26 @@ class Pattern { * * \return whether value matches the pattern. */ - template + template bool Match(const NodeType& value) const { derived().InitMatch_(); return derived().Match_(value); } /*! \return Derived instance of current class. */ - const Derived& derived() const { - return *static_cast(this); - } + const Derived& derived() const { return *static_cast(this); } }; /*! * \brief Default deep equality checker * \tparam T the comparison point. */ -template +template class PEqualChecker { public: - bool operator()(const T& lhs, const T& rhs) const { - return lhs == rhs; - } + bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; } }; -template<> +template <> class PEqualChecker { public: bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { @@ -140,20 +138,16 @@ class PEqualChecker { } }; -template<> +template <> class PEqualChecker { public: - bool operator()(const IntImm& lhs, const IntImm& rhs) const { - return lhs->value == rhs->value; - } + bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } }; -template<> +template <> class PEqualChecker { public: - bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { - return lhs.same_as(rhs); - } + bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); } }; /*! @@ -166,15 +160,13 @@ class PEqualChecker { * \note PVar is not thread safe. * Do not use the same PVar in multiple threads. */ -template -class PVar : public Pattern > { +template +class PVar : public Pattern> { public: // Store PVars by reference in the expression. using Nested = const PVar&; - void InitMatch_() const { - filled_ = false; - } + void InitMatch_() const { filled_ = false; } bool Match_(const T& value) const { if (!filled_) { @@ -186,9 +178,8 @@ class PVar : public Pattern > { } } - template::value>::type> + template ::value>::type> bool Match_(const NodeRefType& value) const { if (const auto* ptr = value.template as()) { return Match_(GetRef(ptr)); @@ -214,21 +205,17 @@ class PVar : public Pattern > { * * \tparam T the type of the hole. */ -template -class PConst : public Pattern > { +template +class PConst : public Pattern> { public: PConst(T value) // NOLINT(*) : value_(value) {} void InitMatch_() const {} - bool Match_(const T& value) const { - return PEqualChecker()(value_, value); - } + bool Match_(const T& value) const { return PEqualChecker()(value_, value); } - T Eval() const { - return value_; - } + T Eval() const { return value_; } private: const T value_; @@ -240,9 +227,8 @@ class PConst : public Pattern > { * \tparam TA The pattern type of the first operand. * \tparam TB The pattern type of the second operand. */ -template -class PBinaryExpr : - public Pattern > { +template +class PBinaryExpr : public Pattern> { public: PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {} @@ -274,12 +260,10 @@ class PBinaryExpr : typename TB::Nested b_; }; -template -class PConstWithTypeLike : - public Pattern > { +template +class PConstWithTypeLike : public Pattern> { public: - PConstWithTypeLike(const TA& ref, int64_t value) - : ref_(ref), value_(value) {} + PConstWithTypeLike(const TA& ref, int64_t value) : ref_(ref), value_(value) {} void InitMatch_() const {} @@ -291,39 +275,33 @@ class PConstWithTypeLike : } } - PrimExpr Eval() const { - return tir::make_const(ref_.Eval().dtype(), value_); - } + PrimExpr Eval() const { return tir::make_const(ref_.Eval().dtype(), value_); } private: typename TA::Nested ref_; int64_t value_; }; - -#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \ - template \ - inline PBinaryExpr \ - FuncName(const Pattern& a, const Pattern& b) { \ - CheckStep; \ - return PBinaryExpr(a.derived(), b.derived()); \ - } \ - template \ - inline PBinaryExpr > \ - FuncName(const Pattern& a, int64_t b) { \ - CheckStep; \ - return FuncName(a, PConstWithTypeLike(a.derived(), b)); \ - } \ - template \ - inline PBinaryExpr, TA> \ - FuncName(int64_t b, const Pattern& a) { \ - CheckStep; \ - return FuncName(PConstWithTypeLike(a.derived(), b), a); \ - } - -#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \ - TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) - +#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \ + template \ + inline PBinaryExpr FuncName(const Pattern& a, const Pattern& b) { \ + CheckStep; \ + return PBinaryExpr(a.derived(), b.derived()); \ + } \ + template \ + inline PBinaryExpr> FuncName(const Pattern& a, \ + int64_t b) { \ + CheckStep; \ + return FuncName(a, PConstWithTypeLike(a.derived(), b)); \ + } \ + template \ + inline PBinaryExpr, TA> FuncName(int64_t b, \ + const Pattern& a) { \ + CheckStep; \ + return FuncName(PConstWithTypeLike(a.derived(), b), a); \ + } + +#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) // raise ambiguity error for operator overload of / and % TVM_PATTERN_BINARY_OP_EX(operator/, tir::DivNode, DivAmbiguityError(a)); @@ -355,15 +333,12 @@ TVM_PATTERN_BINARY_OP(operator||, tir::OrNode); * \brief Pattern not expression. * \tparam TA The pattern type of the true operand. */ -template -class PNotExpr : public Pattern > { +template +class PNotExpr : public Pattern> { public: - explicit PNotExpr(const TA& value) - : value_(value) {} + explicit PNotExpr(const TA& value) : value_(value) {} - void InitMatch_() const { - value_.InitMatch_(); - } + void InitMatch_() const { value_.InitMatch_(); } bool Match_(const ObjectRef& node) const { if (const tir::NotNode* ptr = node.as()) { @@ -374,15 +349,13 @@ class PNotExpr : public Pattern > { } } - PrimExpr Eval() const { - return tir::NotNode::make(value_.Eval()); - } + PrimExpr Eval() const { return tir::NotNode::make(value_.Eval()); } private: typename TA::Nested value_; }; -template +template inline PNotExpr operator!(const Pattern& value) { return PNotExpr(value.derived()); } @@ -394,16 +367,11 @@ inline PNotExpr operator!(const Pattern& value) { * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ -template -class PSelectExpr : - public Pattern > { +template +class PSelectExpr : public Pattern> { public: - PSelectExpr(const TCond& condition, - const TA& true_value, - const TB& false_value) - : condition_(condition), - true_value_(true_value), - false_value_(false_value) {} + PSelectExpr(const TCond& condition, const TA& true_value, const TB& false_value) + : condition_(condition), true_value_(true_value), false_value_(false_value) {} void InitMatch_() const { condition_.InitMatch_(); @@ -423,8 +391,7 @@ class PSelectExpr : } PrimExpr Eval() const { - return tir::SelectNode::make( - condition_.Eval(), true_value_.Eval(), false_value_.Eval()); + return tir::SelectNode::make(condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } private: @@ -446,13 +413,12 @@ class PSelectExpr : * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ -template -inline PSelectExpr -select(const Pattern& condition, - const Pattern& true_value, - const Pattern& false_value) { - return PSelectExpr( - condition.derived(), true_value.derived(), false_value.derived()); +template +inline PSelectExpr select(const Pattern& condition, + const Pattern& true_value, + const Pattern& false_value) { + return PSelectExpr(condition.derived(), true_value.derived(), + false_value.derived()); } /*! @@ -460,13 +426,10 @@ select(const Pattern& condition, * \tparam DType The Pattern type of dtype. * \tparam TA The pattern type of the first operand. */ -template -class PCastExpr : - public Pattern > { +template +class PCastExpr : public Pattern> { public: - PCastExpr(const DType& dtype, const TA& value) - : dtype_(dtype), value_(value) { - } + PCastExpr(const DType& dtype, const TA& value) : dtype_(dtype), value_(value) {} void InitMatch_() const { dtype_.InitMatch_(); @@ -483,9 +446,7 @@ class PCastExpr : } } - PrimExpr Eval() const { - return tir::CastNode::make(dtype_.Eval(), value_.Eval()); - } + PrimExpr Eval() const { return tir::CastNode::make(dtype_.Eval(), value_.Eval()); } private: typename DType::Nested dtype_; @@ -503,9 +464,8 @@ class PCastExpr : * \tparam DType The pattern type of type. * \tparam TA The pattern type of value. */ -template -inline PCastExpr -cast(const Pattern& dtype, const Pattern& value) { +template +inline PCastExpr cast(const Pattern& dtype, const Pattern& value) { return PCastExpr(dtype.derived(), value.derived()); } @@ -515,15 +475,11 @@ cast(const Pattern& dtype, const Pattern& value) { * \tparam TStride The pattern type of the stride. * \tparam TLanes The pattern type of the lanes. */ -template -class PRampExpr : - public Pattern > { +template +class PRampExpr : public Pattern> { public: - PRampExpr(const TBase& base, - const TStride& stride, - const TLanes& lanes) - : base_(base), stride_(stride), lanes_(lanes) { - } + PRampExpr(const TBase& base, const TStride& stride, const TLanes& lanes) + : base_(base), stride_(stride), lanes_(lanes) {} void InitMatch_() const { base_.InitMatch_(); @@ -542,9 +498,7 @@ class PRampExpr : } } - PrimExpr Eval() const { - return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); - } + PrimExpr Eval() const { return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); } private: typename TBase::Nested base_; @@ -565,24 +519,18 @@ class PRampExpr : * \tparam TStride The pattern type of the stride. * \tparam TLanes The pattern type of the lanes. */ -template -inline PRampExpr -ramp(const Pattern& base, - const Pattern& stride, - const Pattern& lanes) { - return PRampExpr( - base.derived(), stride.derived(), lanes.derived()); +template +inline PRampExpr ramp(const Pattern& base, + const Pattern& stride, + const Pattern& lanes) { + return PRampExpr(base.derived(), stride.derived(), lanes.derived()); } -template -inline PRampExpr, PConst> -ramp(const Pattern& base, - int stride, - int lanes) { +template +inline PRampExpr, PConst> ramp(const Pattern& base, + int stride, int lanes) { return PRampExpr, PConst>( - base.derived(), - PConstWithTypeLike(base.derived(), stride), - PConst(lanes)); + base.derived(), PConstWithTypeLike(base.derived(), stride), PConst(lanes)); } /*! @@ -590,14 +538,10 @@ ramp(const Pattern& base, * \tparam TA The pattern type of the value. * \tparam TLanes The pattern type of the lanes. */ -template -class PBroadcastExpr : - public Pattern > { +template +class PBroadcastExpr : public Pattern> { public: - PBroadcastExpr(const TA& value, - const TLanes& lanes) - : value_(value), lanes_(lanes) { - } + PBroadcastExpr(const TA& value, const TLanes& lanes) : value_(value), lanes_(lanes) {} void InitMatch_() const { value_.InitMatch_(); @@ -614,9 +558,7 @@ class PBroadcastExpr : } } - PrimExpr Eval() const { - return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); - } + PrimExpr Eval() const { return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); } private: typename TA::Nested value_; @@ -634,40 +576,37 @@ class PBroadcastExpr : * \tparam TA The pattern type of the value. * \tparam TLanes The pattern type of the lanes. */ -template -inline PBroadcastExpr -broadcast(const Pattern& value, const Pattern& lanes) { +template +inline PBroadcastExpr broadcast(const Pattern& value, + const Pattern& lanes) { return PBroadcastExpr(value.derived(), lanes.derived()); } // internal namespace namespace detail { // implementation details for CallExpr -template +template struct tuple_for_each_dispatcher { - template - static void run(F& f, const TTuple& tuple) { // NOLINT(*) + template + static void run(F& f, const TTuple& tuple) { // NOLINT(*) f(I, std::get(tuple)); - tuple_for_each_dispatcher< - (I + 1) == std::tuple_size::value, (I + 1), F> - ::run(f, tuple); + tuple_for_each_dispatcher<(I + 1) == std::tuple_size::value, (I + 1), F>::run(f, tuple); } }; -template +template struct tuple_for_each_dispatcher { - template - static void run(F& f, const TTuple& tuple) {} // NOLINT(*) + template + static void run(F& f, const TTuple& tuple) {} // NOLINT(*) }; -template +template inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*) - tuple_for_each_dispatcher::value == 0, 0, F> - ::run(f, tuple); + tuple_for_each_dispatcher::value == 0, 0, F>::run(f, tuple); } struct PCallExprInitMatchFunctor { - template + template void operator()(size_t i, const T& pattern) const { pattern.InitMatch_(); } @@ -677,10 +616,9 @@ struct PCallExprMatchFunctor { const tir::CallNode* call_; bool matched_{true}; - explicit PCallExprMatchFunctor(const tir::CallNode* call) - : call_(call) {} + explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {} - template + template void operator()(size_t i, const T& pattern) { matched_ = matched_ && pattern.Match_(call_->args[i]); } @@ -689,7 +627,7 @@ struct PCallExprMatchFunctor { struct PCallExprEvalArgsFunctor { Array args_; - template + template void operator()(size_t i, const T& pattern) { args_.push_back(pattern.Eval()); } @@ -703,13 +641,10 @@ struct PCallExprEvalArgsFunctor { * \note Op functor contains the name of the function and * the implementation of Eval. */ -template -class PCallExpr : - public Pattern > { +template +class PCallExpr : public Pattern> { public: - explicit PCallExpr(const TArgs&... args) - : args_(args...) { - } + explicit PCallExpr(const TArgs&... args) : args_(args...) {} void InitMatch_() const { detail::PCallExprInitMatchFunctor finit; @@ -739,18 +674,16 @@ class PCallExpr : }; // arithemetic intrinsics -#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::CallNode::make(args[0].dtype(), kName, args, \ - tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr \ - FuncName(const Pattern& a, const Pattern& b) { \ - return PCallExpr(a.derived(), b.derived()); \ +#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::CallNode::make(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ + return PCallExpr(a.derived(), b.derived()); \ } TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left"); @@ -760,18 +693,16 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or"); TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::CallNode::make(args[0].dtype(), kName, args, \ - tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr \ - FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::CallNode::make(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); @@ -779,9 +710,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(Array args) { - return tir::CallNode::make( - args[1].dtype(), kName, args, - tir::CallNode::PureIntrinsic); + return tir::CallNode::make(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic); } static constexpr const char* kName = "tvm_if_then_else"; }; @@ -799,13 +728,12 @@ struct PIfThenElseOp { * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ -template -inline PCallExpr -if_then_else(const Pattern& cond, - const Pattern& true_value, - const Pattern& false_value) { - return PCallExpr( - cond.derived(), true_value.derived(), false_value.derived()); +template +inline PCallExpr if_then_else(const Pattern& cond, + const Pattern& true_value, + const Pattern& false_value) { + return PCallExpr(cond.derived(), true_value.derived(), + false_value.derived()); } } // namespace arith diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 1263108..3b8ccfb 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -22,12 +22,15 @@ * \brief Rewrite-rule based simplification. */ // Acknowledgement: Most rewrite-rules are from Halide. +#include "rewrite_simplify.h" + #include #include + #include + #include "const_fold.h" #include "pattern_match.h" -#include "rewrite_simplify.h" namespace tvm { namespace arith { @@ -35,9 +38,9 @@ namespace arith { using namespace tir; // macro for doing simple rewrite -#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ - if ((SrcExpr).Match(ret)) { \ - return (ResExpr).Eval(); \ +#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ + if ((SrcExpr).Match(ret)) { \ + return (ResExpr).Eval(); \ } // macro for rewrite + recursively rewrite ResExpr @@ -47,15 +50,15 @@ using namespace tir; } // macro rewrite only if CondExor is true after match. -#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ - if ((SrcExpr).Match(ret) && (CondExpr)) { \ - return (ResExpr).Eval(); \ +#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return (ResExpr).Eval(); \ } // macro rewrite + recursive_rewrite only if CondExor is true after match. -#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ - if ((SrcExpr).Match(ret) && (CondExpr)) { \ - return RecursiveRewrite((ResExpr).Eval()); \ +#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return RecursiveRewrite((ResExpr).Eval()); \ } // NOTE for developers: @@ -66,8 +69,8 @@ using namespace tir; // // try to prove x equals val -RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl:: -TryCompare(const PrimExpr& x, int64_t val) { +RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, + int64_t val) { PrimExpr diff = this->VisitExpr(x); if (const auto* ptr = diff.as()) { if (ptr->value == val) { @@ -100,23 +103,19 @@ TryCompare(const PrimExpr& x, int64_t val) { return kUnknown; } -void RewriteSimplifier::Impl:: -Update(const Var& var, const PrimExpr& info, bool can_override) { +void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) { if (!can_override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(ExprDeepEqual()(it->second, info)) - << "Trying to update var \'" << var << "\'" - << " with a different value: " - << "original=" << it->second - << ", new=" << info; + CHECK(ExprDeepEqual()(it->second, info)) << "Trying to update var \'" << var << "\'" + << " with a different value: " + << "original=" << it->second << ", new=" << info; } } var_map_[var] = info; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const AddNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -129,14 +128,10 @@ VisitExpr_(const AddNode* op) { PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), - ramp(b1 + b2, s1 + s2, lanes)); - TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), - ramp(b1 + x, s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), - ramp(x + b1, s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), - broadcast(x + y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes)); } if (IsIndexType(op->dtype)) { @@ -167,14 +162,10 @@ VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y); TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y); - TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), - c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), - c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), - c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), - c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), c1.Eval()->value == -c2.Eval()->value); // constant folding // NOTE: canonicalization might better at this. @@ -213,8 +204,7 @@ VisitExpr_(const AddNode* op) { } // condition rules. - TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), - select(x, b1 + s1, b2 + s2)); + TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), select(x, b1 + s1, b2 + s2)); // default value return ret; } @@ -230,8 +220,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c return frecover; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const SubNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -244,14 +233,10 @@ VisitExpr_(const SubNode* op) { PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), - ramp(b1 - b2, s1 - s2, lanes)); - TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), - ramp(b1 - x, s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), - ramp(x - b1, 0 - s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), - broadcast(x - y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0 - s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes)); } if (IsIndexType(op->dtype)) { @@ -293,20 +278,20 @@ VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE((y + x) - (z + x), y - z); TVM_TRY_REWRITE((y + x) - (x + z), y - z); - TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x)); - TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x)); - TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y)); - TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y)); + TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x)); + TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x)); + TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y)); + TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y)); - TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x)); - TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x)); - TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); - TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); + TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x)); + TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x)); + TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); + TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); - TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); - TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); - TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); - TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y)); + TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); + TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); + TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); + TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y)); TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x)); @@ -324,10 +309,8 @@ VisitExpr_(const SubNode* op) { // DivMod rules // trucdiv // NOTE: c*(x/c) + x % c == x is true all division mode. - TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), - c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), - c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y, c1.Eval()->value != 0); TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1), @@ -337,45 +320,40 @@ VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); // Proof in the case of floordiv, need positive condition. // let x = a * c3 + r // (x + c1) / c3 - x / c3 => (r + c1) / c3 // NOTE: the use of floormod(c2, c3) was intentional to simplify the const. - TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3), + TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3), truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3), CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && - c1.Eval()->value >= c2.Eval()->value && - c3.Eval()->value > 0); - TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x, c3), - truncdiv(truncmod(x, c3) + c1, c3), - CanProveGreaterEqual(x.Eval(), 0) && - c1.Eval()->value >= 0 && - c3.Eval()->value > 0); + c1.Eval()->value >= c2.Eval()->value && c3.Eval()->value > 0); + TVM_TRY_REWRITE_IF( + truncdiv(x + c1, c3) - truncdiv(x, c3), truncdiv(truncmod(x, c3) + c1, c3), + CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value >= 0 && c3.Eval()->value > 0); // floordiv - TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), - c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), - c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(x - floordiv(x + y, c1) * c1, floormod(x + y, c1) - y, c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c1 - x, y - floormod(x + y, c1), @@ -385,30 +363,29 @@ VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3), floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3), c3.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), - floordiv(floormod(x, c3) + c1, c3), + TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), floordiv(floormod(x, c3) + c1, c3), c3.Eval()->value > 0); // canonicalization rule @@ -420,17 +397,13 @@ VisitExpr_(const SubNode* op) { } // condition rules. - TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), - select(x, b1 - s1, b2 - s2)); - TVM_TRY_REWRITE(select(x, y, z) - z, - select(x, y - z, ZeroWithTypeLike(z))); - TVM_TRY_REWRITE(select(x, y, z) - y, - select(x, ZeroWithTypeLike(y), z - y)); + TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), select(x, b1 - s1, b2 - s2)); + TVM_TRY_REWRITE(select(x, y, z) - z, select(x, y - z, ZeroWithTypeLike(z))); + TVM_TRY_REWRITE(select(x, y, z) - y, select(x, ZeroWithTypeLike(y), z - y)); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const MulNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -443,12 +416,9 @@ VisitExpr_(const MulNode* op) { PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), - broadcast(x * y, lanes)); - TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), - ramp(b1 * x, s1 * x, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), - ramp(b1 * x, s1 * x, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes)); } if (IsIndexType(op->dtype)) { @@ -461,15 +431,12 @@ VisitExpr_(const MulNode* op) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); - TVM_TRY_RECURSIVE_REWRITE_IF( - (x - y) * c1, (y - x) * (0 - c1), - c1.Eval()->value < 0); + TVM_TRY_RECURSIVE_REWRITE_IF((x - y) * c1, (y - x) * (0 - c1), c1.Eval()->value < 0); } return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const DivNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -490,8 +457,7 @@ VisitExpr_(const DivNode* op) { // Vector rules if (op->dtype.lanes() != 1) { // NOTE: use div as the pattern also works for float. - TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), - broadcast(div(x, y), lanes)); + TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes)); // ramp / bcast if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; @@ -532,10 +498,8 @@ VisitExpr_(const DivNode* op) { c1.Eval()->value > 0 && c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3), - c1.Eval()->value > 0 && - c2.Eval()->value >= 0 && - c3.Eval()->value > 0 && - CanProveGreaterEqual(x.Eval(), 0)); + c1.Eval()->value > 0 && c2.Eval()->value >= 0 && c3.Eval()->value > 0 && + CanProveGreaterEqual(x.Eval(), 0)); if (truncdiv(x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; @@ -551,147 +515,102 @@ VisitExpr_(const DivNode* op) { TVM_TRY_REWRITE(truncdiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), - x * truncdiv(c1, c2) + truncdiv(y, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), - min(x * truncdiv(c1, c2), truncdiv(y, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), - max(x * truncdiv(c1, c2), truncdiv(y, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), - truncdiv(y, c2) + x * truncdiv(c1, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), - min(truncdiv(y, c2), x * truncdiv(c1, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), - max(truncdiv(y, c2), x * truncdiv(c1, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), x * truncdiv(c1, c2) + truncdiv(y, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), min(x * truncdiv(c1, c2), truncdiv(y, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), max(x * truncdiv(c1, c2), truncdiv(y, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), truncdiv(y, c2) + x * truncdiv(c1, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), min(truncdiv(y, c2), x * truncdiv(c1, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), max(truncdiv(y, c2), x * truncdiv(c1, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); // Rules involving 3-operands. - TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y + z, c2), - x * truncdiv(c1, c2) + truncdiv(y + z, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(x * c1 - y + z, c2), - x * truncdiv(c1, c2) + truncdiv(z - y, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((z - y).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y - z, c2), - x * truncdiv(c1, c2) + truncdiv(y - z, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y - z).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(y + x * c1 + z, c2), - x * truncdiv(c1, c2) + truncdiv(y + z, c2), - c1.Eval()->value > 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), - truncdiv(x, c2) + truncdiv(c1, c2), - c1.Eval()->value > 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv(x * c1 + y + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv(x * c1 - y + z, c2), x * truncdiv(c1, c2) + truncdiv(z - y, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((z - y).Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv(x * c1 + y - z, c2), x * truncdiv(c1, c2) + truncdiv(y - z, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y - z).Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv(y + x * c1 + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), truncdiv(x, c2) + truncdiv(c1, c2), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(x + y, x), truncdiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y + x, x), truncdiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv((x + y) + z, x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(truncdiv((y + x) + z, x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(truncdiv(y + (z + x), x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(truncdiv(y + (x + z), x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv((x + y) + z, x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv((y + x) + z, x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv(y + (z + x), x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv(y + (x + z), x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(x * y, y), x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y * x, y), x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(x * z + y, z), x + truncdiv(y, z), - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(z * x + y, z), x + truncdiv(y, z), - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y + x * z, z), truncdiv(y, z) + x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y + z * x, z), truncdiv(y, z) + x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); } return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const ModNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -725,8 +644,7 @@ VisitExpr_(const ModNode* op) { if (ramp_min == ramp_max) { return ramp(truncmod(bmod->base, c2), c1, lanes).Eval(); } else { - return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), - broadcast(c2, lanes)).Eval(); + return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } } @@ -738,41 +656,34 @@ VisitExpr_(const ModNode* op) { // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. TVM_TRY_REWRITE_IF(truncmod(x * c1, c2), ZeroWithTypeLike(x), - c2.Eval()->value != 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(truncmod(x * c1 + y, c2), truncmod(y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual((x * c1).Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual((x * c1).Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncmod(x + c1, c2), truncmod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value >= 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value >= 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(truncmod(x + y * c1, c2), truncmod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y * c1).Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y * c1).Eval(), 0)); // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required TVM_TRY_RECURSIVE_REWRITE_IF( - truncmod(x, c1), - truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), + truncmod(x, c1), truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), c1.Eval()->value < 0); // try modular analysis if (truncmod(x, c1).Match(ret)) { ModularSet mod = analyzer_->modular_set(x.Eval()); int64_t c1val = c1.Eval()->value; - if (mod->coeff % c1val == 0 && - c1val > 0 && - CanProveGreaterEqual(x.Eval(), 0)) { + if (mod->coeff % c1val == 0 && c1val > 0 && CanProveGreaterEqual(x.Eval(), 0)) { return truncmod(mod->base, c1).Eval(); } } @@ -780,8 +691,7 @@ VisitExpr_(const ModNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const FloorDivNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -836,67 +746,43 @@ VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE(floordiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), - x * floordiv(c1, c2) + floordiv(y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), - min(x * floordiv(c1, c2), floordiv(y, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), - max(x * floordiv(c1, c2), floordiv(y, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), - floordiv(y, c2) + x * floordiv(c1, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), - min(floordiv(y, c2), x * floordiv(c1, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), - max(floordiv(y, c2), x * floordiv(c1, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); // Rules involving 3-operands. - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), - x * floordiv(c1, c2) + floordiv(y + z, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), - x * floordiv(c1, c2) + floordiv(z - y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2), - x * floordiv(c1, c2) + floordiv(y - z, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), - x * floordiv(c1, c2) + floordiv(y + z, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), - floordiv(x, c2) + floordiv(c1, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), x * floordiv(c1, c2) + floordiv(z - y, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2), x * floordiv(c1, c2) + floordiv(y - z, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); + + TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv((x + y) + z, x), floordiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -907,10 +793,8 @@ VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(y + (x + z), x), floordiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); - TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, - CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, - CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(x * z + y, z), x + floordiv(y, z), CanProveGreaterEqual(z.Eval(), 0)); @@ -924,8 +808,7 @@ VisitExpr_(const FloorDivNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const FloorModNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -967,20 +850,16 @@ VisitExpr_(const FloorModNode* op) { if (IsIndexType(op->dtype)) { // Be-aware of the division rules: we use floordiv/floormod here TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x), - c2.Eval()->value != 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); // try modular analysis if (floormod(x, c1).Match(ret)) { @@ -994,8 +873,7 @@ VisitExpr_(const FloorModNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const MinNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1009,8 +887,7 @@ VisitExpr_(const MinNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), - broadcast(min(x, y), lanes)); + TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes)); TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)), min(x, broadcast(min(y, z), lanes))); } @@ -1035,8 +912,7 @@ VisitExpr_(const MinNode* op) { return (x + c2).Eval(); } } - if (min(x + c1, x).Match(ret) || - min(x, x + c1).Match(ret)) { + if (min(x + c1, x).Match(ret) || min(x, x + c1).Match(ret)) { if (c1.Eval()->value < 0) { return (x + c1).Eval(); } else { @@ -1055,40 +931,30 @@ VisitExpr_(const MinNode* op) { // Divide up rounding: truc div // NOTE: trucdiv(x, y) >= floordiv(x, y) TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, x), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value && - CanProveGreaterEqual(x.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value && - CanProveGreaterEqual(x.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); // Divide up rounding: floor div TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, x), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(x, floordiv(x + c1, c2) * c2), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(max(x, c2), floordiv(x + c1, c2) * c2), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, - c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, c2.Eval()->value > 0); TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y)); TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y)); @@ -1168,19 +1034,15 @@ VisitExpr_(const MinNode* op) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE_IF( - min(c1 - x, c2), c1 - max(x, c1 - c2), - c2.Eval()->value != 0); + TVM_TRY_RECURSIVE_REWRITE_IF(min(c1 - x, c2), c1 - max(x, c1 - c2), c2.Eval()->value != 0); } // condition rules. - TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), - select(x, min(y, s1), min(z, s2))); + TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), select(x, min(y, s1), min(z, s2))); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const MaxNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1194,8 +1056,7 @@ VisitExpr_(const MaxNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), - broadcast(max(x, y), lanes)); + TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes)); TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)), max(x, broadcast(max(y, z), lanes))); } @@ -1220,8 +1081,7 @@ VisitExpr_(const MaxNode* op) { return (x + c2).Eval(); } } - if (max(x + c1, x).Match(ret) || - max(x, x + c1).Match(ret)) { + if (max(x + c1, x).Match(ret) || max(x, x + c1).Match(ret)) { if (c1.Eval()->value > 0) { return (x + c1).Eval(); } else { @@ -1239,27 +1099,19 @@ VisitExpr_(const MaxNode* op) { // DivMod rules // Divide up rounding: truc div // NOTE: trucdiv(x, y) >= floordiv(x, y) - TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), - truncdiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), - truncdiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), truncdiv(x + c1, c2) * c2, + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), truncdiv(x + c1, c2) * c2, + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); // Divide up rounding: floor div TVM_TRY_REWRITE_IF(max(floordiv(x + c1, c2) * c2, x), floordiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(max(x, floordiv(x + c1, c2) * c2), floordiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x, - c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x, - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x, c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x, c2.Eval()->value > 0); TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y)); TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y)); @@ -1342,18 +1194,15 @@ VisitExpr_(const MaxNode* op) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE_IF( - max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); + TVM_TRY_RECURSIVE_REWRITE_IF(max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); } // condition rules. - TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), - select(x, max(y, s1), max(z, s2))); + TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), select(x, max(y, s1), max(z, s2))); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const EQNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1367,8 +1216,7 @@ VisitExpr_(const EQNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), - broadcast(x == y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes)); } if (IsIndexType(op->a.dtype())) { @@ -1386,28 +1234,23 @@ VisitExpr_(const EQNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const NENode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { return this->VisitExpr(NotNode::make(op->a == op->b)); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const LENode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { return this->VisitExpr(NotNode::make(op->b < op->a)); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const GTNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const GENode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) { return this->VisitExpr(NotNode::make(op->a < op->b)); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const LTNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1421,10 +1264,8 @@ VisitExpr_(const LTNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), - broadcast(x < y, lanes)); - TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), - broadcast(x < y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes)); + TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes)); } if (IsIndexType(op->a.dtype())) { @@ -1436,6 +1277,7 @@ VisitExpr_(const LTNode* op) { return make_const(op->dtype, false); } + // clang-format off TVM_TRY_REWRITE(x + y < x + z, y < z); TVM_TRY_REWRITE(x + y < z + x, y < z); TVM_TRY_REWRITE(y + x < x + z, y < z); @@ -1449,100 +1291,76 @@ VisitExpr_(const LTNode* op) { TVM_TRY_REWRITE(c1 < x + c2, c1 - c2 < x); TVM_TRY_REWRITE(c1 < c2 - x, x < c2 - c1); - TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, - c1.Eval()->value < 0); + TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, c1.Eval()->value < 0); // constant cancelation: only need to make use of one mod // truc div - TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1 - 1, c2) + 1, - c1.Eval()->value > 0 && - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(x * c2 < c1, + x < truncdiv(c1 - 1, c2) + 1, c1.Eval()->value > 0 && c2.Eval()->value > 0); // NOTE: trunc div required TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1, c2), - c1.Eval()->value <= 0 && - c2.Eval()->value > 0); + c1.Eval()->value <= 0 && c2.Eval()->value > 0); // NOTE: trunc div required (euclidean is ok too, floored is not) - TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, - c1.Eval()->value > 0 && + TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, c1.Eval()->value > 0 && c2.Eval()->value < 0); // NOTE: trunc div required (floored is ok too, euclidean is not) TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1, c2) < x, - c1.Eval()->value <= 0 && - c2.Eval()->value < 0); + c1.Eval()->value <= 0 && c2.Eval()->value < 0); // NOTE: trunc div required TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1 + 1, c2) - 1 < x, - c1.Eval()->value < 0 && - c2.Eval()->value > 0); + c1.Eval()->value < 0 && c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1, c2) < x, - c1.Eval()->value >= 0 && - c2.Eval()->value > 0); + c1.Eval()->value >= 0 && c2.Eval()->value > 0); // NOTE: trunc div required (floored is ok too, euclidean is not) TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1 + 1, c2) + 1, - c1.Eval()->value < 0 && - c2.Eval()->value < 0); + c1.Eval()->value < 0 && c2.Eval()->value < 0); // NOTE: trunc div required (euclidean is ok too, floored is not) TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value < 0); + c1.Eval()->value >= 0 && c2.Eval()->value < 0); // DivMod rules // trucdiv - TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * c2, - c1.Eval()->value > 0 && - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, + xvalue> 0 && c2.Eval()->value > 0); // NOTE: trunc div required - TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * (c2 - 1) + 1, - c1.Eval()->value > 0 && - c2.Eval()->value <= 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, + xvalue> 0 && c2.Eval()->value <= 0); TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), (c1 + 1) * c2 - 1 < x, - c1.Eval()->value >= 0 && - c2.Eval()->value > 0); + c1.Eval()->value >= 0 && c2.Eval()->value > 0); // NOTE: trunc div required TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), c1 * c2 < x, - c1.Eval()->value < 0 && - c2.Eval()->value > 0); + c1.Eval()->value < 0 && c2.Eval()->value > 0); // invariance for any div mod: x - (x / c1) * c1 == x % c1 - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, 0 < truncmod(x, c1) + y, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, y < truncmod(x, c1), - c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, + 0 < truncmod(x, c1) + y, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, + y < truncmod(x, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x, - c2 < truncmod(x + c2, c1), - c1.Eval()->value > 0); + c2 < truncmod(x + c2, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x + y, - c2 < truncmod(x + c2, c1) + y, - c1.Eval()->value > 0); + c2 < truncmod(x + c2, c1) + y, c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x - y, - y < truncmod(x + c2, c1) + (0 - c2), - c1.Eval()->value > 0); + y < truncmod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0); // floordiv - TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x, - c2.Eval()->value > 0); - - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y, 0 < floormod(x, c1) + y, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y, y < floormod(x, c1), - c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x, c2.Eval()->value > 0); + + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y, + 0 < floormod(x, c1) + y, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y, + y < floormod(x, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x, - c2 < floormod(x + c2, c1), - c1.Eval()->value > 0); + c2 < floormod(x + c2, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x + y, - c2 < floormod(x + c2, c1) + y, - c1.Eval()->value > 0); + c2 < floormod(x + c2, c1) + y, c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x - y, - y < floormod(x + c2, c1) + (0 - c2), - c1.Eval()->value > 0); + y < floormod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0); // canonicalization rule TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z); @@ -1558,12 +1376,12 @@ VisitExpr_(const LTNode* op) { TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1); TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1); TVM_TRY_REWRITE(x - c1 < 0, x < c1); + // clang-format on } return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const NotNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a); @@ -1587,8 +1405,7 @@ VisitExpr_(const NotNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const AndNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1601,8 +1418,7 @@ VisitExpr_(const AndNode* op) { PVar lanes; if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), - broadcast(x && y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); } auto cfalse = PConst(make_const(op->dtype, false)); @@ -1612,32 +1428,23 @@ VisitExpr_(const AndNode* op) { TVM_TRY_REWRITE(x <= y && y < x, cfalse); TVM_TRY_REWRITE(y < x && x <= y, cfalse); - TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, - c2.Eval()->value + 1 >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, - c2.Eval()->value + 1 >= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, - c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, - c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, - c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, - c2.Eval()->value >= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, - c2.Eval()->value > c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, - c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value); TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2); TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const OrNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1650,8 +1457,7 @@ VisitExpr_(const OrNode* op) { PVar lanes; if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), - broadcast(x || y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); } auto ctrue = PConst(make_const(op->dtype, true)); @@ -1662,32 +1468,23 @@ VisitExpr_(const OrNode* op) { TVM_TRY_REWRITE(x <= y || y < x, ctrue); TVM_TRY_REWRITE(y < x || x <= y, ctrue); - TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, - c2.Eval()->value < c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, - c2.Eval()->value < c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, - c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, - c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, - c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, - c2.Eval()->value <= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, - c2.Eval()->value <= c1.Eval()->value + 1); - TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, - c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const SelectNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SelectNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; @@ -1697,8 +1494,7 @@ VisitExpr_(const SelectNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const CallNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // add condition context to if_then_else PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); @@ -1728,8 +1524,7 @@ VisitExpr_(const CallNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const VarNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { Var var = GetRef(op); auto it = var_map_.find(var); if (it != var_map_.end()) { @@ -1738,15 +1533,13 @@ VisitExpr_(const VarNode* op) { return GetRef(op); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const CastNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); return cast(op->dtype, op->value); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const LetNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); if (!tir::HasSideEffect(value)) { // it is fine to discard the let binding @@ -1755,8 +1548,7 @@ VisitExpr_(const LetNode* op) { return this->VisitExpr(op->body); } PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); @@ -1775,9 +1567,7 @@ PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) { return res; } -void RewriteSimplifier::Update(const Var& var, - const PrimExpr& info, - bool override) { +void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool override) { impl_->Update(var, info, override); } @@ -1785,13 +1575,9 @@ std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constra return impl_->EnterConstraint(constraint); } -RewriteSimplifier::RewriteSimplifier(Analyzer* parent) - : impl_(new Impl(parent)) { -} +RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} -RewriteSimplifier::~RewriteSimplifier() { - delete impl_; -} +RewriteSimplifier::~RewriteSimplifier() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 8798df9..fd248b9 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -26,11 +26,13 @@ #include #include + #include #include + #include "const_fold.h" -#include "pattern_match.h" #include "ir_mutator_with_analyzer.h" +#include "pattern_match.h" namespace tvm { namespace arith { @@ -46,8 +48,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { public: using IRMutatorWithAnalyzer::VisitExpr_; - explicit Impl(Analyzer* parent) - : IRMutatorWithAnalyzer(parent) {} + explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {} void Update(const Var& var, const PrimExpr& info, bool override_info); PrimExpr VisitExpr_(const AddNode* op) override; @@ -78,15 +79,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { protected: /*! \brief internal structure for comparison. */ - enum CompareResult { - kUnknown, - kEQ, - kGT, - kGE, - kLT, - kLE, - kNE - }; + enum CompareResult { kUnknown, kEQ, kGT, kGE, kLT, kLE, kNE }; // counter to record recursive rewrite depth. int recur_depth_{0}; // internal variable map @@ -127,18 +120,17 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { return res; } - template + template PConstWithTypeLike ZeroWithTypeLike(const Pattern& pattern) { return PConstWithTypeLike(pattern.derived(), 0); } - template + template PConstWithTypeLike OneWithTypeLike(const Pattern& pattern) { return PConstWithTypeLike(pattern.derived(), 1); } }; - } // namespace arith } // namespace tvm #endif // TVM_ARITH_REWRITE_SIMPLIFY_H_ diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index a89cebe..50a3243 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -21,26 +21,23 @@ * \file tvm/arith/solve_linear_equation.cc * \brief Solve linear equations. */ -#include -#include #include #include -#include #include - +#include +#include +#include +#include #include #include -#include namespace tvm { namespace arith { using namespace tvm::runtime; -void SmithNormalFormDiag(std::vector >* S, - std::vector >* V, - std::vector* x, - std::vector* y) { +void SmithNormalFormDiag(std::vector>* S, std::vector>* V, + std::vector* x, std::vector* y) { if (S->empty() || V->empty()) return; size_t m = S->size(); size_t n = (*S)[0].size(); // n is # of variables @@ -124,9 +121,9 @@ void SmithNormalFormDiag(std::vector >* S, for (size_t j = index; j < (*S)[i].size(); ++j) { // Multiply index-th row by a and add the i-th row multiplied by b // This will make the index-th diagonal element equal to the gcd - int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j]; + int64_t new_index_j = a * (*S)[index][j] + b * (*S)[i][j]; // This transformation performs zeroing of matrix[i][index] - int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j]; + int64_t new_i_j = n_g * (*S)[index][j] - m_g * (*S)[i][j]; (*S)[index][j] = new_index_j; (*S)[i][j] = new_i_j; } @@ -135,8 +132,8 @@ void SmithNormalFormDiag(std::vector >* S, PrimExpr eb = tir::make_const((*y)[i].dtype(), b); PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g); PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g); - PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i]; - PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i]; + PrimExpr new_index_rhs = ea * (*y)[index] + eb * (*y)[i]; + PrimExpr new_i_rhs = e_n_g * (*y)[index] - e_m_g * (*y)[i]; (*y)[index] = new_index_rhs; (*y)[i] = new_i_rhs; } @@ -178,15 +175,15 @@ void SmithNormalFormDiag(std::vector >* S, int64_t n_g = (*S)[index][j] / g; for (size_t i = index; i < m; ++i) { - int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j]; - int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j]; + int64_t new_i_index = a * (*S)[i][index] + b * (*S)[i][j]; + int64_t new_i_j = n_g * (*S)[i][index] - m_g * (*S)[i][j]; (*S)[i][index] = new_i_index; (*S)[i][j] = new_i_j; } // We do exactly the same transformations with V for (size_t i = 0; i < n; ++i) { - int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j]; - int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j]; + int64_t new_i_index = a * (*V)[i][index] + b * (*V)[i][j]; + int64_t new_i_j = n_g * (*V)[i][index] - m_g * (*V)[i][j]; (*V)[i][index] = new_i_index; (*V)[i][j] = new_i_j; } @@ -195,8 +192,8 @@ void SmithNormalFormDiag(std::vector >* S, PrimExpr eb = tir::make_const((*x)[index].dtype(), b); PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g); PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g); - PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j]; - PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j]; + PrimExpr new_index = e_m_g * (*x)[index] + e_n_g * (*x)[j]; + PrimExpr new_j = eb * (*x)[index] - ea * (*x)[j]; (*x)[index] = new_index; (*x)[j] = new_j; } @@ -210,8 +207,7 @@ void SmithNormalFormDiag(std::vector >* S, } } -Map InferRange(const Map& vars_to_infer, - const Array& ori_vars, +Map InferRange(const Map& vars_to_infer, const Array& ori_vars, const Map& ori_ranges) { // The resulting ranges Map new_ranges; @@ -245,8 +241,7 @@ Map InferRange(const Map& vars_to_infer, // pretty print matrix equation void DebugPrint(const std::vector>& S, - const std::vector>& V, - const std::vector& V_inv_x, + const std::vector>& V, const std::vector& V_inv_x, const std::vector& rhs) { std::cout << "S:\n"; for (size_t i = 0; i < S.size(); ++i) { @@ -267,7 +262,7 @@ void DebugPrint(const std::vector>& S, std::cout << "\n" << std::endl; } -IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) { +IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve) { // m: # of equations // n: # of variables // we first construct A_{mxn} x_{nx1} = y_{mx1} @@ -275,10 +270,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // S_{mxn} = U_{mxm} A_{mxn} V_{nxn} // => U^{-1} S V^{-1} x = y // S V^{-1} x = U y - std::vector Uy; // mx1 + std::vector Uy; // mx1 std::vector> S; // mxn std::vector> V; // nxn - std::vector V_inv_x; // V^{-1} x, nx1 + std::vector V_inv_x; // V^{-1} x, nx1 // Conditions we don't know what to do with std::vector rest; @@ -301,9 +296,8 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol for (const PrimExpr& equation : system_to_solve->relations) { if (const tir::EQNode* eq = equation.as()) { // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] - Array coeffs = arith::DetectLinearEquation( - analyzer_problem.Simplify(eq->a - eq->b), - system_to_solve->variables); + Array coeffs = arith::DetectLinearEquation(analyzer_problem.Simplify(eq->a - eq->b), + system_to_solve->variables); if (!coeffs.empty()) { std::vector row; for (size_t j = 0; j < coeffs.size() - 1; ++j) { @@ -365,13 +359,12 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol new_relation = analyzer_problem.Simplify(new_relation); if (tir::is_const_int(new_relation, 0)) { // unable to solve the system. - return IntConstraintsTransform( - system_to_solve, - IntConstraints( - /*variables=*/{}, - /*ranges=*/{}, - /*relations=*/{tir::make_zero(DataType::Bool())}), - {}, {}); + return IntConstraintsTransform(system_to_solve, + IntConstraints( + /*variables=*/{}, + /*ranges=*/{}, + /*relations=*/{tir::make_zero(DataType::Bool())}), + {}, {}); } else if (!tir::is_const_int(new_relation, 1)) { new_relations.push_back(new_relation); } @@ -405,14 +398,12 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // S^{-1}_{nxm} Uy_{mxn} if (S[j][j] >= 0) { PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]); - solution_for_V_inv_x.push_back( - analyzer_problem.Simplify(floordiv(Uy[j], a))); + solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]); - solution_for_V_inv_x.push_back( - analyzer_problem.Simplify(floordiv(-Uy[j], a))); + solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(-Uy[j], a))); } } } @@ -421,15 +412,15 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol for (size_t i = 0; i < num_vars; ++i) { PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype()); for (size_t j = 0; j < num_vars; ++j) { - e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; + e = e + tir::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; } e = analyzer_problem.Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); } // The resulting ranges - Map new_ranges = InferRange( - new_to_old_map, system_to_solve->variables, system_to_solve->ranges); + Map new_ranges = + InferRange(new_to_old_map, system_to_solve->variables, system_to_solve->ranges); Analyzer analyzer_solution; analyzer_solution.Bind(new_ranges); @@ -440,10 +431,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol const Range& old_range = p.second; if (old_to_new_map.count(old_var)) { PrimExpr express_by_new_vars = old_to_new_map[old_var]; - PrimExpr lower_cond = analyzer_solution.Simplify( - old_range->min <= express_by_new_vars); - PrimExpr upper_cond = analyzer_solution.Simplify( - express_by_new_vars < old_range->min + old_range->extent); + PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); + PrimExpr upper_cond = + analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); if (!tir::is_const_int(lower_cond, 1)) { new_relations.push_back(lower_cond); } @@ -459,23 +449,21 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol } IntConstraints solution(new_vars, new_ranges, new_relations); - IntConstraintsTransform transform( - system_to_solve, solution, old_to_new_map, new_to_old_map); + IntConstraintsTransform transform(system_to_solve, solution, old_to_new_map, new_to_old_map); return transform; } -TVM_REGISTER_GLOBAL("arith.SolveLinearEquations") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() == 1) { - *ret = SolveLinearEquations(args[0]); - } else if (args.size() == 3) { - IntConstraints problem(args[0], args[1], args[2]); - *ret = SolveLinearEquations(problem); - } else { - LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); - } - }); +TVM_REGISTER_GLOBAL("arith.SolveLinearEquations").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = SolveLinearEquations(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = SolveLinearEquations(problem); + } else { + LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); + } +}); } // namespace arith } // namespace tvm diff --git a/src/arith/util.cc b/src/arith/util.cc index 058c3e9..7b71892 100644 --- a/src/arith/util.cc +++ b/src/arith/util.cc @@ -21,8 +21,8 @@ * \file util.cc * \brief The utils for arithmetic analysis. */ -#include #include +#include namespace tvm { namespace arith { @@ -44,7 +44,7 @@ std::tuple xgcd(int64_t a, int64_t b) { CHECK_EQ(a % old_r, 0); CHECK_EQ(b % old_r, 0); - CHECK(old_r == old_s*a + old_t*b); + CHECK(old_r == old_s * a + old_t * b); return std::make_tuple(old_r, old_s, old_t); } diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index da044ba..54fc252 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -30,10 +30,9 @@ namespace autotvm { // for loop void FeatureVisitor::VisitStmt_(const ForNode* op) { - const auto *extent = op->extent.as(); + const auto* extent = op->extent.as(); int64_t loop_extent = -1; - if (extent != nullptr) - loop_extent = extent->value; + if (extent != nullptr) loop_extent = extent->value; AnnotationType ann = kSerial; switch (op->for_type) { case ForType ::Parallel: @@ -58,10 +57,9 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { // parallel axis, virtual thread void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { Var var = op->node.as()->var; - const auto *extent = op->value.as(); + const auto* extent = op->value.as(); CHECK(extent); std::string name = var.get()->name_hint; diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index 5391bdd..8180839 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -29,6 +29,7 @@ #include #include #include + #include namespace tvm { @@ -40,8 +41,17 @@ using namespace tvm::tir; * \brief Type of for loop, used as one-hot encoding in features */ enum AnnotationType { - kBlockX, kBlockY, kBlockZ, kThreadX, kThreadY, kThreadZ, - kUnrolled, kVectorized, kParallel, kSerial, kVirtualThread, + kBlockX, + kBlockY, + kBlockZ, + kThreadX, + kThreadY, + kThreadZ, + kUnrolled, + kVectorized, + kParallel, + kSerial, + kVirtualThread, kNum, }; @@ -59,17 +69,17 @@ class FeatureVisitor : public StmtExprVisitor { void VisitExpr_(const LoadNode* op) final; void VisitStmt_(const StoreNode* op) final; - using StmtExprVisitor::VisitStmt_; using StmtExprVisitor::VisitExpr_; + using StmtExprVisitor::VisitStmt_; protected: /*! - * \brief Enter a for loop node - * \param var The expression to be printed. - * \param length The output stream - * \param ann_type The type for the for loop - * \return skip Whether skip this node - */ + * \brief Enter a for loop node + * \param var The expression to be printed. + * \param length The output stream + * \param ann_type The type for the for loop + * \return skip Whether skip this node + */ virtual bool EnterItervar_(tir::Var var, int64_t length, AnnotationType ann_type) = 0; /*! \brief Exit a for loop subtree */ virtual void ExitItervar_() = 0; diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index fbd0829..02dae64 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -24,9 +24,9 @@ #include "touch_extractor.h" -#include #include #include +#include #include namespace tvm { @@ -34,9 +34,14 @@ namespace autotvm { int ParallelLevel(AnnotationType ann) { switch (ann) { - case kBlockX: case kBlockY: case kBlockZ: + case kBlockX: + case kBlockY: + case kBlockZ: return 2; - case kThreadX: case kThreadY: case kThreadZ: case kParallel: + case kThreadX: + case kThreadY: + case kThreadZ: + case kParallel: return 1; default: return 0; @@ -44,7 +49,7 @@ int ParallelLevel(AnnotationType ann) { } // get touch pattern from index expression -class IndexParser: public ExprVisitor { +class IndexParser : public ExprVisitor { public: void Parse(PrimExpr expr) { pattern_map.clear(); @@ -95,11 +100,9 @@ bool TouchExtractor::EnterItervar_(Var var, int64_t length, AnnotationType ann_t itervar_map.erase(var); } - itervar_map.insert({var, ItervarFeature(var, length, - static_cast(itervar_stack_.size()), - ann_type, - topdown_product_, - static_cast(itervar_counter_++))}); + itervar_map.insert( + {var, ItervarFeature(var, length, static_cast(itervar_stack_.size()), ann_type, + topdown_product_, static_cast(itervar_counter_++))}); } return true; @@ -120,7 +123,7 @@ void TouchExtractor::ExitItervar_() { CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end()); touch_pattern->second.count *= itervar_map[var].length; } - } else { // multiply reuse ratio + } else { // multiply reuse ratio for (auto stack_var : itervar_stack_) { auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first); CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end()); @@ -131,8 +134,7 @@ void TouchExtractor::ExitItervar_() { itervar_stack_.pop_back(); int64_t length = itervar_map[var].length; - if (length != 0) - topdown_product_ /= length; + if (length != 0) topdown_product_ /= length; int64_t bottomup_product = -1; for (auto kv : itervar_map[var].touch_feature) { bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse); @@ -188,8 +190,7 @@ void TouchExtractor::EnterMem_(Var buffer_var, PrimExpr index) { } } -void TouchExtractor::ExitMem_() { -} +void TouchExtractor::ExitMem_() {} /*! * \brief Get axis-based feature for all axes @@ -219,7 +220,7 @@ void TouchExtractor::ExitMem_() { * \note If you want to flatten these features as the input of your model, * You can use the faster one GetItervarFeatureFlatten below. */ -void GetItervarFeature(Stmt stmt, bool take_log, Array > > *ret_feature) { +void GetItervarFeature(Stmt stmt, bool take_log, Array > >* ret_feature) { // extract TouchExtractor touch_analyzer; touch_analyzer.Analyze(stmt); @@ -229,7 +230,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > for (auto kv : touch_analyzer.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; }); @@ -237,28 +238,26 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > std::function trans; if (take_log) { trans = [](int64_t x) { - if (x < 0) - return -std::log(-x+1) / std::log(2); + if (x < 0) return -std::log(-x + 1) / std::log(2); x = x + 1; return std::log(x) / std::log(2); }; } else { - trans = [](int64_t x) { - return x; - }; + trans = [](int64_t x) { return x; }; } // serialize for front end for (auto var : vars) { Array > feature_row; - ItervarFeature &fea = touch_analyzer.itervar_map[var]; + ItervarFeature& fea = touch_analyzer.itervar_map[var]; feature_row.push_back(Array{tvm::tir::StringImmNode::make("_itervar_"), var}); - Array attr{tvm::tir::StringImmNode::make("_attr_"), - FloatImm(DataType::Float(32), trans(fea.length)), - IntImm(DataType::Int(32), fea.nest_level), - FloatImm(DataType::Float(32), trans(fea.topdown_product)), - FloatImm(DataType::Float(32), trans(fea.bottomup_product)), + Array attr{ + tvm::tir::StringImmNode::make("_attr_"), + FloatImm(DataType::Float(32), trans(fea.length)), + IntImm(DataType::Int(32), fea.nest_level), + FloatImm(DataType::Float(32), trans(fea.topdown_product)), + FloatImm(DataType::Float(32), trans(fea.bottomup_product)), }; // one hot annotation for (int i = 0; i < kNum; i++) { @@ -267,10 +266,11 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > feature_row.push_back(attr); // arithmetic - feature_row.push_back(Array{tvm::tir::StringImmNode::make("_arith_"), - FloatImm(DataType::Float(32), trans(fea.add_ct)), - FloatImm(DataType::Float(32), trans(fea.mul_ct)), - FloatImm(DataType::Float(32), trans(fea.div_ct)), + feature_row.push_back(Array{ + tvm::tir::StringImmNode::make("_arith_"), + FloatImm(DataType::Float(32), trans(fea.add_ct)), + FloatImm(DataType::Float(32), trans(fea.mul_ct)), + FloatImm(DataType::Float(32), trans(fea.div_ct)), }); // touch map @@ -280,16 +280,16 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > } std::sort(bufs.begin(), bufs.end()); for (auto k : bufs) { - TouchPattern &v = fea.touch_feature[k]; - feature_row.push_back( - Array{tvm::tir::StringImmNode::make(k), - FloatImm(DataType::Float(32), trans(v.stride)), - FloatImm(DataType::Float(32), trans(v.mod)), - FloatImm(DataType::Float(32), trans(v.count)), - FloatImm(DataType::Float(32), trans(v.reuse)), - FloatImm(DataType::Float(32), trans(v.thread_count)), - FloatImm(DataType::Float(32), trans(v.thread_reuse)), - }); + TouchPattern& v = fea.touch_feature[k]; + feature_row.push_back(Array{ + tvm::tir::StringImmNode::make(k), + FloatImm(DataType::Float(32), trans(v.stride)), + FloatImm(DataType::Float(32), trans(v.mod)), + FloatImm(DataType::Float(32), trans(v.count)), + FloatImm(DataType::Float(32), trans(v.reuse)), + FloatImm(DataType::Float(32), trans(v.thread_count)), + FloatImm(DataType::Float(32), trans(v.thread_reuse)), + }); } ret_feature->push_back(feature_row); @@ -305,7 +305,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > * \note See GetItervarFeature for more details about the return value. * This is an optimized version of GetItervarFeature + Flatten. This runs much faster. */ -void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_feature) { +void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector* ret_feature) { // extract touch feature TouchExtractor touch_analyzer; touch_analyzer.Analyze(stmt); @@ -315,7 +315,7 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ for (auto kv : touch_analyzer.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; }); @@ -323,20 +323,17 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ std::function trans; if (take_log) { trans = [](int64_t x) { - if (x < 0) - return -std::log(-x+1) / std::log(2); + if (x < 0) return -std::log(-x + 1) / std::log(2); x = x + 1; return std::log(x) / std::log(2); }; } else { - trans = [](int64_t x) { - return x; - }; + trans = [](int64_t x) { return x; }; } // serialize for front end for (auto var : vars) { - ItervarFeature &fea = touch_analyzer.itervar_map[var]; + ItervarFeature& fea = touch_analyzer.itervar_map[var]; ret_feature->push_back(trans(fea.length)); ret_feature->push_back(fea.nest_level); @@ -360,7 +357,7 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ } std::sort(bufs.begin(), bufs.end()); for (auto k : bufs) { - TouchPattern &v = fea.touch_feature[k]; + TouchPattern& v = fea.touch_feature[k]; ret_feature->push_back(trans(v.stride)); ret_feature->push_back(trans(v.mod)); ret_feature->push_back(trans(v.count)); @@ -372,12 +369,12 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ } /*! - * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional vector. - * \param stmt The statement to be extracted - * \param sample_n The number of points used for sampling a curve (along one dimension) - * \param ret_feature The buffer where the return value is stored + * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional + * vector. \param stmt The statement to be extracted \param sample_n The number of points used for + * sampling a curve (along one dimension) \param ret_feature The buffer where the return value is + * stored */ -void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *ret_feature) { +void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector* ret_feature) { // extract touch feature TouchExtractor touch_ext; touch_ext.Analyze(stmt); @@ -387,7 +384,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r for (auto kv : touch_ext.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order; }); @@ -401,14 +398,14 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // find maximum depth of loop nest for (auto var : vars) { - ItervarFeature &fea = touch_ext.itervar_map[var]; + ItervarFeature& fea = touch_ext.itervar_map[var]; max_depth = std::max(max_depth, fea.nest_level); } // mark inner most buffer for (auto iter = vars.rbegin(); iter != vars.rend(); iter++) { auto var = *iter; - ItervarFeature &fea = touch_ext.itervar_map[var]; + ItervarFeature& fea = touch_ext.itervar_map[var]; if (fea.nest_level == max_depth) { for (auto kv : fea.touch_feature) { // delete buffer no (e.g. 'A_0' -> 'A', 'A_1' -> 'A') @@ -416,8 +413,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // delete memory scope (e.g. 'A.local' -> 'A', 'A.shared' -> 'A') size_t pos = raw_name.find("."); - if (pos < kv.first.size()) - raw_name = raw_name.substr(0, pos); + if (pos < kv.first.size()) raw_name = raw_name.substr(0, pos); // If there are multiple innermost buffers that are derived from a same raw buffer // We only record the last occurrence (note the `iter` is in reverse order) @@ -441,7 +437,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // extract curves for (auto var : vars) { - ItervarFeature &fea = touch_ext.itervar_map[var]; + ItervarFeature& fea = touch_ext.itervar_map[var]; for (auto kv : fea.touch_feature) { if (innermost_buffers.find(kv.first) != innermost_buffers.end()) { reuse_curve[kv.first].emplace_back(std::log(kv.second.reuse) / std::log(2)); @@ -453,7 +449,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r } // sample relation in the curve - auto sample_curve = [&](const std::vector &x, const std::vector &y, + auto sample_curve = [&](const std::vector& x, const std::vector& y, double weight) { for (int i = 0; i < sample_n; i++) { double xx = i * weight; @@ -469,9 +465,9 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // serialize to frontend for (auto k : innermost_buffers) { - std::vector &count = count_curve[k]; - std::vector &reuse = reuse_curve[k]; - std::vector &top_down = topdown_curve[k]; + std::vector& count = count_curve[k]; + std::vector& reuse = reuse_curve[k]; + std::vector& top_down = topdown_curve[k]; std::sort(count.begin(), count.end()); std::sort(reuse.begin(), reuse.end()); @@ -484,49 +480,45 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r } } - // register API for front end TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeature") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Stmt stmt = args[0]; - bool take_log = args[1]; - Array > > ret_feature; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Stmt stmt = args[0]; + bool take_log = args[1]; + Array > > ret_feature; - GetItervarFeature(stmt, take_log, &ret_feature); - - *ret = ret_feature; -}); + GetItervarFeature(stmt, take_log, &ret_feature); + *ret = ret_feature; + }); TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeatureFlatten") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Stmt stmt = args[0]; - bool take_log = args[1]; - std::vector ret_feature; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Stmt stmt = args[0]; + bool take_log = args[1]; + std::vector ret_feature; - GetItervarFeatureFlatten(stmt, take_log, &ret_feature); - - TVMByteArray arr; - arr.size = sizeof(float) * ret_feature.size(); - arr.data = reinterpret_cast(ret_feature.data()); - *ret = arr; -}); + GetItervarFeatureFlatten(stmt, take_log, &ret_feature); + TVMByteArray arr; + arr.size = sizeof(float) * ret_feature.size(); + arr.data = reinterpret_cast(ret_feature.data()); + *ret = arr; + }); TVM_REGISTER_GLOBAL("autotvm.feature.GetCurveSampleFeatureFlatten") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Stmt stmt = args[0]; - int sample_n = args[1]; - std::vector ret_feature; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Stmt stmt = args[0]; + int sample_n = args[1]; + std::vector ret_feature; - GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature); - - TVMByteArray arr; - arr.size = sizeof(float) * ret_feature.size(); - arr.data = reinterpret_cast(ret_feature.data()); - *ret = arr; -}); + GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature); + TVMByteArray arr; + arr.size = sizeof(float) * ret_feature.size(); + arr.data = reinterpret_cast(ret_feature.data()); + *ret = arr; + }); } // namespace autotvm } // namespace tvm diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 23fbc54..973efb3 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -25,16 +25,17 @@ #ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ #define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ +#include #include #include -#include -#include -#include +#include #include +#include #include -#include #include +#include + #include "feature_visitor.h" namespace tvm { @@ -55,11 +56,7 @@ struct TouchPattern { // all the feature of an iter var struct ItervarFeature { - ItervarFeature(Var var, - int64_t extent, - int nest, - AnnotationType ann_type, - int64_t topdown, + ItervarFeature(Var var, int64_t extent, int nest, AnnotationType ann_type, int64_t topdown, int counter) : length(extent), nest_level(nest), ann(ann_type), topdown_product(topdown), order(counter) {} ItervarFeature() {} @@ -67,9 +64,9 @@ struct ItervarFeature { // Axis Attributes int64_t length; int nest_level; - AnnotationType ann; // one-hot axis type - int64_t topdown_product; // accumulative product of axis length, in top-down order - int64_t bottomup_product; // accumulative product of axis length, in bottom-up order + AnnotationType ann; // one-hot axis type + int64_t topdown_product; // accumulative product of axis length, in top-down order + int64_t bottomup_product; // accumulative product of axis length, in bottom-up order // bottomup_product = reuse * count for any touched buffer int order; // used for soring axis @@ -86,38 +83,31 @@ struct ItervarFeature { // extract iter vars and their touch pattern from ir class TouchExtractor : public FeatureVisitor { public: - void Analyze(const Stmt& stmt) { - operator()(stmt); - } + void Analyze(const Stmt& stmt) { operator()(stmt); } // arithmetic stats void VisitExpr_(const AddNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const SubNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const MulNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].mul_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const DivNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const ModNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); } diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index bb97900..f61ad33 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -20,10 +20,12 @@ /*! * \file codegen_hybrid.cc */ +#include "codegen_hybrid.h" + #include -#include + #include -#include "codegen_hybrid.h" +#include namespace tvm { namespace contrib { @@ -34,7 +36,7 @@ using runtime::TVMRetValue; using namespace tir; std::string dot_to_underscore(std::string s) { - for (auto &ch : s) + for (auto& ch : s) if (ch == '.') ch = '_'; return s; } @@ -57,11 +59,9 @@ std::string CodeGenHybrid::GetUniqueName(std::string prefix) { return prefix; } -std::string CodeGenHybrid::Finish() { - return stream.str(); -} +std::string CodeGenHybrid::Finish() { return stream.str(); } -void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { +void CodeGenHybrid::PrintType(DataType t, std::ostream& os) { if (t.is_float()) { os << "float"; CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64); @@ -80,20 +80,19 @@ void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOL os << op->value; } -void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; } -void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) os << "'" << op->value << "'"; } -template -inline void PrintBinaryExpr(const T* op, - const char* opstr, +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { - CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; + CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; if (isalpha(opstr[0])) { os << opstr << '('; p->PrintExpr(op->a, os); @@ -111,11 +110,10 @@ inline void PrintBinaryExpr(const T* op, } } -inline void PrintBinaryIntrinsitc(const CallNode* op, - const char* opstr, +inline void PrintBinaryIntrinsitc(const CallNode* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { - CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; + CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; CHECK_EQ(op->args.size(), 2U); os << '('; p->PrintExpr(op->args[0], os); @@ -252,9 +250,7 @@ void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLIN LOG(FATAL) << "Phase 0 has no Load(s)!"; } -void CodeGenHybrid::VisitStmt_(const StoreNode* op) { - LOG(FATAL) << "Phase 0 has no Store(s)!"; -} +void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; } void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Phase 0 has no Let(s)!"; @@ -268,7 +264,7 @@ void CodeGenHybrid::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLIN LOG(FATAL) << "Ramp to be supported yet"; } -void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Broadcast: not supported "; } @@ -293,8 +289,8 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { CHECK(iter_var); binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint); PrintIndent(); - stream << "for " << binds_[iter_var->var.get()] << " in bind('" - << iter_var->var->name_hint << "', "; + stream << "for " << binds_[iter_var->var.get()] << " in bind('" << iter_var->var->name_hint + << "', "; PrintExpr(op->value, stream); stream << "):\n"; indent_ += tab_; @@ -355,17 +351,16 @@ void CodeGenHybrid::VisitStmt_(const ForNode* op) { std::string extent = PrintExpr(op->extent); PrintIndent(); std::string vid = GetVarID(op->loop_var.get()); - stream << "for " << vid << " in " << "range(" << extent << "):\n"; + stream << "for " << vid << " in " + << "range(" << extent << "):\n"; indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; } -bool is_noop(const Stmt &stmt) { - if (!stmt.defined()) - return true; - if (auto eval = stmt.as()) - return is_const(eval->value); +bool is_noop(const Stmt& stmt) { + if (!stmt.defined()) return true; + if (auto eval = stmt.as()) return is_const(eval->value); return false; } @@ -395,17 +390,13 @@ void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) { void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; std::string str = PrintExpr(op->value); - if (!str.empty()) - stream << str << "\n"; + if (!str.empty()) stream << str << "\n"; } -void CodeGenHybrid::PrintIndent() { - stream << std::string(indent_, ' '); -} +void CodeGenHybrid::PrintIndent() { stream << std::string(indent_, ' '); } -std::string CodeGenHybrid::GetVarID(const VarNode *v) { - if (binds_.count(v)) - return binds_[v]; +std::string CodeGenHybrid::GetVarID(const VarNode* v) { + if (binds_.count(v)) return binds_[v]; auto key = std::make_pair(static_cast(v), 0); if (id_map_.count(key)) { return id_map_[key]; @@ -413,7 +404,7 @@ std::string CodeGenHybrid::GetVarID(const VarNode *v) { return id_map_[key] = GetUniqueName(v->name_hint); } -std::string CodeGenHybrid::GetTensorID(const FunctionRef &func, int value_index) { +std::string CodeGenHybrid::GetTensorID(const FunctionRef& func, int value_index) { auto key = std::make_pair(func.get(), value_index); if (id_map_.count(key)) { return id_map_[key]; @@ -469,10 +460,8 @@ void CodeGenHybrid::ReserveKeywords() { GetUniqueName("max_num_threads"); } -void CodeGenHybrid::DumpStmt(const Stmt &stmt, - const Array &inputs, - const Array &outputs, - const std::string &name) { +void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array& inputs, + const Array& outputs, const std::string& name) { ReserveKeywords(); GetUniqueName(name); @@ -491,14 +480,12 @@ void CodeGenHybrid::DumpStmt(const Stmt &stmt, indent_ += tab_; for (size_t i = 0; i < outputs.size(); ++i) { PrintIndent(); - stream << GetTensorID(outputs[i]->op, outputs[i]->value_index) - << " = output_tensor(("; + stream << GetTensorID(outputs[i]->op, outputs[i]->value_index) << " = output_tensor(("; for (size_t j = 0; j < outputs[i]->shape.size(); ++j) { if (j) stream << ", "; PrintExpr(outputs[i]->shape[j], stream); } - if (outputs[i]->shape.size() == 1) - stream << ", "; + if (outputs[i]->shape.size() == 1) stream << ", "; stream << "), '" << outputs[i]->dtype << "')\n"; } PrintStmt(stmt); @@ -511,14 +498,13 @@ void CodeGenHybrid::DumpStmt(const Stmt &stmt, stream << "\n"; } -TVM_REGISTER_GLOBAL("hybrid._Dump") -.set_body([](TVMArgs args, TVMRetValue* rv) { - CodeGenHybrid codegen; - if (args.size() == 4) - codegen.DumpStmt(args[0], args[1], args[2], args[3]); - else - codegen.DumpStmt(args[0], args[1], args[2]); - *rv = codegen.Finish(); - }); +TVM_REGISTER_GLOBAL("hybrid._Dump").set_body([](TVMArgs args, TVMRetValue* rv) { + CodeGenHybrid codegen; + if (args.size() == 4) + codegen.DumpStmt(args[0], args[1], args[2], args[3]); + else + codegen.DumpStmt(args[0], args[1], args[2]); + *rv = codegen.Finish(); +}); } // namespace contrib } // namespace tvm diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index d282edb..78a22b5 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -24,10 +24,11 @@ #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ -#include -#include #include #include +#include +#include + #include #include #include @@ -45,9 +46,8 @@ using namespace tir; * **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3. * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``. */ -class CodeGenHybrid : - public ExprFunctor, - public StmtFunctor { +class CodeGenHybrid : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Dump the given function body to hybrid script. @@ -56,8 +56,8 @@ class CodeGenHybrid : * \param outputs Output tensors of this schedule. * \param name The name of the function. */ - void DumpStmt(const Stmt &stmt, const Array &inputs, const Array &outputs, - const std::string &name = "hybrid_func"); + void DumpStmt(const Stmt& stmt, const Array& inputs, const Array& outputs, + const std::string& name = "hybrid_func"); /*! * \brief Finalize the compilation and return the code. * \return The code. @@ -69,55 +69,51 @@ class CodeGenHybrid : * \brief Print the Stmt n to CodeGenHybrid->stream * \param n The statement to be printed. */ - void PrintStmt(const Stmt &n) { - this->VisitStmt(n); - } + void PrintStmt(const Stmt& n) { this->VisitStmt(n); } /*! * \brief Print the expression n(or its ssa id if in ssa mode) into os * \param n The expression to be printed. * \param os The output stream */ - void PrintExpr(const PrimExpr &n, std::ostream &os) { - this->VisitExpr(n, os); - } + void PrintExpr(const PrimExpr& n, std::ostream& os) { this->VisitExpr(n, os); } /*! * \brief Same as PrintExpr, but simply returns result string * \param n The expression to be printed. */ - std::string PrintExpr(const PrimExpr &n) { + std::string PrintExpr(const PrimExpr& n) { std::ostringstream os; PrintExpr(n, os); return os.str(); } // expression - void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment void VisitStmt_(const LetStmtNode* op) override; @@ -136,7 +132,7 @@ class CodeGenHybrid : * \param t The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) + virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) private: /*! \brief The current indent of the code dump. */ @@ -150,9 +146,9 @@ class CodeGenHybrid : /*! * \brief Keys are either (tensors, value_index) or (variables, 0). * Values are the corresponding IDs.*/ - std::map, std::string> id_map_; + std::map, std::string> id_map_; /*! \brief Variables (keys) binded to the threads (values). */ - std::map binds_; + std::map binds_; /*! * \brief Find an unallocated name for the given prefix. * \param prefix The given prefix. @@ -164,13 +160,13 @@ class CodeGenHybrid : * \brief Get or allocate the ID for the given variable. * \param v The given variable. */ - std::string GetVarID(const VarNode *v); + std::string GetVarID(const VarNode* v); /*! * \brief Get or allocate the ID for the given tensor. * \param func The tensor to allocate a name. * \param value_index The value index of the given tensor. */ - std::string GetTensorID(const FunctionRef &func, int value_index); + std::string GetTensorID(const FunctionRef& func, int value_index); /*! \brief the storage scope of allocation */ std::map alloc_storage_scope_; }; diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 8231c1b..cdd9d54 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -23,13 +23,12 @@ */ #include #include -#include - -#include -#include -#include #include #include +#include +#include +#include +#include #include #include @@ -37,9 +36,9 @@ namespace tvm { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; bool LLVMEnabled() { const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); @@ -59,12 +58,8 @@ Target DefaultTargetHost(Target target) { } } -tir::Buffer BufferWithOffsetAlignment(Array shape, - DataType dtype, - std::string name, - int data_alignment, - int offset_factor, - bool compact) { +tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, + int data_alignment, int offset_factor, bool compact) { auto data = tir::Var(name, DataType::Handle()); bool has_any = false; if (!compact) { @@ -85,21 +80,19 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, } return tir::BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - data_alignment, offset_factor, buffer_type); + data_alignment, offset_factor, buffer_type); } -void GetBinds(const Array& args, - bool compact, +void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, - Map* out_binds, - Array* out_arg_list, + Map* out_binds, Array* out_arg_list, const BuildConfig& config) { *out_binds = binds; - for (const auto &x : args) { + for (const auto& x : args) { if (out_binds->find(x) == out_binds->end()) { - auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, - config->data_alignment, config->offset_factor, compact); + auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, config->data_alignment, + config->offset_factor, compact); out_binds->Set(x, buf); out_arg_list->push_back(buf); } else { @@ -115,8 +108,7 @@ transform::Pass BindTarget(Target target) { return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); } - -template +template transform::Pass Filter(FCond fcond) { auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { if (fcond(f)) { @@ -128,10 +120,7 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } - -IRModule lower(te::Schedule sch, - const Array& args, - const std::string& name, +IRModule lower(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, const BuildConfig& config) { Array out_arg_list; @@ -147,8 +136,7 @@ IRModule lower(te::Schedule sch, GetBinds(args, compact, binds, &out_binds, &out_arg_list, config); // build the function - tir::PrimFunc f = te::SchedulePostProcToPrimFunc( - out_arg_list, std::move(stmt), out_binds); + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); if (config->restricted_func) { f = WithAttr(std::move(f), "tir.noalias", Integer(1)); @@ -159,8 +147,7 @@ IRModule lower(te::Schedule sch, // Phase 0 pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back( - tir::transform::StorageFlatten(64, config->instrument_bound_checkers)); + pass_list.push_back(tir::transform::StorageFlatten(64, config->instrument_bound_checkers)); // Phase 1 pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -170,10 +157,8 @@ IRModule lower(te::Schedule sch, pass_list.push_back(tir::transform::InjectDoubleBuffer(config->double_buffer_split_loop)); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back( - tir::transform::UnrollLoop(config->auto_unroll_max_step, - config->auto_unroll_max_depth, - config->auto_unroll_max_extent, - config->unroll_explicit)); + tir::transform::UnrollLoop(config->auto_unroll_max_step, config->auto_unroll_max_depth, + config->auto_unroll_max_extent, config->unroll_explicit)); // Phase 2 pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); @@ -189,16 +174,11 @@ IRModule lower(te::Schedule sch, return mod; } - -std::pair -split_dev_host_funcs(IRModule mod_mixed, - const Target& target, - const Target& target_host, - const BuildConfig& config) { - Array mixed_pass_list = { - BindTarget(target), - tir::transform::VerifyMemory() - }; +std::pair split_dev_host_funcs(IRModule mod_mixed, const Target& target, + const Target& target_host, + const BuildConfig& config) { + Array mixed_pass_list = {BindTarget(target), + tir::transform::VerifyMemory()}; if (config->detect_global_barrier) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } @@ -212,32 +192,30 @@ split_dev_host_funcs(IRModule mod_mixed, mod_mixed = opt_mixed(std::move(mod_mixed)); auto host_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target_host), - tir::transform::LowerTVMBuiltin(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - tir::transform::CombineContextCall(), + Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != + CallingConv::kDeviceKernelLaunch; + }), + BindTarget(target_host), + tir::transform::LowerTVMBuiltin(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), + tir::transform::CombineContextCall(), }; auto opt_host = transform::Sequential(host_pass_list); auto mhost = opt_host(mod_mixed); // device pipeline auto device_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target), - tir::transform::LowerWarpMemory(), - tir::transform::Simplify(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), + Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDeviceKernelLaunch; + }), + BindTarget(target), + tir::transform::LowerWarpMemory(), + tir::transform::Simplify(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), }; auto opt_device = transform::Sequential(device_pass_list); auto mdevice = opt_device(mod_mixed); @@ -246,26 +224,21 @@ split_dev_host_funcs(IRModule mod_mixed, auto keys = target->keys(); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); if (target_is_gpu && mdevice->functions.size() == 0) { - LOG(WARNING) << "Specified target " - << target->str() + LOG(WARNING) << "Specified target " << target->str() << " but cannot find device code. Did you forget to bind?"; } - if (target->device_type == target::llvm()->device_type && - target_host == target) { - CHECK(mdevice->functions.empty()) - << "No device code should be generated when target " - << "and host_target are both llvm target." - << "\n"; + if (target->device_type == target::llvm()->device_type && target_host == target) { + CHECK(mdevice->functions.empty()) << "No device code should be generated when target " + << "and host_target are both llvm target." + << "\n"; } return {mhost, mdevice}; } - // Build for heterogeneous execution. -runtime::Module build(const Map& inputs, - const Target& target_host, +runtime::Module build(const Map& inputs, const Target& target_host, const BuildConfig& config) { std::vector device_modules; @@ -286,8 +259,7 @@ runtime::Module build(const Map& inputs, IRModule mhost_all = IRModule(Map()); for (const auto& it : inputs) { - auto pair = - split_dev_host_funcs(it.second, it.first, target_host_val, config); + auto pair = split_dev_host_funcs(it.second, it.first, target_host_val, config); auto& mhost = pair.first; auto& mdevice = pair.second; @@ -308,8 +280,7 @@ runtime::Module build(const Map& inputs, } // Build for heterogeneous execution when target is a string. -runtime::Module build(const Map& inputs, - const Target& target_host, +runtime::Module build(const Map& inputs, const Target& target_host, const BuildConfig& config) { Map updated_input; for (const auto& it : inputs) { @@ -323,9 +294,7 @@ runtime::Module build(const Map& inputs, } // Build for homogeneous execution. -runtime::Module build(const IRModule& funcs, - const Target& target, - const Target& target_host, +runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host, const BuildConfig& config) { Map inputs = {{target, funcs}}; return build(inputs, target_host, config); diff --git a/src/ir/adt.cc b/src/ir/adt.cc index 4650a3b..957905d 100644 --- a/src/ir/adt.cc +++ b/src/ir/adt.cc @@ -21,14 +21,12 @@ * \file src/ir/adt.cc * \brief ADT type definitions. */ -#include #include +#include namespace tvm { -Constructor::Constructor(std::string name_hint, - tvm::Array inputs, - GlobalTypeVar belong_to) { +Constructor::Constructor(std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->inputs = std::move(inputs); @@ -39,21 +37,18 @@ Constructor::Constructor(std::string name_hint, TVM_REGISTER_NODE_TYPE(ConstructorNode); TVM_REGISTER_GLOBAL("ir.Constructor") -.set_body_typed([](std::string name_hint, - tvm::Array inputs, - GlobalTypeVar belong_to) { - return Constructor(name_hint, inputs, belong_to); -}); + .set_body_typed([](std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { + return Constructor(name_hint, inputs, belong_to); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ConstructorNode(" << node->name_hint << ", " - << node->inputs << ", " << node->belong_to << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstructorNode(" << node->name_hint << ", " << node->inputs << ", " + << node->belong_to << ")"; + }); -TypeData::TypeData(GlobalTypeVar header, - tvm::Array type_vars, +TypeData::TypeData(GlobalTypeVar header, tvm::Array type_vars, tvm::Array constructors) { ObjectPtr n = make_object(); n->header = std::move(header); @@ -65,17 +60,16 @@ TypeData::TypeData(GlobalTypeVar header, TVM_REGISTER_NODE_TYPE(TypeDataNode); TVM_REGISTER_GLOBAL("ir.TypeData") -.set_body_typed([](GlobalTypeVar header, - tvm::Array type_vars, - tvm::Array constructors) { - return TypeData(header, type_vars, constructors); -}); + .set_body_typed([](GlobalTypeVar header, tvm::Array type_vars, + tvm::Array constructors) { + return TypeData(header, type_vars, constructors); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " - << node->constructors << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " + << node->constructors << ")"; + }); } // namespace tvm diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index dbd5a4f..56a561b 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -32,6 +32,7 @@ #include #include + #include namespace tvm { @@ -39,16 +40,13 @@ namespace tvm { template class AttrFunctor; -#define ATTR_FUNCTOR_DEFAULT \ +#define ATTR_FUNCTOR_DEFAULT \ { return VisitAttrDefault_(op, std::forward(args)...); } - -#define ATTR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitAttr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ +#define ATTR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitAttr_(static_cast(n.get()), std::forward(args)...); \ + }); // A functor for common attribute information. template diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index bee103d..edc81ae 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -22,20 +22,16 @@ */ #include #include + #include "attr_functor.h" namespace tvm { -void DictAttrsNode::VisitAttrs(AttrVisitor* v) { - v->Visit("__dict__", &dict); -} +void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } -void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { - v->Visit("__dict__", &dict); -} +void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } -void DictAttrsNode::InitByPackedArgs( - const runtime::TVMArgs& args, bool allow_unknown) { +void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; runtime::TVMArgValue val = args[i + 1]; @@ -49,9 +45,7 @@ void DictAttrsNode::InitByPackedArgs( } } -Array DictAttrsNode::ListFieldInfo() const { - return {}; -} +Array DictAttrsNode::ListFieldInfo() const { return {}; } DictAttrs::DictAttrs(Map dict) { ObjectPtr n = make_object(); @@ -60,22 +54,20 @@ DictAttrs::DictAttrs(Map dict) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dict; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->dict; + }); TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); -TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict") -.set_body_typed([](DictAttrs attrs) { +TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs) { return attrs->dict; }); -TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo") -.set_body_typed([](Attrs attrs) { +TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo").set_body_typed([](Attrs attrs) { return attrs->ListFieldInfo(); }); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 4d3ed30..7deff90 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -26,16 +26,15 @@ namespace tvm { - using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "EnvFunc(" << op->name << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "EnvFunc(" << op->name << ")"; + }); ObjectPtr CreateEnvNode(const std::string& name) { auto* f = runtime::Registry::Get(name); @@ -46,31 +45,24 @@ ObjectPtr CreateEnvNode(const std::string& name) { return n; } -EnvFunc EnvFunc::Get(const std::string& name) { - return EnvFunc(CreateEnvNode(name)); -} +EnvFunc EnvFunc::Get(const std::string& name) { return EnvFunc(CreateEnvNode(name)); } -TVM_REGISTER_GLOBAL("ir.EnvFuncGet") -.set_body_typed(EnvFunc::Get); +TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); -TVM_REGISTER_GLOBAL("ir.EnvFuncCall") -.set_body([](TVMArgs args, TVMRetValue* rv) { - EnvFunc env = args[0]; - CHECK_GE(args.size(), 1); - env->func.CallPacked(TVMArgs(args.values + 1, - args.type_codes + 1, - args.size() - 1), rv); - }); +TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body([](TVMArgs args, TVMRetValue* rv) { + EnvFunc env = args[0]; + CHECK_GE(args.size(), 1); + env->func.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), rv); +}); -TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc") -.set_body_typed([](const EnvFunc&n) { - return n->func; - }); +TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc").set_body_typed([](const EnvFunc& n) { + return n->func; +}); TVM_REGISTER_NODE_TYPE(EnvFuncNode) -.set_creator(CreateEnvNode) -.set_repr_bytes([](const Object* n) -> std::string { - return static_cast(n)->name; - }); + .set_creator(CreateEnvNode) + .set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; + }); } // namespace tvm diff --git a/src/ir/error.cc b/src/ir/error.cc index 9d49828..9db61a0 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -22,8 +22,8 @@ * \brief Utilities for error tracking and reporting. */ -#include #include +#include // NOTE: reverse dependency on relay. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -31,13 +31,15 @@ // Rationale: use relay's printer for astext. #include +// clang-fomat off #include #include #include +// clang-format on namespace tvm { -template +template using NodeMap = std::unordered_map; void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { @@ -76,9 +78,9 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // Setup error map. auto it = error_maps.find(global); if (it != error_maps.end()) { - it->second.insert({ node, err_msg.str() }); + it->second.insert({node, err_msg.str()}); } else { - error_maps.insert({ global, { { node, err_msg.str() }}}); + error_maps.insert({global, {{node, err_msg.str()}}}); } } @@ -87,10 +89,10 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { std::stringstream annotated_prog; // First we output a header for the errors. - annotated_prog << - rang::style::bold << std::endl << - "Error(s) have occurred. The program has been annotated with them:" - << std::endl << std::endl << rang::style::reset; + annotated_prog << rang::style::bold << std::endl + << "Error(s) have occurred. The program has been annotated with them:" << std::endl + << std::endl + << rang::style::reset; // For each global function which contains errors, we will // construct an annotated function. @@ -101,11 +103,8 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // We output the name of the function before displaying // the annotated program. - annotated_prog << - rang::style::bold << - "In `" << global->name_hint << "`: " << - std::endl << - rang::style::reset; + annotated_prog << rang::style::bold << "In `" << global->name_hint << "`: " << std::endl + << rang::style::reset; // We then call into the Relay printer to generate the program. // @@ -140,9 +139,9 @@ void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, con if (it != this->node_to_error_.end()) { it->second.push_back(index_to_insert); } else { - this->node_to_error_.insert({ node, { index_to_insert }}); + this->node_to_error_.insert({node, {index_to_insert}}); } - this->node_to_gv_.insert({ node, global }); + this->node_to_gv_.insert({node, global}); } } // namespace tvm diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 7272213..000305b 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -21,9 +21,9 @@ * \file src/ir/expr.cc * \brief The expression AST nodes for the common IR infra. */ -#include #include #include +#include // NOTE: reverse dependency on top/tir. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -34,11 +34,9 @@ namespace tvm { -PrimExpr::PrimExpr(int32_t value) - : PrimExpr(IntImm(DataType::Int(32), value)) {} +PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} -PrimExpr::PrimExpr(float value) - : PrimExpr(FloatImm(DataType::Float(32), value)) {} +PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; @@ -52,17 +50,14 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { return tir::StringImmNode::make(GetRef(ptr)); } CHECK(ObjectTypeChecker::Check(ref.get())) - << "Expect type " << ObjectTypeChecker::TypeName() - << " but get " << ref->GetTypeKey(); + << "Expect type " << ObjectTypeChecker::TypeName() << " but get " + << ref->GetTypeKey(); return Downcast(ref); } - IntImm::IntImm(DataType dtype, int64_t value) { - CHECK(dtype.is_scalar()) - << "ValueError: IntImm can only take scalar."; - CHECK(dtype.is_int() || dtype.is_uint()) - << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm can only take scalar."; if (dtype.is_uint()) { CHECK_GE(value, 0U); } @@ -72,86 +67,75 @@ IntImm::IntImm(DataType dtype, int64_t value) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.IntImm") -.set_body_typed([](DataType dtype, int64_t value) { +TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value) { return IntImm(dtype, value); }); TVM_REGISTER_NODE_TYPE(IntImmNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - if (op->dtype == DataType::Int(32)) { - p->stream << op->value; - } else { - p->stream << "(" << op->dtype << ")" << op->value; - } - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + if (op->dtype == DataType::Int(32)) { + p->stream << op->value; + } else { + p->stream << "(" << op->dtype << ")" << op->value; + } + }); FloatImm::FloatImm(DataType dtype, double value) { - CHECK_EQ(dtype.lanes(), 1) - << "ValueError: FloatImm can only take scalar."; + CHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; ObjectPtr node = make_object(); node->dtype = dtype; node->value = value; data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.FloatImm") -.set_body_typed([](DataType dtype, double value) { +TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value) { return FloatImm(dtype, value); }); TVM_REGISTER_NODE_TYPE(FloatImmNode); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - auto& stream = p->stream; - switch (op->dtype.bits()) { - case 64: - stream << op->value; - break; - case 32: - stream << op->value << 'f'; - break; - case 16: - stream << op->value << 'h'; - break; - default: - LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); - } - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + auto& stream = p->stream; + switch (op->dtype.bits()) { + case 64: + stream << op->value; + break; + case 32: + stream << op->value << 'f'; + break; + case 16: + stream << op->value << 'h'; + break; + default: + LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); + } + }); Range::Range(PrimExpr begin, PrimExpr end) - : Range(make_object( - begin, - tir::is_zero(begin) ? end : (end - begin))) { -} + : Range(make_object(begin, tir::is_zero(begin) ? end : (end - begin))) {} Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } -TVM_REGISTER_GLOBAL("ir.range_by_min_extent") -.set_body_typed(Range::make_by_min_extent); +TVM_REGISTER_GLOBAL("ir.range_by_min_extent").set_body_typed(Range::make_by_min_extent); -TVM_REGISTER_GLOBAL("ir.Range") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = Range(args[0], args[1]); - }); +}); TVM_REGISTER_NODE_TYPE(RangeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; + }); GlobalVar::GlobalVar(std::string name_hint) { ObjectPtr n = make_object(); @@ -161,57 +145,56 @@ GlobalVar::GlobalVar(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalVar") -.set_body_typed([](std::string name){ +TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](std::string name) { return GlobalVar(name); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalVar(" << node->name_hint << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "GlobalVar(" << node->name_hint << ")"; + }); // Container printer TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '['; - for (size_t i = 0 ; i < op->data.size(); ++i) { - if (i != 0) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '['; + for (size_t i = 0; i < op->data.size(); ++i) { + if (i != 0) { + p->stream << ", "; + } + p->Print(op->data[i]); } - p->Print(op->data[i]); - } - p->stream << ']'; -}); + p->stream << ']'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->data.begin(); it != op->data.end(); ++it) { - if (it != op->data.begin()) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '{'; + for (auto it = op->data.begin(); it != op->data.end(); ++it) { + if (it != op->data.begin()) { + p->stream << ", "; + } + p->Print(it->first); + p->stream << ": "; + p->Print(it->second); } - p->Print(it->first); - p->stream << ": "; - p->Print(it->second); - } - p->stream << '}'; - }); + p->stream << '}'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->data.begin(); it != op->data.end(); ++it) { - if (it != op->data.begin()) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '{'; + for (auto it = op->data.begin(); it != op->data.end(); ++it) { + if (it != op->data.begin()) { + p->stream << ", "; + } + p->stream << '\"' << it->first << "\": "; + p->Print(it->second); } - p->stream << '\"' << it->first << "\": "; - p->Print(it->second); - } - p->stream << '}'; - }); + p->stream << '}'; + }); } // namespace tvm diff --git a/src/ir/function.cc b/src/ir/function.cc index 08cdc93..57d62b4 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -21,40 +21,32 @@ * \file src/ir/function.cc * \brief The function data structure. */ -#include #include +#include // NOTE: reverse dependency on relay, tir/ // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. // // Rationale: We calls into the type specific WithAttr function -#include #include - +#include namespace tvm { -TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs") -.set_body_typed([](BaseFunc func) { - return func->attrs; -}); +TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { return func->attrs; }); -TVM_REGISTER_GLOBAL("ir.BaseFuncCopy") -.set_body_typed([](BaseFunc func) { - return func; -}); +TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; }); TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") -.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc { - if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); - return func; - } -}); - + .set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc { + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 1be58f3..c739374 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -21,9 +21,9 @@ * \file module.cc * \brief The global module in Relay. */ -#include #include #include +#include // NOTE: reverse dependency on relay. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -32,8 +32,8 @@ #include #include -#include #include +#include #include namespace tvm { @@ -52,14 +52,14 @@ IRModule::IRModule(tvm::Map functions, for (const auto& kv : n->functions) { // set global var map CHECK(n->global_var_map_.count(kv.first->name_hint) == 0) - << "Duplicate global function name " << kv.first->name_hint; + << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } for (const auto& kv : n->type_definitions) { // set global typevar map CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0) - << "Duplicate global type definition name " << kv.first->name_hint; + << "Duplicate global type definition name " << kv.first->name_hint; n->global_type_var_map_.Set(kv.first->name_hint, kv.first); n->RegisterConstructors(kv.first, kv.second); } @@ -87,9 +87,8 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { auto reduce_temp = [&]() { // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.first < rhs.first; - }); + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); hash_reduce(static_cast(temp.size())); // hash the content @@ -150,7 +149,7 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const { CHECK(global_type_var_map_.defined()); auto it = global_type_var_map_.find(name); CHECK(it != global_type_var_map_.end()) - << "Cannot find global type var " << name << " in the Module"; + << "Cannot find global type var " << name << " in the Module"; return (*it).second; } @@ -174,7 +173,7 @@ tvm::Array IRModuleNode::GetGlobalTypeVars() const { return tvm::Array(global_type_vars); } -template +template tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { tvm::Array ret(l); for (const T& t : r) { @@ -184,55 +183,37 @@ tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { } // helper function to run type check -relay::Function RunTypeCheck(const IRModule& mod, - const GlobalVar& var, - relay::Function f) { +relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::Function f) { auto func = Downcast(relay::DeDup(std::move(f))); // Type check the item before we add it to the module. auto fv = relay::FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod); if (fv.size() != 0) { - LOG(WARNING) - << "There are free variables: " - << fv - << " in function: " - << AsText(func, false) - << std::endl; + LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false) + << std::endl; } if (ftv.size() != 0) { - LOG(WARNING) - << "There are free type variables: " - << ftv - << " in function: " - << AsText(func, false) - << std::endl; + LOG(WARNING) << "There are free type variables: " << ftv + << " in function: " << AsText(func, false) << std::endl; } - func = relay::Function(concat(func->params, fv), - func->body, - func->ret_type, - concat(func->type_params, ftv), - func->attrs); + func = relay::Function(concat(func->params, fv), func->body, func->ret_type, + concat(func->type_params, ftv), func->attrs); // Type check the item before we add it to the module. relay::Function checked_func = InferType(func, mod, var); return checked_func; } -void IRModuleNode::Add(const GlobalVar& var, - const BaseFunc& f, - bool update) { +void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { BaseFunc checked_func = f; if (auto* ptr = f.as()) { - checked_func = RunTypeCheck(GetRef(this), - var, - GetRef(ptr)); + checked_func = RunTypeCheck(GetRef(this), var, GetRef(ptr)); } Type type = checked_func->checked_type(); CHECK(type.as() == nullptr); if (functions.find(var) != functions.end()) { - CHECK(update) - << "Already have definition for " << var->name_hint; + CHECK(update) << "Already have definition for " << var->name_hint; auto old_type = functions[var]->checked_type(); CHECK(tvm::StructuralEqual()(type, old_type)) << "Module#update changes type, not possible in this mode."; @@ -241,8 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } -void IRModuleNode::AddUnchecked(const GlobalVar& var, - const BaseFunc& func) { +void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { this->functions.Set(var, func); auto it = global_var_map_.find(var->name_hint); @@ -268,36 +248,31 @@ void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData } } -void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, - const TypeData& type, - bool update) { +void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) { AddTypeDefUnchecked(var, type, update); // need to kind check at the end because the check can look up // a definition potentially CHECK(relay::KindCheck(type, GetRef(this)) == TypeKind::kTypeData) - << "Invalid or malformed typedata given to module: " << type; + << "Invalid or malformed typedata given to module: " << type; } -void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, - const TypeData& type, +void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) { this->type_definitions.Set(var, type); if (!update) { // set global type var map CHECK(global_type_var_map_.count(var->name_hint) == 0) - << "Duplicate global type definition name " << var->name_hint; + << "Duplicate global type definition name " << var->name_hint; } global_type_var_map_.Set(var->name_hint, var); RegisterConstructors(var, type); } -void IRModuleNode::Update(const GlobalVar& var, - const BaseFunc& func) { +void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) { this->Add(var, func, true); } -void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, - const TypeData& type) { +void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) { this->AddTypeDef(var, type, true); } @@ -310,8 +285,7 @@ void IRModuleNode::Remove(const GlobalVar& var) { BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); - CHECK(it != functions.end()) - << "There is no definition of " << var->name_hint; + CHECK(it != functions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } @@ -322,8 +296,7 @@ BaseFunc IRModuleNode::Lookup(const String& name) const { TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); - CHECK(it != type_definitions.end()) - << "There is no definition of " << var->name_hint; + CHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } @@ -334,8 +307,7 @@ TypeData IRModuleNode::LookupTypeDef(const String& name) const { Constructor IRModuleNode::LookupTag(const int32_t tag) { auto it = constructor_tag_map_.find(tag); - CHECK(it != constructor_tag_map_.end()) - << "There is no constructor with the tag " << tag; + CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag; return (*it).second; } @@ -356,10 +328,9 @@ void IRModuleNode::Update(const IRModule& mod) { } } -IRModule IRModule::FromExpr( - const RelayExpr& expr, - const tvm::Map& global_funcs, - const tvm::Map& type_definitions) { +IRModule IRModule::FromExpr(const RelayExpr& expr, + const tvm::Map& global_funcs, + const tvm::Map& type_definitions) { auto mod = IRModule(global_funcs, type_definitions); BaseFunc func; std::string gv_name = "main"; @@ -371,8 +342,7 @@ IRModule IRModule::FromExpr( } } else { - func = relay::Function(relay::FreeVars(expr), expr, Type(), - relay::FreeTypeVars(expr, mod), {}); + func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); } auto main_gv = GlobalVar(gv_name); mod->Add(main_gv, func); @@ -384,9 +354,8 @@ void IRModuleNode::Import(const String& path) { this->import_set_.insert(path); DLOG(INFO) << "Importing: " << path; std::fstream src_file(path, std::fstream::in); - std::string file_contents { - std::istreambuf_iterator(src_file), - std::istreambuf_iterator() }; + std::string file_contents{std::istreambuf_iterator(src_file), + std::istreambuf_iterator()}; auto mod_to_import = IRModule::FromText(file_contents, path); Update(mod_to_import); } @@ -399,9 +368,7 @@ void IRModuleNode::ImportFromStd(const String& path) { this->Import(std_path + "/" + path.operator std::string()); } -std::unordered_set IRModuleNode::Imports() const { - return this->import_set_; -} +std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } IRModule IRModule::FromText(const String& text, const String& source_path) { auto* f = tvm::runtime::Registry::Get("relay.fromtext"); @@ -413,13 +380,12 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") -.set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); -}); + .set_body_typed([](tvm::Map funcs, + tvm::Map types) { + return IRModule(funcs, types, {}); + }); -TVM_REGISTER_GLOBAL("ir.Module_Add") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) { IRModule mod = args[0]; GlobalVar var = args[1]; ObjectRef val = args[2]; @@ -443,75 +409,65 @@ TVM_REGISTER_GLOBAL("ir.Module_Add") *ret = mod; }); -TVM_REGISTER_GLOBAL("ir.Module_AddDef") -.set_body_method(&IRModuleNode::AddTypeDef); +TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") -.set_body_method(&IRModuleNode::GetGlobalVar); + .set_body_method(&IRModuleNode::GetGlobalVar); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars") -.set_body_method(&IRModuleNode::GetGlobalVars); + .set_body_method(&IRModuleNode::GetGlobalVars); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") -.set_body_method(&IRModuleNode::GetGlobalTypeVars); + .set_body_method(&IRModuleNode::GetGlobalTypeVars); TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") -.set_body_method(&IRModuleNode::ContainGlobalVar); + .set_body_method(&IRModuleNode::ContainGlobalVar); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") -.set_body_method(&IRModuleNode::GetGlobalTypeVar); + .set_body_method(&IRModuleNode::GetGlobalTypeVar); -TVM_REGISTER_GLOBAL("ir.Module_Lookup") -.set_body_typed([](IRModule mod, GlobalVar var) { +TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_Lookup_str") -.set_body_typed([](IRModule mod, String var) { +TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_LookupDef") -.set_body_typed([](IRModule mod, GlobalTypeVar var) { +TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) { return mod->LookupTypeDef(var); }); -TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str") -.set_body_typed([](IRModule mod, String var) { +TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) { return mod->LookupTypeDef(var); }); -TVM_REGISTER_GLOBAL("ir.Module_LookupTag") -.set_body_typed([](IRModule mod, int32_t tag) { - return mod->LookupTag(tag); - }); +TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) { + return mod->LookupTag(tag); +}); TVM_REGISTER_GLOBAL("ir.Module_FromExpr") -.set_body_typed([](RelayExpr e, - tvm::Map funcs, - tvm::Map type_defs) { - return IRModule::FromExpr(e, funcs, type_defs); -}); + .set_body_typed([](RelayExpr e, tvm::Map funcs, + tvm::Map type_defs) { + return IRModule::FromExpr(e, funcs, type_defs); + }); -TVM_REGISTER_GLOBAL("ir.Module_Update") -.set_body_typed([](IRModule mod, IRModule from) { +TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { mod->Update(from); }); -TVM_REGISTER_GLOBAL("ir.Module_Import") -.set_body_typed([](IRModule mod, String path) { +TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { mod->Import(path); }); -TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd") -.set_body_typed([](IRModule mod, String path) { +TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) { mod->ImportFromStd(path); -});; +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IRModuleNode( " << node->functions << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IRModuleNode( " << node->functions << ")"; + }); } // namespace tvm diff --git a/src/ir/op.cc b/src/ir/op.cc index bd8a6e2..8f58768 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -23,8 +23,8 @@ */ #include #include -#include #include +#include #include #include @@ -37,13 +37,11 @@ DMLC_REGISTRY_ENABLE(::tvm::OpRegistry); namespace tvm { -using runtime::TVMRetValue; -using runtime::TVMArgs; using runtime::PackedFunc; +using runtime::TVMArgs; +using runtime::TVMRetValue; -::dmlc::Registry* OpRegistry::Registry() { - return ::dmlc::Registry::Get(); -} +::dmlc::Registry* OpRegistry::Registry() { return ::dmlc::Registry::Get(); } // single manager of operator information. struct OpManager { @@ -112,9 +110,7 @@ void OpRegistry::reset_attr(const std::string& key) { } } -void OpRegistry::UpdateAttr(const std::string& key, - TVMRetValue value, - int plevel) { +void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int plevel) { OpManager* mgr = OpManager::Global(); std::lock_guard lock(mgr->mutex); std::unique_ptr& op_map = mgr->attr[key]; @@ -127,94 +123,81 @@ void OpRegistry::UpdateAttr(const std::string& key, op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); } std::pair& p = op_map->data_[index]; - CHECK(p.second != plevel) - << "Attribute " << key << " of operator " << this->name - << " is already registered with same plevel=" << plevel; + CHECK(p.second != plevel) << "Attribute " << key << " of operator " << this->name + << " is already registered with same plevel=" << plevel; CHECK(value.type_code() != kTVMNullptr) - << "Registered packed_func is Null for " << key - << " of operator " << this->name; + << "Registered packed_func is Null for " << key << " of operator " << this->name; if (p.second < plevel && value.type_code() != kTVMNullptr) { op_map->data_[index] = std::make_pair(value, plevel); } } // Frontend APIs -TVM_REGISTER_GLOBAL("relay.op._ListOpNames") -.set_body_typed([]() { - Array ret; - for (const std::string& name : dmlc::Registry::ListAllNames()) { - ret.push_back(name); - } - return ret; - }); +TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() { + Array ret; + for (const std::string& name : dmlc::Registry::ListAllNames()) { + ret.push_back(name); + } + return ret; +}); -TVM_REGISTER_GLOBAL("relay.op._GetOp") -.set_body_typed([](std::string name) -> Op { +TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](std::string name) -> Op { return Op::Get(name); }); -TVM_REGISTER_GLOBAL("relay.op._OpGetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - auto op_map = Op::GetAttr(attr_name); - if (op_map.count(op)) { - *rv = op_map[op]; - } - }); - -TVM_REGISTER_GLOBAL("relay.op._OpSetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = - OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name(); - reg.set_attr(attr_name, value, plevel); - }); - -TVM_REGISTER_GLOBAL("relay.op._OpResetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - auto& reg = - OpRegistry::Registry()->__REGISTER_OR_GET__(op->name); - reg.reset_attr(attr_name); - }); - -TVM_REGISTER_GLOBAL("relay.op._Register") -.set_body([](TVMArgs args, TVMRetValue* rv) { - std::string op_name = args[0]; - std::string attr_key = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = - OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); - // enable resgiteration and override of certain properties - if (attr_key == "num_inputs" && plevel > 128) { - reg.set_num_inputs(value); - } else if (attr_key == "attrs_type_key" && plevel > 128) { - LOG(FATAL) << "attrs type key no longer supported"; +TVM_REGISTER_GLOBAL("relay.op._OpGetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + *rv = op_map[op]; + } +}); + +TVM_REGISTER_GLOBAL("relay.op._OpSetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name(); + reg.set_attr(attr_name, value, plevel); +}); + +TVM_REGISTER_GLOBAL("relay.op._OpResetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name); + reg.reset_attr(attr_name); +}); + +TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string op_name = args[0]; + std::string attr_key = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + LOG(FATAL) << "attrs type key no longer supported"; + } else { + // normal attr table override. + if (args[2].type_code() == kTVMPackedFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = args[2]; + // If we get a function from frontend, avoid deleting it. + OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); + reg.set_attr(attr_key, f, plevel); } else { - // normal attr table override. - if (args[2].type_code() == kTVMPackedFuncHandle) { - // do an eager copy of the PackedFunc - PackedFunc f = args[2]; - // If we get a function from frontend, avoid deleting it. - OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); - reg.set_attr(attr_key, f, plevel); - } else { - reg.set_attr(attr_key, args[2], plevel); - } + reg.set_attr(attr_key, args[2], plevel); } - }); + } +}); // helper to get internal dev function in objectref. struct Op2ObjectPtr : public ObjectRef { - static ObjectPtr Get(const Op& op) { - return GetDataPtr(op); - } + static ObjectPtr Get(const Op& op) { return GetDataPtr(op); } }; ObjectPtr CreateOp(const std::string& name) { @@ -224,16 +207,14 @@ ObjectPtr CreateOp(const std::string& name) { return Op2ObjectPtr::Get(op); } -TVM_REGISTER_NODE_TYPE(OpNode) -.set_creator(CreateOp) -.set_repr_bytes([](const Object* n) { - return static_cast(n)->name; - }); +TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes([](const Object* n) { + return static_cast(n)->name; +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Op(" << node->name << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Op(" << node->name << ")"; + }); } // namespace tvm diff --git a/src/ir/span.cc b/src/ir/span.cc index 5a06a10..742c985 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -45,24 +45,21 @@ ObjectPtr GetSourceNameNodeByStr(const std::string& name) { return GetSourceNameNode(name); } -SourceName SourceName::Get(const String& name) { - return SourceName(GetSourceNameNode(name)); -} +SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } -TVM_REGISTER_GLOBAL("ir.SourceName") -.set_body_typed(SourceName::Get); +TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "SourceName(" << node->name << ", " << node << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "SourceName(" << node->name << ", " << node << ")"; + }); TVM_REGISTER_NODE_TYPE(SourceNameNode) -.set_creator(GetSourceNameNodeByStr) -.set_repr_bytes([](const Object* n) -> std::string { - return static_cast(n)->name; -}); + .set_creator(GetSourceNameNodeByStr) + .set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; + }); Span SpanNode::make(SourceName source, int lineno, int col_offset) { auto n = make_object(); @@ -74,13 +71,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span") -.set_body_typed(SpanNode::make); +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed(SpanNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source << ", " << node->lineno << ", " - << node->col_offset << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Span(" << node->source << ", " << node->lineno << ", " << node->col_offset + << ")"; + }); } // namespace tvm diff --git a/src/ir/tensor_type.cc b/src/ir/tensor_type.cc index 92f0ea2..0fab0ac 100644 --- a/src/ir/tensor_type.cc +++ b/src/ir/tensor_type.cc @@ -21,8 +21,8 @@ * \file src/ir/tensor_type.cc * \brief The type system AST nodes of Relay. */ -#include #include +#include #include namespace tvm { @@ -37,9 +37,7 @@ TensorType::TensorType(Array shape, DataType dtype) { data_ = std::move(n); } -TensorType TensorType::Scalar(DataType dtype) { - return TensorType({}, dtype); -} +TensorType TensorType::Scalar(DataType dtype) { return TensorType({}, dtype); } PrimExpr TensorTypeNode::Size() const { if (shape.size() == 0) { @@ -55,15 +53,14 @@ PrimExpr TensorTypeNode::Size() const { TVM_REGISTER_NODE_TYPE(TensorTypeNode); -TVM_REGISTER_GLOBAL("ir.TensorType") -.set_body_typed([](Array shape, DataType dtype) { +TVM_REGISTER_GLOBAL("ir.TensorType").set_body_typed([](Array shape, DataType dtype) { return TensorType(shape, dtype); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; + }); } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index c1547d5..d7d9b06 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -22,11 +22,11 @@ * \brief Infrastructure for transformation passes. */ #include -#include +#include +#include #include #include -#include -#include +#include // TODO(tqchen): Update to use String container after it is merged. #include @@ -37,9 +37,9 @@ namespace tvm { namespace transform { +using tvm::ReprPrinter; using tvm::runtime::TVMArgs; using tvm::runtime::TVMRetValue; -using tvm::ReprPrinter; struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ @@ -48,32 +48,26 @@ struct PassContextThreadLocalEntry { /*! \brief The current pass context. */ std::stack context_stack; - PassContextThreadLocalEntry() { - default_context = PassContext(make_object()); - } + PassContextThreadLocalEntry() { default_context = PassContext(make_object()); } }; /*! \brief Thread local store to hold the pass context. */ -typedef dmlc::ThreadLocalStore - RelayPassContextThreadLocalStore; +typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; void PassContext::EnterWithScope() { - PassContextThreadLocalEntry* entry = - RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); } void PassContext::ExitWithScope() { - PassContextThreadLocalEntry* entry = - RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } PassContext PassContext::Current() { - PassContextThreadLocalEntry* entry = - RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); if (!entry->context_stack.empty()) { return entry->context_stack.top(); } else { @@ -81,15 +75,13 @@ PassContext PassContext::Current() { } } -PassContext PassContext::Create() { - return PassContext(make_object()); -} +PassContext PassContext::Create() { return PassContext(make_object()); } void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { - auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->trace_func != nullptr) { - pass_ctx_node->trace_func(module, info, is_before); - } + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->trace_func != nullptr) { + pass_ctx_node->trace_func(module, info, is_before); + } } class ModulePass; @@ -114,9 +106,7 @@ class ModulePassNode : public PassNode { ModulePassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a module pass on given pass context. @@ -211,9 +201,7 @@ class SequentialNode : public PassNode { TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); }; -PassInfo::PassInfo(int opt_level, - std::string name, - tvm::Array required) { +PassInfo::PassInfo(int opt_level, std::string name, tvm::Array required) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); @@ -221,9 +209,8 @@ PassInfo::PassInfo(int opt_level, data_ = std::move(pass_info); } -ModulePass::ModulePass( - runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { +ModulePass::ModulePass(runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); @@ -231,13 +218,10 @@ ModulePass::ModulePass( } // Module -> Module optimizations. -IRModule ModulePassNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); - DLOG(INFO) << "Executing module pass : " - << pass_info->name - << " with opt level: " - << pass_info->opt_level; + DLOG(INFO) << "Executing module pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; CHECK(mod.defined()); pass_ctx.Trace(mod, pass_info, true); @@ -307,20 +291,18 @@ Pass GetPass(const std::string& pass_name) { // pass } else if ((f = Registry::Get("relay._transform." + pass_name))) { } - CHECK(f != nullptr) << "Cannot use " << pass_name - << "to create the pass"; + CHECK(f != nullptr) << "Cannot use " << pass_name << "to create the pass"; return (*f)(); } // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. -IRModule SequentialNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); - if (!PassEnabled(pass_info)) continue; + if (!PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); @@ -330,11 +312,9 @@ IRModule SequentialNode::operator()(IRModule mod, return mod; } -Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { +Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, + int opt_level, const std::string& name, + const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return ModulePass(pass_func, pass_info); } @@ -342,55 +322,50 @@ Pass CreateModulePass( TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") -.set_body_typed([](int opt_level, std::string name, tvm::Array required) { - return PassInfo(opt_level, name, required); -}); + .set_body_typed([](int opt_level, std::string name, tvm::Array required) { + return PassInfo(opt_level, name, required); + }); -TVM_REGISTER_GLOBAL("transform.Info") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; *ret = pass->Info(); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, tvm::ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "The meta data of the pass: "; - p->stream << "pass name: " << node->name; - p->stream << "opt_level: " << node->opt_level; - p->stream << "required passes: [" << "\n"; - for (const auto& it : node->required) { - p->stream << it << ", "; - } - p->stream << "]\n"; -}); + .set_dispatch([](const ObjectRef& ref, tvm::ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "The meta data of the pass: "; + p->stream << "pass name: " << node->name; + p->stream << "opt_level: " << node->opt_level; + p->stream << "required passes: [" + << "\n"; + for (const auto& it : node->required) { + p->stream << it << ", "; + } + p->stream << "]\n"; + }); TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_GLOBAL("transform.MakeModulePass") -.set_body_typed( - [](runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - return ModulePass(pass_func, pass_info); -}); + .set_body_typed([](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return ModulePass(pass_func, pass_info); }); -TVM_REGISTER_GLOBAL("transform.RunPass") -.set_body_typed([](Pass pass, IRModule mod) { +TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) { return pass(std::move(mod)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Module pass: " << info->name - << " at the optimization level " << info->opt_level; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Module pass: " << info->name << " at the optimization level " + << info->opt_level; + }); TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_GLOBAL("transform.Sequential") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValue* ret) { tvm::Array passes = args[0]; int opt_level = args[1]; std::string name = args[2]; @@ -400,23 +375,22 @@ TVM_REGISTER_GLOBAL("transform.Sequential") }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Sequential pass: " << info->name - << " at the optimization level " << info->opt_level << ". "; - p->stream << "The passes will be executed are: ["; - for (const auto& it : node->passes) { - const PassInfo pass_info = it->Info(); - p->stream << pass_info->name << " "; - } - p->stream << "]"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Sequential pass: " << info->name << " at the optimization level " + << info->opt_level << ". "; + p->stream << "The passes will be executed are: ["; + for (const auto& it : node->passes) { + const PassInfo pass_info = it->Info(); + p->stream << pass_info->name << " "; + } + p->stream << "]"; + }); TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_GLOBAL("transform.PassContext") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("transform.PassContext").set_body([](TVMArgs args, TVMRetValue* ret) { auto pctx = PassContext::Create(); int opt_level = args[0]; int fallback_device = args[1]; @@ -432,59 +406,48 @@ TVM_REGISTER_GLOBAL("transform.PassContext") }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Pass context information: " << "\n"; - p->stream << "\topt_level: " << node->opt_level << "\n"; - p->stream << "\tfallback device: " - << runtime::DeviceName(node->fallback_device) - << "\n"; - - p->stream << "\trequired passes: [" << node->opt_level; - for (const auto& it : node->required_pass) { - p->stream << it << " "; - } - p->stream << "]\n"; - - p->stream << "\tdisabled passes: [" << node->opt_level; - for (const auto& it : node->disabled_pass) { - p->stream << it << " "; - } - p->stream << "]"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Pass context information: " + << "\n"; + p->stream << "\topt_level: " << node->opt_level << "\n"; + p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n"; + + p->stream << "\trequired passes: [" << node->opt_level; + for (const auto& it : node->required_pass) { + p->stream << it << " "; + } + p->stream << "]\n"; + + p->stream << "\tdisabled passes: [" << node->opt_level; + for (const auto& it : node->disabled_pass) { + p->stream << it << " "; + } + p->stream << "]"; + }); class PassContext::Internal { public: - static void EnterScope(PassContext pass_ctx) { - pass_ctx.EnterWithScope(); - } + static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); } - static void ExitScope(PassContext pass_ctx) { - pass_ctx.ExitWithScope(); - } + static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; -TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext") -.set_body_typed(PassContext::Current); - -TVM_REGISTER_GLOBAL("transform.EnterPassContext") -.set_body_typed(PassContext::Internal::EnterScope); +TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); -TVM_REGISTER_GLOBAL("transform.ExitPassContext") -.set_body_typed(PassContext::Internal::ExitScope); +TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); +TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); Pass PrintIR(std::string header, bool show_meta_data) { - auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) { - LOG(INFO) << "PrintIR(" << header << "):\n" - << AsText(mod, show_meta_data); + auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { + LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); return mod; }; return CreateModulePass(pass_func, 0, "PrintIR", {}); } -TVM_REGISTER_GLOBAL("transform.PrintIR") -.set_body_typed(PrintIR); +TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); } // namespace transform } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index 5d46893..212a6e5 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -33,17 +33,15 @@ PrimType::PrimType(runtime::DataType dtype) { TVM_REGISTER_NODE_TYPE(PrimTypeNode); -TVM_REGISTER_GLOBAL("ir.PrimType") -.set_body_typed([](runtime::DataType dtype) { +TVM_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) { return PrimType(dtype); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << node->dtype; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << node->dtype; + }); PointerType::PointerType(Type element_type) { ObjectPtr n = make_object(); @@ -53,18 +51,16 @@ PointerType::PointerType(Type element_type) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); -TVM_REGISTER_GLOBAL("ir.PointerType") -.set_body_typed([](Type element_type) { +TVM_REGISTER_GLOBAL("ir.PointerType").set_body_typed([](Type element_type) { return PointerType(element_type); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->Print(node->element_type); - p->stream << '*'; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->Print(node->element_type); + p->stream << '*'; + }); TypeVar::TypeVar(String name, TypeKind kind) { ObjectPtr n = make_object(); @@ -75,18 +71,15 @@ TypeVar::TypeVar(String name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(TypeVarNode); -TVM_REGISTER_GLOBAL("ir.TypeVar") -.set_body_typed([](String name, int kind) { +TVM_REGISTER_GLOBAL("ir.TypeVar").set_body_typed([](String name, int kind) { return TypeVar(name, static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeVar(" << node->name_hint << ", " - << node->kind << ")"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { ObjectPtr n = make_object(); @@ -97,21 +90,17 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalTypeVar") -.set_body_typed([](std::string name, int kind) { +TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](std::string name, int kind) { return GlobalTypeVar(name, static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalTypeVar(" << node->name_hint << ", " - << node->kind << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); -FuncType::FuncType(tvm::Array arg_types, - Type ret_type, - tvm::Array type_params, +FuncType::FuncType(tvm::Array arg_types, Type ret_type, tvm::Array type_params, tvm::Array type_constraints) { ObjectPtr n = make_object(); n->arg_types = std::move(arg_types); @@ -124,21 +113,17 @@ FuncType::FuncType(tvm::Array arg_types, TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_GLOBAL("ir.FuncType") -.set_body_typed([](tvm::Array arg_types, - Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints) { - return FuncType(arg_types, ret_type, type_params, type_constraints); -}); + .set_body_typed([](tvm::Array arg_types, Type ret_type, tvm::Array type_params, + tvm::Array type_constraints) { + return FuncType(arg_types, ret_type, type_params, type_constraints); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FuncType(" << node->type_params << ", " - << node->arg_types << ", " << node->ret_type << ", " - << node->type_constraints << ")"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FuncType(" << node->type_params << ", " << node->arg_types << ", " + << node->ret_type << ", " << node->type_constraints << ")"; + }); TupleType::TupleType(Array fields) { ObjectPtr n = make_object(); @@ -146,23 +131,19 @@ TupleType::TupleType(Array fields) { data_ = std::move(n); } -TupleType TupleType::Empty() { - return TupleType(Array()); -} +TupleType TupleType::Empty() { return TupleType(Array()); } TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("ir.TupleType") -.set_body_typed([](Array fields) { +TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { return TupleType(fields); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleTypeNode(" << node->fields << ")"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleTypeNode(" << node->fields << ")"; + }); IncompleteType::IncompleteType(TypeKind kind) { auto n = make_object(); @@ -172,17 +153,15 @@ IncompleteType::IncompleteType(TypeKind kind) { TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); -TVM_REGISTER_GLOBAL("ir.IncompleteType") -.set_body_typed([](int kind) { - return IncompleteType(static_cast(kind)); - }); +TVM_REGISTER_GLOBAL("ir.IncompleteType").set_body_typed([](int kind) { + return IncompleteType(static_cast(kind)); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; - }); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); RelayRefType::RelayRefType(Type value) { ObjectPtr n = make_object(); @@ -190,17 +169,16 @@ RelayRefType::RelayRefType(Type value) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("ir.RelayRefType") -.set_body_typed([](Type value) { +TVM_REGISTER_GLOBAL("ir.RelayRefType").set_body_typed([](Type value) { return RelayRefType(value); }); TVM_REGISTER_NODE_TYPE(RelayRefTypeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RelayRefTypeNode(" << node->value << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RelayRefTypeNode(" << node->value << ")"; + }); } // namespace tvm diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index 9d9167f..21ce3d0 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -22,18 +22,16 @@ * \brief Implementations of type functors. */ #include + #include namespace tvm { -void TypeVisitor::VisitType_(const TypeVarNode* op) { -} +void TypeVisitor::VisitType_(const TypeVarNode* op) {} -void TypeVisitor::VisitType_(const TensorTypeNode* op) { -} +void TypeVisitor::VisitType_(const TensorTypeNode* op) {} -void TypeVisitor::VisitType_(const IncompleteTypeNode* op) { -} +void TypeVisitor::VisitType_(const IncompleteTypeNode* op) {} void TypeVisitor::VisitType_(const FuncTypeNode* op) { for (auto type_param : op->type_params) { @@ -56,9 +54,7 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) { } } -void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { - this->VisitType(op->value); -} +void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { this->VisitType(op->value); } void TypeVisitor::VisitType_(const TypeRelationNode* op) { for (const Type& t : op->args) { @@ -66,8 +62,7 @@ void TypeVisitor::VisitType_(const TypeRelationNode* op) { } } -void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) { -} +void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) {} void TypeVisitor::VisitType_(const TypeCallNode* op) { this->VisitType(op->func); @@ -90,12 +85,9 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { } } -void TypeVisitor::VisitType_(const PrimTypeNode* op) { -} +void TypeVisitor::VisitType_(const PrimTypeNode* op) {} -void TypeVisitor::VisitType_(const PointerTypeNode* op) { - this->VisitType(op->element_type); -} +void TypeVisitor::VisitType_(const PointerTypeNode* op) { this->VisitType(op->element_type); } Type TypeMutator::VisitType(const Type& t) { return t.defined() ? TypeFunctor::VisitType(t) : t; @@ -115,18 +107,14 @@ Array TypeMutator::MutateArray(Array arr) { return arr; } -Type TypeMutator::VisitType_(const TypeVarNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const TypeVarNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const TensorTypeNode* op) { // TODO(tvm-team) recursively visit to replace Var return GetRef(op); } -Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const FuncTypeNode* op) { bool changed = false; @@ -145,8 +133,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { for (auto type_cs : op->type_constraints) { auto new_type_cs = VisitType(type_cs); changed = changed || !new_type_cs.same_as(type_cs); - if (const TypeConstraintNode* tin = - new_type_cs.as()) { + if (const TypeConstraintNode* tin = new_type_cs.as()) { type_constraints.push_back(GetRef(tin)); } else { LOG(FATAL) << new_type_cs; @@ -160,10 +147,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { changed = changed || !new_ret_type.same_as(op->ret_type); if (!changed) return GetRef(op); - return FuncType(new_args, - new_ret_type, - type_params, - type_constraints); + return FuncType(new_args, new_ret_type, type_params, type_constraints); } Type TypeMutator::VisitType_(const TupleTypeNode* op) { @@ -184,16 +168,11 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { if (new_args.same_as(type_rel->args)) { return GetRef(type_rel); } else { - return TypeRelation(type_rel->func, - new_args, - type_rel->num_inputs, - type_rel->attrs); + return TypeRelation(type_rel->func, new_args, type_rel->num_inputs, type_rel->attrs); } } -Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const TypeCallNode* op) { Type new_func = VisitType(op->func); @@ -205,13 +184,9 @@ Type TypeMutator::VisitType_(const TypeCallNode* op) { } } -Type TypeMutator::VisitType_(const TypeDataNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const TypeDataNode* op) { return GetRef(op); } -Type TypeMutator::VisitType_(const PrimTypeNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const PrimTypeNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const PointerTypeNode* op) { Type element_type = VisitType(op->element_type); @@ -226,8 +201,7 @@ Type TypeMutator::VisitType_(const PointerTypeNode* op) { // Implements bind. class TypeBinder : public TypeMutator { public: - explicit TypeBinder(const tvm::Map& args_map) - : args_map_(args_map) {} + explicit TypeBinder(const tvm::Map& args_map) : args_map_(args_map) {} Type VisitType_(const TypeVarNode* op) override { auto id = GetRef(op); diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc index ab479e7..f038a66 100644 --- a/src/ir/type_relation.cc +++ b/src/ir/type_relation.cc @@ -35,22 +35,17 @@ TypeCall::TypeCall(Type func, tvm::Array args) { TVM_REGISTER_NODE_TYPE(TypeCallNode); -TVM_REGISTER_GLOBAL("ir.TypeCall") -.set_body_typed([](Type func, Array type) { +TVM_REGISTER_GLOBAL("ir.TypeCall").set_body_typed([](Type func, Array type) { return TypeCall(func, type); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeCallNode(" << node->func << ", " - << node->args << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; + }); -TypeRelation::TypeRelation(TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs) { +TypeRelation::TypeRelation(TypeRelationFn func, Array args, int num_inputs, Attrs attrs) { ObjectPtr n = make_object(); n->func = std::move(func); n->args = std::move(args); @@ -62,18 +57,13 @@ TypeRelation::TypeRelation(TypeRelationFn func, TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_GLOBAL("ir.TypeRelation") -.set_body_typed([](TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs) { - return TypeRelation(func, args, num_inputs, attrs); -}); + .set_body_typed([](TypeRelationFn func, Array args, int num_inputs, Attrs attrs) { + return TypeRelation(func, args, num_inputs, attrs); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeRelationNode(" - << node->func->name - << ", " << node->args << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeRelationNode(" << node->func->name << ", " << node->args << ")"; + }); } // namespace tvm diff --git a/src/node/container.cc b/src/node/container.cc index 52e4bf1..a5e7669 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -20,10 +20,11 @@ * Expose container API to frontend. * \file src/node/container.cc */ -#include -#include #include +#include +#include #include + #include "../support/str_escape.h" namespace tvm { @@ -32,14 +33,11 @@ namespace tvm { struct StringObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::StringObj* key, - SHashReducer hash_reduce) { - hash_reduce->SHashReduceHashedValue( - runtime::String::HashBytes(key->data, key->size)); + static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); } - static bool SEqualReduce(const runtime::StringObj* lhs, - const runtime::StringObj* rhs, + static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs, SEqualReducer equal) { if (lhs == rhs) return true; if (lhs->size != rhs->size) return false; @@ -49,32 +47,29 @@ struct StringObjTrait { }; struct RefToObjectPtr : public ObjectRef { - static ObjectPtr Get(const ObjectRef& ref) { - return GetDataPtr(ref); - } + static ObjectPtr Get(const ObjectRef& ref) { return GetDataPtr(ref); } }; TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) -.set_creator([](const std::string& bytes) { - return RefToObjectPtr::Get(runtime::String(bytes)); -}) -.set_repr_bytes([](const Object* n) -> std::string { - return GetRef( - static_cast(n)).operator std::string(); -}); + .set_creator([](const std::string& bytes) { + return RefToObjectPtr::Get(runtime::String(bytes)); + }) + .set_repr_bytes([](const Object* n) -> std::string { + return GetRef(static_cast(n)) + . + operator std::string(); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; -}); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; + }); struct ADTObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::ADTObj* key, - SHashReducer hash_reduce) { + static void SHashReduce(const runtime::ADTObj* key, SHashReducer hash_reduce) { hash_reduce(key->tag); hash_reduce(static_cast(key->size)); for (uint32_t i = 0; i < key->size; ++i) { @@ -82,8 +77,7 @@ struct ADTObjTrait { } } - static bool SEqualReduce(const runtime::ADTObj* lhs, - const runtime::ADTObj* rhs, + static bool SEqualReduce(const runtime::ADTObj* lhs, const runtime::ADTObj* rhs, SEqualReducer equal) { if (lhs == rhs) return true; if (lhs->tag != rhs->tag) return false; @@ -98,39 +92,31 @@ struct ADTObjTrait { TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); - struct NDArrayContainerTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::NDArray::Container* key, - SHashReducer hash_reduce) { + static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) { CHECK_EQ(key->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - CHECK(runtime::IsContiguous(key->dl_tensor)) - << "Can only hash contiguous tensor"; + CHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; hash_reduce(runtime::DataType(key->dl_tensor.dtype)); hash_reduce(key->dl_tensor.ndim); for (int i = 0; i < key->dl_tensor.ndim; ++i) { hash_reduce(key->dl_tensor.shape[i]); } - hash_reduce->SHashReduceHashedValue( - runtime::String::HashBytes( - static_cast(key->dl_tensor.data), - runtime::GetDataSize(key->dl_tensor))); + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( + static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); } static bool SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, - SEqualReducer equal) { + const runtime::NDArray::Container* rhs, SEqualReducer equal) { if (lhs == rhs) return true; auto ldt = lhs->dl_tensor.dtype; auto rdt = rhs->dl_tensor.dtype; CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - CHECK(runtime::IsContiguous(lhs->dl_tensor)) - << "Can only compare contiguous tensor"; - CHECK(runtime::IsContiguous(rhs->dl_tensor)) - << "Can only compare contiguous tensor"; + CHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; + CHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { @@ -147,21 +133,17 @@ struct NDArrayContainerTrait { TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait); - struct ArrayNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ArrayNode* key, - SHashReducer hash_reduce) { + static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) { hash_reduce(static_cast(key->data.size())); for (size_t i = 0; i < key->data.size(); ++i) { hash_reduce(key->data[i]); } } - static bool SEqualReduce(const ArrayNode* lhs, - const ArrayNode* rhs, - SEqualReducer equal) { + static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) { if (lhs->data.size() != rhs->data.size()) return false; for (size_t i = 0; i < lhs->data.size(); ++i) { if (!equal(lhs->data[i], rhs->data[i])) return false; @@ -172,53 +154,45 @@ struct ArrayNodeTrait { TVM_REGISTER_OBJECT_TYPE(ArrayNode); TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - - -TVM_REGISTER_GLOBAL("node.Array") -.set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector data; - for (int i = 0; i < args.size(); ++i) { - if (args[i].type_code() != kTVMNullptr) { - data.push_back(args[i].operator ObjectRef()); - } else { - data.push_back(ObjectRef(nullptr)); - } + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +TVM_REGISTER_GLOBAL("node.Array").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector data; + for (int i = 0; i < args.size(); ++i) { + if (args[i].type_code() != kTVMNullptr) { + data.push_back(args[i].operator ObjectRef()); + } else { + data.push_back(ObjectRef(nullptr)); } - auto node = make_object(); - node->data = std::move(data); - *ret = Array(node); - }); + } + auto node = make_object(); + node->data = std::move(data); + *ret = Array(node); +}); -TVM_REGISTER_GLOBAL("node.ArrayGetItem") -.set_body([](TVMArgs args, TVMRetValue* ret) { - int64_t i = args[1]; - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - CHECK_LT(static_cast(i), n->data.size()) - << "out of bound of array"; - *ret = n->data[static_cast(i)]; - }); - -TVM_REGISTER_GLOBAL("node.ArraySize") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - CHECK(ptr->IsInstance()); - *ret = static_cast( - static_cast(ptr)->data.size()); - }); +TVM_REGISTER_GLOBAL("node.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + int64_t i = args[1]; + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + CHECK_LT(static_cast(i), n->data.size()) << "out of bound of array"; + *ret = n->data[static_cast(i)]; +}); +TVM_REGISTER_GLOBAL("node.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + *ret = static_cast(static_cast(ptr)->data.size()); +}); struct MapNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const MapNode* key, - SHashReducer hash_reduce) { + static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) { // SHash's var handling depends on the determinism of traversal. // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store @@ -233,15 +207,15 @@ struct MapNodeTrait { } } // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.first < rhs.first; - }); + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); // add size to the hash hash_reduce(static_cast(key->data.size())); // hash the content for (size_t i = 0; i < temp.size();) { size_t k = i + 1; - for (; k < temp.size() && temp[k].first == temp[i].first; ++k) {} + for (; k < temp.size() && temp[k].first == temp[i].first; ++k) { + } // ties are rare, but we need to skip them to make the hash determinsitic if (k == i + 1) { hash_reduce->SHashReduceHashedValue(temp[i].first); @@ -251,9 +225,7 @@ struct MapNodeTrait { } } - static bool SEqualReduce(const MapNode* lhs, - const MapNode* rhs, - SEqualReducer equal) { + static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { if (rhs->data.size() != lhs->data.size()) return false; for (const auto& kv : lhs->data) { // Only allow equal checking if the keys are already mapped @@ -272,16 +244,14 @@ struct MapNodeTrait { TVM_REGISTER_OBJECT_TYPE(MapNode); TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); struct StrMapNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const StrMapNode* key, - SHashReducer hash_reduce) { + static void SHashReduce(const StrMapNode* key, SHashReducer hash_reduce) { // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store // Map where Var is defined in the function @@ -289,9 +259,8 @@ struct StrMapNodeTrait { using KV = std::pair; std::vector temp(key->data.begin(), key->data.end()); // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.first < rhs.first; - }); + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); // NOTE: we won't have ties // add size to the hash after sorting. hash_reduce(static_cast(key->data.size())); @@ -302,9 +271,7 @@ struct StrMapNodeTrait { } } - static bool SEqualReduce(const StrMapNode* lhs, - const StrMapNode* rhs, - SEqualReducer equal) { + static bool SEqualReduce(const StrMapNode* lhs, const StrMapNode* rhs, SEqualReducer equal) { if (rhs->data.size() != lhs->data.size()) return false; for (const auto& kv : lhs->data) { auto it = rhs->data.find(kv.first); @@ -317,120 +284,104 @@ struct StrMapNodeTrait { TVM_REGISTER_OBJECT_TYPE(StrMapNode); TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - - -TVM_REGISTER_GLOBAL("node.Map") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size() % 2, 0); - if (args.size() != 0 && args[0].type_code() == kTVMStr) { - // StrMap - StrMapNode::ContainerType data; - for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kTVMStr) - << "key of str map need to be str"; - CHECK(args[i + 1].IsObjectRef()) - << "value of the map to be NodeRef"; - data.emplace(std::make_pair(args[i].operator std::string(), - args[i + 1].operator ObjectRef())); - } - auto node = make_object(); - node->data = std::move(data); - *ret = Map(node); - } else { - // Container node. - MapNode::ContainerType data; - for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].IsObjectRef()) - << "key of str map need to be object"; - CHECK(args[i + 1].IsObjectRef()) - << "value of map to be NodeRef"; - data.emplace(std::make_pair(args[i].operator ObjectRef(), - args[i + 1].operator ObjectRef())); - } - auto node = make_object(); - node->data = std::move(data); - *ret = Map(node); + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size() % 2, 0); + if (args.size() != 0 && args[0].type_code() == kTVMStr) { + // StrMap + StrMapNode::ContainerType data; + for (int i = 0; i < args.num_args; i += 2) { + CHECK(args[i].type_code() == kTVMStr) << "key of str map need to be str"; + CHECK(args[i + 1].IsObjectRef()) << "value of the map to be NodeRef"; + data.emplace( + std::make_pair(args[i].operator std::string(), args[i + 1].operator ObjectRef())); + } + auto node = make_object(); + node->data = std::move(data); + *ret = Map(node); + } else { + // Container node. + MapNode::ContainerType data; + for (int i = 0; i < args.num_args; i += 2) { + CHECK(args[i].IsObjectRef()) << "key of str map need to be object"; + CHECK(args[i + 1].IsObjectRef()) << "value of map to be NodeRef"; + data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef())); } - }); + auto node = make_object(); + node->data = std::move(data); + *ret = Map(node); + } +}); +TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + *ret = static_cast(n->data.size()); + } else { + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + *ret = static_cast(n->data.size()); + } +}); -TVM_REGISTER_GLOBAL("node.MapSize") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - *ret = static_cast(n->data.size()); - } else { - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - *ret = static_cast(n->data.size()); - } - }); +TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + auto it = n->data.find(args[1].operator ObjectRef()); + CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; + *ret = (*it).second; + } else { + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + auto it = n->data.find(args[1].operator std::string()); + CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; + *ret = (*it).second; + } +}); -TVM_REGISTER_GLOBAL("node.MapGetItem") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - auto it = n->data.find(args[1].operator ObjectRef()); - CHECK(it != n->data.end()) - << "cannot find the corresponding key in the Map"; - *ret = (*it).second; - } else { - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - auto it = n->data.find(args[1].operator std::string()); - CHECK(it != n->data.end()) - << "cannot find the corresponding key in the Map"; - *ret = (*it).second; - } - }); +TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); -TVM_REGISTER_GLOBAL("node.MapCount") -.set_body([](TVMArgs args, TVMRetValue* ret) { + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); + *ret = static_cast(n->data.count(args[1].operator ObjectRef())); + } else { + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + *ret = static_cast(n->data.count(args[1].operator std::string())); + } +}); - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - *ret = static_cast( - n->data.count(args[1].operator ObjectRef())); - } else { - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - *ret = static_cast( - n->data.count(args[1].operator std::string())); - } - }); +TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); -TVM_REGISTER_GLOBAL("node.MapItems") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - auto rkvs = make_object(); - for (const auto& kv : n->data) { - rkvs->data.push_back(kv.first); - rkvs->data.push_back(kv.second); - } - *ret = Array(rkvs); - } else { - auto* n = static_cast(ptr); - auto rkvs = make_object(); - for (const auto& kv : n->data) { - rkvs->data.push_back(tir::StringImmNode::make(kv.first)); - rkvs->data.push_back(kv.second); - } - *ret = Array(rkvs); + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + auto rkvs = make_object(); + for (const auto& kv : n->data) { + rkvs->data.push_back(kv.first); + rkvs->data.push_back(kv.second); + } + *ret = Array(rkvs); + } else { + auto* n = static_cast(ptr); + auto rkvs = make_object(); + for (const auto& kv : n->data) { + rkvs->data.push_back(tir::StringImmNode::make(kv.first)); + rkvs->data.push_back(kv.second); } - }); + *ret = Array(rkvs); + } +}); } // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 08a914f..c3397e7 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -21,17 +21,17 @@ * Reflection utilities. * \file node/reflection.cc */ -#include -#include +#include #include +#include #include -#include +#include namespace tvm { -using runtime::TVMRetValue; -using runtime::TVMArgs; using runtime::PackedFunc; +using runtime::TVMArgs; +using runtime::TVMRetValue; // Attr getter. class AttrGetter : public AttrVisitor { @@ -39,9 +39,7 @@ class AttrGetter : public AttrVisitor { const std::string& skey; TVMRetValue* ret; - AttrGetter(const std::string &skey, - TVMRetValue* ret) - : skey(skey), ret(ret) {} + AttrGetter(const std::string& skey, TVMRetValue* ret) : skey(skey), ret(ret) {} bool found_ref_object{false}; @@ -86,8 +84,7 @@ class AttrGetter : public AttrVisitor { } }; -runtime::TVMRetValue ReflectionVTable::GetAttr( - Object* self, const std::string& field_name) const { +runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const std::string& field_name) const { runtime::TVMRetValue ret; AttrGetter getter(field_name, &ret); @@ -110,8 +107,8 @@ runtime::TVMRetValue ReflectionVTable::GetAttr( } } if (!success) { - LOG(FATAL) << "AttributeError: " << self->GetTypeKey() - << " object has no attributed " << getter.skey; + LOG(FATAL) << "AttributeError: " << self->GetTypeKey() << " object has no attributed " + << getter.skey; } return ret; } @@ -121,40 +118,19 @@ class AttrDir : public AttrVisitor { public: std::vector* names; - void Visit(const char* key, double* value) final { - names->push_back(key); - } - void Visit(const char* key, int64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, uint64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, bool* value) final { - names->push_back(key); - } - void Visit(const char* key, int* value) final { - names->push_back(key); - } - void Visit(const char* key, void** value) final { - names->push_back(key); - } - void Visit(const char* key, DataType* value) final { - names->push_back(key); - } - void Visit(const char* key, std::string* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::NDArray* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::ObjectRef* value) final { - names->push_back(key); - } + void Visit(const char* key, double* value) final { names->push_back(key); } + void Visit(const char* key, int64_t* value) final { names->push_back(key); } + void Visit(const char* key, uint64_t* value) final { names->push_back(key); } + void Visit(const char* key, bool* value) final { names->push_back(key); } + void Visit(const char* key, int* value) final { names->push_back(key); } + void Visit(const char* key, void** value) final { names->push_back(key); } + void Visit(const char* key, DataType* value) final { names->push_back(key); } + void Visit(const char* key, std::string* value) final { names->push_back(key); } + void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); } + void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); } }; -std::vector -ReflectionVTable::ListAttrNames(Object* self) const { +std::vector ReflectionVTable::ListAttrNames(Object* self) const { std::vector names; AttrDir dir; dir.names = &names; @@ -176,13 +152,11 @@ ReflectionVTable* ReflectionVTable::Global() { return &inst; } -ObjectPtr -ReflectionVTable::CreateInitObject(const std::string& type_key, - const std::string& repr_bytes) const { +ObjectPtr ReflectionVTable::CreateInitObject(const std::string& type_key, + const std::string& repr_bytes) const { uint32_t tindex = Object::TypeKey2Index(type_key); if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) { - LOG(FATAL) << "TypeError: " << type_key - << " is not registered via TVM_REGISTER_NODE_TYPE"; + LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE"; } return fcreate_[tindex](repr_bytes); } @@ -192,30 +166,16 @@ class NodeAttrSetter : public AttrVisitor { std::string type_key; std::unordered_map attrs; - void Visit(const char* key, double* value) final { - *value = GetAttr(key).operator double(); - } - void Visit(const char* key, int64_t* value) final { - *value = GetAttr(key).operator int64_t(); - } - void Visit(const char* key, uint64_t* value) final { - *value = GetAttr(key).operator uint64_t(); - } - void Visit(const char* key, int* value) final { - *value = GetAttr(key).operator int(); - } - void Visit(const char* key, bool* value) final { - *value = GetAttr(key).operator bool(); - } + void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); } + void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); } + void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); } + void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); } + void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); } void Visit(const char* key, std::string* value) final { *value = GetAttr(key).operator std::string(); } - void Visit(const char* key, void** value) final { - *value = GetAttr(key).operator void*(); - } - void Visit(const char* key, DataType* value) final { - *value = GetAttr(key).operator DataType(); - } + void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); } + void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); } void Visit(const char* key, runtime::NDArray* value) final { *value = GetAttr(key).operator runtime::NDArray(); } @@ -240,8 +200,7 @@ void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { setter.type_key = n->GetTypeKey(); CHECK_EQ(args.size() % 2, 0); for (int i = 0; i < args.size(); i += 2) { - setter.attrs.emplace(args[i].operator std::string(), - args[i + 1]); + setter.attrs.emplace(args[i].operator std::string(), args[i + 1]); } auto* reflection = ReflectionVTable::Global(); reflection->VisitAttrs(n, &setter); @@ -249,7 +208,7 @@ void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { if (setter.attrs.size() != 0) { std::ostringstream os; os << setter.type_key << " does not contain field "; - for (const auto &kv : setter.attrs) { + for (const auto& kv : setter.attrs) { os << " " << kv.first; } LOG(FATAL) << os.str(); @@ -267,17 +226,17 @@ void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kTVMObjectHandle); Object* self = static_cast(args[0].value().v_handle); - auto names = std::make_shared >( - ReflectionVTable::Global()->ListAttrNames(self)); + auto names = + std::make_shared >(ReflectionVTable::Global()->ListAttrNames(self)); - *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) { - int64_t i = args[0]; - if (i == -1) { - *rv = static_cast(names->size()); - } else { - *rv = (*names)[i]; - } - }); + *ret = PackedFunc([names](TVMArgs args, TVMRetValue* rv) { + int64_t i = args[0]; + if (i == -1) { + *rv = static_cast(names->size()); + } else { + *rv = (*names)[i]; + } + }); } // API function to make node. @@ -297,13 +256,9 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) { *rv = ObjectRef(n); } +TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr); -TVM_REGISTER_GLOBAL("node.NodeGetAttr") -.set_body(NodeGetAttr); - -TVM_REGISTER_GLOBAL("node.NodeListAttrNames") -.set_body(NodeListAttrNames); +TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames); -TVM_REGISTER_GLOBAL("node.MakeNode") -.set_body(MakeNode); +TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode); } // namespace tvm diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index bf41c82..ea26343 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -21,8 +21,8 @@ * Printer utilities * \file node/repr_printer.cc */ -#include #include +#include namespace tvm { @@ -51,16 +51,11 @@ ReprPrinter::FType& ReprPrinter::vtable() { return inst; } -void Dump(const runtime::ObjectRef& n) { - std::cerr << n << "\n"; -} +void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } -void Dump(const runtime::Object* n) { - Dump(runtime::GetRef(n)); -} +void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } -TVM_REGISTER_GLOBAL("node.AsRepr") -.set_body_typed([](runtime::ObjectRef obj) { +TVM_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](runtime::ObjectRef obj) { std::ostringstream os; os << obj; return os.str(); diff --git a/src/node/serialization.cc b/src/node/serialization.cc index ee6072d..4675c53 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -23,29 +23,25 @@ */ #include #include -#include -#include -#include +#include #include #include #include -#include +#include +#include +#include -#include #include #include +#include #include "../support/base64.h" namespace tvm { -inline std::string Type2String(const DataType& t) { - return runtime::DLDataType2String(t); -} +inline std::string Type2String(const DataType& t) { return runtime::DLDataType2String(t); } -inline DataType String2Type(std::string s) { - return DataType(runtime::String2DLDataType(s)); -} +inline DataType String2Type(std::string s) { return DataType(runtime::String2DLDataType(s)); } inline std::string Base64Decode(std::string s) { dmlc::MemoryStringStream mstrm(&s); @@ -148,7 +144,7 @@ struct JSONNode { /*! \brief values of a map or array. */ std::vector data; - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("type_key", type_key); if (repr_bytes.size() != 0) { @@ -173,7 +169,7 @@ struct JSONNode { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { attrs.clear(); data.clear(); repr_bytes.clear(); @@ -213,36 +209,23 @@ class JSONAttrGetter : public AttrVisitor { s << (*value); node_->attrs[key] = s.str(); } - void Visit(const char* key, int64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, uint64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, int* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, bool* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, std::string* value) final { - node_->attrs[key] = *value; - } + void Visit(const char* key, int64_t* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, uint64_t* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, int* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, bool* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, std::string* value) final { node_->attrs[key] = *value; } void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to serialize a pointer"; } - void Visit(const char* key, DataType* value) final { - node_->attrs[key] = Type2String(*value); - } + void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); } void Visit(const char* key, runtime::NDArray* value) final { - node_->attrs[key] = std::to_string( - tensor_index_->at(const_cast((*value).operator->()))); + node_->attrs[key] = + std::to_string(tensor_index_->at(const_cast((*value).operator->()))); } void Visit(const char* key, ObjectRef* value) final { - node_->attrs[key] = std::to_string( - node_index_->at(const_cast(value->get()))); + node_->attrs[key] = std::to_string(node_index_->at(const_cast(value->get()))); } // Get the node @@ -262,23 +245,19 @@ class JSONAttrGetter : public AttrVisitor { if (node->IsInstance()) { ArrayNode* n = static_cast(node); for (size_t i = 0; i < n->data.size(); ++i) { - node_->data.push_back( - node_index_->at(const_cast(n->data[i].get()))); + node_->data.push_back(node_index_->at(const_cast(n->data[i].get()))); } } else if (node->IsInstance()) { MapNode* n = static_cast(node); for (const auto& kv : n->data) { - node_->data.push_back( - node_index_->at(const_cast(kv.first.get()))); - node_->data.push_back( - node_index_->at(const_cast(kv.second.get()))); + node_->data.push_back(node_index_->at(const_cast(kv.first.get()))); + node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); } } else if (node->IsInstance()) { StrMapNode* n = static_cast(node); for (const auto& kv : n->data) { node_->keys.push_back(kv.first); - node_->data.push_back( - node_index_->at(const_cast(kv.second.get()))); + node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); } } else { // recursively index normal object. @@ -304,7 +283,7 @@ class JSONAttrSetter : public AttrVisitor { } return it->second; } - template + template void ParseValue(const char* key, T* value) const { std::istringstream is(GetValue(key)); is >> *value; @@ -312,24 +291,12 @@ class JSONAttrSetter : public AttrVisitor { LOG(FATAL) << "Wrong value format for field " << key; } } - void Visit(const char* key, double* value) final { - ParseValue(key, value); - } - void Visit(const char* key, int64_t* value) final { - ParseValue(key, value); - } - void Visit(const char* key, uint64_t* value) final { - ParseValue(key, value); - } - void Visit(const char* key, int* value) final { - ParseValue(key, value); - } - void Visit(const char* key, bool* value) final { - ParseValue(key, value); - } - void Visit(const char* key, std::string* value) final { - *value = GetValue(key); - } + void Visit(const char* key, double* value) final { ParseValue(key, value); } + void Visit(const char* key, int64_t* value) final { ParseValue(key, value); } + void Visit(const char* key, uint64_t* value) final { ParseValue(key, value); } + void Visit(const char* key, int* value) final { ParseValue(key, value); } + void Visit(const char* key, bool* value) final { ParseValue(key, value); } + void Visit(const char* key, std::string* value) final { *value = GetValue(key); } void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to deserialize a pointer"; } @@ -363,15 +330,14 @@ class JSONAttrSetter : public AttrVisitor { MapNode* n = static_cast(node); CHECK_EQ(node_->data.size() % 2, 0U); for (size_t i = 0; i < node_->data.size(); i += 2) { - n->data[ObjectRef(node_list_->at(node_->data[i]))] - = ObjectRef(node_list_->at(node_->data[i + 1])); + n->data[ObjectRef(node_list_->at(node_->data[i]))] = + ObjectRef(node_list_->at(node_->data[i + 1])); } } else if (node->IsInstance()) { StrMapNode* n = static_cast(node); CHECK_EQ(node_->data.size(), node_->keys.size()); for (size_t i = 0; i < node_->data.size(); ++i) { - n->data[node_->keys[i]] - = ObjectRef(node_list_->at(node_->data[i])); + n->data[node_->keys[i]] = ObjectRef(node_list_->at(node_->data[i])); } } else { reflection_->VisitAttrs(node, this); @@ -390,7 +356,7 @@ struct JSONGraph { // global attributes AttrMap attrs; - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("root", root); writer->WriteObjectKeyValue("nodes", nodes); @@ -401,7 +367,7 @@ struct JSONGraph { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { attrs.clear(); dmlc::JSONObjectReadHelper helper; helper.DeclareField("root", &root); @@ -471,8 +437,7 @@ ObjectRef LoadJSON(std::string json_str) { for (const JSONNode& jnode : jgraph.nodes) { if (jnode.type_key.length() != 0) { - ObjectPtr node = - reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); + ObjectPtr node = reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); nodes.emplace_back(node); } else { nodes.emplace_back(ObjectPtr()); @@ -488,8 +453,7 @@ ObjectRef LoadJSON(std::string json_str) { // Skip the nodes that has an repr bytes representation. // NOTE: the second condition is used to guard the case // where the repr bytes itself is an empty string "". - if (setter.node_->repr_bytes.length() == 0 && - nodes[i] != nullptr && + if (setter.node_->repr_bytes.length() == 0 && nodes[i] != nullptr && !reflection->GetReprBytes(nodes[i].get(), nullptr)) { setter.Set(nodes[i].get()); } @@ -497,9 +461,7 @@ ObjectRef LoadJSON(std::string json_str) { return ObjectRef(nodes.at(jgraph.root)); } -TVM_REGISTER_GLOBAL("node.SaveJSON") -.set_body_typed(SaveJSON); +TVM_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON); -TVM_REGISTER_GLOBAL("node.LoadJSON") -.set_body_typed(LoadJSON); +TVM_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON); } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 03cdf9c..b135315 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -19,10 +19,10 @@ /*! * \file src/node/structural_equal.cc */ -#include -#include #include #include +#include +#include #include #include @@ -30,13 +30,13 @@ namespace tvm { // Define the dispatch functio here since primary user is in this file. -bool ReflectionVTable:: -SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const { +bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, + SEqualReducer equal) const { uint32_t tindex = self->type_index(); if (tindex >= fsequal_reduce_.size() || fsequal_reduce_[tindex] == nullptr) { LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey() - << " is not registered via TVM_REGISTER_NODE_TYPE." - << " Did you forget to set _type_has_method_sequal_reduce=true?"; + << " is not registered via TVM_REGISTER_NODE_TYPE." + << " Did you forget to set _type_has_method_sequal_reduce=true?"; } return fsequal_reduce_[tindex](self, other, equal); } @@ -50,11 +50,9 @@ SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const * The order of SEqual being called is the same as the order as if we * eagerly do recursive calls in SEqualReduce. */ -class RemapVarSEqualHandler : - public SEqualReducer::Handler { +class RemapVarSEqualHandler : public SEqualReducer::Handler { public: - explicit RemapVarSEqualHandler(bool assert_mode) - : assert_mode_(assert_mode) {} + explicit RemapVarSEqualHandler(bool assert_mode) : assert_mode_(assert_mode) {} bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { // We cannot use check lhs.same_as(rhs) to check equality. @@ -121,9 +119,8 @@ class RemapVarSEqualHandler : // Check the result. bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { if (assert_mode_ && !result) { - LOG(FATAL) - << "ValueError: StructuralEqual check failed, caused by\n" - << "lhs = " << lhs << "\nrhs = " << rhs; + LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n" + << "lhs = " << lhs << "\nrhs = " << rhs; } return result; } @@ -177,9 +174,7 @@ class RemapVarSEqualHandler : // The default equal as registered in the structural equal vtable. bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { auto compute = [=]() { - CHECK(lhs.defined() && - rhs.defined() && - lhs->type_index() == rhs->type_index()); + CHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index()); // skip entries that already have equality maps. auto it = equal_map_lhs_.find(lhs); if (it != equal_map_lhs_.end()) { @@ -227,15 +222,12 @@ class RemapVarSEqualHandler : }; TVM_REGISTER_GLOBAL("node.StructuralEqual") -.set_body_typed([](const ObjectRef& lhs, - const ObjectRef& rhs, - bool assert_mode, - bool map_free_vars) { - return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars); -}); - -bool StructuralEqual::operator()(const ObjectRef& lhs, - const ObjectRef& rhs) const { + .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, + bool map_free_vars) { + return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars); + }); + +bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { return RemapVarSEqualHandler(false).Equal(lhs, rhs, false); } diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index a29340c..91a2524 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -19,25 +19,23 @@ /*! * \file src/node/structural_hash.cc */ -#include -#include #include #include +#include +#include #include -#include #include - +#include namespace tvm { // Define the dispatch functio here since primary user is in this file. -void ReflectionVTable:: -SHashReduce(const Object* self, SHashReducer reducer) const { +void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) const { uint32_t tindex = self->type_index(); if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) { LOG(FATAL) << "TypeError: SHashReduce of " << self->GetTypeKey() - << " is not registered via TVM_REGISTER_NODE_TYPE"; + << " is not registered via TVM_REGISTER_NODE_TYPE"; } fshash_reduce_[tindex](self, reducer); } @@ -49,8 +47,7 @@ SHashReduce(const Object* self, SHashReducer reducer) const { // In particular, when we traverse unordered_map, we should first sort // the entries by keys(or hash of keys) before traversing. -class VarCountingSHashHandler : - public SHashReducer::Handler { +class VarCountingSHashHandler : public SHashReducer::Handler { public: /*! \brief Pending reduce tasks. */ struct Task { @@ -76,7 +73,6 @@ class VarCountingSHashHandler : : object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {} }; - VarCountingSHashHandler() {} void MarkGraphNode() final { @@ -95,8 +91,7 @@ class VarCountingSHashHandler : } void SHashReduceHashedValue(size_t hashed_value) final { - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), hashed_value, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), hashed_value, false)); } void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) final { @@ -104,13 +99,11 @@ class VarCountingSHashHandler : if (map_free_vars) { // use counter value. size_t value = std::hash()(free_var_counter_++); - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), value, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); } else { // use pointer hash size_t value = std::hash()(var); - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), value, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); } } @@ -124,12 +117,10 @@ class VarCountingSHashHandler : } auto it = hash_memo_.find(object); if (it != hash_memo_.end()) { - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), it->second, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), it->second, false)); } else { // Push a pending task with initial value. - pending_tasks_.emplace_back( - Task(object, object->GetTypeKeyHash(), map_free_vars)); + pending_tasks_.emplace_back(Task(object, object->GetTypeKeyHash(), map_free_vars)); } } @@ -195,9 +186,8 @@ class VarCountingSHashHandler : // Append the graph node counter to the hash // so that we can distinguish DAG from trees. if (entry.graph_node_hash) { - entry.reduced_hash = HashCombine( - entry.reduced_hash, - std::hash()(graph_node_counter_++)); + entry.reduced_hash = + HashCombine(entry.reduced_hash, std::hash()(graph_node_counter_++)); } hash_memo_[entry.object] = entry.reduced_hash; } @@ -268,13 +258,11 @@ class VarCountingSHashHandler : std::unordered_map hash_memo_; }; - TVM_REGISTER_GLOBAL("node.StructuralHash") -.set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { - size_t hashed_value = - VarCountingSHashHandler().Hash(object, map_free_vars); - return static_cast(hashed_value); -}); + .set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { + size_t hashed_value = VarCountingSHashHandler().Hash(object, map_free_vars); + return static_cast(hashed_value); + }); size_t StructuralHash::operator()(const ObjectRef& object) const { return VarCountingSHashHandler().Hash(object, false); diff --git a/src/printer/doc.cc b/src/printer/doc.cc index ee260f4..d487e3e 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -23,10 +23,12 @@ * * Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98 */ +#include "doc.h" + #include -#include + #include -#include "doc.h" +#include namespace tvm { @@ -38,9 +40,7 @@ class DocTextNode : public DocAtomNode { /*! \brief The str content in the text. */ std::string str; - explicit DocTextNode(std::string str_val) - : str(str_val) { - } + explicit DocTextNode(std::string str_val) : str(str_val) {} static constexpr const char* _type_key = "printer.DocText"; TVM_DECLARE_FINAL_OBJECT_INFO(DocTextNode, DocAtomNode); @@ -68,8 +68,7 @@ class DocLineNode : public DocAtomNode { /*! \brief The amount of indent in newline. */ int indent; - explicit DocLineNode(int indent) - : indent(indent) {} + explicit DocLineNode(int indent) : indent(indent) {} static constexpr const char* _type_key = "printer.DocLine"; TVM_DECLARE_FINAL_OBJECT_INFO(DocLineNode, DocAtomNode); @@ -79,9 +78,7 @@ TVM_REGISTER_OBJECT_TYPE(DocLineNode); class DocLine : public DocAtom { public: - explicit DocLine(int indent) { - data_ = runtime::make_object(indent); - } + explicit DocLine(int indent) { data_ = runtime::make_object(indent); } TVM_DEFINE_OBJECT_REF_METHODS(DocLine, DocAtom, DocLineNode); }; @@ -89,14 +86,11 @@ class DocLine : public DocAtom { // DSL function implementations Doc& Doc::operator<<(const Doc& right) { CHECK(this != &right); - this->stream_.insert( - this->stream_.end(), right.stream_.begin(), right.stream_.end()); + this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end()); return *this; } -Doc& Doc::operator<<(std::string right) { - return *this << DocText(right); -} +Doc& Doc::operator<<(std::string right) { return *this << DocText(right); } Doc& Doc::operator<<(const DocAtom& right) { this->stream_.push_back(right); @@ -117,13 +111,9 @@ std::string Doc::str() { return os.str(); } -Doc Doc::NewLine(int indent) { - return Doc() << DocLine(indent); -} +Doc Doc::NewLine(int indent) { return Doc() << DocLine(indent); } -Doc Doc::Text(std::string text) { - return Doc() << DocText(text); -} +Doc Doc::Text(std::string text) { return Doc() << DocText(text); } Doc Doc::RawText(std::string text) { return Doc() << DocAtom(runtime::make_object(text)); @@ -152,10 +142,7 @@ Doc Doc::PyBoolLiteral(bool value) { } } -Doc Doc::Brace(std::string open, - const Doc& body, - std::string close, - int indent) { +Doc Doc::Brace(std::string open, const Doc& body, std::string close, int indent) { Doc doc; doc << open; doc << Indent(indent, NewLine() << body) << NewLine(); diff --git a/src/printer/doc.h b/src/printer/doc.h index 7d8d72e..dc6ba89 100644 --- a/src/printer/doc.h +++ b/src/printer/doc.h @@ -26,12 +26,13 @@ #ifndef TVM_PRINTER_DOC_H_ #define TVM_PRINTER_DOC_H_ +#include #include #include -#include + #include -#include #include +#include namespace tvm { @@ -48,7 +49,7 @@ class DocAtomNode : public Object { /*! * \brief Managed reference to DocAtomNode. * \sa DocAtomNode. -*/ + */ class DocAtom : public ObjectRef { public: TVM_DEFINE_OBJECT_REF_METHODS(DocAtom, ObjectRef, DocAtomNode); @@ -93,8 +94,7 @@ class Doc { * \tparam T the type of the value. * \return reference to self. */ - template::value>::type> + template ::value>::type> Doc& operator<<(const T& value) { std::ostringstream os; os << value; @@ -149,10 +149,7 @@ class Doc { * \param indent amount of indentation. * \return The created doc. */ - static Doc Brace(std::string open, - const Doc& body, - std::string close, - int indent = 2); + static Doc Brace(std::string open, const Doc& body, std::string close, int indent = 2); /*! * \brief Create a doc by concatenating together with separator. * \param vec The docs to be concatenated. diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h index 8bf58ec..ebc76dc 100644 --- a/src/printer/meta_data.h +++ b/src/printer/meta_data.h @@ -24,10 +24,12 @@ #ifndef TVM_PRINTER_META_DATA_H_ #define TVM_PRINTER_META_DATA_H_ -#include #include +#include + #include #include + #include "doc.h" namespace tvm { @@ -98,8 +100,7 @@ class TextMetaDataContext { } std::string type_key = node->GetTypeKey(); CHECK(!type_key.empty()); - Array& mvector = - meta_data_[type_key]; + Array& mvector = meta_data_[type_key]; int64_t index = static_cast(mvector.size()); mvector.push_back(node); Doc doc; @@ -113,9 +114,7 @@ class TextMetaDataContext { * \param node The query node * \return whether the node has been put in meta */ - bool InMeta(const ObjectRef& node) { - return meta_repr_.find(node) != meta_repr_.end(); - } + bool InMeta(const ObjectRef& node) { return meta_repr_.find(node) != meta_repr_.end(); } /*! * \brief Print a key value pair @@ -135,9 +134,7 @@ class TextMetaDataContext { } /*! \return whether the meta data context is empty. */ - bool empty() const { - return meta_data_.empty(); - } + bool empty() const { return meta_data_.empty(); } private: /*! \brief additional metadata stored in TVM json format */ diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 9e6abee..3c545ef 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -32,24 +32,25 @@ * - Var * - Otherwise, inline if the node is at the end of a scope and is used at most once. */ -#include #include -#include +#include #include #include +#include + +#include "../ir/attr_functor.h" +#include "../relay/analysis/dependency_graph.h" #include "doc.h" #include "meta_data.h" -#include "../relay/analysis/dependency_graph.h" -#include "../ir/attr_functor.h" #include "text_printer.h" namespace tvm { namespace relay { /*! - * \brief Print additional info about expr in comment. - * \param expr The expression. - */ + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { Doc doc; // default annotations @@ -90,8 +91,7 @@ Doc RelayTextPrinter::PrintScope(const ObjectRef& node) { } Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) { - if (node->IsInstance() && - !node->IsInstance()) { + if (node->IsInstance() && !node->IsInstance()) { // Temporarily skip non-relay functions. // TODO(tvm-team) enhance the code to work for all functions } else if (node.as()) { @@ -106,8 +106,7 @@ Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) { Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { bool is_non_relay_func = - node->IsInstance() && - !node->IsInstance(); + node->IsInstance() && !node->IsInstance(); if (node.as() && !is_non_relay_func) { return PrintExpr(Downcast(node), meta, try_inline); } else if (node.as()) { @@ -129,15 +128,13 @@ Doc RelayTextPrinter::TempVar(int n) { return doc << "%" << n; } -Doc RelayTextPrinter::AllocTemp() { - return TempVar(temp_var_counter_++); -} +Doc RelayTextPrinter::AllocTemp() { return TempVar(temp_var_counter_++); } /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. - */ + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) { std::string unique_prefix = prefix; auto it = name_alloc_map_.find(prefix); @@ -158,21 +155,21 @@ Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) { Doc RelayTextPrinter::Print(Kind k) { switch (k) { - case kType: - return Doc::Text("Type"); - case kShapeVar: - return Doc::Text("Shape"); - case kBaseType: - return Doc::Text("BaseType"); - case kConstraint: - return Doc::Text("Constraint"); - case kAdtHandle: - return Doc::Text("AdtHandle"); - case kTypeData: - return Doc::Text("TypeData"); - default: - LOG(ERROR) << "Unknown Kind"; - throw; + case kType: + return Doc::Text("Type"); + case kShapeVar: + return Doc::Text("Shape"); + case kBaseType: + return Doc::Text("BaseType"); + case kConstraint: + return Doc::Text("Constraint"); + case kAdtHandle: + return Doc::Text("AdtHandle"); + case kTypeData: + return Doc::Text("TypeData"); + default: + LOG(ERROR) << "Unknown Kind"; + throw; } } /*! @@ -290,16 +287,14 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { // Should only be triggered when op is a free variable being visited for the // first time. -Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { - return AllocVar(GetRef(op)); -} +Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef(op)); } /*! * \brief special method to print out const scalar * \param dtype The data type * \param value The value to be printed. */ -template +template Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) { std::ostringstream os; if (dtype == DataType::Int(32)) { @@ -369,13 +364,8 @@ Doc RelayTextPrinter::VisitExpr_(const IfNode* op) { Doc RelayTextPrinter::VisitExpr_(const LetNode* op) { Doc doc; - doc - << "let " - << AllocVar(op->var) - << " = " - << Print(op->value, false, true) - << ";" - << Doc::NewLine(); + doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << ";" + << Doc::NewLine(); // we use a scope here so GNF hoisting doesn't escape too far // and nested, unique lets are not hoisted doc << PrintScope(op->body); @@ -420,7 +410,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) { } else { // def @xyz = meta['ExternalFunc'][id] Doc doc; - doc << prefix << " = " << meta_->GetMetaNode(base_func); + doc << prefix << " = " << meta_->GetMetaNode(base_func); return doc; } } @@ -456,13 +446,9 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { return PrintFunc(Doc::Text("fn "), GetRef(op)); } -Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { - return Doc::Text('@' + op->name_hint); -} +Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text('@' + op->name_hint); } -Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { - return Doc::Text(op->name); -} +Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); } Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { Doc doc; @@ -569,13 +555,9 @@ Doc RelayTextPrinter::VisitPattern_(const PatternTupleNode* pt) { return doc; } -Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) { - return Doc::Text("_"); -} +Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) { return Doc::Text("_"); } -Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) { - return AllocVar(pv->var); -} +Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) { return AllocVar(pv->var); } Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) { Doc doc; @@ -612,9 +594,7 @@ Doc RelayTextPrinter::VisitTypeDefault_(const Object* node) { return Print(GetRef(node), true); } -Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) { - return Doc::Text(node->name_hint); -} +Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) { return Doc::Text(node->name_hint); } Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) { return Doc::Text(node->name_hint); @@ -770,17 +750,14 @@ Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { return Doc::StrLiteral(op->value); } - /*! * \brief Attribute printer which prints the attributes in the call. */ -class RelayTextPrinter::AttrPrinter : - public AttrVisitor { +class RelayTextPrinter::AttrPrinter : public AttrVisitor { public: - AttrPrinter(std::vector* doc, RelayTextPrinter* parent) - : docs(doc), parent_(parent) {} + AttrPrinter(std::vector* doc, RelayTextPrinter* parent) : docs(doc), parent_(parent) {} - template + template void PrintKV(const char* key, const T& value) { Doc doc; doc << key << "=" << value; @@ -792,24 +769,12 @@ class RelayTextPrinter::AttrPrinter : doc << key << "=" << *value << "f"; docs->push_back(doc); } - void Visit(const char* key, int64_t* value) final { - PrintKV(key, *value); - } - void Visit(const char* key, uint64_t* value) final { - PrintKV(key, *value); - } - void Visit(const char* key, int* value) final { - PrintKV(key, *value); - } - void Visit(const char* key, bool* value) final { - PrintKV(key, Doc::PyBoolLiteral(*value)); - } - void Visit(const char* key, std::string* value) final { - PrintKV(key, Doc::StrLiteral(*value)); - } - void Visit(const char* key, void** value) final { - LOG(FATAL) << "do not allow void as argument"; - } + void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, int* value) final { PrintKV(key, *value); } + void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); } + void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); } + void Visit(const char* key, void** value) final { LOG(FATAL) << "do not allow void as argument"; } void Visit(const char* key, DataType* value) final { PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); } @@ -825,8 +790,7 @@ class RelayTextPrinter::AttrPrinter : RelayTextPrinter* parent_; }; -std::vector RelayTextPrinter::PrintCallAttrs( - const Attrs& attrs, const Expr& op) { +std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { std::vector docs; if (!attrs.defined()) return docs; const auto* op_node = op.as(); diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 592aabe..2993d38 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -23,9 +23,11 @@ * that can be parsed by a parser. */ +#include "text_printer.h" + #include + #include -#include "text_printer.h" namespace tvm { @@ -79,26 +81,21 @@ String PrettyPrint(const ObjectRef& node) { return doc.str(); } -String AsText(const ObjectRef& node, - bool show_meta_data, +String AsText(const ObjectRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { Doc doc; doc << kSemVer << Doc::NewLine(); runtime::TypedPackedFunc ftyped = nullptr; if (annotate != nullptr) { ftyped = runtime::TypedPackedFunc( - [&annotate](const ObjectRef& expr) -> std::string { - return annotate(expr); - }); + [&annotate](const ObjectRef& expr) -> std::string { return annotate(expr); }); } doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node); return doc.str(); } -TVM_REGISTER_GLOBAL("ir.PrettyPrint") -.set_body_typed(PrettyPrint); +TVM_REGISTER_GLOBAL("ir.PrettyPrint").set_body_typed(PrettyPrint); -TVM_REGISTER_GLOBAL("ir.AsText") -.set_body_typed(AsText); +TVM_REGISTER_GLOBAL("ir.AsText").set_body_typed(AsText); } // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 63767af..00b6fb9 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -26,22 +26,21 @@ #ifndef TVM_PRINTER_TEXT_PRINTER_H_ #define TVM_PRINTER_TEXT_PRINTER_H_ +#include +#include #include #include #include -#include -#include -#include -#include #include -#include -#include +#include +#include + +#include #include #include -#include -#include "../relay/analysis/dependency_graph.h" -#include "../ir/attr_functor.h" +#include "../ir/attr_functor.h" +#include "../relay/analysis/dependency_graph.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -53,21 +52,19 @@ class TextPrinter; namespace tvm { namespace relay { -class RelayTextPrinter : - public ExprFunctor, - public PatternFunctor, - public TypeFunctor, - public AttrFunctor { +class RelayTextPrinter : public ExprFunctor, + public PatternFunctor, + public TypeFunctor, + public AttrFunctor { public: - explicit RelayTextPrinter(bool show_meta_data, - TextMetaDataContext* meta, + explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta, runtime::TypedPackedFunc annotate) : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {} /*! - * \brief Print additional info about expr in comment. - * \param expr The expression. - */ + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ Doc PrintOptionalInfo(const Expr& expr); // indent a new body Doc PrintBody(const ObjectRef& node, int indent = 2); @@ -83,10 +80,10 @@ class RelayTextPrinter : Doc TempVar(int n); Doc AllocTemp(); /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. - */ + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ Doc GetUniqueName(const std::string& prefix); Doc Print(Kind k); /*! @@ -213,8 +210,8 @@ class MetaCollector : public StmtExprVisitor { void Collect(const ObjectRef& n) { // these nodes can be print directly(StringLiteral or use identifier to identify) - if (!n.defined() || n.as() || n.as() || n.as() - || n.as() || n.as() || n.as()) { + if (!n.defined() || n.as() || n.as() || n.as() || + n.as() || n.as() || n.as()) { return; } if (n->IsInstance()) { @@ -243,7 +240,7 @@ class TIRTextPrinter : public StmtFunctor, public TypeFunctor { public: explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta) - : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {} + : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {} /*! \brief Print the node */ Doc Print(const ObjectRef& node); @@ -323,9 +320,7 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintIterVar(const IterVarNode* op); Doc PrintRange(const RangeNode* op); Doc PrintBuffer(const BufferNode* op); - Doc PrintString(const StringObj* op) { - return Doc::StrLiteral(op->data); - } + Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } /*! * \brief special method to print out data type @@ -360,7 +355,8 @@ class TextPrinter { public: explicit TextPrinter(bool show_meta_data, const runtime::TypedPackedFunc& annotate) - : show_meta_data_(show_meta_data), annotate_(annotate), + : show_meta_data_(show_meta_data), + annotate_(annotate), relay_text_printer_(show_meta_data, &meta_, annotate), tir_text_printer_(show_meta_data, &meta_) {} @@ -379,8 +375,8 @@ class TextPrinter { Doc doc; if (node->IsInstance()) { doc << PrintMod(Downcast(node)); - } else if (node->IsInstance() || node->IsInstance() - || node->IsInstance()) { + } else if (node->IsInstance() || node->IsInstance() || + node->IsInstance()) { doc << tir_text_printer_.Print(node); } else { doc << relay_text_printer_.PrintFinal(node); diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index a5754d7..511a243 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -86,7 +86,8 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) { } // print PrimFunc Doc doc; - doc << "primfn" << "("; + doc << "primfn" + << "("; // print params and its type annotation std::vector params; for (const auto& param : op->params) { @@ -109,10 +110,9 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) { std::vector buffer_docs; for (const auto& it : memo_buf_) { const auto& buf = it.first; - buffer_docs.push_back(Print(buf) - << Doc::Text(": Buffer(") << Print(buf->data) << ", " - << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " - << Print(buf->strides)); + buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") << Print(buf->data) << ", " + << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " + << Print(buf->strides)); if (!is_zero(buf->elem_offset)) { buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset); } @@ -138,8 +138,8 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) { for (const auto& it : op->buffer_map) { buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second)); } - doc << Doc::Indent(2, Doc::NewLine() - << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); + doc << Doc::Indent( + 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); doc << PrintBody(op->body); return doc; } @@ -226,12 +226,12 @@ Doc TIRTextPrinter::VisitExpr_(const VarNode* op) { return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef(op)); } -#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \ - Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \ - Doc doc; \ - doc << "(" << Print(op->a) << OpString; \ - doc << Print(op->b) << ")"; \ - return doc; \ +#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \ + Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \ + Doc doc; \ + doc << "(" << Print(op->a) << OpString; \ + doc << Print(op->b) << ")"; \ + return doc; \ } TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ") @@ -293,8 +293,8 @@ Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) { Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) { Doc doc; - doc << "(" << PrintDType(op->dtype) << "*)" - << Print(op->buffer_var) << "[" << Print(op->index) << "])"; + doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index) + << "])"; if (!is_one(op->predicate)) { doc << " if " << Print(op->predicate); } @@ -321,12 +321,18 @@ Doc TIRTextPrinter::VisitExpr_(const LetNode* op) { inline const char* CallType2String(CallNode::CallType t) { switch (t) { - case CallNode::Extern:return "extern"; - case CallNode::ExternCPlusPlus:return "extern_cpp"; - case CallNode::PureExtern:return "pure_extern"; - case CallNode::Halide:return "halide"; - case CallNode::Intrinsic:return "intrin"; - case CallNode::PureIntrinsic:return "pure_intrin"; + case CallNode::Extern: + return "extern"; + case CallNode::ExternCPlusPlus: + return "extern_cpp"; + case CallNode::PureExtern: + return "pure_extern"; + case CallNode::Halide: + return "halide"; + case CallNode::Intrinsic: + return "intrin"; + case CallNode::PureIntrinsic: + return "pure_intrin"; } LOG(FATAL) << "Unknown CallType"; return "Unknown"; @@ -339,8 +345,7 @@ Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { for (const auto& arg : op->args) { args.push_back(Print(arg)); } - doc << PrintSep(args, Doc::Text(", ")) - << ", dtype=" << PrintDType(op->dtype) + doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) << ", index=" << op->value_index << ")"; return doc; @@ -455,10 +460,14 @@ Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { inline const char* ForType2String(ForType t) { switch (t) { - case ForType::Serial:return "serial"; - case ForType::Parallel:return "parallel"; - case ForType::Vectorized:return "vectorized"; - case ForType::Unrolled:return "unroll"; + case ForType::Serial: + return "serial"; + case ForType::Parallel: + return "parallel"; + case ForType::Vectorized: + return "vectorized"; + case ForType::Unrolled: + return "unroll"; } LOG(FATAL) << "Unknown ForType"; return "Unknown"; @@ -525,9 +534,15 @@ Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) { } doc << Doc::Text(os.str()); switch (dtype.code()) { - case kDLInt: doc << "i"; break; - case kDLUInt: doc << "u"; break; - case kDLFloat: doc << "f"; break; + case kDLInt: + doc << "i"; + break; + case kDLUInt: + doc << "u"; + break; + case kDLFloat: + doc << "f"; + break; } doc << Doc::Text(std::to_string(dtype.bits())); if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes())); @@ -540,8 +555,8 @@ Doc TIRTextPrinter::GetUniqueName(std::string prefix) { std::string unique_prefix = prefix; auto it = name_alloc_map_.find(prefix); if (it != name_alloc_map_.end()) { - while (name_alloc_map_.count( - unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {} + while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) { + } } name_alloc_map_[unique_prefix] = 0; return Doc::Text(unique_prefix); diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 103ddcb..587add3 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -19,14 +19,13 @@ #include "annotated_region_set.h" -#include #include +#include #include #include #include - namespace tvm { namespace relay { @@ -39,8 +38,7 @@ AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { return AnnotatedRegion(nullptr); } -void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, - AnnotatedRegion dest) { +void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, AnnotatedRegion dest) { if (dest == src) { return; } @@ -104,12 +102,12 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { for (auto arg : args) { const CallNode* end = arg.as(); if (end && end->op == end_op_) { // Ignore closed regions. - continue; + continue; } region = region_set_->GetRegion(arg); if (region.defined()) { - break; + break; } } @@ -117,7 +115,7 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { for (auto arg : args) { const CallNode* end = arg.as(); if (end && end->op == end_op_) { // Ignore closed regions. - continue; + continue; } auto arg_region = region_set_->GetRegion(arg); @@ -171,9 +169,7 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { } } - void VisitExpr_(const TupleNode* op) { - AddToArgRegion(GetRef(op), op->fields); - } + void VisitExpr_(const TupleNode* op) { AddToArgRegion(GetRef(op), op->fields); } void VisitExpr_(const TupleGetItemNode* g) { Array args = {g->tuple}; @@ -227,15 +223,14 @@ TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode); TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode); TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet") -.set_body_typed([](Expr expr, Op begin, Op end) { - return AnnotatedRegionSet::Create(expr, begin, end); -}); + .set_body_typed([](Expr expr, Op begin, Op end) { + return AnnotatedRegionSet::Create(expr, begin, end); + }); TVM_REGISTER_GLOBAL("relay.analysis.GetRegion") -.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) { - return region_set->GetRegion(expr); -}); - + .set_body_typed([](AnnotatedRegionSet region_set, Expr expr) { + return region_set->GetRegion(expr); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index 3bd5693..f12db6a 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -27,19 +27,19 @@ #ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ #define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ +#include #include #include #include -#include #include -#include #include +#include +#include #include #include #include #include -#include namespace tvm { namespace relay { @@ -61,29 +61,19 @@ class AnnotatedRegionNode : public Object { } /*! \brief Get the region ID. */ - int GetID() const { - return id_; - } + int GetID() const { return id_; } /*! \brief Get the region target. */ - std::string GetTarget() const { - return target_; - } + std::string GetTarget() const { return target_; } /*! \brief Get the region's inputs. */ - std::list GetInputs() const { - return ins_; - } + std::list GetInputs() const { return ins_; } /*! \brief Get the region's outputs. */ - std::list GetOutputs() const { - return outs_; - } + std::list GetOutputs() const { return outs_; } /*! \brief Get the region's nodes. */ - std::unordered_set GetNodes() const { - return nodes_; - } + std::unordered_set GetNodes() const { return nodes_; } static constexpr const char* _type_key = "relay.AnnotatedRegion"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object); @@ -107,7 +97,7 @@ class AnnotatedRegionNode : public Object { /*! * \brief An object to hold the properties of a region as used by the * AnnotatedRegionSet class. This should be considered read-only. -*/ + */ class AnnotatedRegion : public ObjectRef { public: AnnotatedRegion() { @@ -116,9 +106,9 @@ class AnnotatedRegion : public ObjectRef { } /*! - * \brief Construct from an object pointer. - * \param n The object pointer. - */ + * \brief Construct from an object pointer. + * \param n The object pointer. + */ explicit AnnotatedRegion(ObjectPtr n) : ObjectRef(n) {} /*! \return Mutable pointers to the node. */ @@ -130,8 +120,7 @@ class AnnotatedRegion : public ObjectRef { }; class AnnotatedRegionSetNode : public Object { - using UnorderedRegionSet = - std::unordered_set; + using UnorderedRegionSet = std::unordered_set; // Create iterator alias for a RegionSet object. using iterator = UnorderedRegionSet::iterator; using const_iterator = UnorderedRegionSet::const_iterator; @@ -141,21 +130,13 @@ class AnnotatedRegionSetNode : public Object { AnnotatedRegionSetNode() = default; /*! \return The begin iterator */ - iterator begin() { - return regions_.begin(); - } + iterator begin() { return regions_.begin(); } /*! \return The end iterator */ - iterator end() { - return regions_.end(); - } + iterator end() { return regions_.end(); } /*! \return The const begin iterator */ - const_iterator begin() const { - return regions_.begin(); - } + const_iterator begin() const { return regions_.begin(); } /*! \return The const end iterator */ - const_iterator end() const { - return regions_.end(); - } + const_iterator end() const { return regions_.end(); } /*! * \brief Get the region that an expression belongs to. @@ -168,11 +149,11 @@ class AnnotatedRegionSetNode : public Object { AnnotatedRegion GetRegion(const Expr& expr) const; /*! - * \brief Merge src region into dest region. - * - * \param src The region to merge - will be erased. - * \param dest The region into which src will be merged. - */ + * \brief Merge src region into dest region. + * + * \param src The region to merge - will be erased. + * \param dest The region into which src will be merged. + */ void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest); void VisitAttrs(AttrVisitor* v) { @@ -214,8 +195,7 @@ class AnnotatedRegionSetNode : public Object { * to update and query regions. */ class AnnotatedRegionSet : public ObjectRef { - using UnorderedRegionSet = - std::unordered_set; + using UnorderedRegionSet = std::unordered_set; // Create iterator alias for a RegionSet object. using iterator = UnorderedRegionSet::iterator; using const_iterator = UnorderedRegionSet::const_iterator; @@ -227,10 +207,10 @@ class AnnotatedRegionSet : public ObjectRef { } /*! - * \brief Construct from an object pointer. - * - * \param n The object pointer. - */ + * \brief Construct from an object pointer. + * + * \param n The object pointer. + */ explicit AnnotatedRegionSet(ObjectPtr n) : ObjectRef(n) {} /*! \return The begin iterator. */ @@ -253,7 +233,7 @@ class AnnotatedRegionSet : public ObjectRef { } /*! \return The end iterator. */ const_iterator end() const { - const auto *n = operator->(); + const auto* n = operator->(); CHECK(n); return n->end(); } @@ -267,7 +247,7 @@ class AnnotatedRegionSet : public ObjectRef { /*! \return The region an expression belongs to. */ AnnotatedRegion operator[](const Expr& expr) { - const auto *n = operator->(); + const auto* n = operator->(); CHECK(n); return n->GetRegion(expr); } @@ -280,9 +260,7 @@ class AnnotatedRegionSet : public ObjectRef { * * \return The created RegionSet for the expression. */ - static AnnotatedRegionSet Create(const Expr& expr, - const Op& begin, - const Op& end); + static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end); private: /*! \brief Helper class to construct a RegionSet from an expr.*/ diff --git a/src/relay/analysis/call_graph.cc b/src/relay/analysis/call_graph.cc index a12d23d..0d3fedc 100644 --- a/src/relay/analysis/call_graph.cc +++ b/src/relay/analysis/call_graph.cc @@ -26,6 +26,7 @@ #include #include + #include #include #include @@ -72,22 +73,21 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { const CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) const { const_iterator cit = call_graph_.find(gv); - CHECK(cit != call_graph_.end()) - << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint + << " not found in the call graph!"; return cit->second.get(); } CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) { const_iterator cit = call_graph_.find(gv); - CHECK(cit != call_graph_.end()) - << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint + << " not found in the call graph!"; return cit->second.get(); } BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const { CHECK(module->ContainGlobalVar(var->name_hint)) - << "GlobalVar " << var->name_hint - << " not found in the current ir module"; + << "GlobalVar " << var->name_hint << " not found in the current ir module"; return module->Lookup(var); } @@ -120,8 +120,8 @@ GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph) { CHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1)) << "Cannot remove global var " << cg_node->GetNameHint() - << " from call graph, because it still calls " - << cg_node->size() << " other global functions"; + << " from call graph, because it still calls " << cg_node->size() + << " other global functions"; if (update_call_graph) { // Update the call graph by removing all edges that point to the node @@ -172,8 +172,7 @@ std::vector CallGraphNode::TopologicalOrder() const { << " with # refs = " << (*this)[it.first]->GetRefCount(); } } - LOG(FATAL) << "Expected " << module->functions.size() - << " globals, but received " + LOG(FATAL) << "Expected " << module->functions.size() << " globals, but received " << ret.size(); } @@ -184,8 +183,7 @@ std::vector CallGraphNode::TopologicalOrder() const { // that are visited by previous CallGraphEntry entries can be memoized. This // helps us to make sure no entry will be visited multiple times when collecting // the nodes for an entire call graph. -std::vector CallGraphEntry::TopologicalOrder( - CallGraphEntrySet* visited) const { +std::vector CallGraphEntry::TopologicalOrder(CallGraphEntrySet* visited) const { std::vector ret; std::vector current_nodes; if (visited->find(this) == visited->end()) { @@ -234,8 +232,7 @@ inline void CallGraphEntry::AddCalledGlobal(CallGraphEntry* cg_node) { // Remove an edge from the current global function to the callee. void CallGraphEntry::RemoveCallTo(const GlobalVar& callee) { for (auto it = begin();; ++it) { - CHECK(it != end()) << "Cannot find global function " - << callee->name_hint << " to remove!"; + CHECK(it != end()) << "Cannot find global function " << callee->name_hint << " to remove!"; if (it->second->GetGlobalVar() == callee) { // Only remove one occurrence of the call site. it->second->DecRef(); @@ -260,8 +257,7 @@ void CallGraphEntry::RemoveAllCallTo(CallGraphEntry* callee) { } // Make sure all references to the callee are removed. CHECK_EQ(callee->GetRefCount(), 0U) - << "All references to " << callee->GetNameHint() - << " should have been removed"; + << "All references to " << callee->GetNameHint() << " should have been removed"; } void CallGraphEntry::Print(std::ostream& os) const { @@ -293,54 +289,51 @@ std::ostream& operator<<(std::ostream& os, const CallGraphEntry& cgn) { TVM_REGISTER_NODE_TYPE(CallGraphNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - CHECK(node); - p->stream << "CallGraph: \n" << GetRef(node); -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + CHECK(node); + p->stream << "CallGraph: \n" << GetRef(node); + }); -TVM_REGISTER_GLOBAL("relay.analysis.CallGraph") -.set_body_typed([](IRModule module) { +TVM_REGISTER_GLOBAL("relay.analysis.CallGraph").set_body_typed([](IRModule module) { return CallGraph(module); }); -TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph") -.set_body_typed([](CallGraph call_graph) { +TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph").set_body_typed([](CallGraph call_graph) { std::stringstream ss; ss << call_graph; return ss.str(); }); -TVM_REGISTER_GLOBAL("relay.analysis.GetModule") -.set_body_typed([](CallGraph call_graph) { +TVM_REGISTER_GLOBAL("relay.analysis.GetModule").set_body_typed([](CallGraph call_graph) { return call_graph->module; }); TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - std::stringstream ss; - ss << *entry_node; - return ss.str(); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + std::stringstream ss; + ss << *entry_node; + return ss.str(); + }); TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - return static_cast(entry_node->GetRefCount()); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->GetRefCount()); + }); TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - return static_cast(entry_node->size()); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->size()); + }); TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - return entry_node->IsRecursive(); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return entry_node->IsRecursive(); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/call_graph.h b/src/relay/analysis/call_graph.h index 86bc646..387d2d3 100644 --- a/src/relay/analysis/call_graph.h +++ b/src/relay/analysis/call_graph.h @@ -32,6 +32,7 @@ #include #include #include + #include #include #include @@ -47,8 +48,7 @@ class CallGraph; class CallGraphNode : public Object { using CallGraphMap = - std::unordered_map, ObjectHash, - ObjectEqual>; + std::unordered_map, ObjectHash, ObjectEqual>; // Create iterator alias for a CallGraphNode object. using iterator = CallGraphMap::iterator; using const_iterator = CallGraphMap::const_iterator; @@ -60,9 +60,7 @@ class CallGraphNode : public Object { /*! \brief Default constructor. */ CallGraphNode() {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("module", &module); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("module", &module); } /*! * \brief Print the call graph. @@ -72,21 +70,13 @@ class CallGraphNode : public Object { void Print(std::ostream& os) const; /*! \return The begin iterator. */ - iterator begin() { - return call_graph_.begin(); - } + iterator begin() { return call_graph_.begin(); } /*! \return The end iterator. */ - iterator end() { - return call_graph_.end(); - } + iterator end() { return call_graph_.end(); } /*! \return The begin iterator. */ - const_iterator begin() const { - return call_graph_.begin(); - } + const_iterator begin() const { return call_graph_.begin(); } /*! \return The end iterator. */ - const_iterator end() const { - return call_graph_.end(); - } + const_iterator end() const { return call_graph_.end(); } /*! * \brief Get an element from the CallGraphNode using a GlobalVar. @@ -157,8 +147,7 @@ class CallGraphNode : public Object { * * \return The GlobalVar removed from the current module. */ - GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, - bool update_call_graph = false); + GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph = false); /*! * \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for @@ -207,8 +196,7 @@ class CallGraphNode : public Object { */ class CallGraph : public ObjectRef { using CallGraphMap = - std::unordered_map, ObjectHash, - ObjectEqual>; + std::unordered_map, ObjectHash, ObjectEqual>; // Create iterator alias for a CallGraph object. using iterator = CallGraphMap::iterator; using const_iterator = CallGraphMap::const_iterator; @@ -340,30 +328,20 @@ class CallGraphEntry { CallGraphEntry& operator=(const CallGraphEntry&) = delete; /*! \return The begin iterator */ - iterator begin() { - return called_globals_.begin(); - } + iterator begin() { return called_globals_.begin(); } /*! \return The end iterator */ - iterator end() { - return called_globals_.end(); - } + iterator end() { return called_globals_.end(); } /*! \return The const begin iterator */ - const_iterator begin() const { - return called_globals_.begin(); - } + const_iterator begin() const { return called_globals_.begin(); } /*! \return The const end iterator */ - const_iterator end() const { - return called_globals_.end(); - } + const_iterator end() const { return called_globals_.end(); } /*! * \brief Return if the list of called nodes is empty. * * \return true if the list is empty. Otherwise, false. */ - bool empty() const { - return called_globals_.empty(); - } + bool empty() const { return called_globals_.empty(); } /*! * \brief Return the size of the list that represents the nodes are called by @@ -371,9 +349,7 @@ class CallGraphEntry { * * \return The number of called nodes. */ - uint32_t size() const { - return static_cast(called_globals_.size()); - } + uint32_t size() const { return static_cast(called_globals_.size()); } /*! * \brief Fetch the i-th CallGraphEntry from the list of nodes that are called @@ -400,27 +376,21 @@ class CallGraphEntry { * * \return The count. */ - uint32_t GetRefCount() const { - return ref_cnt_; - } + uint32_t GetRefCount() const { return ref_cnt_; } /*! * \brief Return the GlobalVar stored in the current CallGraphEntry. * * \return The GlobalVar. */ - GlobalVar GetGlobalVar() const { - return global_; - } + GlobalVar GetGlobalVar() const { return global_; } /*! * \brief Return the name hint of the GlobalVar stored in the CallGraphEntry. * * \return The name hint of the global function. */ - std::string GetNameHint() const { - return global_->name_hint; - } + std::string GetNameHint() const { return global_->name_hint; } /*! * \brief Return if the global function corresponding to the current @@ -428,9 +398,7 @@ class CallGraphEntry { * * \return true if it is recursive. Otherwise, false. */ - bool IsRecursive() const { - return is_recursive_; - } + bool IsRecursive() const { return is_recursive_; } /*! * \brief Return if the global function corresponding to the current @@ -439,9 +407,7 @@ class CallGraphEntry { * * \return true if it is both a recursive function and an entry. Otherwise, false. */ - bool IsRecursiveEntry() const { - return GetRefCount() == 1 && IsRecursive(); - } + bool IsRecursiveEntry() const { return GetRefCount() == 1 && IsRecursive(); } /*! * \brief Return the topological order of the CallGraphEntry. diff --git a/src/relay/analysis/dependency_graph.cc b/src/relay/analysis/dependency_graph.cc index 7e48d12..a583e9a 100644 --- a/src/relay/analysis/dependency_graph.cc +++ b/src/relay/analysis/dependency_graph.cc @@ -22,7 +22,9 @@ * \brief Implementation of dependency graph APIs. */ #include "dependency_graph.h" + #include + #include #include @@ -32,8 +34,7 @@ namespace relay { // Creator of DependencyGraph class DependencyGraph::Creator : private ExprFunctor { public: - explicit Creator(support::Arena* arena) - : arena_(arena) {} + explicit Creator(support::Arena* arena) : arena_(arena) {} DependencyGraph Create(const Expr& body) { this->VisitExpr(body); @@ -164,15 +165,15 @@ class DependencyGraph::Creator : private ExprFunctor { } } - void VisitExpr_(const VarNode* v) final { } + void VisitExpr_(const VarNode* v) final {} - void VisitExpr_(const GlobalVarNode* v) final { } + void VisitExpr_(const GlobalVarNode* v) final {} - void VisitExpr_(const ConstantNode* c) final { } + void VisitExpr_(const ConstantNode* c) final {} - void VisitExpr_(const OpNode* o) final { } + void VisitExpr_(const OpNode* o) final {} - void VisitExpr_(const ConstructorNode* c) final { } + void VisitExpr_(const ConstructorNode* c) final {} }; DependencyGraph DependencyGraph::Create(support::Arena* arena, const Expr& body) { diff --git a/src/relay/analysis/dependency_graph.h b/src/relay/analysis/dependency_graph.h index 5e2dc0c..4aad95e 100644 --- a/src/relay/analysis/dependency_graph.h +++ b/src/relay/analysis/dependency_graph.h @@ -25,16 +25,18 @@ #define TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_ #include + #include #include -#include "../transforms/let_list.h" + #include "../../support/arena.h" +#include "../transforms/let_list.h" namespace tvm { namespace relay { -using support::LinkNode; using support::LinkedList; +using support::LinkNode; /* DependencyGraph track input and output of an Expr. * Additionally, dummy scope is created to model scope. diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index 95c2f73..9e94459 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -21,11 +21,12 @@ * \file feature.cc * \brief Detect features used in Expr/Module */ -#include +#include #include #include #include -#include +#include + #include "../transforms/pass_util.h" namespace tvm { @@ -49,34 +50,30 @@ FeatureSet DetectFeature(const Expr& expr) { } } } -#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \ - void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \ - STMT \ - fs += f##CONSTRUCT_NAME; \ - } -#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \ - ExprVisitor::VisitExpr_(op); \ - }) +#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \ + void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { STMT fs += f##CONSTRUCT_NAME; } +#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) \ + DETECT_CONSTRUCT(CONSTRUCT_NAME, { ExprVisitor::VisitExpr_(op); }) DETECT_DEFAULT_CONSTRUCT(Var) DETECT_DEFAULT_CONSTRUCT(GlobalVar) DETECT_DEFAULT_CONSTRUCT(Constant) DETECT_DEFAULT_CONSTRUCT(Tuple) DETECT_DEFAULT_CONSTRUCT(TupleGetItem) DETECT_CONSTRUCT(Function, { - if (!op->HasNonzeroAttr(attr::kPrimitive)) { - ExprVisitor::VisitExpr_(op); - } - }) + if (!op->HasNonzeroAttr(attr::kPrimitive)) { + ExprVisitor::VisitExpr_(op); + } + }) DETECT_DEFAULT_CONSTRUCT(Op) DETECT_DEFAULT_CONSTRUCT(Call) DETECT_CONSTRUCT(Let, { - for (const Var& v : FreeVars(op->value)) { - if (op->var == v) { - fs += fLetRec; - } + for (const Var& v : FreeVars(op->value)) { + if (op->var == v) { + fs += fLetRec; } - ExprVisitor::VisitExpr_(op); - }) + } + ExprVisitor::VisitExpr_(op); + }) DETECT_DEFAULT_CONSTRUCT(If) DETECT_DEFAULT_CONSTRUCT(RefCreate) DETECT_DEFAULT_CONSTRUCT(RefRead) @@ -104,8 +101,7 @@ Array PyDetectFeature(const Expr& expr, const IRModule& mod) { return static_cast>(fs); } -TVM_REGISTER_GLOBAL("relay.analysis.detect_feature") -.set_body_typed(PyDetectFeature); +TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc index b4835cc..ac0abc0 100644 --- a/src/relay/analysis/kind_check.cc +++ b/src/relay/analysis/kind_check.cc @@ -31,9 +31,9 @@ * We check this by ensuring the `dtype` field of a Tensor always * contains a data type such as `int`, `float`, `uint`. */ +#include #include #include -#include namespace tvm { namespace relay { @@ -51,40 +51,28 @@ struct KindChecker : TypeFunctor { this->err_reporter.RenderErrors(mod); } - void CheckKindMatches(const Type& t, const Type& outer, - Kind expected, const std::string& description) { + void CheckKindMatches(const Type& t, const Type& outer, Kind expected, + const std::string& description) { Kind k = this->VisitType(t); if (k != expected) { ReportFatalError(ErrorBuilder() - << "Incorrect kind for a " << description - << ". Type " << t << " inside " << outer - << " is of kind " << k - << " but was expected to be " - << expected); + << "Incorrect kind for a " << description << ". Type " << t << " inside " + << outer << " is of kind " << k << " but was expected to be " << expected); } } - Kind VisitType_(const IncompleteTypeNode* op) override { - return op->kind; - } + Kind VisitType_(const IncompleteTypeNode* op) override { return op->kind; } - Kind VisitType_(const TypeVarNode* op) override { - return op->kind; - } + Kind VisitType_(const TypeVarNode* op) override { return op->kind; } - Kind VisitType_(const GlobalTypeVarNode* op) override { - return op->kind; - } + Kind VisitType_(const GlobalTypeVarNode* op) override { return op->kind; } - Kind VisitType_(const TensorTypeNode* op) override { - return Kind::kType; - } + Kind VisitType_(const TensorTypeNode* op) override { return Kind::kType; } Kind VisitType_(const TupleTypeNode* op) override { // tuples should only contain normal types for (const Type& t : op->fields) { - CheckKindMatches(t, GetRef(op), Kind::kType, - "tuple member"); + CheckKindMatches(t, GetRef(op), Kind::kType, "tuple member"); } return Kind::kType; } @@ -117,8 +105,7 @@ struct KindChecker : TypeFunctor { Kind VisitType_(const TypeRelationNode* op) override { // arguments to type relation should be normal types for (const Type& t : op->args) { - CheckKindMatches(t, GetRef(op), Kind::kType, - "argument to type relation"); + CheckKindMatches(t, GetRef(op), Kind::kType, "argument to type relation"); } return Kind::kConstraint; } @@ -128,9 +115,8 @@ struct KindChecker : TypeFunctor { TypeCall tc = GetRef(op); const auto* gtv = op->func.as(); if (gtv == nullptr) { - ReportFatalError( - ErrorBuilder() <<"The callee in " << tc - << " is not a global type var, but is " << op->func); + ReportFatalError(ErrorBuilder() << "The callee in " << tc + << " is not a global type var, but is " << op->func); } CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function"); @@ -143,9 +129,8 @@ struct KindChecker : TypeFunctor { auto var = GetRef(gtv); auto data = mod->LookupTypeDef(var); if (data->type_vars.size() != op->args.size()) { - ReportFatalError(ErrorBuilder() - << "Expected " << data->type_vars.size() << "arguments for " << tc - << "; got " << op->args.size()); + ReportFatalError(ErrorBuilder() << "Expected " << data->type_vars.size() << "arguments for " + << tc << "; got " << op->args.size()); } return Kind::kType; } @@ -164,9 +149,8 @@ struct KindChecker : TypeFunctor { for (const auto& con : op->constructors) { if (!con->belong_to.same_as(op->header)) { - ReportFatalError(ErrorBuilder() - <belong_to - << " but " << op << " has header " << op->header); + ReportFatalError(ErrorBuilder() << con << " has header " << con->belong_to << " but " << op + << " has header " << op->header); } for (const Type& t : con->inputs) { @@ -176,9 +160,7 @@ struct KindChecker : TypeFunctor { return Kind::kTypeData; } - Kind Check(const Type& t) { - return this->VisitType(t); - } + Kind Check(const Type& t) { return this->VisitType(t); } }; Kind KindCheck(const Type& t, const IRModule& mod) { @@ -186,14 +168,13 @@ Kind KindCheck(const Type& t, const IRModule& mod) { return kc.Check(t); } -TVM_REGISTER_GLOBAL("relay.analysis.check_kind") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args.size() == 1) { - *ret = KindCheck(args[0], IRModule({}, {})); - } else { - *ret = KindCheck(args[0], args[1]); - } - }); +TVM_REGISTER_GLOBAL("relay.analysis.check_kind").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = KindCheck(args[0], IRModule({}, {})); + } else { + *ret = KindCheck(args[0], args[1]); + } +}); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/mac_count.cc b/src/relay/analysis/mac_count.cc index fecde3c..882bba9 100644 --- a/src/relay/analysis/mac_count.cc +++ b/src/relay/analysis/mac_count.cc @@ -26,11 +26,12 @@ * otherwise the count is 0. */ -#include +#include #include #include -#include +#include #include + #include "../transforms/pattern_util.h" namespace tvm { @@ -52,8 +53,7 @@ inline int64_t GetCartesianProd(Array arr) { * \param call_node The call node. * \return The number of MACs. */ -using FMacCount = runtime::TypedPackedFunc< - int64_t(const Call& call_node)>; +using FMacCount = runtime::TypedPackedFunc; //---------------------------------------------- // Per operator defs for MAC count @@ -65,30 +65,26 @@ int64_t ConvMacCount(const Call& call_node) { return 0; } Array args = call_node->args; - CHECK_EQ(args.size(), 2) - << "The number of input arguments of a CONV 2D node should be 2."; + CHECK_EQ(args.size(), 2) << "The number of input arguments of a CONV 2D node should be 2."; const auto* conv_2d_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); Array data_shape = data_type->shape; std::string data_layout = conv_2d_attr->data_layout; int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); - CHECK_NE(C_ind, -1) - << "There is no input channel dimension."; + CHECK_NE(C_ind, -1) << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); - if (c_ind != -1) - input_channel *= static_cast(data_shape[c_ind].as()->value); + if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_attr->kernel_size; - CHECK_EQ(kernel_size.size(), 2) - << "The dimension of the kernel in Conv 2D should be 2."; + CHECK_EQ(kernel_size.size(), 2) << "The dimension of the kernel in Conv 2D should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D should be 4 or 5."; + << "The dimension of the output tensor in Conv 2D should be 4 or 5."; int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); CHECK_EQ(input_channel % conv_2d_attr->groups, 0) - << "The number of input channels is not divisble by groups."; - count *= input_channel/conv_2d_attr->groups; + << "The number of input channels is not divisble by groups."; + count *= input_channel / conv_2d_attr->groups; return count; } @@ -99,29 +95,27 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) { } Array args = call_node->args; CHECK_EQ(args.size(), 2) - << "The number of input arguments of a CONV 2D Transpose node should be 2."; + << "The number of input arguments of a CONV 2D Transpose node should be 2."; const auto* conv_2d_transpose_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); Array data_shape = data_type->shape; std::string data_layout = conv_2d_transpose_attr->data_layout; int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); - CHECK_NE(C_ind, -1) - << "There is no input channel dimension."; + CHECK_NE(C_ind, -1) << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); - if (c_ind != -1) - input_channel *= static_cast(data_shape[c_ind].as()->value); + if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_transpose_attr->kernel_size; CHECK_EQ(kernel_size.size(), 2) - << "The dimension of the kernel in Conv 2D Transpose should be 2."; + << "The dimension of the kernel in Conv 2D Transpose should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5."; + << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5."; int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); CHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0) - << "The number of input channels is not divisble by groups."; - count *= input_channel/conv_2d_transpose_attr->groups; + << "The number of input channels is not divisble by groups."; + count *= input_channel / conv_2d_transpose_attr->groups; return count; } @@ -131,20 +125,18 @@ int64_t DenseMacCount(const Call& call_node) { return 0; } Array args = call_node->args; - CHECK_EQ(args.size(), 2) - << "The number of input arguments of a Dense node should be 2."; + CHECK_EQ(args.size(), 2) << "The number of input arguments of a Dense node should be 2."; const auto* data_type = args[0]->checked_type().as(); const auto* weight_type = args[1]->checked_type().as(); Array data_shape = data_type->shape; Array weight_shape = weight_type->shape; CHECK(data_shape.size() == 2 && weight_shape.size() == 2) - << "The dimension of an input tensor to Dense node should be 2."; + << "The dimension of an input tensor to Dense node should be 2."; int64_t d1 = static_cast(data_shape[0].as()->value); int64_t d2 = static_cast(data_shape[1].as()->value); int64_t d3 = static_cast(weight_shape[0].as()->value); int64_t d4 = static_cast(weight_shape[1].as()->value); - CHECK_EQ(d2, d4) - << "The dimensions of input arguments do not match."; + CHECK_EQ(d2, d4) << "The dimensions of input arguments do not match."; int64_t count = d1 * d2 * d3; return count; } @@ -165,23 +157,17 @@ int64_t BatchMatmulMacCount(const Call& call_node) { return batch * m * k * n; } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FMacCount", ConvMacCount); +RELAY_REGISTER_OP("nn.conv2d").set_attr("FMacCount", ConvMacCount); -RELAY_REGISTER_OP("nn.conv2d_transpose") -.set_attr("FMacCount", Conv2dTransposeMacCount); +RELAY_REGISTER_OP("nn.conv2d_transpose").set_attr("FMacCount", Conv2dTransposeMacCount); -RELAY_REGISTER_OP("nn.dense") -.set_attr("FMacCount", DenseMacCount); +RELAY_REGISTER_OP("nn.dense").set_attr("FMacCount", DenseMacCount); -RELAY_REGISTER_OP("nn.batch_matmul") -.set_attr("FMacCount", BatchMatmulMacCount); +RELAY_REGISTER_OP("nn.batch_matmul").set_attr("FMacCount", BatchMatmulMacCount); class MacCounter : private ExprVisitor { public: - MacCounter() { - count_ = 0; - } + MacCounter() { count_ = 0; } static int64_t GetTotalMacNumber(const Expr& expr) { LOG(INFO) << "This pass only counts MACs in direct conv2d, " << "conv2d_transpose, dense, and batch_matmul ops"; @@ -192,8 +178,7 @@ class MacCounter : private ExprVisitor { private: void VisitExpr_(const CallNode* call_node) final { - static const auto& fprep = - Op::GetAttr("FMacCount"); + static const auto& fprep = Op::GetAttr("FMacCount"); auto f = fprep.get(call_node->op, nullptr); if (f != nullptr) count_ += f(GetRef(call_node)); ExprVisitor::VisitExpr_(call_node); @@ -202,12 +187,9 @@ class MacCounter : private ExprVisitor { int64_t count_; }; -int64_t GetTotalMacNumber(const Expr& expr) { - return MacCounter::GetTotalMacNumber(expr); -} +int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); } -TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber") -.set_body_typed(GetTotalMacNumber); +TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber").set_body_typed(GetTotalMacNumber); } // namespace mac_count } // namespace relay diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc index eeb7fce..96dab6b 100644 --- a/src/relay/analysis/match_exhaustion.cc +++ b/src/relay/analysis/match_exhaustion.cc @@ -27,10 +27,11 @@ * code correctness, since hitting an unmatched case results in a * dynamic error unless exhaustiveness is checked in advance. */ -#include #include +#include #include #include + #include namespace tvm { @@ -154,17 +155,14 @@ Array> CartesianProduct(Array> fields) { } Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, - const Pattern& cand, - const IRModule& mod); + const Pattern& cand, const IRModule& mod); -Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, - const Pattern& cand, +Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, const IRModule& mod); // Expands all wildcards in the candidate pattern once // Returns a list of all possible expansions. -Array ExpandWildcards(const Pattern& clause_pat, - const Pattern& cand, +Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, const IRModule& mod) { if (auto clause_ctor = clause_pat.as()) { return ExpandWildcardsConstructor(GetRef(clause_ctor), cand, mod); @@ -179,8 +177,7 @@ Array ExpandWildcards(const Pattern& clause_pat, // Use the pattern to decide which constructors to insert. // Returns a list of all possible expansions. Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, - const Pattern& cand, - const IRModule& mod) { + const Pattern& cand, const IRModule& mod) { auto gtv = Downcast(clause_ctor->constructor->belong_to); // for a wildcard node, create constructor nodes with wildcards for all args. @@ -203,9 +200,8 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, // for constructors, we will expand the wildcards in any field that is an ADT. Array> values_by_field; for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { - values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i], - ctor_cand->patterns[i], - mod)); + values_by_field.push_back( + ExpandWildcards(clause_ctor->patterns[i], ctor_cand->patterns[i], mod)); } // generate new candidates using a cartesian product. @@ -219,8 +215,7 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, // Expands all wildcards in the candidate pattern once. // Returns a list of all possible expansions. -Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, - const Pattern& cand, +Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, const IRModule& mod) { // for a wildcard node, create constructor nodes with wildcards for all args. if (cand.as()) { @@ -236,9 +231,8 @@ Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, // for constructors, we will expand the wildcards in any field that is an ADT. Array> values_by_field; for (size_t i = 0; i < tuple_cand->patterns.size(); i++) { - values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i], - tuple_cand->patterns[i], - mod)); + values_by_field.push_back( + ExpandWildcards(clause_tuple->patterns[i], tuple_cand->patterns[i], mod)); } // generate new candidates using a cartesian product @@ -311,14 +305,13 @@ Array UnmatchedCases(const Match& match, const IRModule& mod) { // expose for testing only TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases") -.set_body_typed( - [](const Match& match, const IRModule& mod_ref) { - IRModule call_mod = mod_ref; - if (!call_mod.defined()) { - call_mod = IRModule({}, {}); - } - return UnmatchedCases(match, call_mod); - }); + .set_body_typed([](const Match& match, const IRModule& mod_ref) { + IRModule call_mod = mod_ref; + if (!call_mod.defined()) { + call_mod = IRModule({}, {}); + } + return UnmatchedCases(match, call_mod); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 650403c..05e231a 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -21,26 +21,25 @@ * \file type_solver.cc * \brief Type solver implementations. */ -#include +#include "type_solver.h" + #include +#include #include -#include + #include +#include #include #include -#include "type_solver.h" namespace tvm { namespace relay { class TypeSolver::Reporter : public TypeReporterNode { public: - explicit Reporter(TypeSolver* solver) - : solver_(solver) {} + explicit Reporter(TypeSolver* solver) : solver_(solver) {} - void Assign(const Type& dst, const Type& src) final { - solver_->Unify(dst, src, location); - } + void Assign(const Type& dst, const Type& src) final { solver_->Unify(dst, src, location); } bool Assert(const IndexExpr& cond) final { if (const int64_t* pdiff = tir::as_const_int(cond)) { @@ -58,13 +57,9 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } - TVM_DLL void SetLocation(const ObjectRef& ref) final { - location = ref; - } + TVM_DLL void SetLocation(const ObjectRef& ref) final { location = ref; } - TVM_DLL IRModule GetModule() final { - return this->solver_->module_; - } + TVM_DLL IRModule GetModule() final { return this->solver_->module_; } private: /*! \brief The location to report unification errors at. */ @@ -76,7 +71,7 @@ class TypeSolver::Reporter : public TypeReporterNode { class TypeSolver::OccursChecker : public TypeVisitor { public: explicit OccursChecker(TypeSolver* solver, TypeNode* var) - : solver_(solver), var_(var), found_(false) {} + : solver_(solver), var_(var), found_(false) {} bool Check(const Type& t) { VisitType(t); @@ -112,25 +107,24 @@ class TypeSolver::Unifier : public TypeFunctor { if (lhs->resolved_type.as()) { CHECK(!OccursCheck(lhs, rhs->resolved_type)) - << "Incomplete type " << lhs->resolved_type << " occurs in " - << rhs->resolved_type << ", cannot unify"; + << "Incomplete type " << lhs->resolved_type << " occurs in " << rhs->resolved_type + << ", cannot unify"; solver_->MergeFromTo(lhs, rhs); return rhs->resolved_type; } else if (rhs->resolved_type.as()) { CHECK(!OccursCheck(rhs, lhs->resolved_type)) - << "Incomplete type " << rhs->resolved_type << " occurs in " - << lhs->resolved_type << ", cannot unify"; + << "Incomplete type " << rhs->resolved_type << " occurs in " << lhs->resolved_type + << ", cannot unify"; solver_->MergeFromTo(rhs, lhs); return lhs->resolved_type; } else { Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); if (!resolved.defined()) { - solver_->ReportError( - ErrorBuilder() << "unable to unify: " - << "`" << PrettyPrint(lhs->resolved_type) << "` and `" - << PrettyPrint(rhs->resolved_type) << "`", - this->loc); + solver_->ReportError(ErrorBuilder() << "unable to unify: " + << "`" << PrettyPrint(lhs->resolved_type) << "` and `" + << PrettyPrint(rhs->resolved_type) << "`", + this->loc); return lhs->resolved_type; } else { TypeNode* top = solver_->GetTypeNode(resolved); @@ -227,14 +221,11 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->ReportError( - ErrorBuilder() << - "tensor type `" << PrettyPrint(tt1) << - "` has " << tt1->shape.size() << - " dimensions, while `" << - PrettyPrint(tt2) << - "` has " << tt2->shape.size() << - " dimensions", this->loc); + this->solver_->ReportError(ErrorBuilder() << "tensor type `" << PrettyPrint(tt1) << "` has " + << tt1->shape.size() << " dimensions, while `" + << PrettyPrint(tt2) << "` has " << tt2->shape.size() + << " dimensions", + this->loc); return Type(nullptr); } @@ -259,12 +250,8 @@ class TypeSolver::Unifier : public TypeFunctor { ErrorBuilder err; err << "in particular "; for (auto mismatch : mismatches) { - err << "dimension " - << std::get<0>(mismatch) - << " conflicts " - << std::get<1>(mismatch) - << " does not match " - << std::get<2>(mismatch); + err << "dimension " << std::get<0>(mismatch) << " conflicts " << std::get<1>(mismatch) + << " does not match " << std::get<2>(mismatch); } Error error(err); this->solver_->ReportError(error, this->loc); @@ -293,9 +280,8 @@ class TypeSolver::Unifier : public TypeFunctor { Type VisitType_(const FuncTypeNode* op, const Type& tn) final { const auto* ftn = tn.as(); - if (!ftn - || op->arg_types.size() != ftn->arg_types.size() - || op->type_constraints.size() != ftn->type_constraints.size()) { + if (!ftn || op->arg_types.size() != ftn->arg_types.size() || + op->type_constraints.size() != ftn->type_constraints.size()) { return Type(nullptr); } @@ -316,10 +302,7 @@ class TypeSolver::Unifier : public TypeFunctor { subst_map.Set(op->type_params[i], IncompleteType(kType)); } - FuncType ft = FuncType(op->arg_types, - op->ret_type, - ft_type_params, - op->type_constraints); + FuncType ft = FuncType(op->arg_types, op->ret_type, ft_type_params, op->type_constraints); auto ft1 = Downcast(Bind(ft, subst_map)); auto ft2 = GetRef(ftn); @@ -333,8 +316,7 @@ class TypeSolver::Unifier : public TypeFunctor { std::vector type_constraints; for (size_t i = 0; i < ft1->type_constraints.size(); ++i) { - Type unified_constraint = Unify(ft1->type_constraints[i], - ft2->type_constraints[i]); + Type unified_constraint = Unify(ft1->type_constraints[i], ft2->type_constraints[i]); const auto* tcn = unified_constraint.as(); CHECK(tcn) << "Two type constraints unified into a non-constraint?" << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; @@ -397,12 +379,10 @@ class TypeSolver::Resolver : public TypeMutator { class TypeSolver::Propagator : public TypeFunctor { public: explicit Propagator(TypeSolver* solver, const std::unordered_set* rels) - : solver_(solver), rels_(rels) {} + : solver_(solver), rels_(rels) {} // adds the relation node to t and all child types of t - void Propagate(const Type& t) { - VisitType(t); - } + void Propagate(const Type& t) { VisitType(t); } void UpdateRelSet(const Type& t) { TypeNode* tnode = solver_->GetTypeNode(t); @@ -532,10 +512,8 @@ class TypeSolver::Merger : public TypeFunctor { }; // constructor -TypeSolver::TypeSolver( - const GlobalVar& current_func, - const IRModule& module, - ErrorReporter* err_reporter) +TypeSolver::TypeSolver(const GlobalVar& current_func, const IRModule& module, + ErrorReporter* err_reporter) : reporter_(make_object(this)), current_func(current_func), err_reporter_(err_reporter), @@ -566,7 +544,7 @@ Type TypeSolver::Unify(const Type& dst, const Type& src, const ObjectRef& loc) { return unifier.Unify(dst, src); } -void TypeSolver::ReportError(const Error& err, const ObjectRef& location) { +void TypeSolver::ReportError(const Error& err, const ObjectRef& location) { CHECK(location.defined()); CHECK(current_func.defined()); err_reporter_->ReportAt(current_func, location, err); @@ -583,20 +561,19 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint, const ObjectRef // populate the type information. for (size_t i = 0; i < op->args.size(); ++i) { // insert link to the type list - LinkNode* tlink = arena_.make >(); + LinkNode* tlink = arena_.make>(); TypeNode* tnode = GetTypeNode(op->args[i]); tlink->value = tnode; rnode->type_list.Push(tlink); // insert type->relation node - std::unordered_set singleton { rnode }; + std::unordered_set singleton{rnode}; Propagator prop(this, &singleton); prop.Propagate(tnode->resolved_type); } // add the relation to the working queue. this->AddToQueue(rnode); } else { - LOG(FATAL) << "Do not know how to handle constraint type" - << constraint->GetTypeKey(); + LOG(FATAL) << "Do not know how to handle constraint type" << constraint->GetTypeKey(); } } @@ -642,11 +619,9 @@ bool TypeSolver::Solve() { rnode->resolved = false; } catch (const dmlc::Error& err) { rnode->resolved = false; - this->ReportError( - ErrorBuilder() << "an internal invariant was violated while " - << "typechecking your program " - << err.what(), - rnode->location); + this->ReportError(ErrorBuilder() << "an internal invariant was violated while " + << "typechecking your program " << err.what(), + rnode->location); } // Mark inqueue as false after the function call @@ -661,45 +636,40 @@ bool TypeSolver::Solve() { // Expose type solver only for debugging purposes. TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") -.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { - using runtime::PackedFunc; - using runtime::TypedPackedFunc; - ErrorReporter *err_reporter = new ErrorReporter(); - auto module = IRModule({}, {}); - auto dummy_fn_name = GlobalVar("test"); - module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); - auto solver = std::make_shared(dummy_fn_name, module, err_reporter); - - auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { - if (name == "Solve") { - return TypedPackedFunc([solver]() { - return solver->Solve(); - }); - } else if (name == "Unify") { - return TypedPackedFunc( - [module, solver, err_reporter](Type lhs, Type rhs) { - auto res = solver->Unify(lhs, rhs, lhs); - if (err_reporter->AnyErrors()) { - err_reporter->RenderErrors(module, true); - } - return res; - }); - } else if (name == "Resolve") { - return TypedPackedFunc([solver](Type t) { - return solver->Resolve(t); - }); - } else if (name == "AddConstraint") { - return TypedPackedFunc([solver](TypeConstraint c) { - Expr e = Var("dummy_var", - IncompleteType(Kind::kType)); + .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { + using runtime::PackedFunc; + using runtime::TypedPackedFunc; + ErrorReporter* err_reporter = new ErrorReporter(); + auto module = IRModule({}, {}); + auto dummy_fn_name = GlobalVar("test"); + module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); + auto solver = std::make_shared(dummy_fn_name, module, err_reporter); + + auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { + if (name == "Solve") { + return TypedPackedFunc([solver]() { return solver->Solve(); }); + } else if (name == "Unify") { + return TypedPackedFunc( + [module, solver, err_reporter](Type lhs, Type rhs) { + auto res = solver->Unify(lhs, rhs, lhs); + if (err_reporter->AnyErrors()) { + err_reporter->RenderErrors(module, true); + } + return res; + }); + } else if (name == "Resolve") { + return TypedPackedFunc([solver](Type t) { return solver->Resolve(t); }); + } else if (name == "AddConstraint") { + return TypedPackedFunc([solver](TypeConstraint c) { + Expr e = Var("dummy_var", IncompleteType(Kind::kType)); return solver->AddConstraint(c, e); }); - } else { - return PackedFunc(); - } - }; - *ret = runtime::TypedPackedFunc(mod); - }); + } else { + return PackedFunc(); + } + }; + *ret = runtime::TypedPackedFunc(mod); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 8ccc2c7..9b7c06c 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -24,21 +24,23 @@ #ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ #define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ +#include +#include #include #include -#include -#include -#include + #include #include #include +#include + #include "../../support/arena.h" namespace tvm { namespace relay { -using support::LinkNode; using support::LinkedList; +using support::LinkNode; /*! * \brief Interface of type solver used in type inference. diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index a86faeb..1d84016 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -28,12 +28,13 @@ #include #include #include + #include "../transforms/pass_util.h" namespace tvm { namespace relay { -template +template struct InsertionSet { std::unordered_set set; std::vector data; @@ -47,10 +48,8 @@ struct InsertionSet { class TypeVarTVisitor : public TypeVisitor { public: - TypeVarTVisitor( - InsertionSet* type_vars, - InsertionSet* bound_type_vars) - : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } + TypeVarTVisitor(InsertionSet* type_vars, InsertionSet* bound_type_vars) + : type_vars_(type_vars), bound_type_vars_(bound_type_vars) {} void VisitType_(const TypeVarNode* tp) final { TypeVar var = GetRef(tp); @@ -149,8 +148,7 @@ class TypeVarEVisitor : private ExprVisitor { } void VisitType(const Type& t) final { - TypeVarTVisitor(&type_vars_, &bound_type_vars_) - .VisitType(t); + TypeVarTVisitor(&type_vars_, &bound_type_vars_).VisitType(t); } private: @@ -204,9 +202,7 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { vars_.Insert(v); } - void VisitExpr_(const VarNode* var) final { - vars_.Insert(GetRef(var)); - } + void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { @@ -221,13 +217,9 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { VisitExpr(op->body); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } - void VisitPattern_(const PatternVarNode* op) final { - MarkBounded(op->var); - } + void VisitPattern_(const PatternVarNode* op) final { MarkBounded(op->var); } private: InsertionSet vars_; @@ -258,82 +250,66 @@ tvm::Array AllTypeVars(const Type& type, const IRModule& mod) { return TypeVarEVisitor(mod).All(type); } -tvm::Array FreeVars(const Expr& expr) { - return VarVisitor().Free(expr); -} +tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } -tvm::Array BoundVars(const Expr& expr) { - return VarVisitor().Bound(expr); -} +tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } -tvm::Array BoundVars(const Pattern& pat) { - return VarVisitor().Bound(pat); -} +tvm::Array BoundVars(const Pattern& pat) { return VarVisitor().Bound(pat); } -tvm::Array AllVars(const Expr& expr) { - return VarVisitor().All(expr); -} +tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } -TVM_REGISTER_GLOBAL("relay.analysis.free_vars") -.set_body_typed(FreeVars); +TVM_REGISTER_GLOBAL("relay.analysis.free_vars").set_body_typed(FreeVars); -TVM_REGISTER_GLOBAL("relay.analysis.bound_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - if (x.as()) { - *ret = BoundVars(Downcast(x)); - } else { - *ret = BoundVars(Downcast(x)); - } - }); +TVM_REGISTER_GLOBAL("relay.analysis.bound_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + if (x.as()) { + *ret = BoundVars(Downcast(x)); + } else { + *ret = BoundVars(Downcast(x)); + } +}); -TVM_REGISTER_GLOBAL("relay.analysis.all_vars") -.set_body_typed(AllVars); +TVM_REGISTER_GLOBAL("relay.analysis.all_vars").set_body_typed(AllVars); -TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars") -.set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - IRModule mod = args[1]; - if (x.as()) { - *ret = FreeTypeVars(Downcast(x), mod); - } else { - *ret = FreeTypeVars(Downcast(x), mod); - } - }); - -TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - IRModule mod = args[1]; - if (x.as()) { - *ret = BoundTypeVars(Downcast(x), mod); - } else { - *ret = BoundTypeVars(Downcast(x), mod); - } - }); - -TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - IRModule mod = args[1]; - if (x.as()) { - *ret = AllTypeVars(Downcast(x), mod); - } else { - *ret = AllTypeVars(Downcast(x), mod); - } - }); +TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + IRModule mod = args[1]; + if (x.as()) { + *ret = FreeTypeVars(Downcast(x), mod); + } else { + *ret = FreeTypeVars(Downcast(x), mod); + } +}); + +TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + IRModule mod = args[1]; + if (x.as()) { + *ret = BoundTypeVars(Downcast(x), mod); + } else { + *ret = BoundTypeVars(Downcast(x), mod); + } +}); + +TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + IRModule mod = args[1]; + if (x.as()) { + *ret = AllTypeVars(Downcast(x), mod); + } else { + *ret = AllTypeVars(Downcast(x), mod); + } +}); /*! * \brief Get reference counter of each internal ExprNode in body. * \param body The body expression. * \return The reference count mapping. */ -std::unordered_map -GetExprRefCount(const Expr& body) { +std::unordered_map GetExprRefCount(const Expr& body) { class ExprRefCounter : private MixedModeVisitor { public: - std::unordered_map - Get(const Expr& body) { + std::unordered_map Get(const Expr& body) { this->VisitExpr(body); return std::move(this->visit_counter_); } @@ -391,9 +367,7 @@ bool IsAllPositiveConstant(const Expr& expr) { } } else if (const auto* op = expr.as()) { // tail recursion. - if (op->op == expand_dims_op || - op->op == reshape_op || - op->op == transpose_op || + if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op || op->op == squeeze_op) { return IsAllPositiveConstant(op->args[0]); } else { @@ -419,17 +393,11 @@ Type TypeSubst(const Type& type, const tvm::Map& subst_map) { Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { class TypeSubstMutator : public ExprMutator, public PatternMutator { public: - explicit TypeSubstMutator(const tvm::Map& subst_map) : subst_map_(subst_map) { } - Type VisitType(const Type& t) final { - return TypeSubst(t, subst_map_); - } - Var VisitVar(const Var& v) final { - return Downcast(VisitExpr(v)); - } + explicit TypeSubstMutator(const tvm::Map& subst_map) : subst_map_(subst_map) {} + Type VisitType(const Type& t) final { return TypeSubst(t, subst_map_); } + Var VisitVar(const Var& v) final { return Downcast(VisitExpr(v)); } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Clause VisitClause(const Clause& c) final { Pattern pat = VisitPattern(c->lhs); diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index f3a2cad..33f52c9 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -24,12 +24,12 @@ #include #include #include + #include namespace tvm { namespace relay { - //! brief make sure each Var is bound at most once in a scope. class WellFormedChecker : private ExprVisitor, PatternVisitor { bool well_formed = true; @@ -41,9 +41,7 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { struct Scope { WellFormedChecker* wfc; - explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { - wfc->scope.push_back({{}}); - } + explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { wfc->scope.push_back({{}}); } ~Scope() { CHECK_GE(wfc->scope.size(), 0); for (const Var& v : wfc->scope.back()) { @@ -98,13 +96,9 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { VisitExpr(c->rhs); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } - void VisitVar(const Var& v) final { - Bound(v); - } + void VisitVar(const Var& v) final { Bound(v); } void VisitExpr(const Expr& e) final { if (auto v = e.as()) { @@ -121,12 +115,9 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { } }; -bool WellFormed(const Expr& e) { - return WellFormedChecker().CheckWellFormed(e); -} +bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); } -TVM_REGISTER_GLOBAL("relay.analysis.well_formed") -.set_body_typed(WellFormed); +TVM_REGISTER_GLOBAL("relay.analysis.well_formed").set_body_typed(WellFormed); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index c26228e..ef273c3 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -21,13 +21,14 @@ * \file relay/backend/build_module.cc * \brief Code generation for TVM's graph runtime. */ -#include #include -#include -#include +#include #include -#include #include +#include +#include +#include + #include #include "../../target/source/codegen_source_base.h" @@ -37,7 +38,6 @@ namespace tvm { namespace relay { namespace backend { - using TargetsMap = Map; using namespace tvm::relay::transform; @@ -63,17 +63,11 @@ struct GraphCodegen { } ~GraphCodegen() {} - void Init(runtime::Module* m, TargetsMap targets) { - CallFunc("init", m, targets); - } + void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } - void Codegen(const Function& func) { - CallFunc("codegen", func); - } + void Codegen(const Function& func) { CallFunc("codegen", func); } - std::string GetJSON() { - return CallFunc("get_graph_json", nullptr); - } + std::string GetJSON() { return CallFunc("get_graph_json", nullptr); } Array GetExternalModules() { return CallFunc>("get_external_modules", nullptr); @@ -96,13 +90,13 @@ struct GraphCodegen { protected: tvm::runtime::Module mod; - template - R CallFunc(const std::string &name, Args... args) { + template + R CallFunc(const std::string& name, Args... args) { auto pf = mod.GetFunction(name, false); return pf(std::forward(args)...); } - template - void CallFunc(const std::string &name, Args... args) { + template + void CallFunc(const std::string& name, Args... args) { auto pf = mod.GetFunction(name, false); pf(std::forward(args)...); return; @@ -121,29 +115,24 @@ class RelayBuildModule : public runtime::ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "get_graph_json") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetGraphJSON(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); }); } else if (name == "get_module") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetModule(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); this->Build(args[0], args[1], args[2]); }); } else if (name == "list_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->ListParamNames(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->ListParamNames(); }); } else if (name == "get_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetParams(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map params = args[0]; @@ -153,11 +142,11 @@ class RelayBuildModule : public runtime::ModuleNode { }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetIRModule(); + *rv = this->graph_codegen_->GetIRModule(); }); } else if (name == "get_external_modules") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetExternalModules(); + *rv = this->graph_codegen_->GetExternalModules(); }); } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -175,18 +164,14 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return const std::string graph_json */ - const std::string& GetGraphJSON() { - return ret_.graph_json; - } + const std::string& GetGraphJSON() { return ret_.graph_json; } /*! * \brief Get the Module object * * \return runtime::Module */ - runtime::Module GetModule() { - return ret_.mod; - } + runtime::Module GetModule() { return ret_.mod; } /*! * \brief List all paramter names @@ -220,18 +205,14 @@ class RelayBuildModule : public runtime::ModuleNode { * \param name name of parameter * \param data_in input DLTensor */ - void SetParam(const std::string& name, runtime::NDArray data_in) { - params_[name] = data_in; - } + void SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } /*! * \brief type key * * \return const char* */ - const char* type_key() const final { - return "RelayBuildModule"; - } + const char* type_key() const final { return "RelayBuildModule"; } /*! * \brief Build relay IRModule for graph runtime @@ -240,9 +221,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void Build(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host) { + void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; BuildRelay(mod, params_); @@ -258,13 +237,10 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize( - IRModule relay_module, - const TargetsMap& targets, - const std::unordered_map& params) { + IRModule Optimize(IRModule relay_module, const TargetsMap& targets, + const std::unordered_map& params) { if (params.size()) { - CHECK(relay_module->ContainGlobalVar("main")) - << "Missing the main entry function"; + CHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); Function main_func = Downcast(relay_module->Lookup(main_glb_var)); auto new_main = BindParamsByName(main_func, params); @@ -328,8 +304,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Handle heterogeneous compilation. transform::PassContext pass_ctx = PassContext::Current(); if (targets_.size() > 1) { - relay_module = - RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); + relay_module = RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); } // Fuse the operations if it is needed. @@ -386,8 +361,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return updated_module The updated module after device annotation. */ - IRModule RunDeviceAnnotationPass(const IRModule& relay_module, - int fallback_device) { + IRModule RunDeviceAnnotationPass(const IRModule& relay_module, int fallback_device) { UpdateHeterogeneousInputs(fallback_device); auto rewrite = transform::RewriteAnnotatedOps(fallback_device); auto updated_module = rewrite(relay_module); @@ -416,12 +390,11 @@ class RelayBuildModule : public runtime::ModuleNode { break; } for (auto kv : annotation_map) { - CHECK_EQ(kv.second->value, dev_type) - << "Expressions in the function are " - << "annotated with various device types," - << "but not device copy operators " - << "found. Please check the " - << "RewriteAnnotation pass."; + CHECK_EQ(kv.second->value, dev_type) << "Expressions in the function are " + << "annotated with various device types," + << "but not device copy operators " + << "found. Please check the " + << "RewriteAnnotation pass."; } targets_.Set(0, CreateDefaultTarget(dev_type)); } @@ -435,9 +408,8 @@ class RelayBuildModule : public runtime::ModuleNode { * \param relay_module The Relay IR module. * \param params The parameters. */ - void BuildRelay( - IRModule relay_module, - const std::unordered_map& params) { + void BuildRelay(IRModule relay_module, + const std::unordered_map& params) { // Relay IRModule -> IRModule optimizations. relay_module = Optimize(relay_module, targets_, params); // Get the updated function. @@ -473,23 +445,19 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.mod = tvm::codegen::CSourceModuleCreate(";", ""); } } else { - ret_.mod = tvm::build( - lowered_funcs, - target_host_, - BuildConfig::Current()); + ret_.mod = tvm::build(lowered_funcs, target_host_, BuildConfig::Current()); } Array ext_mods = graph_codegen_->GetExternalModules(); // Import all external runtime modules. - for (const auto& it : ext_mods) - ret_.mod.Import(it); + for (const auto& it : ext_mods) ret_.mod.Import(it); } private: Target GetTargetHost() { Target target_host = target_host_; if (!target_host_.defined()) { - for (const auto &it : targets_) { + for (const auto& it : targets_) { if (it.second->device_type == kDLCPU) { target_host = it.second; break; @@ -516,20 +484,19 @@ runtime::Module RelayBuildCreate() { return runtime::Module(exec); } -TVM_REGISTER_GLOBAL("relay.build_module._BuildModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = RelayBuildCreate(); }); TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Map params = args[1]; - std::unordered_map params_; - for (const auto& kv : params) { - params_[kv.first] = kv.second->data; - } - *rv = relay::backend::BindParamsByName(args[0], params_); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + Map params = args[1]; + std::unordered_map params_; + for (const auto& kv : params) { + params_[kv.first] = kv.second->data; + } + *rv = relay::backend::BindParamsByName(args[0], params_); + }); } // namespace backend } // namespace relay diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ce0a314..3851de1 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -89,8 +89,7 @@ bool IsDynamic(const Type& ty) { } // TODO(@jroesch): MOVE ME -TVM_REGISTER_GLOBAL("relay.ir.IsDynamic") -.set_body_typed(IsDynamic); +TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic); Array GetShape(const Array& shape) { // for now, we always use int32 shape when possible @@ -124,8 +123,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> for (Var param : prim_func->params) { Array inputs; if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder( - GetShape(ttype->shape), ttype->dtype); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); } else { @@ -135,8 +133,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> const auto* ttype = field.as(); // TODO(@icemelon): Allow recursive tuple CHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder( - GetShape(ttype->shape), ttype->dtype); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); } @@ -149,7 +146,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } @@ -190,29 +187,31 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> CHECK(op->is_scalar()); void* data = op->data->data; DataType dtype = DataType(op->data->dtype); - auto value = te::compute({}, [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, "compile_engine_const", topi::kBroadcast); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); scalars_.push_back(value->op); return {value}; } Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = - Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttr("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); CHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -227,12 +226,10 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } } if (count_tuple) { - CHECK_EQ(call_node->args.size(), 1U) - << "Only allow function with a single tuple input"; + CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } - CHECK(call_node->op.as()) - << "Primitive function only allows call into primitive ops"; + CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); Array outputs; @@ -240,8 +237,8 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> // Skip fcompute for device copy operators as it is not registered. if (op == device_copy_op_) { const auto* copy_input = inputs[0].operator->(); - outputs.push_back(te::TensorNode::make(copy_input->shape, copy_input->dtype, - te::Operation(), 0)); + outputs.push_back( + te::TensorNode::make(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); outputs = lowered_out->outputs; @@ -251,8 +248,8 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> int op_pattern = fpattern[op]; if (op_pattern >= kCommReduce) { CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce) - << "Two complicated op in a primitive function " - << " master=" << master_op_ << " current=" << op; + << "Two complicated op in a primitive function " + << " master=" << master_op_ << " current=" << op; } if (op_pattern >= master_op_pattern_) { master_op_ = op; @@ -261,8 +258,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> master_implementation_ = impl; } if (outputs.size() != 1) { - const auto* tuple_type = - call_node->checked_type().as(); + const auto* tuple_type = call_node->checked_type().as(); CHECK(tuple_type) << "Expect output to be a tuple type"; CHECK_EQ(tuple_type->fields.size(), outputs.size()); } @@ -292,8 +288,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { - CHECK(field->checked_type().as()) - << "Only allow Tuple of Tensor"; + CHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; Array res = VisitExpr(field); CHECK_EQ(res.size(), 1); fields.push_back(res[0]); @@ -349,15 +344,15 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> shape_inputs.push_back(shape_tensor); }; - if (const auto *ttype = param->checked_type().as()) { + if (const auto* ttype = param->checked_type().as()) { add_placeholder(ttype); } else { // flatten tuple of tensor type. - const auto *tuple_type = param->type_as(); + const auto* tuple_type = param->type_as(); // TODO(@icemelon): Support recursive tuple CHECK(tuple_type); for (Type field : tuple_type->fields) { - const auto *ttype = field.as(); + const auto* ttype = field.as(); CHECK(ttype); add_placeholder(ttype); } @@ -372,7 +367,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } @@ -448,28 +443,31 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> if (data_dependant) { void* data = op->data->data; DataType dtype = DataType(op->data->dtype); - auto value = tvm::te::compute({}, [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, "data_const", topi::kBroadcast); + auto value = tvm::te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "data_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } else { - auto value = tvm::te::compute({}, [&](const Array&) { - return tir::make_const(DataType::Int(64), 0); - }, "shape_const", topi::kBroadcast); + auto value = tvm::te::compute( + {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, + "shape_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } @@ -477,18 +475,15 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const CallNode* call_node) final { static auto fshape_func = Op::GetAttr("FShapeFunc"); - static auto tshape_data_dependant = Op::GetAttr( - "TShapeDataDependant"); - CHECK(call_node->op.as()) - << "Primitive function only allows call into primitive ops"; + static auto tshape_data_dependant = Op::GetAttr("TShapeDataDependant"); + CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); CHECK(data_dependants_.empty() || !data_dependants_.back()) - << "Error in op fusion: output of the shape func is fed to a " - << "data-dependant shape func"; - CHECK_GT(fshape_func.count(op), 0) - << "Internal error, cannot find ShapeFunc for " << op->name; + << "Error in op fusion: output of the shape func is fed to a " + << "data-dependant shape func"; + CHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; CHECK_GT(tshape_data_dependant.count(op), 0) - << "Internal error, cannot find TShapeDataDependant for " << op->name; + << "Internal error, cannot find TShapeDataDependant for " << op->name; data_dependants_.push_back(tshape_data_dependant[op]); // Visit all inputs @@ -503,8 +498,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } } if (count_tuple) { - CHECK_EQ(call_node->args.size(), 1U) - << "Only allow function with a single tuple input"; + CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } // Get output ndims auto ret_type = call_node->checked_type(); @@ -543,8 +537,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { - CHECK(field->checked_type().as()) - << "Only allow Tuple of Tensor"; + CHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; Array res = VisitExpr(field); CHECK_EQ(res.size(), 1); fields.push_back(res[0]); @@ -570,9 +563,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. - CachedFunc Lower(const CCacheKey& key) { - return LowerInternal(key)->cached_func; - } + CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; } // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { @@ -633,9 +624,7 @@ class CompileEngineImpl : public CompileEngineNode { return ret; } - void Clear() final { - cache_.clear(); - } + void Clear() final { cache_.clear(); } // List all items in the cache. Array ListItems() { std::lock_guard lock(mutex_); @@ -659,7 +648,7 @@ class CompileEngineImpl : public CompileEngineNode { private: // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key) { + CCacheValue LowerInternal(const CCacheKey& key) { std::lock_guard lock(mutex_); CCacheValue value; auto it = cache_.find(key); @@ -676,10 +665,8 @@ class CompileEngineImpl : public CompileEngineNode { // codegen tool once and lower all functions together. if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto cache_node = make_object(); - const auto name_node = - key->source_func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(name_node.defined()) - << "External function has not been attached a name yet."; + const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(name_node.defined()) << "External function has not been attached a name yet."; cache_node->func_name = std::string(name_node.value()); cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); @@ -690,8 +677,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto cfunc = CreateSchedule(key->source_func, key->target); - auto cache_node = make_object( - *(cfunc.operator->())); + auto cache_node = make_object(*(cfunc.operator->())); // Skip lowering for device copy node. const Expr body = (key->source_func)->body; @@ -710,13 +696,11 @@ class CompileEngineImpl : public CompileEngineNode { } // lower the function if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)( - cfunc->schedule, all_args, cache_node->func_name, key->source_func); + cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); } else { tvm::BuildConfig bcfg = BuildConfig::Create(); std::unordered_map binds; - cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, - binds, bcfg); + cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds, bcfg); } value->cached_func = CachedFunc(cache_node); return value; @@ -740,8 +724,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto spair = MakeShapeFunc().Create(key->source_func); - auto cache_node = make_object( - *(spair.second.operator->())); + auto cache_node = make_object(*(spair.second.operator->())); cache_node->func_name = GetUniqueName(cache_node->func_name); cache_node->target = key->target; @@ -792,57 +775,41 @@ class CompileEngineImpl : public CompileEngineNode { const CompileEngine& CompileEngine::Global() { // intentionally allocate raw pointer to avoid // free during destructuion. - static CompileEngine* inst = new CompileEngine( - make_object()); + static CompileEngine* inst = new CompileEngine(make_object()); return *inst; } TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") -.set_body_typed([](tvm::Array outputs, OpImplementation impl) { - return LoweredOutput(outputs, impl); -}); + .set_body_typed([](tvm::Array outputs, OpImplementation impl) { + return LoweredOutput(outputs, impl); + }); TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") -.set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); -}); + .set_body_typed([](Function source_func, Target target) { + return CCacheKey(source_func, target); + }); -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal") -.set_body_typed([]() { +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { return CompileEngine::Global(); }); -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear") -.set_body_typed([](CompileEngine self) { +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) { self->Clear(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") -.set_body_typed( - [](CompileEngine self, CCacheKey key) { - return self->Lower(key); -}); + .set_body_typed([](CompileEngine self, CCacheKey key) { return self->Lower(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") -.set_body_typed( - [](CompileEngine self, CCacheKey key) { - return self->LowerShapeFunc(key); -}); + .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") -.set_body_typed([](CompileEngine self) { - return self->LowerExternalFunctions(); -}); + .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") -.set_body_typed( - [](CompileEngine self, CCacheKey key) { - return self->JIT(key); -}); + .set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); }); -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") -.set_body_typed( - [](CompileEngine self){ +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) { return static_cast(self.operator->())->ListItems(); }); } // namespace relay diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 4a3a04d..9abe80c 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -27,13 +27,14 @@ #include #include -#include #include #include -#include #include -#include +#include +#include + #include +#include namespace tvm { namespace relay { @@ -150,9 +151,7 @@ class CCacheKey : public ObjectRef { */ TVM_DLL CCacheKey(Function source_func, Target target); - const CCacheKeyNode* operator->() const { - return static_cast(get()); - } + const CCacheKeyNode* operator->() const { return static_cast(get()); } // comparator inline bool operator==(const CCacheKey& other) const { CHECK(defined() && other.defined()); @@ -184,12 +183,8 @@ class CCacheValue : public ObjectRef { public: CCacheValue() {} explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} - CCacheValueNode* operator->() { - return static_cast(get_mutable()); - } - const CCacheValueNode* operator->() const { - return static_cast(get()); - } + CCacheValueNode* operator->() { return static_cast(get_mutable()); } + const CCacheValueNode* operator->() const { return static_cast(get()); } using ContainerType = CCacheValueNode; }; @@ -240,9 +235,7 @@ class CompileEngine : public ObjectRef { public: CompileEngine() {} explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} - CompileEngineNode* operator->() { - return static_cast(get_mutable()); - } + CompileEngineNode* operator->() { return static_cast(get_mutable()); } using ContainerType = CompileEngineNode; /*! \brief The global compile engine. */ TVM_DLL static const CompileEngine& Global(); @@ -260,17 +253,15 @@ inline size_t CCacheKeyNode::Hash() const { if (hash_ != 0) return hash_; // do structral hash, avoid 0. hash_ = tvm::StructuralHash()(this->source_func); - hash_ = dmlc::HashCombine( - hash_, std::hash()(target->str())); + hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); if (hash_ == 0) hash_ = 1; return hash_; } -inline bool CCacheKeyNode::Equal( - const CCacheKeyNode* other) const { +inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && - tvm::StructuralEqual()(this->source_func, other->source_func); + tvm::StructuralEqual()(this->source_func, other->source_func); } } // namespace relay @@ -278,7 +269,7 @@ inline bool CCacheKeyNode::Equal( namespace std { // overload hash -template<> +template <> struct hash<::tvm::relay::CCacheKey> { size_t operator()(const ::tvm::relay::CCacheKey& key) const { CHECK(key.defined()); diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index ed36fda..b8803d4 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -151,8 +151,8 @@ class CodegenC : public MemoizedExprTranslator>, public Code for (size_t i = 0; i < out_shape.size(); ++i) { out_size *= out_shape[i]; } - buf_stream << dtype << "* " << out << - " = (" << dtype << "*)std::malloc(4 * " << out_size << ");"; + buf_stream << dtype << "* " << out << " = (" << dtype << "*)std::malloc(4 * " << out_size + << ");"; buf_decl_.push_back(buf_stream.str()); decl_stream << ", " << out << ");"; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 9226386..2ee68ce 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -25,9 +25,10 @@ #define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_ #include -#include #include +#include #include + #include #include #include @@ -69,8 +70,7 @@ class CSourceModuleCodegenBase { * \return An external symbol. */ std::string GetExtSymbol(const Function& func) const { - const auto name_node = - func->GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; return std::string(name_node.value()); } @@ -124,8 +124,7 @@ class CodegenCBase { * * \endcode */ - void GenerateBackendCFunc(const std::string& func_name, - const Array& args, + void GenerateBackendCFunc(const std::string& func_name, const Array& args, const Output& out) { // Print signature code_stream_ << "\n"; @@ -158,8 +157,8 @@ class CodegenCBase { code_stream_ << "}\n\n"; // Generate the macro - code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " - << func_name << "_wrapper_);\n\n"; + code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " << func_name + << "_wrapper_);\n\n"; } /*! @@ -187,8 +186,7 @@ class CodegenCBase { */ std::string JitImpl(const std::string& ext_func_id, const Array& args, const std::vector& buf_decl, - const std::vector& body, - const std::vector& out) { + const std::vector& body, const std::vector& out) { // Create the signature. For example, it could be: // extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {} code_stream_ << "extern \"C\" void " << ext_func_id << "_("; diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 5e45e94..3db5dc4 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -174,7 +174,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C // Allocate large arrays on the static section to avoid stakc overflow. // Note that this would probably increase compilation time as the source // file could be really large. - buf_stream << "static float " << output.name << "[" << num_elems <<"] = {"; + buf_stream << "static float " << output.name << "[" << num_elems << "] = {"; for (int64_t i = 0; i < num_elems - 1; i++) { buf_stream << ptr[i] << ","; } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 736509d..820e17f 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -22,10 +22,11 @@ * \brief Memory index assignment pass for executing * the program in the graph runtime. */ -#include +#include #include #include -#include +#include + #include "../../support/arena.h" namespace tvm { @@ -60,9 +61,7 @@ class StorageAllocaBaseVisitor : public ExprVisitor { } } - void VisitExpr_(const ConstantNode* op) final { - this->CreateToken(op, false); - } + void VisitExpr_(const ConstantNode* op) final { this->CreateToken(op, false); } void VisitExpr_(const VarNode* op) final { // Do nothing. @@ -96,9 +95,7 @@ class StorageAllocaBaseVisitor : public ExprVisitor { token_map_[op] = {tok[op->index]}; } - void VisitExpr_(const IfNode* op) final { - LOG(FATAL) << "if is not supported."; - } + void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } void VisitExpr_(const LetNode* op) final { auto token = GetToken(op->value); @@ -131,12 +128,11 @@ class StorageAllocaBaseVisitor : public ExprVisitor { class StorageAllocaInit : protected StorageAllocaBaseVisitor { public: - explicit StorageAllocaInit(support::Arena* arena) - : arena_(arena) {} + explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {} /*! \return The internal token map */ - std::unordered_map > - GetInitTokenMap(const Function& func) { + std::unordered_map > GetInitTokenMap( + const Function& func) { node_device_map_ = CollectDeviceInfo(func); this->Run(func); return std::move(token_map_); @@ -145,12 +141,11 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { protected: using StorageAllocaBaseVisitor::VisitExpr_; - void CreateToken(const ExprNode* op, bool can_realloc) final { + void CreateToken(const ExprNode* op, bool can_realloc) final { CHECK(!token_map_.count(op)); std::vector tokens; - int device_type = node_device_map_.count(GetRef(op)) - ? node_device_map_[GetRef(op)]->value - : 0; + int device_type = + node_device_map_.count(GetRef(op)) ? node_device_map_[GetRef(op)]->value : 0; if (const auto* tuple_type = op->checked_type().as()) { for (Type t : tuple_type->fields) { const auto* ttype = t.as(); @@ -227,10 +222,9 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { - LOG(FATAL) - << num_annotated_nodes << " out of " << num_nodes - << "expressions are assigned with virtual device types. Either all " - "or none of the expressions are expected to be annotated."; + LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes + << "expressions are assigned with virtual device types. Either all " + "or none of the expressions are expected to be annotated."; } return smap; } @@ -296,12 +290,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { size_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = tir::as_const_int(dim); - CHECK(pval != nullptr) - << "Cannot allocate memory symbolic tensor shape " - << ttype->shape; - CHECK_GE(*pval, 0) - << "Cannot allocate memory for tensor with negative shape" - << *pval; + CHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; + CHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; size *= static_cast(pval[0]); } size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8); @@ -324,7 +314,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { auto end = free_.upper_bound(size * match_range_); // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { - StorageToken *tok = it->second; + StorageToken* tok = it->second; if (tok->device_type != prototype->device_type) continue; CHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy @@ -337,7 +327,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // then search for memory blocks smaller than requested space for (auto it = mid; it != begin;) { --it; - StorageToken *tok = it->second; + StorageToken* tok = it->second; if (tok->device_type != prototype->device_type) continue; CHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy @@ -390,8 +380,7 @@ Map > GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } -TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory") -.set_body_typed(GraphPlanMemory); +TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7b686c7..c8ec1bf 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -44,7 +44,7 @@ class GraphInputNode; class GraphOpNode; using IntegerArray = Array; -using ShapeVector = std::vector >; +using ShapeVector = std::vector>; using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; @@ -70,8 +70,7 @@ class GraphNodeRef { public: GraphNodeRef() {} GraphNodeRef(int ident, int index, int version = 0) - : ident_(ident), index_(index), version_(version) {} - + : ident_(ident), index_(index), version_(version) {} inline void Save(dmlc::JSONWriter* writer) const { writer->BeginArray(); @@ -81,9 +80,7 @@ class GraphNodeRef { writer->EndArray(); } - inline void Load(dmlc::JSONReader* reader) { - LOG(FATAL) << "Not implemented."; - } + inline void Load(dmlc::JSONReader* reader) { LOG(FATAL) << "Not implemented."; } protected: int ident_; @@ -136,11 +133,8 @@ class GraphInputNode : public GraphNode { class GraphOpNode : public GraphNode { public: GraphOpNode() {} - GraphOpNode(const std::string& name, - const GraphAttrs& nd_attrs, - const std::string& op_name, - const std::vector& inputs, - const GraphAttrs& attrs, + GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name, + const std::vector& inputs, const GraphAttrs& attrs, size_t num_outputs = 1) { name_ = name; attrs_ = nd_attrs; @@ -173,8 +167,7 @@ class GraphOpNode : public GraphNode { const GraphAttrs& nd_attrs, const std::string& op_name, const std::vector& inputs, - const GraphAttrs& attrs, - size_t num_outputs = 1) { + const GraphAttrs& attrs, size_t num_outputs = 1) { auto ptr = std::make_shared(name, nd_attrs, op_name, inputs, attrs, num_outputs); return std::dynamic_pointer_cast(ptr); } @@ -335,8 +328,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator GraphAddCallNode(const CallNode* op, - const std::string& op_name, + std::vector GraphAddCallNode(const CallNode* op, const std::string& op_name, const std::string& func_name) { std::vector inputs; for (auto arg : op->args) { @@ -345,11 +337,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator(op)); } @@ -384,11 +372,11 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorvalue; // Normal Relay Function if (targets_.size() == 1) { - // homogeneous execution. + // homogeneous execution. const auto& it = targets_.begin(); target = (*it).second; } else { @@ -400,8 +388,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorstr()] = IRModule::Empty(); } lowered_funcs_[target->str()]->Update(lowered_func->funcs); - return GraphAddCallNode(op, - _GetUniqueName(lowered_func->func_name), - lowered_func->func_name); + return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name); } std::vector VisitExpr_(const LetNode* op) override { @@ -560,37 +545,34 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator& sptr_to_self) { - if (name == "init") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 2) - << "The expected of arguments are: " - << "runtime::Module mod and Map targets"; - void* mod = args[0]; - Map tmp = args[1]; - TargetsMap targets; - for (const auto& it : tmp) { - auto dev_type = it.first.as(); - CHECK(dev_type); - targets[dev_type->value] = it.second; - } - codegen_ = std::make_shared( - reinterpret_cast(mod), targets); - }); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "init") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 2) << "The expected of arguments are: " + << "runtime::Module mod and Map targets"; + void* mod = args[0]; + Map tmp = args[1]; + TargetsMap targets; + for (const auto& it : tmp) { + auto dev_type = it.first.as(); + CHECK(dev_type); + targets[dev_type->value] = it.second; + } + codegen_ = + std::make_shared(reinterpret_cast(mod), targets); + }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Function func = args[0]; this->output_ = this->codegen_->Codegen(func); }); } else if (name == "get_graph_json") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->output_.graph_json; - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; }); } else if (name == "list_params_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Array ret; - for (const auto &kv : this->output_.params) { + for (const auto& kv : this->output_.params) { ret.push_back(kv.first); } *rv = ret; @@ -614,9 +596,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { } } - const char* type_key() const final { - return "RelayGraphRuntimeCodegenModule"; - } + const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } private: std::shared_ptr codegen_; @@ -629,9 +609,7 @@ runtime::Module CreateGraphCodegenMod() { } TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CreateGraphCodegenMod(); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); }); } // namespace backend } // namespace relay diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 465f788..c529997 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -21,16 +21,16 @@ * \file src/relay/interpreter.cc * \brief An interpreter for the Relay IR. */ -#include -#include -#include -#include -#include -#include +#include #include #include +#include #include -#include +#include +#include +#include +#include +#include #include "compile_engine.h" @@ -39,8 +39,7 @@ namespace relay { using namespace runtime; -InterpreterClosure::InterpreterClosure(tvm::Map env, - Function func) { +InterpreterClosure::InterpreterClosure(tvm::Map env, Function func) { ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); @@ -48,10 +47,10 @@ InterpreterClosure::InterpreterClosure(tvm::Map env, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")"; + }); inline const PackedFunc& GetPackedFunc(const std::string& name) { const PackedFunc* pf = tvm::runtime::Registry::Get(name); @@ -69,10 +68,10 @@ RecClosure::RecClosure(InterpreterClosure clos, Var bind) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RecClosureObj(" << node->clos << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RecClosureObj(" << node->clos << ")"; + }); RefValue::RefValue(ObjectRef value) { ObjectPtr n = make_object(); @@ -80,21 +79,19 @@ RefValue::RefValue(ObjectRef value) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relay._make.RefValue") -.set_body_typed([](ObjectRef value){ +TVM_REGISTER_GLOBAL("relay._make.RefValue").set_body_typed([](ObjectRef value) { return RefValue(value); }); TVM_REGISTER_NODE_TYPE(RefValueObj); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefValueObj(" << node->value << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefValueObj(" << node->value << ")"; + }); -ConstructorValue::ConstructorValue(int32_t tag, - tvm::Array fields, +ConstructorValue::ConstructorValue(int32_t tag, tvm::Array fields, Constructor constructor) { ObjectPtr n = make_object(); n->tag = tag; @@ -104,19 +101,17 @@ ConstructorValue::ConstructorValue(int32_t tag, } TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") -.set_body_typed([](int32_t tag, tvm::Array fields, - Constructor constructor) { - return ConstructorValue(tag, fields, constructor); -}); + .set_body_typed([](int32_t tag, tvm::Array fields, Constructor constructor) { + return ConstructorValue(tag, fields, constructor); + }); TVM_REGISTER_NODE_TYPE(ConstructorValueObj); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ConstructorValueObj(" << node->tag << "," - << node->fields << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstructorValueObj(" << node->tag << "," << node->fields << ")"; + }); /*! * \brief A stack frame in the Relay interpreter. @@ -161,9 +156,7 @@ struct Stack { */ struct LocalFrame { Stack& st; - explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { - st.frames.push_back(fr); - } + explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { st.frames.push_back(fr); } ~LocalFrame() { st.frames.pop_back(); } }; }; @@ -213,9 +206,8 @@ InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) { // contains DAG in dataflow-form. // // Conversion to ANF is recommended before running the interpretation. -class Interpreter : - public ExprFunctor, - PatternFunctor { +class Interpreter : public ExprFunctor, + PatternFunctor { public: Interpreter(IRModule mod, DLContext context, Target target) : mod_(mod), @@ -232,21 +224,13 @@ class Interpreter : return f(); } - void extend(const Var& id, ObjectRef v) { - stack_.current_frame().locals.Set(id, v); - } + void extend(const Var& id, ObjectRef v) { stack_.current_frame().locals.Set(id, v); } - ObjectRef Lookup(const Var& local) { - return stack_.Lookup(local); - } + ObjectRef Lookup(const Var& local) { return stack_.Lookup(local); } - ObjectRef Eval(const Expr& expr) { - return VisitExpr(expr); - } + ObjectRef Eval(const Expr& expr) { return VisitExpr(expr); } - ObjectRef VisitExpr_(const VarNode* var_node) final { - return Lookup(GetRef(var_node)); - } + ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef(var_node)); } ObjectRef VisitExpr_(const GlobalVarNode* op) final { return Eval(mod_->Lookup(GetRef(op))); @@ -260,9 +244,7 @@ class Interpreter : return ObjectRef(); } - ObjectRef VisitExpr_(const ConstantNode* op) final { - return op->data.CopyTo(context_); - } + ObjectRef VisitExpr_(const ConstantNode* op) final { return op->data.CopyTo(context_); } ObjectRef VisitExpr_(const TupleNode* op) final { std::vector values; @@ -302,8 +284,7 @@ class Interpreter : return MakeClosure(func); } - Array ComputeDynamicShape(const Function& func, - const Array& args) { + Array ComputeDynamicShape(const Function& func, const Array& args) { CCacheKey key(func, Target::Create("llvm")); auto cfunc = engine_->LowerShapeFunc(key); size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); @@ -319,26 +300,26 @@ class Interpreter : cpu_ctx.device_id = 0; auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) { - auto nd_array = Downcast(val); - if (need_shape) { - int64_t ndim = nd_array.Shape().size(); - NDArray shape_arr; - if (ndim == 0) { - shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx); - } else { - shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); - int64_t* data = reinterpret_cast(shape_arr->data); - for (auto j = 0; j < ndim; ++j) { - data[j] = nd_array.Shape()[j]; - } - } - inputs[i] = shape_arr; - setter(i, shape_arr); + auto nd_array = Downcast(val); + if (need_shape) { + int64_t ndim = nd_array.Shape().size(); + NDArray shape_arr; + if (ndim == 0) { + shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx); } else { - auto arr = nd_array.CopyTo(cpu_ctx); - inputs[i] = arr; - setter(i, arr); + shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); + int64_t* data = reinterpret_cast(shape_arr->data); + for (auto j = 0; j < ndim; ++j) { + data[j] = nd_array.Shape()[j]; + } } + inputs[i] = shape_arr; + setter(i, shape_arr); + } else { + auto arr = nd_array.CopyTo(cpu_ctx); + inputs[i] = arr; + setter(i, arr); + } }; size_t arg_counter = 0; @@ -367,17 +348,16 @@ class Interpreter : } } } - CHECK_EQ(arg_counter, cfunc->inputs.size()) - << "Shape function input sizes mismatch"; + CHECK_EQ(arg_counter, cfunc->inputs.size()) << "Shape function input sizes mismatch"; auto fset_shape_output = [&](size_t i, Type val_type) { - // TODO(@icemelon): allow recursive tuple - const TensorTypeNode* rtype = val_type.as(); - CHECK(rtype != nullptr); - int64_t ndim = rtype->shape.size(); - auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); - outputs[i] = arr; - setter(arg_counter + i, arr); + // TODO(@icemelon): allow recursive tuple + const TensorTypeNode* rtype = val_type.as(); + CHECK(rtype != nullptr); + int64_t ndim = rtype->shape.size(); + auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); + outputs[i] = arr; + setter(arg_counter + i, arr); }; auto ret_type = func->body->checked_type(); @@ -392,8 +372,7 @@ class Interpreter : auto tt = Downcast(ret_type); fset_shape_output(0, tt); } - CHECK_EQ(cfunc->outputs.size(), out_cnt) - << "Shape function output sizes mismatch"; + CHECK_EQ(cfunc->outputs.size(), out_cnt) << "Shape function output sizes mismatch"; PackedFunc shape_func; Module m; @@ -419,8 +398,7 @@ class Interpreter : return out_shapes; } - ObjectRef InvokePrimitiveOp(const Function& func, - const Array& args) { + ObjectRef InvokePrimitiveOp(const Function& func, const Array& args) { const auto* call_node = func->body.as(); if (call_node && call_node->op == debug_op_) { @@ -451,8 +429,7 @@ class Interpreter : if (const auto* tuple_type = func->body->checked_type().as()) { arg_len += tuple_type->fields.size(); } else { - CHECK(func->body->checked_type().as()) - << func->body->checked_type(); + CHECK(func->body->checked_type().as()) << func->body->checked_type(); arg_len += 1; } std::vector values(arg_len); @@ -463,16 +440,14 @@ class Interpreter : const auto nd_array = Downcast(val); setter(i, nd_array); DLContext arg_ctx = nd_array->ctx; - CHECK(arg_ctx.device_type == context_.device_type && - arg_ctx.device_id == context_.device_id) - << "Interpreter expect context to be " - << context_ << ", but get " << arg_ctx; + CHECK(arg_ctx.device_type == context_.device_type && arg_ctx.device_id == context_.device_id) + << "Interpreter expect context to be " << context_ << ", but get " << arg_ctx; }; int arg_counter = 0; for (ObjectRef arg : args) { if (arg->IsInstance()) { - fset_input(arg_counter++, arg); + fset_input(arg_counter++, arg); } else { auto adt = Downcast(arg); for (size_t i = 0; i < adt.size(); ++i) { @@ -547,8 +522,7 @@ class Interpreter : } // Invoke the closure - ObjectRef Invoke(const InterpreterClosure& closure, - const tvm::Array& args, + ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. if (closure->func->HasNonzeroAttr(attr::kPrimitive)) { @@ -625,11 +599,9 @@ class Interpreter : ObjectRef VisitExpr_(const TupleGetItemNode* op) final { ObjectRef val = Eval(op->tuple); const auto* adt_obj = val.as(); - CHECK(adt_obj) - << "interal error: when evaluating TupleGetItem expected an ADT value"; + CHECK(adt_obj) << "interal error: when evaluating TupleGetItem expected an ADT value"; auto adt = GetRef(adt_obj); - CHECK_LT(static_cast(op->index), adt.size()) - << "internal error: index out of bounds"; + CHECK_LT(static_cast(op->index), adt.size()) << "internal error: index out of bounds"; return adt[op->index]; } @@ -665,9 +637,7 @@ class Interpreter : } } - ObjectRef VisitExpr_(const RefCreateNode* op) final { - return RefValue(Eval(op->value)); - } + ObjectRef VisitExpr_(const RefCreateNode* op) final { return RefValue(Eval(op->value)); } ObjectRef VisitExpr_(const RefReadNode* op) final { ObjectRef r = Eval(op->ref); @@ -718,9 +688,7 @@ class Interpreter : return true; } - bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { - return true; - } + bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { return true; } bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final { extend(op->var, v); @@ -754,17 +722,11 @@ class Interpreter : const Op& shape_of_op_; }; - -TypedPackedFunc -CreateInterpreter( - IRModule mod, - DLContext context, - Target target) { +TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, Target target) { if (mod.defined()) { // eta expand to support constructors in argument position - transform::Sequential seq({ - transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)}); + transform::Sequential seq({transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)}); transform::PassContext pass_ctx = transform::PassContext::Current(); tvm::With ctx(pass_ctx); mod = seq(mod); @@ -779,8 +741,7 @@ CreateInterpreter( return TypedPackedFunc(packed); } -TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") -.set_body_typed(CreateInterpreter); +TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter").set_body_typed(CreateInterpreter); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/param_dict.cc b/src/relay/backend/param_dict.cc index e517fee..cd760b8 100644 --- a/src/relay/backend/param_dict.cc +++ b/src/relay/backend/param_dict.cc @@ -22,86 +22,77 @@ * \brief Implementation and registration of parameter dictionary * serializing/deserializing functions. */ -#include +#include "param_dict.h" + #include +#include #include -#include #include - -#include "param_dict.h" - - +#include namespace tvm { namespace relay { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict") -.set_body([](TVMArgs args, TVMRetValue *rv) { - CHECK_EQ(args.size() % 2, 0u); - // `args` is in the form "key, value, key, value, ..." - size_t num_params = args.size() / 2; - std::vector names; - names.reserve(num_params); - std::vector arrays; - arrays.reserve(num_params); - for (size_t i = 0; i < num_params * 2; i += 2) { - names.emplace_back(args[i].operator std::string()); - arrays.emplace_back(args[i + 1].operator DLTensor*()); - } - std::string bytes; - dmlc::MemoryStringStream strm(&bytes); - dmlc::Stream* fo = &strm; - uint64_t header = kTVMNDArrayListMagic, reserved = 0; - fo->Write(header); - fo->Write(reserved); - fo->Write(names); - { - uint64_t sz = static_cast(arrays.size()); - fo->Write(sz); - for (size_t i = 0; i < sz; ++i) { - tvm::runtime::SaveDLTensor(fo, arrays[i]); - } +TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size() % 2, 0u); + // `args` is in the form "key, value, key, value, ..." + size_t num_params = args.size() / 2; + std::vector names; + names.reserve(num_params); + std::vector arrays; + arrays.reserve(num_params); + for (size_t i = 0; i < num_params * 2; i += 2) { + names.emplace_back(args[i].operator std::string()); + arrays.emplace_back(args[i + 1].operator DLTensor*()); + } + std::string bytes; + dmlc::MemoryStringStream strm(&bytes); + dmlc::Stream* fo = &strm; + uint64_t header = kTVMNDArrayListMagic, reserved = 0; + fo->Write(header); + fo->Write(reserved); + fo->Write(names); + { + uint64_t sz = static_cast(arrays.size()); + fo->Write(sz); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(fo, arrays[i]); } - TVMByteArray arr; - arr.data = bytes.c_str(); - arr.size = bytes.length(); - *rv = arr; - }); + } + TVMByteArray arr; + arr.data = bytes.c_str(); + arr.size = bytes.length(); + *rv = arr; +}); -TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string bytes = args[0]; - std::vector names; - dmlc::MemoryStringStream memstrm(&bytes); - dmlc::Stream* strm = &memstrm; - uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid parameters file format"; - CHECK(header == kTVMNDArrayListMagic) - << "Invalid parameters file format"; - CHECK(strm->Read(&reserved)) - << "Invalid parameters file format"; - CHECK(strm->Read(&names)) - << "Invalid parameters file format"; - uint64_t sz; - strm->Read(&sz, sizeof(sz)); - size_t size = static_cast(sz); - CHECK(size == names.size()) - << "Invalid parameters file format"; - tvm::Array ret; - for (size_t i = 0; i < size; ++i) { - tvm::runtime::NDArray temp; - temp.Load(strm); - auto n = tvm::make_object(); - n->name = std::move(names[i]); - n->array = temp; - ret.push_back(NamedNDArray(n)); - } - *rv = ret; - }); +TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string bytes = args[0]; + std::vector names; + dmlc::MemoryStringStream memstrm(&bytes); + dmlc::Stream* strm = &memstrm; + uint64_t header, reserved; + CHECK(strm->Read(&header)) << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; + CHECK(strm->Read(&names)) << "Invalid parameters file format"; + uint64_t sz; + strm->Read(&sz, sizeof(sz)); + size_t size = static_cast(sz); + CHECK(size == names.size()) << "Invalid parameters file format"; + tvm::Array ret; + for (size_t i = 0; i < size; ++i) { + tvm::runtime::NDArray temp; + temp.Load(strm); + auto n = tvm::make_object(); + n->name = std::move(names[i]); + n->array = temp; + ret.push_back(NamedNDArray(n)); + } + *rv = ret; +}); TVM_REGISTER_NODE_TYPE(NamedNDArrayNode); diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h index c829e54..384201f 100644 --- a/src/relay/backend/param_dict.h +++ b/src/relay/backend/param_dict.h @@ -25,9 +25,9 @@ #define TVM_RELAY_BACKEND_PARAM_DICT_H_ #include -#include #include #include +#include #include diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 5c98728..b19d272 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -215,5 +215,4 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, } // namespace relay } // namespace tvm - #endif // TVM_RELAY_BACKEND_UTILS_H_ diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 9cdd365..b2a5e83 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -22,27 +22,29 @@ * \brief A compiler from relay::Module to the VM byte code. */ -#include +#include "compiler.h" + +#include #include +#include #include #include #include -#include #include #include -#include -#include +#include +#include #include #include #include #include #include -#include "../utils.h" + #include "../../backend/compile_engine.h" -#include "../../transforms/pass_util.h" #include "../../op/op_common.h" -#include "compiler.h" +#include "../../transforms/pass_util.h" +#include "../utils.h" namespace tvm { namespace relay { @@ -93,8 +95,7 @@ struct AccessField : MatchValue { // Runtime register num after compiling the access field path RegName reg{-1}; - AccessField(MatchValuePtr parent, size_t index) - : parent(parent), index(index) {} + AccessField(MatchValuePtr parent, size_t index) : parent(parent), index(index) {} ~AccessField() {} }; @@ -115,8 +116,7 @@ struct VarBinding : ConditionNode { Var var; MatchValuePtr val; - VarBinding(Var var, MatchValuePtr val) - : var(var), val(val) {} + VarBinding(Var var, MatchValuePtr val) : var(var), val(val) {} ~VarBinding() {} }; @@ -131,9 +131,7 @@ struct TagCompare : ConditionNode { /*! \brief The expected tag */ int target_tag; - TagCompare(MatchValuePtr obj, size_t target) - : obj(obj), target_tag(target) { - } + TagCompare(MatchValuePtr obj, size_t target) : obj(obj), target_tag(target) {} ~TagCompare() {} }; @@ -143,10 +141,8 @@ using TreeLeafNode = relay::TreeLeafNode; using TreeLeafFatalNode = relay::TreeLeafFatalNode; using TreeBranchNode = relay::TreeBranchNode; -TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, - Pattern pattern, - TreeObjectPtr then_branch, - TreeObjectPtr else_branch) { +TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, Pattern pattern, + TreeObjectPtr then_branch, TreeObjectPtr else_branch) { if (pattern.as()) { // We ignore wildcard binding since it's not producing new vars return then_branch; @@ -176,11 +172,10 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, } } -TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, - Clause clause, - TreeObjectPtr else_branch) { - return BuildDecisionTreeFromPattern(data, clause->lhs, - TreeLeafNode::Make(clause->rhs), else_branch); +TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, Clause clause, + TreeObjectPtr else_branch) { + return BuildDecisionTreeFromPattern(data, clause->lhs, TreeLeafNode::Make(clause->rhs), + else_branch); } TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array clauses) { @@ -196,12 +191,11 @@ TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array ToAllocTensorShape(NDArray shape) { std::vector raw_shape; CHECK_EQ(shape->ndim, 1u); - CHECK_EQ(shape->dtype.code, 0U) - << "The dtype of constant shape must be int32 or int64, but got " - << DLDataType2String(shape->dtype); + CHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got " + << DLDataType2String(shape->dtype); CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32) - << "The dtype of constant shape must be int32 or int64, but got" - << DLDataType2String(shape->dtype); + << "The dtype of constant shape must be int32 or int64, but got" + << DLDataType2String(shape->dtype); if (shape->dtype.bits == 64) { int64_t* int_ptr = reinterpret_cast(shape->data); @@ -217,7 +211,6 @@ std::vector ToAllocTensorShape(NDArray shape) { return raw_shape; } - class VMFunctionCompiler : ExprFunctor { public: VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) @@ -310,11 +303,7 @@ class VMFunctionCompiler : ExprFunctor { } // TODO(@jroesch): use correct tag - Emit(Instruction::AllocADT( - 0, - tuple->fields.size(), - fields_registers, - NewRegister())); + Emit(Instruction::AllocADT(0, tuple->fields.size(), fields_registers, NewRegister())); } void VisitExpr_(const MatchNode* match_node) { @@ -415,52 +404,46 @@ class VMFunctionCompiler : ExprFunctor { for (auto input : inputs) { auto reg = var_register_map_.find(Downcast(input)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } for (auto output : outputs) { auto reg = var_register_map_.find(Downcast(output)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } - Emit(Instruction::InvokePacked(op_index, - argument_registers.size(), - outputs.size(), - argument_registers)); + Emit(Instruction::InvokePacked(op_index, argument_registers.size(), outputs.size(), + argument_registers)); } - void EmitInvokeTVMOp(const Function& func, - const Expr& inputs, - const Expr& outputs) { + void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) { std::vector argument_registers; CHECK(func->GetAttr(attr::kPrimitive, 0) != 0) - << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; + << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); - CHECK(input_tuple) - << "internal error: invoke_tvm_op inputs must be a tuple," - << "please file a bug in the memory manifestation pass"; + CHECK(input_tuple) << "internal error: invoke_tvm_op inputs must be a tuple," + << "please file a bug in the memory manifestation pass"; auto output_tuple = outputs.as(); - CHECK(output_tuple) - << "internal error: invoke_tvm_op outputs must be a tuple," - << "please file a bug in the memory manifestation pass"; + CHECK(output_tuple) << "internal error: invoke_tvm_op outputs must be a tuple," + << "please file a bug in the memory manifestation pass"; for (auto input : input_tuple->fields) { auto reg = var_register_map_.find(Downcast(input)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } for (auto output : output_tuple->fields) { auto reg = var_register_map_.find(Downcast(output)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } @@ -500,10 +483,8 @@ class VMFunctionCompiler : ExprFunctor { } } - Emit(Instruction::InvokePacked(op_index, - argument_registers.size(), - output_tuple->fields.size(), - argument_registers)); + Emit(Instruction::InvokePacked(op_index, argument_registers.size(), output_tuple->fields.size(), + argument_registers)); } void VisitExpr_(const CallNode* call_node) { @@ -514,70 +495,73 @@ class VMFunctionCompiler : ExprFunctor { // allocation operations. if (op.as()) { OpMatch matcher; - matcher.Match("memory.invoke_tvm_op", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 3); - EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); - }).Match("memory.alloc_tensor", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 2); - - // Get the attributes. - auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) - << "must be the alloc tensor attrs"; - auto dtype = alloc_attrs->dtype; - - // The storage will be passed dynamically. - this->VisitExpr(args[0]); - auto storage_register = last_register_; - - // If the shape is constant then we will emit a static tensor allocation instruction. - auto const_shape = args[1].as(); - - if (const_shape) { - NDArray shape = const_shape->data; - // TODO(@jroesch): we need to get an RFC done to standarize shape dtype - std::vector raw_shape = ToAllocTensorShape(shape); - // Add context field. - Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister())); - } else { - this->VisitExpr(args[1]); - auto shape_register = last_register_; - Emit(Instruction::AllocTensorReg( - storage_register, - shape_register, - dtype, - NewRegister())); - } - }).Match("memory.alloc_storage", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 2); - // Compute the size of the allocation. - this->VisitExpr(args[0]); - auto size_register = last_register_; - - this->VisitExpr(args[1]); - auto alignment_register = last_register_; - - // Get the dtype hint from the attributes. - auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) - << "must be the alloc tensor attrs"; - auto dtype = alloc_attrs->dtype; - - Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, NewRegister())); - }).Match("memory.shape_func", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 3); - auto shape_func = Downcast(args[0]); - auto inputs = Downcast(args[1]); - auto outputs = Downcast(args[2]); - EmitShapeFunc(shape_func, inputs->fields, outputs->fields); - }).Match("memory.kill", - [](const Array& args, const Attrs& attrs, const Array& type_arg) { - LOG(FATAL) << "memory.kill is not yet supported"; - }); + matcher + .Match("memory.invoke_tvm_op", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); + EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); + }) + .Match( + "memory.alloc_tensor", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 2); + + // Get the attributes. + auto alloc_attrs = attrs.as(); + CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + auto dtype = alloc_attrs->dtype; + + // The storage will be passed dynamically. + this->VisitExpr(args[0]); + auto storage_register = last_register_; + + // If the shape is constant then we will emit a static tensor allocation + // instruction. + auto const_shape = args[1].as(); + + if (const_shape) { + NDArray shape = const_shape->data; + // TODO(@jroesch): we need to get an RFC done to standarize shape dtype + std::vector raw_shape = ToAllocTensorShape(shape); + // Add context field. + Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister())); + } else { + this->VisitExpr(args[1]); + auto shape_register = last_register_; + Emit(Instruction::AllocTensorReg(storage_register, shape_register, dtype, + NewRegister())); + } + }) + .Match("memory.alloc_storage", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 2); + // Compute the size of the allocation. + this->VisitExpr(args[0]); + auto size_register = last_register_; + + this->VisitExpr(args[1]); + auto alignment_register = last_register_; + + // Get the dtype hint from the attributes. + auto alloc_attrs = attrs.as(); + CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + auto dtype = alloc_attrs->dtype; + + Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, + NewRegister())); + }) + .Match("memory.shape_func", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); + auto shape_func = Downcast(args[0]); + auto inputs = Downcast(args[1]); + auto outputs = Downcast(args[2]); + EmitShapeFunc(shape_func, inputs->fields, outputs->fields); + }) + .Match("memory.kill", + [](const Array& args, const Attrs& attrs, const Array& type_arg) { + LOG(FATAL) << "memory.kill is not yet supported"; + }); matcher(GetRef(call_node)); return; } @@ -600,14 +584,13 @@ class VMFunctionCompiler : ExprFunctor { auto it = context_->global_map.find(global); CHECK(it != context_->global_map.end()); DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint - << " with func_index=" << it->second; + << " with func_index=" << it->second; // TODO(tvm-team): // Think about mixed call into global that is not a relay::Function // perhaps establish as an invariance(all functions in mod must be relay::Function) auto func = Downcast(context_->module->Lookup(global)); - if (IsClosure(func)) { auto arity = func->params.size(); Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); @@ -738,9 +721,7 @@ class VMFunctionCompiler : ExprFunctor { Target target_host_; }; - -PackedFunc VMCompiler::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "lower") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); @@ -753,9 +734,8 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, this->Codegen(); }); } else if (name == "get_executable") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = runtime::Module(exec_); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = runtime::Module(exec_); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map params = args[0]; @@ -786,11 +766,8 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host) { - CHECK_EQ(targets.size(), 1) - << "Currently VM compiler doesn't support heterogeneous compilation"; +void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { + CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation"; if (params_.size()) { BaseFunc base_func = mod->Lookup("main"); CHECK(base_func->IsInstance()) @@ -867,7 +844,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // eta expand to support constructors in argument position pass_seqs.push_back(transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)); + /* expand_constructor */ true, /* expand_global_var */ false)); pass_seqs.push_back(transform::SimplifyInference()); PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { @@ -949,7 +926,7 @@ void VMCompiler::Codegen() { LOG(WARNING) << "Did you forget to call VMCompiler::Lower?"; return; } - auto const &cached_funcs = context_.cached_funcs; + auto const& cached_funcs = context_.cached_funcs; if (cached_funcs.size() == 0) { return; } @@ -999,8 +976,7 @@ runtime::Module CreateVMCompiler() { return runtime::Module(exec); } -TVM_REGISTER_GLOBAL("relay._vm._VMCompiler") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay._vm._VMCompiler").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateVMCompiler(); }); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index c1040f1..7faab9d 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -28,9 +28,11 @@ #include #include #include -#include #include #include +#include +#include + #include #include #include @@ -38,8 +40,9 @@ #include #include #include -#include "../../../runtime/vm/profiler/vm.h" + #include "../../../runtime/vm/naive_allocator.h" +#include "../../../runtime/vm/profiler/vm.h" #include "../../backend/compile_engine.h" #include "../../transforms/pass_util.h" @@ -79,17 +82,13 @@ struct VMCompilerContext { std::unordered_map seen_funcs; }; - class VMCompiler : public runtime::ModuleNode { public: virtual ~VMCompiler() {} - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); - const char* type_key() const { - return "VMCompiler"; - } + const char* type_key() const { return "VMCompiler"; } /*! * \brief Set the parameters @@ -107,9 +106,7 @@ class VMCompiler : public runtime::ModuleNode { to target mapping. For homogeneous compilation, it is a build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host); + void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 12113b0..8e960a7 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -24,9 +24,10 @@ #include #include -#include #include #include +#include + #include #include @@ -125,18 +126,13 @@ struct PrimitiveInliner : ExprMutator { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - DLOG(INFO) << "Before inlining primitives: " << global - << std::endl << AsText(func, false); + DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false); - func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); module_->Add(global, func, true); - DLOG(INFO) << "After inlining primitives: " << global - << std::endl << AsText(func, false); + DLOG(INFO) << "After inlining primitives: " << global << std::endl << AsText(func, false); } } return module_; @@ -149,16 +145,13 @@ namespace transform { Pass InlinePrimitives() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::vm::PrimitiveInliner(m).Inline(); - }; + [=](IRModule m, PassContext pc) { return relay::vm::PrimitiveInliner(m).Inline(); }; auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {}); // Eliminate dead code for each function after inlining. return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives"); } -TVM_REGISTER_GLOBAL("relay._transform.InlinePrimitives") -.set_body_typed(InlinePrimitives); +TVM_REGISTER_GLOBAL("relay._transform.InlinePrimitives").set_body_typed(InlinePrimitives); } // namespace transform diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index bfbefd5..1d3fff7 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -24,12 +24,13 @@ #include #include +#include #include #include -#include -#include #include #include +#include + #include #include @@ -44,9 +45,7 @@ inline std::string GenerateName(const Function& func) { return std::string("lifted_name") + std::to_string(hash); } -bool IsClosure(const Function& func) { - return func->GetAttr(attr::kClosure, 0) != 0; -} +bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } Function MarkClosure(Function func) { return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); @@ -85,8 +84,7 @@ class LambdaLifter : public ExprMutator { if (!letrec_.empty() && var == letrec_.back()) { auto it = lambda_map_.find(var); CHECK(it != lambda_map_.end()); - return Call(it->second, call->args, call_node->attrs, - call_node->type_args); + return Call(it->second, call->args, call_node->attrs, call_node->type_args); } } return std::move(call); @@ -153,18 +151,15 @@ class LambdaLifter : public ExprMutator { if (captured_vars.size() == 0 && free_type_vars.size() == 0) { lifted_func = Function(body->params, body->body, body->ret_type, body->type_params); } else { - lifted_func = - Function(captured_vars, body, func->func_type_annotation(), free_type_vars); + lifted_func = Function(captured_vars, body, func->func_type_annotation(), free_type_vars); lifted_func = MarkClosure(lifted_func); } CHECK(lifted_func.defined()); - if (module_->ContainGlobalVar(name)) { const auto existing_func = module_->Lookup(name); - CHECK(tvm::StructuralEqual()(lifted_func, existing_func)) - << "lifted function hash collision"; + CHECK(tvm::StructuralEqual()(lifted_func, existing_func)) << "lifted function hash collision"; // If an identical function already exists, use its global var. global = module_->GetGlobalVar(name); } else { @@ -192,10 +187,7 @@ class LambdaLifter : public ExprMutator { if (auto* n = pair.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); module_->Add(pair.first, func, true); } @@ -215,14 +207,11 @@ namespace transform { Pass LambdaLift() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::vm::LambdaLifter(m).Lift(); - }; + [=](IRModule m, PassContext pc) { return relay::vm::LambdaLifter(m).Lift(); }; return CreateModulePass(pass_func, 1, "LambdaLift", {}); } -TVM_REGISTER_GLOBAL("relay._transform.LambdaLift") -.set_body_typed(LambdaLift); +TVM_REGISTER_GLOBAL("relay._transform.LambdaLift").set_body_typed(LambdaLift); } // namespace transform diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index c2fe37f..64ddbe3 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -22,12 +22,13 @@ * \brief Remove unused global relay functions in a relay module. */ +#include #include #include -#include -#include #include #include +#include + #include #include #include @@ -48,10 +49,7 @@ struct CallTracer : ExprVisitor { // Record the expressions that are being visited std::unordered_set visiting_; - explicit CallTracer(const IRModule& module) - : module_{module}, - called_funcs_{}, - visiting_{} {} + explicit CallTracer(const IRModule& module) : module_{module}, called_funcs_{}, visiting_{} {} void VisitExpr_(const GlobalVarNode* op) final { called_funcs_.insert(op->name_hint); @@ -86,8 +84,7 @@ struct CallTracer : ExprVisitor { * * \return The module with dead functions removed. */ -IRModule RemoveUnusedFunctions(const IRModule& module, - Array entry_funcs) { +IRModule RemoveUnusedFunctions(const IRModule& module, Array entry_funcs) { std::unordered_set called_funcs{}; for (auto entry : entry_funcs) { auto funcs = CallTracer(module).Trace(entry); @@ -108,15 +105,14 @@ IRModule RemoveUnusedFunctions(const IRModule& module, namespace transform { Pass RemoveUnusedFunctions(Array entry_functions) { - runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { return relay::vm::RemoveUnusedFunctions(m, entry_functions); }; return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {}); } -TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions") -.set_body_typed(RemoveUnusedFunctions); +TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions); } // namespace transform diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 11c2cbb..d808351 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -21,8 +21,8 @@ * \file src/ir/adt.cc * \brief AST nodes for Relay algebraic data types (ADTs). */ -#include #include +#include namespace tvm { namespace relay { @@ -34,15 +34,12 @@ PatternWildcard::PatternWildcard() { TVM_REGISTER_NODE_TYPE(PatternWildcardNode); -TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard") -.set_body_typed([]() { - return PatternWildcard(); -}); +TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard").set_body_typed([]() { return PatternWildcard(); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "PatternWildcardNode()"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "PatternWildcardNode()"; + }); PatternVar::PatternVar(tvm::relay::Var var) { ObjectPtr n = make_object(); @@ -52,19 +49,17 @@ PatternVar::PatternVar(tvm::relay::Var var) { TVM_REGISTER_NODE_TYPE(PatternVarNode); -TVM_REGISTER_GLOBAL("relay.ir.PatternVar") -.set_body_typed([](tvm::relay::Var var) { +TVM_REGISTER_GLOBAL("relay.ir.PatternVar").set_body_typed([](tvm::relay::Var var) { return PatternVar(var); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PatternVarNode(" << node->var << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PatternVarNode(" << node->var << ")"; + }); -PatternConstructor::PatternConstructor(Constructor constructor, - tvm::Array patterns) { +PatternConstructor::PatternConstructor(Constructor constructor, tvm::Array patterns) { ObjectPtr n = make_object(); n->constructor = std::move(constructor); n->patterns = std::move(patterns); @@ -74,16 +69,15 @@ PatternConstructor::PatternConstructor(Constructor constructor, TVM_REGISTER_NODE_TYPE(PatternConstructorNode); TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor") -.set_body_typed([](Constructor constructor, tvm::Array patterns) { - return PatternConstructor(constructor, patterns); -}); + .set_body_typed([](Constructor constructor, tvm::Array patterns) { + return PatternConstructor(constructor, patterns); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PatternConstructorNode(" << node->constructor - << ", " << node->patterns << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PatternConstructorNode(" << node->constructor << ", " << node->patterns << ")"; + }); PatternTuple::PatternTuple(tvm::Array patterns) { ObjectPtr n = make_object(); @@ -93,16 +87,15 @@ PatternTuple::PatternTuple(tvm::Array patterns) { TVM_REGISTER_NODE_TYPE(PatternTupleNode); -TVM_REGISTER_GLOBAL("relay.ir.PatternTuple") -.set_body_typed([](tvm::Array patterns) { +TVM_REGISTER_GLOBAL("relay.ir.PatternTuple").set_body_typed([](tvm::Array patterns) { return PatternTuple(patterns); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PatternTupleNode(" << node->patterns << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PatternTupleNode(" << node->patterns << ")"; + }); Clause::Clause(Pattern lhs, Expr rhs) { ObjectPtr n = make_object(); @@ -113,17 +106,15 @@ Clause::Clause(Pattern lhs, Expr rhs) { TVM_REGISTER_NODE_TYPE(ClauseNode); -TVM_REGISTER_GLOBAL("relay.ir.Clause") -.set_body_typed([](Pattern lhs, Expr rhs) { +TVM_REGISTER_GLOBAL("relay.ir.Clause").set_body_typed([](Pattern lhs, Expr rhs) { return Clause(lhs, rhs); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ClauseNode(" << node->lhs << ", " - << node->rhs << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")"; + }); Match::Match(Expr data, tvm::Array clauses, bool complete) { ObjectPtr n = make_object(); @@ -136,16 +127,16 @@ Match::Match(Expr data, tvm::Array clauses, bool complete) { TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_GLOBAL("relay.ir.Match") -.set_body_typed([](Expr data, tvm::Array clauses, bool complete) { - return Match(data, clauses, complete); -}); + .set_body_typed([](Expr data, tvm::Array clauses, bool complete) { + return Match(data, clauses, complete); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "MatchNode(" << node->data << ", " - << node->clauses << ", " << node->complete << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "MatchNode(" << node->data << ", " << node->clauses << ", " << node->complete + << ")"; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 76a3f9d..37b0ff5 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -23,8 +23,8 @@ */ #include -#include #include +#include namespace tvm { namespace relay { @@ -39,8 +39,7 @@ Id::Id(std::string name_hint) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("ir.NodeSetSpan") -.set_body_typed([](ObjectRef node_ref, Span sp) { +TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { rn->span = sp; } else if (auto* rn = node_ref.as()) { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 169db62..5ac5805 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -38,19 +38,18 @@ Constant::Constant(runtime::NDArray data) { TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_GLOBAL("relay.ir.Constant") -.set_body_typed([](runtime::NDArray data) { +TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data) { return Constant(data); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PackedFunc* fprint = Registry::Get("relay._constant_repr"); - CHECK(fprint) << "unable to find printing function for constants"; - std::string data = (*fprint)(GetRef(node)); - p->stream << "Constant(" << data << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PackedFunc* fprint = Registry::Get("relay._constant_repr"); + CHECK(fprint) << "unable to find printing function for constants"; + std::string data = (*fprint)(GetRef(node)); + p->stream << "Constant(" << data << ")"; + }); TensorType ConstantNode::tensor_type() const { auto dtype = DataType(data->dtype); @@ -58,8 +57,7 @@ TensorType ConstantNode::tensor_type() const { for (int i = 0; i < data->ndim; i++) { CHECK_LE(data->shape[i], std::numeric_limits::max()); CHECK_GE(data->shape[i], std::numeric_limits::min()); - shape.push_back( - tvm::IntImm(DataType::Int(32), data->shape[i])); + shape.push_back(tvm::IntImm(DataType::Int(32), data->shape[i])); } return TensorType(shape, dtype); @@ -73,17 +71,15 @@ Tuple::Tuple(tvm::Array fields) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_REGISTER_GLOBAL("relay.ir.Tuple") -.set_body_typed([](tvm::Array fields) { +TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields) { return Tuple(fields); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Tuple(" << node->fields << ")"; - }); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Tuple(" << node->fields << ")"; + }); Var::Var(Id vid, Type type_annotation) { ObjectPtr n = make_object(); @@ -94,21 +90,20 @@ Var::Var(Id vid, Type type_annotation) { TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_GLOBAL("relay.ir.Var") -.set_body_typed([](std::string str, Type type_annotation) { +TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](std::string str, Type type_annotation) { return Var(str, type_annotation); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Var(" << node->name_hint(); - if (node->type_annotation.defined()) { - p->stream << ", ty="; - p->Print(node->type_annotation); - } - p->stream << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Var(" << node->name_hint(); + if (node->type_annotation.defined()) { + p->stream << ", ty="; + p->Print(node->type_annotation); + } + p->stream << ")"; + }); Call::Call(Expr op, Array args, Attrs attrs, Array type_args) { ObjectPtr n = make_object(); @@ -122,16 +117,16 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args) { TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay.ir.Call") -.set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args) { - return Call(op, args, attrs, type_args); -}); + .set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args) { + return Call(op, args, attrs, type_args); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " + << node->type_args << ")"; + }); Let::Let(Var var, Expr value, Expr body) { ObjectPtr n = make_object(); @@ -143,17 +138,15 @@ Let::Let(Var var, Expr value, Expr body) { TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_GLOBAL("relay.ir.Let") -.set_body_typed([](Var var, Expr value, Expr body) { +TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body) { return Let(var, value, body); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "LetNode(" << node->var << ", " << node->value - << ", " << node->body << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")"; + }); If::If(Expr cond, Expr true_branch, Expr false_branch) { ObjectPtr n = make_object(); @@ -166,16 +159,16 @@ If::If(Expr cond, Expr true_branch, Expr false_branch) { TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") -.set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { - return If(cond, true_branch, false_branch); -}); + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { + return If(cond, true_branch, false_branch); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IfNode(" << node->cond << ", " << node->true_branch - << ", " << node->false_branch << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " + << node->false_branch << ")"; + }); TupleGetItem::TupleGetItem(Expr tuple, int index) { ObjectPtr n = make_object(); @@ -186,16 +179,15 @@ TupleGetItem::TupleGetItem(Expr tuple, int index) { TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem") -.set_body_typed([](Expr tuple, int index) { +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { return TupleGetItem(tuple, index); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; + }); RefCreate::RefCreate(Expr value) { ObjectPtr n = make_object(); @@ -205,16 +197,15 @@ RefCreate::RefCreate(Expr value) { TVM_REGISTER_NODE_TYPE(RefCreateNode); -TVM_REGISTER_GLOBAL("relay.ir.RefCreate") -.set_body_typed([](Expr value) { +TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) { return RefCreate(value); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefCreateNode(" << node->value << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefCreateNode(" << node->value << ")"; + }); RefRead::RefRead(Expr ref) { ObjectPtr n = make_object(); @@ -224,16 +215,13 @@ RefRead::RefRead(Expr ref) { TVM_REGISTER_NODE_TYPE(RefReadNode); -TVM_REGISTER_GLOBAL("relay.ir.RefRead") -.set_body_typed([](Expr ref) { - return RefRead(ref); -}); +TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return RefRead(ref); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefReadNode(" << node->ref << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefReadNode(" << node->ref << ")"; + }); RefWrite::RefWrite(Expr ref, Expr value) { ObjectPtr n = make_object(); @@ -244,24 +232,21 @@ RefWrite::RefWrite(Expr ref, Expr value) { TVM_REGISTER_NODE_TYPE(RefWriteNode); -TVM_REGISTER_GLOBAL("relay.ir.RefWrite") -.set_body_typed([](Expr ref, Expr value) { +TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value) { return RefWrite(ref, value); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; + }); -TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize") -.set_body_typed([](TempExpr temp) { +TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize").set_body_typed([](TempExpr temp) { return temp->Realize(); }); -TVM_REGISTER_GLOBAL("relay.ir.Any") -.set_body_typed([]() { return Any::make(); }); +TVM_REGISTER_GLOBAL("relay.ir.Any").set_body_typed([]() { return Any::make(); }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index cb5d06f..18fd1c7 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -154,9 +154,7 @@ bool MixedModeMutator::CheckVisited(const Expr& expr) { } } -Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) { - return ExprMutator::VisitExpr(expr); -} +Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) { return ExprMutator::VisitExpr(expr); } Expr MixedModeMutator::VisitExpr(const Expr& expr) { auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; @@ -178,6 +176,7 @@ class PostOrderRewriter : public MixedModeMutator { auto post = ExprFunctor::VisitExpr(expr); return rewriter_->Rewrite(expr, post); } + protected: ExprRewriter* rewriter_; }; @@ -208,17 +207,11 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const ConstantNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const OpNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef(op); } Expr ExprMutator::VisitExpr_(const TupleNode* op) { tvm::Array fields; @@ -257,9 +250,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { auto ret_type = this->VisitType(op->ret_type); auto body = this->Mutate(op->body); - if (all_ty_params_unchanged && - all_params_unchanged && - ret_type.same_as(op->ret_type) && + if (all_ty_params_unchanged && all_params_unchanged && ret_type.same_as(op->ret_type) && body.same_as(op->body)) { return GetRef(op); } else { @@ -297,9 +288,7 @@ Expr ExprMutator::VisitExpr_(const LetNode* op) { auto value = this->Mutate(op->value); auto body = this->Mutate(op->body); - if (var.same_as(op->var) && - value.same_as(op->value) && - body.same_as(op->body)) { + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return Let(var, value, body); @@ -310,10 +299,9 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { auto guard = this->Mutate(op->cond); auto true_b = this->Mutate(op->true_branch); auto false_b = this->Mutate(op->false_branch); - if (op->cond.same_as(guard) && - op->true_branch.same_as(true_b) && + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { - return GetRef(op);; + return GetRef(op); } else { return If(guard, true_b, false_b); } @@ -356,9 +344,7 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { } } -Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { - return GetRef(c); -} +Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef(c); } Expr ExprMutator::VisitExpr_(const MatchNode* m) { std::vector clauses; @@ -394,11 +380,9 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { } } -void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { -} +void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {} -void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { -} +void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {} void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { for (auto field : op->fields) { @@ -440,17 +424,11 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { void ExprVisitor::VisitExpr_(const OpNode* op) { return; } -void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { - this->VisitExpr(op->tuple); -} +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } -void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { - this->VisitExpr(op->value); -} +void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { - this->VisitExpr(op->ref); -} +void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { this->VisitExpr(op->ref); } void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { this->VisitExpr(op->ref); @@ -501,30 +479,23 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit") -.set_body_typed([](Expr expr, PackedFunc f) { - PostOrderVisit(expr, [f](const Expr& n) { - f(n); - }); - }); +TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); +}); // Implement bind. class ExprBinder : public ExprMutator, PatternMutator { public: - explicit ExprBinder(const tvm::Map& args_map) - : args_map_(args_map) { - } + explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} Expr VisitExpr_(const LetNode* op) final { - CHECK(!args_map_.count(op->var)) - << "Cannot bind an internel variable in let"; + CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let"; return ExprMutator::VisitExpr_(op); } Expr VisitExpr_(const FunctionNode* op) final { for (Var param : op->params) { - CHECK(!args_map_.count(param)) - << "Cannnot bind an internal function parameter"; + CHECK(!args_map_.count(param)) << "Cannnot bind an internal function parameter"; } return ExprMutator::VisitExpr_(op); } @@ -539,9 +510,7 @@ class ExprBinder : public ExprMutator, PatternMutator { } } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Clause VisitClause(const Clause& c) final { Pattern pat = VisitPattern(c->lhs); @@ -549,8 +518,7 @@ class ExprBinder : public ExprMutator, PatternMutator { } Var VisitVar(const Var& v) final { - CHECK(!args_map_.count(v)) - << "Cannnot bind an internal pattern variable"; + CHECK(!args_map_.count(v)) << "Cannnot bind an internal pattern variable"; return v; } @@ -567,15 +535,10 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { new_params.push_back(param); } } - if (new_body.same_as(func->body) && - new_params.size() == func->params.size()) { + if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { return expr; } - auto ret = Function(new_params, - new_body, - func->ret_type, - func->type_params, - func->attrs); + auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); @@ -585,11 +548,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { new_params.push_back(v); } } - ret = Function(new_params, - new_body, - func->ret_type, - func->type_params, - func->attrs); + ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { @@ -597,15 +556,14 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } -TVM_REGISTER_GLOBAL("relay.ir.Bind") -.set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef input = args[0]; - if (input->IsInstance()) { - *ret = Bind(Downcast(input), args[1]); - } else { - CHECK(input->IsInstance()); - *ret = Bind(Downcast(input), args[1]); - } - }); +TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef input = args[0]; + if (input->IsInstance()) { + *ret = Bind(Downcast(input), args[1]); + } else { + CHECK(input->IsInstance()); + *ret = Bind(Downcast(input), args[1]); + } +}); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 12a80c5..5312e6d 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -26,11 +26,8 @@ namespace tvm { namespace relay { -Function::Function(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array type_params, - DictAttrs attrs) { +Function::Function(tvm::Array params, Expr body, Type ret_type, + tvm::Array type_params, DictAttrs attrs) { ObjectPtr n = make_object(); CHECK(params.defined()); CHECK(type_params.defined()); @@ -45,34 +42,29 @@ Function::Function(tvm::Array params, FuncType FunctionNode::func_type_annotation() const { Array param_types; for (auto param : this->params) { - Type param_type = (param->type_annotation.defined()) ? param->type_annotation - : IncompleteType(Kind::kType); + Type param_type = + (param->type_annotation.defined()) ? param->type_annotation : IncompleteType(Kind::kType); param_types.push_back(param_type); } - Type ret_type = (this->ret_type.defined()) ? this->ret_type - : IncompleteType(Kind::kType); + Type ret_type = (this->ret_type.defined()) ? this->ret_type : IncompleteType(Kind::kType); return FuncType(param_types, ret_type, this->type_params, {}); } TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay.ir.Function") -.set_body_typed([](tvm::Array params, - Expr body, - Type ret_type, - tvm::Array ty_params, - tvm::DictAttrs attrs) { - return Function(params, body, ret_type, ty_params, attrs); -}); + .set_body_typed([](tvm::Array params, Expr body, Type ret_type, + tvm::Array ty_params, tvm::DictAttrs attrs) { + return Function(params, body, ret_type, ty_params, attrs); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FunctionNode(" << node->params << ", " << node->ret_type - << ", " << node->body << ", " << node->type_params << ", " - << node->attrs << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body + << ", " << node->type_params << ", " << node->attrs << ")"; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/op_strategy.cc b/src/relay/ir/op_strategy.cc index 4e407db..989e3a6 100644 --- a/src/relay/ir/op_strategy.cc +++ b/src/relay/ir/op_strategy.cc @@ -31,21 +31,18 @@ TVM_REGISTER_NODE_TYPE(OpImplementationNode); TVM_REGISTER_NODE_TYPE(OpSpecializationNode); TVM_REGISTER_NODE_TYPE(OpStrategyNode); -Array OpImplementation::Compute(const Attrs& attrs, - const Array& inputs, +Array OpImplementation::Compute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return (*this)->fcompute(attrs, inputs, out_type); } -te::Schedule OpImplementation::Schedule(const Attrs& attrs, - const Array &outs, +te::Schedule OpImplementation::Schedule(const Attrs& attrs, const Array& outs, const Target& target) { return (*this)->fschedule(attrs, outs, target); } void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute, - tvm::relay::FTVMSchedule fschedule, - std::string name, + tvm::relay::FTVMSchedule fschedule, std::string name, int plevel) { auto n = make_object(); n->fcompute = fcompute; @@ -55,9 +52,7 @@ void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute, (*this)->implementations.push_back(OpImplementation(n)); } -void OpStrategy::AddImplementation(FTVMCompute fcompute, - FTVMSchedule fschedule, - std::string name, +void OpStrategy::AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name, int plevel) { auto curr_cond = te::SpecializedCondition::Current(); auto self = this->operator->(); @@ -77,38 +72,37 @@ void OpStrategy::AddImplementation(FTVMCompute fcompute, } TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute") -.set_body([](TVMArgs args, TVMRetValue* rv) { - OpImplementation imp = args[0]; - Attrs attrs = args[1]; - Array inputs = args[2]; - Type out_type = args[3]; - *rv = imp.Compute(attrs, inputs, out_type); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + OpImplementation imp = args[0]; + Attrs attrs = args[1]; + Array inputs = args[2]; + Type out_type = args[3]; + *rv = imp.Compute(attrs, inputs, out_type); + }); TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - OpImplementation imp = args[0]; - Attrs attrs = args[1]; - Array outs = args[2]; - Target target = args[3]; - *rv = imp.Schedule(attrs, outs, target); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + OpImplementation imp = args[0]; + Attrs attrs = args[1]; + Array outs = args[2]; + Target target = args[3]; + *rv = imp.Schedule(attrs, outs, target); + }); -TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy") -.set_body([](TVMArgs args, TVMRetValue* rv) { - ObjectPtr n = make_object(); - *rv = OpStrategy(n); +TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy").set_body([](TVMArgs args, TVMRetValue* rv) { + ObjectPtr n = make_object(); + *rv = OpStrategy(n); }); TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation") -.set_body([](TVMArgs args, TVMRetValue* rv) { - OpStrategy strategy = args[0]; - FTVMCompute compute = args[1]; - FTVMSchedule schedule = args[2]; - std::string name = args[3]; - int plevel = args[4]; - strategy.AddImplementation(compute, schedule, name, plevel); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + OpStrategy strategy = args[0]; + FTVMCompute compute = args[1]; + FTVMSchedule schedule = args[2]; + std::string name = args[3]; + int plevel = args[4]; + strategy.AddImplementation(compute, schedule, name, plevel); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc index 6795884..8c366ba 100644 --- a/src/relay/ir/pattern_functor.cc +++ b/src/relay/ir/pattern_functor.cc @@ -27,13 +27,9 @@ namespace tvm { namespace relay { -Pattern PatternMutator::Mutate(const Pattern& pat) { - return (*this)(pat); -} +Pattern PatternMutator::Mutate(const Pattern& pat) { return (*this)(pat); } -Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { - return GetRef(op); -} +Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { return GetRef(op); } Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) { return PatternVar(VisitVar(op->var)); @@ -55,28 +51,20 @@ Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) { return PatternTuple(pat); } -Type PatternMutator::VisitType(const Type& t) { - return t; -} +Type PatternMutator::VisitType(const Type& t) { return t; } Var PatternMutator::VisitVar(const Var& v) { if (var_map_.count(v) == 0) { - var_map_.insert(std::pair(v, - Var(v->name_hint(), - VisitType(v->type_annotation)))); + var_map_.insert(std::pair(v, Var(v->name_hint(), VisitType(v->type_annotation)))); } return var_map_.at(v); } -Constructor PatternMutator::VisitConstructor(const Constructor& v) { - return v; -} +Constructor PatternMutator::VisitConstructor(const Constructor& v) { return v; } -void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { } +void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) {} -void PatternVisitor::VisitPattern_(const PatternVarNode* op) { - VisitVar(op->var); -} +void PatternVisitor::VisitPattern_(const PatternVarNode* op) { VisitVar(op->var); } void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { VisitConstructor(op->constructor); @@ -91,11 +79,9 @@ void PatternVisitor::VisitPattern_(const PatternTupleNode* op) { } } -void PatternVisitor::VisitType(const Type& t) { } +void PatternVisitor::VisitType(const Type& t) {} -void PatternVisitor::VisitVar(const Var& v) { - VisitType(v->type_annotation); -} +void PatternVisitor::VisitVar(const Var& v) { VisitType(v->type_annotation); } void PatternVisitor::VisitConstructor(const Constructor& c) { for (const auto& inp : c->inputs) { diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 06dd2b1..6b99c93 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -22,10 +22,9 @@ * \brief Relay specific transformation passes. */ #include -#include #include #include - +#include namespace tvm { namespace relay { @@ -56,9 +55,7 @@ class FunctionPassNode : public PassNode { FunctionPassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a function pass on given pass context. @@ -113,14 +110,11 @@ FunctionPass::FunctionPass( } // Perform Module -> Module optimizations at the Function level. -IRModule FunctionPassNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); - DLOG(INFO) << "Executing function pass : " - << pass_info->name - << " with opt level: " - << pass_info->opt_level; + DLOG(INFO) << "Executing function pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; pass_ctx.Trace(mod, pass_info, true); // Execute the pass function and return a new module. @@ -130,9 +124,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, // only picks up relay::Function if (auto* n = it.second.as()) { Function func = GetRef(n); - auto updated_func = SkipFunction(func) - ? func - : pass_func(func, updated_mod, pass_ctx); + auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx); updates.push_back({it.first, updated_func}); } } @@ -146,14 +138,12 @@ IRModule FunctionPassNode::operator()(IRModule mod, bool FunctionPassNode::SkipFunction(const Function& func) const { return (func->GetAttr(attr::kCompiler).defined()) || - func->GetAttr(attr::kSkipOptimization, 0) != 0; + func->GetAttr(attr::kSkipOptimization, 0) != 0; } Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { + int opt_level, const std::string& name, const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return FunctionPass(pass_func, pass_info); } @@ -161,18 +151,17 @@ Pass CreateFunctionPass( TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass") -.set_body_typed([](runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - return FunctionPass(pass_func, pass_info); -}); + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return FunctionPass(pass_func, pass_info); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Function pass: " << info->name - << " at the optimization level " << info->opt_level; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name << " at the optimization level " + << info->opt_level; + }); } // namespace transform } // namespace relay diff --git a/src/relay/op/algorithm/argsort.cc b/src/relay/op/algorithm/argsort.cc index 5b03cee..a240974 100644 --- a/src/relay/op/algorithm/argsort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -21,17 +21,15 @@ * \file argsort.cc * \brief Argsort operators */ -#include #include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ArgsortAttrs); -bool ArgsortRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ArgsortRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] const ArgsortAttrs* param = attrs.as(); @@ -39,18 +37,14 @@ bool ArgsortRel(const Array& types, const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "Argsort: expect input type to be TensorType but get " - << types[0]; + << "Argsort: expect input type to be TensorType but get " << types[0]; return false; } reporter->Assign(types[1], TensorType(data->shape, param->dtype)); return true; } -Expr MakeArgsort(Expr data, - int axis, - bool is_ascend, - DataType dtype) { +Expr MakeArgsort(Expr data, int axis, bool is_ascend, DataType dtype) { auto attrs = make_object(); attrs->axis = axis; attrs->is_ascend = is_ascend; @@ -59,19 +53,17 @@ Expr MakeArgsort(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op._make.argsort") -.set_body_typed(MakeArgsort); +TVM_REGISTER_GLOBAL("relay.op._make.argsort").set_body_typed(MakeArgsort); RELAY_REGISTER_OP("argsort") -.describe(R"doc(Returns the indices that would sort an + .describe(R"doc(Returns the indices that would sort an input array along the given axis. )doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "Input data.") -.set_support_level(6) -.add_type_rel("Argsort", ArgsortRel); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .set_support_level(6) + .add_type_rel("Argsort", ArgsortRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index 225575c..f641f84 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -21,17 +21,15 @@ * \file topk.cc * \brief TopK operators */ -#include #include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(TopKAttrs); -bool TopKRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] const TopKAttrs* param = attrs.as(); @@ -66,12 +64,7 @@ bool TopKRel(const Array& types, return true; } -Expr MakeTopK(Expr data, - int k, - int axis, - std::string ret_type, - bool is_ascend, - DataType dtype) { +Expr MakeTopK(Expr data, int k, int axis, std::string ret_type, bool is_ascend, DataType dtype) { auto attrs = make_object(); attrs->k = k; attrs->axis = axis; @@ -82,19 +75,16 @@ Expr MakeTopK(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op._make.topk") -.set_body_typed(MakeTopK); +TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK); RELAY_REGISTER_OP("topk") -.describe(R"doc(Get the top k elements in an input tensor along the given axis. + .describe(R"doc(Get the top k elements in an input tensor along the given axis. )doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "Input data.") -.set_support_level(6) -.add_type_rel("TopK", TopKRel); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .set_support_level(6) + .add_type_rel("TopK", TopKRel); } // namespace relay } // namespace tvm - diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index dd1bcdc..2e93b58 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -23,12 +23,12 @@ * \brief Registration of annotation operators. */ -#include +#include #include #include #include #include -#include +#include #include "../../transforms/infer_layout_util.h" #include "../type_relations.h" @@ -40,48 +40,46 @@ namespace relay { TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") -.set_body_typed([](Expr data, int device_type) { - auto attrs = make_object(); - attrs->device_type = device_type; - static const Op& op = Op::Get("on_device"); - return Call(op, {data}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr data, int device_type) { + auto attrs = make_object(); + attrs->device_type = device_type; + static const Op& op = Op::Get("on_device"); + return Call(op, {data}, Attrs(attrs), {}); + }); RELAY_REGISTER_OP("on_device") -.describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout); + .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); Expr StopFusion(Expr data) { static const Op& op = Op::Get("annotation.stop_fusion"); return Call(op, {data}, Attrs{}, {}); } -TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion") -.set_body_typed([](Expr data) { - return StopFusion(data); +TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion").set_body_typed([](Expr data) { + return StopFusion(data); }); RELAY_REGISTER_OP("annotation.stop_fusion") -.describe(R"code(Annotate an expression to prevent it being fused with previous expressions.)code" -TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input data.") -.add_type_rel("Identity", IdentityRel) -.set_support_level(10) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .describe( + R"code(Annotate an expression to prevent it being fused with previous expressions.)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .add_type_rel("Identity", IdentityRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); // relay.annotation.cast_hint TVM_REGISTER_NODE_TYPE(CastHintAttrs); @@ -94,134 +92,127 @@ Expr CastHint(Expr data, DataType dtype) { } RELAY_REGISTER_OP("annotation.cast_hint") -.describe(R"code(Annotate an expression to be cast into specific data type.)code" -TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input data.") -.add_type_rel("Identity", IdentityRel) -.set_support_level(10) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); - + .describe( + R"code(Annotate an expression to be cast into specific data type.)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .add_type_rel("Identity", IdentityRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); RELAY_REGISTER_OP("annotation.bitpack_start") -.describe(R"code( + .describe(R"code( Mark the start of bitpacking. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); RELAY_REGISTER_OP("annotation.bitpack_end") -.describe(R"code( + .describe(R"code( Mark the end of bitpacking. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); - -TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint") -.set_body_typed([](Expr data) { + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint").set_body_typed([](Expr data) { static const Op& op = Op::Get("annotation.checkpoint"); return Call(op, {data}, Attrs{}, {}); }); RELAY_REGISTER_OP("annotation.checkpoint") -.describe(R"code( + .describe(R"code( Mark a checkpoint for checkpointing memory optimization. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - Array outputs; - for (size_t i = 0; i < inputs.size(); ++i) { - outputs.push_back(topi::identity(inputs[i])); - } - return outputs; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + Array outputs; + for (size_t i = 0; i < inputs.size(); ++i) { + outputs.push_back(topi::identity(inputs[i])); + } + return outputs; + }); TVM_REGISTER_NODE_TYPE(CompilerAttrs); RELAY_REGISTER_OP("annotation.compiler_begin") -.describe(R"code( + .describe(R"code( Beginning of a region that is handled by a given compiler. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin") -.set_body_typed([](Expr expr, std::string compiler) { - auto attrs = make_object(); - attrs->compiler = compiler; - static const Op& op = Op::Get("annotation.compiler_begin"); - return Call(op, {expr}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_object(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_begin"); + return Call(op, {expr}, Attrs(attrs), {}); + }); RELAY_REGISTER_OP("annotation.compiler_end") -.describe(R"code( + .describe(R"code( End of a region that is handled by a given compiler. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end") -.set_body_typed([](Expr expr, std::string compiler) { - auto attrs = make_object(); - attrs->compiler = compiler; - static const Op& op = Op::Get("annotation.compiler_end"); - return Call(op, {expr}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_object(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_end"); + return Call(op, {expr}, Attrs(attrs), {}); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc index 8e8586f..790f1ee 100644 --- a/src/relay/op/debug.cc +++ b/src/relay/op/debug.cc @@ -22,36 +22,37 @@ * \brief Property def of nn operators. */ -#include -#include -#include #include +#include +#include +#include + #include -#include "./type_relations.h" + #include "./op_common.h" +#include "./type_relations.h" namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(DebugAttrs); -Array DebugCompute(const Attrs& attrs, - const Array& inputs, +Array DebugCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return Array{ topi::identity(inputs[0]) }; + return Array{topi::identity(inputs[0])}; } RELAY_REGISTER_OP("debug") -.describe(R"code(Enter the interpreter's debugger. + .describe(R"code(Enter the interpreter's debugger. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("program", "Tuple", "The program to execute before debugging.") -.set_support_level(1) -.set_attrs_type() -.add_type_rel("Debug", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("FTVMCompute", DebugCompute); + .set_num_inputs(1) + .add_argument("program", "Tuple", "The program to execute before debugging.") + .set_support_level(1) + .set_attrs_type() + .add_type_rel("Debug", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("FTVMCompute", DebugCompute); Expr MakeDebug(Expr expr, std::string name) { auto dattrs = make_object(); @@ -64,9 +65,7 @@ Expr MakeDebug(Expr expr, std::string name) { return Call(op, {expr}, Attrs(dattrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.debug") -.set_body_typed(MakeDebug); +TVM_REGISTER_GLOBAL("relay.op._make.debug").set_body_typed(MakeDebug); } // namespace relay } // namespace tvm - diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 4aae549..923965f 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -26,14 +26,14 @@ * used as "barrier" to avoid fusing operators belonging to differen devices. */ -#include #include #include #include #include +#include -#include "type_relations.h" #include "../transforms/infer_layout_util.h" +#include "type_relations.h" namespace tvm { namespace relay { @@ -42,27 +42,25 @@ namespace relay { TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); TVM_REGISTER_GLOBAL("relay.op._make.device_copy") -.set_body_typed([](Expr data, int src_dev_type, - int dst_dev_type) { - auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; - static const Op& op = Op::Get("device_copy"); - return Call(op, {data}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr data, int src_dev_type, int dst_dev_type) { + auto attrs = make_object(); + attrs->src_dev_type = src_dev_type; + attrs->dst_dev_type = dst_dev_type; + static const Op& op = Op::Get("device_copy"); + return Call(op, {data}, Attrs(attrs), {}); + }); RELAY_REGISTER_OP("device_copy") -.describe(R"code( + .describe(R"code( Copy data from one tensor to another. The source and destination might be on different devices. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); } // namespace relay } // namespace tvm diff --git a/src/relay/op/image/dilation2d.cc b/src/relay/op/image/dilation2d.cc index 7146f37..43ec856 100644 --- a/src/relay/op/image/dilation2d.cc +++ b/src/relay/op/image/dilation2d.cc @@ -21,9 +21,10 @@ * \file dilation2d.cc * \brief Morphological dilation operator */ -#include -#include #include +#include +#include + #include "../op_common.h" namespace tvm { @@ -32,27 +33,20 @@ namespace relay { // relay.image.dilation2d TVM_REGISTER_NODE_TYPE(Dilation2DAttrs); -template -Array > Dilation2DInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +template +Array > Dilation2DInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { const T* params = attrs.as(); - return Array >{{params->data_layout, params->kernel_layout}, - {params->data_layout}}; + return Array >{{params->data_layout, params->kernel_layout}, {params->data_layout}}; } // Positional relay function to create dilation2d operator // used by frontend FFI. -Expr MakeDilation2D(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilations, - std::string data_layout, - std::string kernel_layout, +Expr MakeDilation2D(Expr data, Expr weight, Array strides, Array padding, + Array dilations, std::string data_layout, std::string kernel_layout, DataType out_dtype) { auto attrs = make_object(); attrs->strides = std::move(strides); @@ -67,7 +61,7 @@ Expr MakeDilation2D(Expr data, template bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); @@ -113,15 +107,13 @@ bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); } else { oshape.Set(3, dshape_nchw[3]); } @@ -136,26 +128,24 @@ bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d") -.set_body_typed(MakeDilation2D); - +TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d").set_body_typed(MakeDilation2D); RELAY_REGISTER_OP("image.dilation2d") -.describe(R"code(Computes grayscale dilation of 4D input and 3D filter. + .describe(R"code(Computes grayscale dilation of 4D input and 3D filter. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, in_channels, height, width) if `layout` is `NCHW`. - **weight**: (in_channels, height, width) - **out**: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Dilation2D", Dilation2DRel) -.set_attr("FInferCorrectLayout", - Dilation2DInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Dilation2D", Dilation2DRel) + .set_attr("FInferCorrectLayout", + Dilation2DInferCorrectLayout); } // namespace relay } // namespace tvm diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index c8f9762..efd815b 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -21,9 +21,10 @@ * \file resize.cc * \brief Image resize operators */ -#include -#include #include +#include +#include + #include "../op_common.h" namespace tvm { @@ -31,9 +32,7 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ResizeAttrs); -bool ResizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -46,8 +45,8 @@ bool ResizeRel(const Array& types, const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) - << "Resize only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "Resize only support input layouts that are convertible from NCHW." + << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(2, param->size[0]); @@ -59,20 +58,14 @@ bool ResizeRel(const Array& types, } // assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), - out_dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); return true; } // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize(Expr data, - Array size, - std::string layout, - std::string method, - std::string coordinate_transformation_mode, - DataType out_dtype) { +Expr MakeResize(Expr data, Array size, std::string layout, std::string method, + std::string coordinate_transformation_mode, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); @@ -83,13 +76,10 @@ Expr MakeResize(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.image._make.resize") -.set_body_typed(MakeResize); - +TVM_REGISTER_GLOBAL("relay.op.image._make.resize").set_body_typed(MakeResize); RELAY_REGISTER_OP("image.resize") -.describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. + .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW @@ -102,26 +92,22 @@ RELAY_REGISTER_OP("image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(5) -.add_type_rel("Resize", ResizeRel) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("Resize", ResizeRel) + .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs); -bool CropAndResizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CropAndResizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* boxes = types[1].as(); const auto* box_indices = types[2].as(); - if (data == nullptr || boxes == nullptr || - box_indices == nullptr) return false; + if (data == nullptr || boxes == nullptr || box_indices == nullptr) return false; const CropAndResizeAttrs* param = attrs.as(); CHECK(param != nullptr); @@ -142,19 +128,12 @@ bool CropAndResizeRel(const Array& types, oshape.Set(3, crop_size[1]); auto bshape = layout_converter.BackwardShape(oshape); // assign output type - reporter->Assign(types[3], - TensorType(layout_converter.BackwardShape(oshape), - out_dtype)); + reporter->Assign(types[3], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); return true; } -Expr MakeCropAndResize(Expr data, - Expr boxes, - Expr box_indices, - Array crop_size, - std::string layout, - std::string method, - double extrapolation_value, +Expr MakeCropAndResize(Expr data, Expr boxes, Expr box_indices, Array crop_size, + std::string layout, std::string method, double extrapolation_value, DataType out_dtype) { auto attrs = make_object(); attrs->crop_size = std::move(crop_size); @@ -166,12 +145,11 @@ Expr MakeCropAndResize(Expr data, return Call(op, {data, boxes, box_indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize") -.set_body_typed(MakeCropAndResize); - +TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize").set_body_typed(MakeCropAndResize); RELAY_REGISTER_OP("image.crop_and_resize") - .describe(R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation. + .describe( + R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW @@ -184,14 +162,14 @@ RELAY_REGISTER_OP("image.crop_and_resize") for layout NHWC (batch_size, crop_size[0], crop_size[1], channels) )code" TVM_ADD_FILELINE) -.set_num_inputs(3) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("boxes", "Tensor", "The boxes tensor.") -.add_argument("box_indices", "Tensor", "The box indices tensor.") -.set_attrs_type() -.set_support_level(5) -.add_type_rel("CropAndResize", CropAndResizeRel) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("boxes", "Tensor", "The boxes tensor.") + .add_argument("box_indices", "Tensor", "The box indices tensor.") + .set_attrs_type() + .set_support_level(5) + .add_type_rel("CropAndResize", CropAndResizeRel) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index c7ffc95..ec96e23 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -23,11 +23,11 @@ */ #include -#include #include #include #include #include +#include #include "../../transforms/infer_layout_util.h" #include "../op_common.h" @@ -109,12 +109,11 @@ std::vector FromConstShape(Constant konst) { runtime::NDArray shape = konst->data; std::vector raw_shape; CHECK_EQ(shape->ndim, 1u); - CHECK_EQ(shape->dtype.code, 0U) - << "The dtype of constant shape must be int32 or int64, but got " - << runtime::DLDataType2String(shape->dtype); + CHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got " + << runtime::DLDataType2String(shape->dtype); CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32) - << "The dtype of constant shape must be int32 or int64, but got" - << runtime::DLDataType2String(shape->dtype); + << "The dtype of constant shape must be int32 or int64, but got" + << runtime::DLDataType2String(shape->dtype); if (shape->dtype.bits == 32) { const int32_t* int_ptr = reinterpret_cast(shape->data); @@ -331,14 +330,12 @@ Expr ToTupleType(const Type& t, const std::vector& exprs) { } } -TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType") -.set_body_typed([](Type type) { +TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType").set_body_typed([](Type type) { auto types = FlattenTupleType(type); return Array(types.begin(), types.end()); }); -TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType") -.set_body_typed([](Type type, Expr expr) { +TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType").set_body_typed([](Type type, Expr expr) { auto exprs = FromTupleType(type, expr); return Array(exprs.begin(), exprs.end()); }); diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index d217457..08637d9 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -22,12 +22,12 @@ * \brief Property def of bitserial operators. */ -#include #include #include +#include -#include "../op_common.h" #include "../../transforms/infer_layout_util.h" +#include "../op_common.h" namespace tvm { namespace relay { @@ -109,11 +109,11 @@ efficient implementation of bitserial operations. packed must be divisible by number of bits. - **out**: Packed tensor with shape appropriately compressed. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "Input data.") -.set_support_level(2) -.add_type_rel("BitPack", BitPackRel); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .set_support_level(2) + .add_type_rel("BitPack", BitPackRel); // relay.nn.bitserial_conv2d TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs); @@ -137,10 +137,8 @@ bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attr Array oshape({dshape_nchw[0], param->channels, 0, 0}); IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set( - 2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1); - oshape.Set( - 3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1); + oshape.Set(2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1); + oshape.Set(3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1); DataType out_dtype = param->out_dtype; oshape = trans_in_layout.BackwardShape(oshape); // assign output type @@ -187,14 +185,14 @@ on some platforms. - **out**: Output with same layout as input. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("BinaryConv2D", BinaryConv2DRel) -.set_attr("FInferCorrectLayout", - BinaryConv2DInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("BinaryConv2D", BinaryConv2DRel) + .set_attr("FInferCorrectLayout", + BinaryConv2DInferCorrectLayout); // relay.nn.bitserial_dense TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs); @@ -248,12 +246,12 @@ RELAY_REGISTER_OP("nn.bitserial_dense") - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "2D Tensor", "Input data.") -.add_argument("weight", "2D Tensor", "Weight matrix.") -.set_support_level(1) -.add_type_rel("BinaryDense", BinaryDenseRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "2D Tensor", "Input data.") + .add_argument("weight", "2D Tensor", "Weight matrix.") + .set_support_level(1) + .add_type_rel("BinaryDense", BinaryDenseRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index b3e1772..4a307c5 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -21,32 +21,25 @@ * \file convolution.cc * \brief Convolution operators */ -#include -#include +#include "convolution.h" + #include +#include +#include + #include #include "../../transforms/infer_layout_util.h" #include "../op_common.h" -#include "convolution.h" namespace tvm { namespace relay { template -Expr MakeConv(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype, - std::string op_name) { +Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, std::string kernel_layout, + std::string out_layout, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -63,19 +56,10 @@ Expr MakeConv(Expr data, } template -Expr MakeConvWinograd(Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype, +Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->tile_size = tile_size; @@ -93,9 +77,7 @@ Expr MakeConvWinograd(Expr data, return Call(op, {data, weight}, Attrs(attrs), {}); } -Expr MakeConvWinogradWeightTransform(Expr weight, - int tile_size, - std::string op_name) { +Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) { auto attrs = make_object(); attrs->tile_size = tile_size; const Op& op = Op::Get(op_name); @@ -103,20 +85,11 @@ Expr MakeConvWinogradWeightTransform(Expr weight, } template -Expr MakeConvTranspose(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype, - std::string op_name) { +Expr MakeConvTranspose(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, + Array output_padding, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -134,21 +107,11 @@ Expr MakeConvTranspose(Expr data, } template -Expr MakeDeformableConv(Expr data, - Expr offset, - Expr weight, - Array strides, - Array padding, - Array dilation, - int deformable_groups, - int groups, - int channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype, - std::string op_name) { +Expr MakeDeformableConv(Expr data, Expr offset, Expr weight, Array strides, + Array padding, Array dilation, int deformable_groups, + int groups, int channels, Array kernel_size, + std::string data_layout, std::string kernel_layout, std::string out_layout, + DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = strides; attrs->padding = padding; @@ -165,32 +128,21 @@ Expr MakeDeformableConv(Expr data, return Call(op, {data, offset, weight}, Attrs{attrs}, {}); } - // relay.nn.conv1d TVM_REGISTER_NODE_TYPE(Conv1DAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.conv1d"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv1d"); + }); RELAY_REGISTER_OP("nn.conv1d") -.describe(R"code(1D convolution layer (e.g. spatial convolution over sequences). + .describe(R"code(1D convolution layer (e.g. spatial convolution over sequences). This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. @@ -202,40 +154,29 @@ with the layer input to produce a tensor of outputs. (batch_size, channels, out_width) if `layout` is `NCW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv1D", Conv1DRel) -.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv1D", Conv1DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.conv2d"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv2d"); + }); RELAY_REGISTER_OP("nn.conv2d") -.describe(R"code(2D convolution layer (e.g. spatial convolution over images). + .describe(R"code(2D convolution layer (e.g. spatial convolution over images). This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. @@ -247,40 +188,29 @@ with the layer input to produce a tensor of outputs. (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv2D", Conv2DRel) -.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv2D", Conv2DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv3d TVM_REGISTER_NODE_TYPE(Conv3DAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.conv3d"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv3d"); + }); RELAY_REGISTER_OP("nn.conv3d") -.describe(R"code(3D convolution layer (e.g. convolution over 3D image data, + .describe(R"code(3D convolution layer (e.g. convolution over 3D image data, like Magnetic Resonance Imaging (MRI) data in medicine). This layer creates a convolution kernel that is convolved @@ -293,40 +223,30 @@ with the layer input to produce a tensor of outputs. (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv3D", Conv3DRel) -.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv3D", Conv3DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype) { - return MakeConvTranspose( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose"); -}); + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, + Array output_padding, DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose"); + }); RELAY_REGISTER_OP("nn.conv2d_transpose") -.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). + .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction @@ -347,40 +267,31 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout) -.add_type_rel("Conv2DTranspose", Conv2DTransposeRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout) + .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); // relay.nn.conv1d_transpose TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype) { - return MakeConvTranspose( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose"); -}); + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, + Array output_padding, DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose"); + }); RELAY_REGISTER_OP("nn.conv1d_transpose") -.describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution). + .describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction @@ -400,39 +311,29 @@ said convolution. out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv1DTranspose", Conv1DTransposeRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv1DTranspose", Conv1DTransposeRel); // relay.nn.contrib_conv2d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform") -.set_body_typed([](Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConvWinograd( - data, weight, tile_size, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_winograd_without_weight_transform"); -}); - + .set_body_typed([](Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConvWinograd( + data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size, + data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_conv2d_winograd_without_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") -.describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout. + .describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout. This operator assumes the weight tensor is already pre-transformed by nn.contrib_conv2d_winograd_weight_transform. @@ -443,64 +344,54 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") - **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DWinograd", Conv2DWinogradRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DWinograd", Conv2DWinogradRel) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout); // relay.nn.contrib_conv2d_winograd_weight_transform TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform") -.set_body_typed([](Expr weight, - int tile_size) { - return MakeConvWinogradWeightTransform( - weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform"); -}); + .set_body_typed([](Expr weight, int tile_size) { + return MakeConvWinogradWeightTransform(weight, tile_size, + "nn.contrib_conv2d_winograd_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform") -.describe(R"code(Weight transformation of winograd fast convolution algorithm. + .describe(R"code(Weight transformation of winograd fast convolution algorithm. Separate this into another operator in order to enable Precompute Pass to compute the weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); // relay.nn.contrib_conv3d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform") -.set_body_typed([](Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConvWinograd( - data, weight, tile_size, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_conv3d_winograd_without_weight_transform"); -}); + .set_body_typed([](Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConvWinograd( + data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size, + data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_conv3d_winograd_without_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform") -.describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout. + .describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout. This operator assumes the weight tensor is already pre-transformed by nn.contrib_conv3d_winograd_weight_transform. @@ -511,22 +402,21 @@ RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform") - **out**: Output is 5D array of shape (batch_size, channels, depth, out_height, out_width) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv3DWinograd", Conv3DWinogradRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv3DWinograd", Conv3DWinogradRel) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout); // relay.nn.contrib_conv3d_winograd_weight_transform TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform") -.set_body_typed([](Expr weight, - int tile_size) { - return MakeConvWinogradWeightTransform( - weight, tile_size, "nn.contrib_conv3d_winograd_weight_transform"); -}); + .set_body_typed([](Expr weight, int tile_size) { + return MakeConvWinogradWeightTransform(weight, tile_size, + "nn.contrib_conv3d_winograd_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform") .describe(R"code(Weight transformation of winograd fast 3d convolution algorithm. @@ -536,18 +426,16 @@ weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2]) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel); // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs); -Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, - int convolution_algorithm, +Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm, DataType out_dtype) { auto attrs = make_object(); attrs->convolution_algorithm = convolution_algorithm; @@ -557,99 +445,75 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, } TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform") -.set_body_typed(MakeConv2DWinogradNNPACKWeightTransform); + .set_body_typed(MakeConv2DWinogradNNPACKWeightTransform); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform") -.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK. + .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK. Separate this into another symbol in order to enable Precompute Pass to compute the weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); // Positional relay function to create conv2d NCHWc operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_NCHWc"); -}); + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_conv2d_NCHWc"); + }); RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") -.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. + .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. - **data**: Input is 5D packed tensor. - **weight**: 6D packed tensor. - **out**: Output is 5D packed tensor )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DNCHWc", Conv2DWinogradRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // Positional relay function to create depthwise conv2d NCHWc operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_depthwise_conv2d_NCHWc"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_depthwise_conv2d_NCHWc"); + }); RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") -.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. + .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. - **data**: Input is 5D packed tensor. - **weight**: 6D packed tensor. - **out**: Output is 5D packed tensor )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2D", Conv2DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs); @@ -673,36 +537,26 @@ along the channel axis, and also evenly split `weight` along the first dimension the convolution on the *i*-th part of the data with the *i*-th weight part. The output is obtained by concating all the *g* results. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("offset", "Tensor", "The offset tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(5) -.add_type_rel("DeformableConv2D", DeformableConv2DRel); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("offset", "Tensor", "The offset tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(5) + .add_type_rel("DeformableConv2D", DeformableConv2DRel); // Positional relay function to create deformable_conv2d operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d") -.set_body_typed([](Expr data, - Expr offset, - Expr weight, - Array strides, - Array padding, - Array dilation, - int deformable_groups, - int groups, - int channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeDeformableConv( - data, offset, weight, strides, padding, dilation, - deformable_groups, groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); -}); + .set_body_typed([](Expr data, Expr offset, Expr weight, Array strides, + Array padding, Array dilation, int deformable_groups, + int groups, int channels, Array kernel_size, + std::string data_layout, std::string kernel_layout, std::string out_layout, + DataType out_dtype) { + return MakeDeformableConv( + data, offset, weight, strides, padding, dilation, deformable_groups, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index f33cd7e..5dc649b 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -35,7 +35,6 @@ namespace tvm { namespace relay { - // Standard convolution operator shape relations template bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -92,7 +91,7 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, auto wshape = trans_kernel_layout.ForwardShape(weight->shape); if (param->kernel_size.defined()) { // check the size - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) ) + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) << "Conv1D: shape of weight is inconsistent with kernel_size, " << " kernel_size=" << param->kernel_size << " wshape=" << wshape; } @@ -110,7 +109,8 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, if (!dshape_ncw[2].as()) { oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize, - param->strides[0]) + 1); + param->strides[0]) + + 1); } else { oshape.Set(2, dshape_ncw[2]); } @@ -159,8 +159,8 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); bool is_depthwise = false; if (param->groups > 1) { - CHECK(weight && weight->shape.defined()) << - "Weight shape must be specified when groups is greater than 1."; + CHECK(weight && weight->shape.defined()) + << "Weight shape must be specified when groups is greater than 1."; Array wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape); if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) && tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) { @@ -222,15 +222,13 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); } else { oshape.Set(3, dshape_nchw[3]); } @@ -336,22 +334,19 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_d, pad_h, pad_w; GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); if (!dshape_ncdhw[2].as()) { - oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, - param->strides[0]) + 1); + oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1); } else { oshape.Set(2, dshape_ncdhw[2]); } if (!dshape_ncdhw[3].as()) { - oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, - param->strides[1]) + 1); + oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1); } else { oshape.Set(3, dshape_ncdhw[3]); } if (!dshape_ncdhw[4].as()) { - oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, - param->strides[2]) + 1); + oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1); } else { oshape.Set(4, dshape_ncdhw[4]); } @@ -365,7 +360,6 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } - // Winograd convolution shape relations inline bool Conv2DWinogradWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -378,15 +372,14 @@ inline bool Conv2DWinogradWeightTransformRel(const Array& types, int num_i CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; - std::vector oshape { + std::vector oshape{ param->tile_size + data->shape[2] - 1, param->tile_size + data->shape[3] - 1, data->shape[0], data->shape[1], }; - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } @@ -404,7 +397,7 @@ inline bool Conv3DWinogradWeightTransformRel(const Array& types, int num_i // Shape of packed weights depends on whether depth is being transformed or not. Array oshape({0, 0, 0, data->shape[0], data->shape[1]}); auto* depth_imm = data->shape[2].as(); - bool transform_depth = (depth_imm->value > 2)&&(depth_imm->value < 8); + bool transform_depth = (depth_imm->value > 2) && (depth_imm->value < 8); if (transform_depth) { oshape.Set(0, param->tile_size + data->shape[2] - 1); oshape.Set(1, param->tile_size + data->shape[3] - 1); @@ -449,10 +442,8 @@ inline bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, int return true; } -template -bool Conv2DWinogradRel(const Array& types, - int num_inputs, - const Attrs& attrs, +template +bool Conv2DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -467,13 +458,13 @@ bool Conv2DWinogradRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); @@ -508,14 +499,12 @@ bool Conv2DWinogradRel(const Array& types, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, (dshape_nchw[2] + pad_h - - dilated_ksize_y) / param->strides[0] + 1); + oshape.Set(2, (dshape_nchw[2] + pad_h - dilated_ksize_y) / param->strides[0] + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, (dshape_nchw[3] + pad_w - - dilated_ksize_x) / param->strides[1] + 1); + oshape.Set(3, (dshape_nchw[3] + pad_w - dilated_ksize_x) / param->strides[1] + 1); } else { oshape.Set(3, dshape_nchw[3]); } @@ -530,11 +519,8 @@ bool Conv2DWinogradRel(const Array& types, return true; } - -template -bool Conv3DWinogradRel(const Array& types, - int num_inputs, - const Attrs& attrs, +template +bool Conv3DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -549,13 +535,13 @@ bool Conv3DWinogradRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCDHW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIDHW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIDHW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); @@ -591,20 +577,17 @@ bool Conv3DWinogradRel(const Array& types, IndexExpr pad_d, pad_h, pad_w; GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); if (!dshape_ncdhw[2].as()) { - oshape.Set(2, (dshape_ncdhw[2] + pad_d - - dilated_ksize_d) / param->strides[0] + 1); + oshape.Set(2, (dshape_ncdhw[2] + pad_d - dilated_ksize_d) / param->strides[0] + 1); } else { oshape.Set(2, dshape_ncdhw[2]); } if (!dshape_ncdhw[2].as()) { - oshape.Set(3, (dshape_ncdhw[3] + pad_h - - dilated_ksize_y) / param->strides[1] + 1); + oshape.Set(3, (dshape_ncdhw[3] + pad_h - dilated_ksize_y) / param->strides[1] + 1); } else { oshape.Set(3, dshape_ncdhw[3]); } if (!dshape_ncdhw[4].as()) { - oshape.Set(4, (dshape_ncdhw[4] + pad_w - - dilated_ksize_x) / param->strides[2] + 1); + oshape.Set(4, (dshape_ncdhw[4] + pad_w - dilated_ksize_x) / param->strides[2] + 1); } else { oshape.Set(4, dshape_ncdhw[4]); } @@ -619,12 +602,9 @@ bool Conv3DWinogradRel(const Array& types, return true; } - // Transposed convolution shape relations template -bool Conv1DTransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -641,19 +621,19 @@ bool Conv1DTransposeRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCW." - << " But got " << out_layout; + << "Conv only support output layouts that are convertible from NCW." + << " But got " << out_layout; IndexExpr channels, dilated_ksize_y, dilated_ksize_x; @@ -664,9 +644,8 @@ bool Conv1DTransposeRel(const Array& types, CHECK_EQ(param->kernel_size.size(), 1); CHECK_EQ(param->dilation.size(), 1); - Array wshape({dshape_ncw[1], - indexdiv(param->channels, param->groups), - param->kernel_size[0]}); + Array wshape( + {dshape_ncw[1], indexdiv(param->channels, param->groups), param->kernel_size[0]}); wshape = trans_kernel_layout.BackwardShape(wshape); dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; @@ -683,14 +662,12 @@ bool Conv1DTransposeRel(const Array& types, // check the size CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) << "Conv1D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << Array(wshape); + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); } if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[1])) << "Conv1D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << Array(wshape); + << " channels=" << param->channels << " wshape=" << Array(wshape); } CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); channels = wshape[1]; @@ -700,8 +677,8 @@ bool Conv1DTransposeRel(const Array& types, IndexExpr pad_w; GetPaddingWidth(param->padding, &pad_w); Array oshape({dshape_ncw[0], channels, 0}); - oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - - pad_w + param->output_padding[0])); + oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + + param->output_padding[0])); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -712,11 +689,8 @@ bool Conv1DTransposeRel(const Array& types, return true; } - template -bool Conv2DTransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -733,19 +707,19 @@ bool Conv2DTransposeRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCHW." - << " But got " << out_layout; + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; IndexExpr channels, dilated_ksize_y, dilated_ksize_x; @@ -756,10 +730,8 @@ bool Conv2DTransposeRel(const Array& types, CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->dilation.size(), 2); - Array wshape({dshape_nchw[1], - indexdiv(param->channels, param->groups), - param->kernel_size[0], - param->kernel_size[1]}); + Array wshape({dshape_nchw[1], indexdiv(param->channels, param->groups), + param->kernel_size[0], param->kernel_size[1]}); wshape = trans_kernel_layout.BackwardShape(wshape); dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; @@ -778,14 +750,12 @@ bool Conv2DTransposeRel(const Array& types, CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && reporter->AssertEQ(param->kernel_size[1], wshape[3])) << "Conv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << Array(wshape); + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); } if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[1])) << "Conv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << Array(wshape); + << " channels=" << param->channels << " wshape=" << Array(wshape); } CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); channels = wshape[1]; @@ -796,10 +766,10 @@ bool Conv2DTransposeRel(const Array& types, Array oshape({dshape_nchw[0], channels, 0, 0}); IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - - pad_h + param->output_padding[0])); - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - - pad_w + param->output_padding[1])); + oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + + param->output_padding[0])); + oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + + param->output_padding[1])); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -810,7 +780,6 @@ bool Conv2DTransposeRel(const Array& types, return true; } - // Deformable Convolution shape relations. template bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -830,11 +799,8 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& if (param->kernel_size.defined() && param->channels.defined()) { CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->dilation.size(), 2); - Array wshape( - {param->channels, - indexdiv(data->shape[1], param->groups), - param->kernel_size[0], - param->kernel_size[1]}); + Array wshape({param->channels, indexdiv(data->shape[1], param->groups), + param->kernel_size[0], param->kernel_size[1]}); channels = param->channels; ksize_y = param->kernel_size[0]; ksize_x = param->kernel_size[1]; @@ -852,14 +818,12 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && reporter->AssertEQ(param->kernel_size[1], wshape[3])) << "DeformableConv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << wshape; + << " kernel_size=" << param->kernel_size << " wshape=" << wshape; } if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[0])) << "DeformableConv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << wshape; + << " channels=" << param->channels << " wshape=" << wshape; } CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); channels = wshape[0]; @@ -873,15 +837,13 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); - oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); + oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); + oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); DataType out_dtype = param->out_dtype; // infer offset shape - Array offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, - oshape[2], oshape[3]}); + Array offset_shape( + {data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]}); reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); if (out_dtype.bits() == 0) { out_dtype = data->dtype; @@ -891,23 +853,20 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& return true; } - -template -Array > ConvInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +template +Array > ConvInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { const T* params = attrs.as(); // We always make other operators to fit the layouts of convolution layers // So this inference ignores all inputs - return Array >{{params->data_layout, params->kernel_layout}, - {params->out_layout == "" ? - params->data_layout : params->out_layout}}; + return Array >{ + {params->data_layout, params->kernel_layout}, + {params->out_layout == "" ? params->data_layout : params->out_layout}}; } - } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_CONVOLUTION_H_ diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 33a9235..670878d 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -22,20 +22,23 @@ * \brief Property def of nn operators. */ -#include -#include -#include -#include +#include "nn.h" + #include #include -#include #include -#include +#include +#include +#include +#include +#include + #include -#include "../type_relations.h" +#include + #include "../../transforms/infer_layout_util.h" #include "../op_common.h" -#include "nn.h" +#include "../type_relations.h" namespace tvm { namespace relay { @@ -43,9 +46,7 @@ namespace relay { // relay.nn.bias_add TVM_REGISTER_NODE_TYPE(BiasAddAttrs); -bool BiasAddRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -61,45 +62,36 @@ bool BiasAddRel(const Array& types, << "axis " << param->axis << " is out of range"; // assign output type - reporter->Assign(types[1], TensorType( - {data->shape[axis]}, data->dtype)); + reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); reporter->Assign(types[2], types[0]); return true; } - // Positional relay function to create dense operator used by frontend FFI. -Expr MakeBiasAdd(Expr data, - Expr bias, - int axis) { +Expr MakeBiasAdd(Expr data, Expr bias, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.bias_add"); return Call(op, {data, bias}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add") -.set_body_typed(MakeBiasAdd); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add").set_body_typed(MakeBiasAdd); RELAY_REGISTER_OP("nn.bias_add") -.describe(R"code(Add bias to an axis of the input. + .describe(R"code(Add bias to an axis of the input. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "nD Tensor", "Input data.") -.add_argument("bias", "1D Tensor", "Bias.") -.set_support_level(1) -.add_type_rel("BiasAdd", BiasAddRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - return tvm::Array{topi::nn::bias_add(inputs[0], inputs[1], param->axis)}; -}); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("bias", "1D Tensor", "Bias.") + .set_support_level(1) + .add_type_rel("BiasAdd", BiasAddRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return tvm::Array{topi::nn::bias_add(inputs[0], inputs[1], param->axis)}; + }); // relay.nn.fifo_buffer TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs); @@ -111,9 +103,7 @@ Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) { return Call(op, {input, buffer}, Attrs(attrs), {}); } -bool FIFOBufferRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool FIFOBufferRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* input = types[0].as(); @@ -125,9 +115,8 @@ bool FIFOBufferRel(const Array& types, CHECK(param != nullptr); CHECK_EQ(input->shape.size(), buffer->shape.size()); - const size_t buffer_axis - = static_cast(param->axis < 0 ? static_cast(buffer->shape.size()) + param->axis - : param->axis); + const size_t buffer_axis = static_cast( + param->axis < 0 ? static_cast(buffer->shape.size()) + param->axis : param->axis); reporter->Assert(buffer_axis < buffer->shape.size()); for (size_t i = 0; i < buffer->shape.size(); ++i) { @@ -143,11 +132,10 @@ bool FIFOBufferRel(const Array& types, return true; } -TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer") -.set_body_typed(MakeFIFOBuffer); +TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer").set_body_typed(MakeFIFOBuffer); RELAY_REGISTER_OP("nn.fifo_buffer") -.describe(R"code(FIFO buffer + .describe(R"code(FIFO buffer Compute equivalent of ``` @@ -159,23 +147,18 @@ Useful for * Encoding explicit re-use of computation in convolution ops operated on a sliding window input * Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "Latest input") -.add_argument("buffer", "Tensor", - "Buffer storing latest [length_buffer] inputs") -.set_support_level(3) -.add_type_rel("FIFOBuffer", FIFOBufferRel); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "Latest input") + .add_argument("buffer", "Tensor", "Buffer storing latest [length_buffer] inputs") + .set_support_level(3) + .add_type_rel("FIFOBuffer", FIFOBufferRel); // relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); // Positional relay function to create dense operator used by frontend FFI. -Expr MakeDense(Expr data, - Expr weight, - IndexExpr units, - DataType out_dtype) { +Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; @@ -183,70 +166,58 @@ Expr MakeDense(Expr data, return Call(op, {data, weight}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.dense") -.set_body_typed(MakeDense); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense); RELAY_REGISTER_OP("nn.dense") -.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "nD Tensor", "Input data.") -.add_argument("weight", "2D Tensor", "Weight matrix.") -.set_support_level(1) -.add_type_rel("Dense", DenseRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("weight", "2D Tensor", "Weight matrix.") + .set_support_level(1) + .add_type_rel("Dense", DenseRel); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); // Positional relay function to create leaky relu operator used by frontend FFI. -Expr MakeLeakyRelu(Expr data, - double alpha) { +Expr MakeLeakyRelu(Expr data, double alpha) { auto attrs = make_object(); attrs->alpha = alpha; static const Op& op = Op::Get("nn.leaky_relu"); return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu") -.set_body_typed(MakeLeakyRelu); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu").set_body_typed(MakeLeakyRelu); RELAY_REGISTER_OP("nn.leaky_relu") -.describe(R"code(Leaky version of a Rectified Linear Unit. + .describe(R"code(Leaky version of a Rectified Linear Unit. `y = x > 0 ? x : alpha * x` )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "Input data.") -.set_support_level(3) -.add_type_rel("Identity", IdentityRel) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - return Array{ topi::leaky_relu(inputs[0], param->alpha) }; -}); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input data.") + .set_support_level(3) + .add_type_rel("Identity", IdentityRel) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return Array{topi::leaky_relu(inputs[0], param->alpha)}; + }); // relay.prelu TVM_REGISTER_NODE_TYPE(PReluAttrs); -bool PReluRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool PReluRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -256,7 +227,7 @@ bool PReluRel(const Array& types, CHECK(param != nullptr); CHECK(param->axis < static_cast(data->shape.size())) - << "Wrong axis (" << param->axis << ")value."; + << "Wrong axis (" << param->axis << ")value."; // assign alpha type Array alpha_shape({data->shape[param->axis]}); @@ -267,72 +238,59 @@ bool PReluRel(const Array& types, return true; } -template -Array > PReluInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { - +template +Array> PReluInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { CHECK_EQ(old_in_layouts.size(), 2U); CHECK_EQ(old_in_types.size(), 2U); Layout data_layout = old_in_layouts[0]; if (new_in_layouts.defined()) { CHECK_EQ(new_in_layouts.size(), 2U); } - return Array >{{data_layout, Layout("C")}, - {data_layout}}; + return Array>{{data_layout, Layout("C")}, {data_layout}}; } // Positional relay function to create prelu operator used by frontend FFI. -Expr MakePRelu(Expr data, - Expr alpha, - int axis) { +Expr MakePRelu(Expr data, Expr alpha, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.prelu"); return Call(op, {data, alpha}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu") -.set_body_typed(MakePRelu); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu").set_body_typed(MakePRelu); RELAY_REGISTER_OP("nn.prelu") -.describe(R"code(Parametric version of a Rectified Linear Unit. + .describe(R"code(Parametric version of a Rectified Linear Unit. It accepts two arguments: an input ``x`` and a channelwise slope ``alpha`` and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`, where :math:`*` is an channelwise multiplication for each sample in the batch. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "Input data.") -.add_argument("alpha", "Tensor", "Input channelwise alpha.") -.set_support_level(3) -.add_type_rel("PRelu", PReluRel) -.set_attr("FInferCorrectLayout", PReluInferCorrectLayout) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - return Array{ topi::prelu(inputs[0], inputs[1], param->axis)}; -}); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "Input data.") + .add_argument("alpha", "Tensor", "Input channelwise alpha.") + .set_support_level(3) + .add_type_rel("PRelu", PReluRel) + .set_attr("FInferCorrectLayout", PReluInferCorrectLayout) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return Array{topi::prelu(inputs[0], inputs[1], param->axis)}; + }); // relay.softmax TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); -TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax") -.set_body_typed([](Expr data, int axis) { +TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax").set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.softmax"); return Call(op, {data}, Attrs(attrs), {}); }); - RELAY_REGISTER_OP("nn.softmax") .describe(R"code(Softmax layer. @@ -343,16 +301,14 @@ RELAY_REGISTER_OP("nn.softmax") - **data**: The input data )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("Identity", IdentityRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("Identity", IdentityRel); // relay.nn.log_softmax -TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax") -.set_body_typed([](Expr data, int axis) { +TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.log_softmax"); @@ -369,26 +325,22 @@ RELAY_REGISTER_OP("nn.log_softmax") - **data**: The input data )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("Identity", IdentityRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - CHECK(param != nullptr); - CHECK(param->axis == -1 || param->axis == static_cast(inputs[0].ndim()) - 1) - << "log_softmax currently only works on last dimension"; - return Array{ topi::nn::log_softmax(inputs[0]) }; -}); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("Identity", IdentityRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + CHECK(param->axis == -1 || param->axis == static_cast(inputs[0].ndim()) - 1) + << "log_softmax currently only works on last dimension"; + return Array{topi::nn::log_softmax(inputs[0])}; + }); // relay.nn.batch_flatten -bool BatchFlattenRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BatchFlattenRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -418,13 +370,10 @@ Expr MakeBatchFlatten(Expr data) { return Call(op, {data}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten") -.set_body_typed(MakeBatchFlatten); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten").set_body_typed(MakeBatchFlatten); RELAY_REGISTER_OP("nn.batch_flatten") -.describe(R"code(Flattens the input into a 2-D array. + .describe(R"code(Flattens the input into a 2-D array. For an input array with shape ``(d1, d2, ..., dk)``, `batch_flatten` operation reshapes the input array into an output array of shape ``(d1, d2*...*dk)``. @@ -445,53 +394,42 @@ Example:: [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]] )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("BatchFlatten", BatchFlattenRel) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - return Array{ topi::nn::flatten(inputs[0]) }; -}); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("BatchFlatten", BatchFlattenRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + return Array{topi::nn::flatten(inputs[0])}; + }); // relu -TVM_REGISTER_GLOBAL("relay.op.nn._make.relu") -.set_body_typed([](Expr data) { - static const Op& op = Op::Get("nn.relu"); - return Call(op, {data}, Attrs(), {}); - }); +TVM_REGISTER_GLOBAL("relay.op.nn._make.relu").set_body_typed([](Expr data) { + static const Op& op = Op::Get("nn.relu"); + return Call(op, {data}, Attrs(), {}); +}); RELAY_REGISTER_OP("nn.relu") -.describe(R"code(Returns the relu input array, computed element-wise. + .describe(R"code(Returns the relu input array, computed element-wise. .. math:: max(x, 0) )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("Identity", IdentityRel) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - return Array{ topi::relu(inputs[0], 0.0f) }; -}); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("Identity", IdentityRel) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + return Array{topi::relu(inputs[0], 0.0f)}; + }); // Positional relay function to create LRN operator used by frontend FFI. TVM_REGISTER_NODE_TYPE(LRNAttrs); -Expr MakeLRN(Expr data, - int size, - int axis, - double alpha, - double beta, - double bias) { +Expr MakeLRN(Expr data, int size, int axis, double alpha, double beta, double bias) { auto attrs = make_object(); attrs->size = size; attrs->axis = axis; @@ -502,11 +440,10 @@ Expr MakeLRN(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn") -.set_body_typed(MakeLRN); +TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn").set_body_typed(MakeLRN); RELAY_REGISTER_OP("nn.lrn") -.describe(R"code(LRN layer. + .describe(R"code(LRN layer. Normalize the input in a local region across or within feature maps. Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta, @@ -519,19 +456,16 @@ centered at that value (zero padding is added where necessary). - **data**: The input tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("Identity", IdentityRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("Identity", IdentityRel); // Positional relay function to create L2Normalize operator used by frontend FFI. TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); -Expr MakeL2Normalize(Expr data, - double eps, - Array axis) { +Expr MakeL2Normalize(Expr data, double eps, Array axis) { auto attrs = make_object(); attrs->eps = eps; attrs->axis = std::move(axis); @@ -539,11 +473,10 @@ Expr MakeL2Normalize(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize") -.set_body_typed(MakeL2Normalize); +TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize").set_body_typed(MakeL2Normalize); RELAY_REGISTER_OP("nn.l2_normalize") -.describe(R"code(L2 Normalization layer. + .describe(R"code(L2 Normalization layer. Normalizes along dimension axis using an L2 norm @@ -552,19 +485,17 @@ Normalizes along dimension axis using an L2 norm - **data**: The input tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.add_type_rel("Identity", IdentityRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .add_type_rel("Identity", IdentityRel); // Dropout TVM_REGISTER_NODE_TYPE(DropoutAttrs); -bool DropoutRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool DropoutRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -584,22 +515,21 @@ Expr MakeDropout(Expr data, double rate) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout") -.set_body_typed(MakeDropout); +TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout").set_body_typed(MakeDropout); RELAY_REGISTER_OP("nn.dropout") -.describe(R"code(Applies the dropout operation to the input array. + .describe(R"code(Applies the dropout operation to the input array. During training, each element of the input is set to zero with probability ``p``. The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "Input to which dropout will be applied.") -.set_support_level(1) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.add_type_rel("Dropout", DropoutRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input to which dropout will be applied.") + .set_support_level(1) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .add_type_rel("Dropout", DropoutRel); // batch_norm TVM_REGISTER_NODE_TYPE(BatchNormAttrs); @@ -638,9 +568,7 @@ Array> BatchNormInferCorrectLayout(const Attrs& attrs, {ret, c_layout, c_layout}}; } -bool BatchNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BatchNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 6); const auto* data = types[0].as(); @@ -662,8 +590,7 @@ bool BatchNormRel(const Array& types, // output is a tuple of the normed data (same shape as input), new running mean, // and new running average (the latter two are both vectors of length dim) std::vector fields; - auto vec_ty = TensorType(Array({data->shape[axis]}), - data->dtype); + auto vec_ty = TensorType(Array({data->shape[axis]}), data->dtype); fields.push_back(TensorType(data->shape, data->dtype)); fields.push_back(vec_ty); fields.push_back(vec_ty); @@ -671,8 +598,8 @@ bool BatchNormRel(const Array& types, return true; } -Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, - int axis, double epsilon, bool center, bool scale) { +Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, int axis, + double epsilon, bool center, bool scale) { auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -682,11 +609,10 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi return Call(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm") -.set_body_typed(MakeBatchNorm); +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm").set_body_typed(MakeBatchNorm); RELAY_REGISTER_OP("nn.batch_norm") -.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014). + .describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014). Normalizes the input at each batch, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1. @@ -722,24 +648,21 @@ axis to be the last item in the input shape. .. note:: This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(5) -.add_argument("data", "Tensor", "Input to which batch_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.add_argument("moving_mean", "Tensor", "Running mean of input.") -.add_argument("moving_var", "Tensor", "Running variance of input.") -.set_attr("FInferCorrectLayout", BatchNormInferCorrectLayout) -.set_support_level(1) -.add_type_rel("BatchNorm", BatchNormRel); - + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .add_argument("moving_mean", "Tensor", "Running mean of input.") + .add_argument("moving_var", "Tensor", "Running variance of input.") + .set_attr("FInferCorrectLayout", BatchNormInferCorrectLayout) + .set_support_level(1) + .add_type_rel("BatchNorm", BatchNormRel); // instance_norm TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); -bool InstanceNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool InstanceNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); @@ -754,8 +677,8 @@ bool InstanceNormRel(const Array& types, return true; } -Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, - bool center, bool scale) { +Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, + bool scale) { auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -766,12 +689,12 @@ Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon } TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeInstanceNorm, args, rv); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeInstanceNorm, args, rv); + }); RELAY_REGISTER_OP("nn.instance_norm") -.describe(R"code(Instance Normalization (Ulyanov and et al., 2016) + .describe(R"code(Instance Normalization (Ulyanov and et al., 2016) Applies instance normalization to the n-dimensional input array. .. math:: @@ -795,21 +718,18 @@ to be the last item in the input shape. This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "Input to which instance_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.set_support_level(1) -.add_type_rel("InstanceNorm", InstanceNormRel); - + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which instance_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_support_level(1) + .add_type_rel("InstanceNorm", InstanceNormRel); // layer_norm TVM_REGISTER_NODE_TYPE(LayerNormAttrs); -bool LayerNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool LayerNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); @@ -824,8 +744,8 @@ bool LayerNormRel(const Array& types, return true; } -Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, - bool center, bool scale) { +Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, + bool scale) { auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -836,27 +756,25 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, } TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeLayerNorm, args, rv); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeLayerNorm, args, rv); + }); RELAY_REGISTER_OP("nn.layer_norm") -.describe(R"code( + .describe(R"code( )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "Input to which layer_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.set_support_level(1) -.add_type_rel("LayerNorm", LayerNormRel); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which layer_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_support_level(1) + .add_type_rel("LayerNorm", LayerNormRel); // group_norm TVM_REGISTER_NODE_TYPE(GroupNormAttrs); -bool GroupNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GroupNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); @@ -871,10 +789,10 @@ bool GroupNormRel(const Array& types, return true; } -Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, - int axis, double epsilon, bool center, bool scale) { +Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, int axis, double epsilon, + bool center, bool scale) { auto attrs = make_object(); - attrs->num_groups = num_groups; + attrs->num_groups = num_groups; attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -884,12 +802,12 @@ Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, } TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeGroupNorm, args, rv); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeGroupNorm, args, rv); + }); RELAY_REGISTER_OP("nn.group_norm") -.describe(R"code( + .describe(R"code( Group normalization normalizes over group of channels for each training examples. We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put all the channels into a single group, group normalization becomes Layer normalization. @@ -916,19 +834,16 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,). This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "Input to which group_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.set_support_level(1) -.add_type_rel("GroupNorm", GroupNormRel); - + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which group_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_support_level(1) + .add_type_rel("GroupNorm", GroupNormRel); // relay.nn.batch_matmul -bool BatchMatmulRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* x = types[0].as(); @@ -937,12 +852,10 @@ bool BatchMatmulRel(const Array& types, CHECK(x->shape.size() == 3 && y->shape.size() == 3); CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "BatchDot: batch dimension doesn't match, " - << " x shape=" << x->shape - << ", y shape=" << y->shape; + << " x shape=" << x->shape << ", y shape=" << y->shape; CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape - << ", y shape=" << y->shape; + << " x shape=" << x->shape << ", y shape=" << y->shape; Array oshape = x->shape; oshape.Set(2, y->shape[1]); @@ -952,21 +865,16 @@ bool BatchMatmulRel(const Array& types, return true; } - // Positional relay function to create batch_matmul operator used by frontend FFI. -Expr MakeBatchMatmul(Expr x, - Expr y) { +Expr MakeBatchMatmul(Expr x, Expr y) { static const Op& op = Op::Get("nn.batch_matmul"); return Call(op, {x, y}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul") -.set_body_typed(MakeBatchMatmul); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul); RELAY_REGISTER_OP("nn.batch_matmul") -.describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` + .describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` are data in batch. .. math:: @@ -978,34 +886,31 @@ are data in batch. - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("x", "3D Tensor", "First input.") -.add_argument("y", "3D Tensor", "Second input.") -.set_support_level(10) -.add_type_rel("BatchMatmul", BatchMatmulRel); - + .set_num_inputs(2) + .add_argument("x", "3D Tensor", "First input.") + .add_argument("y", "3D Tensor", "Second input.") + .set_support_level(10) + .add_type_rel("BatchMatmul", BatchMatmulRel); // relay.nn.cross_entropy -bool CrossEntropyRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { +bool CrossEntropyRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* x = types[0].as(); const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 2 && y->shape.size() == 2) - << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape = " << x->shape << ", " - << "y shape = " << y->shape; + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) - << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape = " << x->shape << ", " - << "y shape = " << y->shape; + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[1], y->shape[1])) - << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape = " << x->shape << ", " - << "y shape = " << y->shape; + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; // assign output type reporter->Assign(types[2], TensorType({}, x->dtype)); return true; @@ -1017,29 +922,23 @@ Expr MakeCrossEntropy(Expr predictions, Expr targets) { return Call(op, {predictions, targets}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy") -.set_body_typed(MakeCrossEntropy); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy").set_body_typed(MakeCrossEntropy); RELAY_REGISTER_OP("nn.cross_entropy") -.describe(R"code( + .describe(R"code( Computes cross entropy given predictions and targets. Do log on the data - do not accept logits. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("x", "1D Tensor", "Predictions.") -.add_argument("y", "1D Tensor", "Targets.") -.set_support_level(10) -.add_type_rel("CrossEntropy", CrossEntropyRel); - + .set_num_inputs(2) + .add_argument("x", "1D Tensor", "Predictions.") + .add_argument("y", "1D Tensor", "Targets.") + .set_support_level(10) + .add_type_rel("CrossEntropy", CrossEntropyRel); // relay.nn.dilate TVM_REGISTER_NODE_TYPE(DilateAttrs); -bool DilateRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool DilateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* x = types[0].as(); @@ -1068,19 +967,16 @@ Expr MakeDilate(Expr data, Array strides) { return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate") -.set_body_typed(MakeDilate); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate").set_body_typed(MakeDilate); RELAY_REGISTER_OP("nn.dilate") -.describe(R"code( + .describe(R"code( Dilate data with zeros. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("x", "1D Tensor", "Data to dilate.") -.set_support_level(10) -.add_type_rel("Dilate", DilateRel); + .set_num_inputs(1) + .add_argument("x", "1D Tensor", "Data to dilate.") + .set_support_level(10) + .add_type_rel("Dilate", DilateRel); // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { @@ -1088,21 +984,19 @@ Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { return Call(op, {predictions, targets}, Attrs(), {}); } - TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy_with_logits") -.set_body_typed(MakeCrossEntropyWithLogits); - + .set_body_typed(MakeCrossEntropyWithLogits); RELAY_REGISTER_OP("nn.cross_entropy_with_logits") -.describe(R"code( + .describe(R"code( Computes cross entropy given predictions and targets. Accept logits. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("x", "1D Tensor", "Predictions.") -.add_argument("y", "1D Tensor", "Targets.") -.set_support_level(10) -.add_type_rel("CrossEntropy", CrossEntropyRel); + .set_num_inputs(2) + .add_argument("x", "1D Tensor", "Predictions.") + .add_argument("y", "1D Tensor", "Targets.") + .set_support_level(10) + .add_type_rel("CrossEntropy", CrossEntropyRel); // Depth to space and space to depth TVM_REGISTER_NODE_TYPE(SubPixelAttrs); @@ -1130,8 +1024,7 @@ bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attr oshape.Set(3, oshape[3] * block_size); // Assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } @@ -1188,8 +1081,7 @@ bool SpaceToDepthRel(const Array& types, int num_inputs, const Attrs& attr oshape.Set(3, indexdiv(oshape[3], block_size)); // Assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index dc876e8..0fb0263 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -24,6 +24,11 @@ #ifndef TVM_RELAY_OP_NN_NN_H_ #define TVM_RELAY_OP_NN_NN_H_ +#include +#include +#include +#include + #include namespace tvm { @@ -58,8 +63,7 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, if (weight == nullptr) return false; Array wshape = weight->shape; CHECK(static_cast(weight->shape.size()) == 2); - CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], - weight->shape[1])) + CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) << "DenseRel: input dimension doesn't match," << " data shape=" << data->shape << ", weight shape=" << weight->shape; oshape.Set((oshape.size() - 1), wshape[0]); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index abff06e..e416a06 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -21,12 +21,14 @@ * \file pad.cc * \brief Implementation of operator pad */ +#include +#include +#include #include #include -#include -#include -#include + #include + #include "../op_common.h" namespace tvm { @@ -35,13 +37,11 @@ namespace relay { // relay.nn.pad TVM_REGISTER_NODE_TYPE(PadAttrs); -Array > PadInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +Array> PadInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. - PadAttrs *params = const_cast(attrs.as()); + PadAttrs* params = const_cast(attrs.as()); Layout ret; // If new_in_layouts are defined, this code tries to modify the layout. @@ -108,12 +108,10 @@ Array > PadInferCorrectLayout( } } - return Array >{{ret}, {ret}}; + return Array>{{ret}, {ret}}; } -bool PadRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool PadRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -124,28 +122,26 @@ bool PadRel(const Array& types, // check that pad widths match lengths CHECK(data->shape.size() == param->pad_width.size()) - << "There should be as many pad width pairs as shape dimensions " - << "but the shape has " << data->shape.size() << " dimensions " - << "and there are " << param->pad_width.size() << " pad width pairs."; + << "There should be as many pad width pairs as shape dimensions " + << "but the shape has " << data->shape.size() << " dimensions " + << "and there are " << param->pad_width.size() << " pad width pairs."; // each pad width element should be a pair of positive integers std::vector oshape; for (size_t i = 0; i < param->pad_width.size(); i++) { CHECK(param->pad_width[i].size() == 2) - << "Each pad width element should be a pair but at index " << i - << " there are " << param->pad_width[i].size() << " elements."; + << "Each pad width element should be a pair but at index " << i << " there are " + << param->pad_width[i].size() << " elements."; auto width1 = tir::as_const_int(param->pad_width[i][0]); auto width2 = tir::as_const_int(param->pad_width[i][1]); CHECK(width1 != nullptr); CHECK(width2 != nullptr); - CHECK(*width1 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width1 << "."; - CHECK(*width2 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width2 << "."; + CHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width1 << "."; + CHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width2 << "."; if (!data->shape[i].as()) { auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); @@ -155,21 +151,17 @@ bool PadRel(const Array& types, } } - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } -Array PadCompute(const Attrs& attrs, - const Array& inputs, +Array PadCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); auto pad_width = param->pad_width; - CHECK(pad_width.size() == inputs[0].ndim() && - pad_width[0].size() == 2) - << "Illegal pad_width"; + CHECK(pad_width.size() == inputs[0].ndim() && pad_width[0].size() == 2) << "Illegal pad_width"; Array pad_before; for (size_t i = 0; i < pad_width.size(); ++i) { pad_before.push_back(pad_width[i][0]); @@ -179,18 +171,13 @@ Array PadCompute(const Attrs& attrs, pad_after.push_back(pad_width[i][1]); } const auto* out_ttype = out_type.as(); - return Array{ topi::pad(inputs[0], pad_before, pad_after, - tvm::tir::make_const(out_ttype->dtype, param->pad_value), - "T_pad", - topi::kElementWise, - param->pad_mode) }; + return Array{topi::pad(inputs[0], pad_before, pad_after, + tvm::tir::make_const(out_ttype->dtype, param->pad_value), + "T_pad", topi::kElementWise, param->pad_mode)}; } // Handler to create a call to the padding op used by front-end FFI -Expr MakePad(Expr data, - Array > pad_width, - double pad_value, - std::string pad_mode) { +Expr MakePad(Expr data, Array> pad_width, double pad_value, std::string pad_mode) { auto attrs = make_object(); attrs->pad_value = pad_value; attrs->pad_width = std::move(pad_width); @@ -199,29 +186,25 @@ Expr MakePad(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.pad") -.set_body_typed(MakePad); +TVM_REGISTER_GLOBAL("relay.op.nn._make.pad").set_body_typed(MakePad); RELAY_REGISTER_OP("nn.pad") -.describe(R"code(Pad for n-D tensor. + .describe(R"code(Pad for n-D tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("Pad", PadRel) -.set_attr("FInferCorrectLayout", PadInferCorrectLayout) -.set_attr("TOpPattern", kInjective) -.set_attr("FTVMCompute", PadCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("Pad", PadRel) + .set_attr("FInferCorrectLayout", PadInferCorrectLayout) + .set_attr("TOpPattern", kInjective) + .set_attr("FTVMCompute", PadCompute); // relay.nn.mirror_pad TVM_REGISTER_NODE_TYPE(MirrorPadAttrs); -bool MirrorPadRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool MirrorPadRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -232,40 +215,37 @@ bool MirrorPadRel(const Array& types, // check that pad widths match lengths CHECK(data->shape.size() == param->pad_width.size()) - << "There should be as many pad width pairs as shape dimensions " - << "but the shape has " << data->shape.size() << " dimensions " - << "and there are " << param->pad_width.size() << " pad width pairs."; + << "There should be as many pad width pairs as shape dimensions " + << "but the shape has " << data->shape.size() << " dimensions " + << "and there are " << param->pad_width.size() << " pad width pairs."; // each pad width element should be a pair of positive integers std::vector oshape; for (size_t i = 0; i < param->pad_width.size(); i++) { CHECK(param->pad_width[i].size() == 2) - << "Each pad width element should be a pair but at index " << i - << " there are " << param->pad_width[i].size() << " elements."; + << "Each pad width element should be a pair but at index " << i << " there are " + << param->pad_width[i].size() << " elements."; auto width1 = tir::as_const_int(param->pad_width[i][0]); auto width2 = tir::as_const_int(param->pad_width[i][1]); CHECK(width1 != nullptr); CHECK(width2 != nullptr); - CHECK(*width1 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width1 << "."; - CHECK(*width2 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width2 << "."; + CHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width1 << "."; + CHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width2 << "."; auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); oshape.push_back(data->shape[i] + padding); } - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } // Handler to create a call to the padding op used by front-end FFI -Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mode) { +Expr MakeMirrorPad(Expr data, Array> pad_width, std::string mode) { auto attrs = make_object(); attrs->mode = mode; attrs->pad_width = std::move(pad_width); @@ -273,19 +253,18 @@ Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mo return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad") -.set_body_typed(MakeMirrorPad); +TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad").set_body_typed(MakeMirrorPad); RELAY_REGISTER_OP("nn.mirror_pad") -.describe(R"code(MirrorPad for n-D tensor. + .describe(R"code(MirrorPad for n-D tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MirrorPad", MirrorPadRel) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MirrorPad", MirrorPadRel) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index c20793d..dd64951 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -21,12 +21,14 @@ * \file pooling.cc * \brief Pooling operators */ -#include +#include +#include #include #include -#include -#include +#include + #include + #include "../../transforms/infer_layout_util.h" namespace tvm { @@ -37,13 +39,12 @@ TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); template -Array > PoolInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +Array > PoolInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. - T *params = const_cast(attrs.as()); + T* params = const_cast(attrs.as()); if (new_in_layouts.defined()) { // Set the pool with the new layout. @@ -56,12 +57,8 @@ Array > PoolInferCorrectLayout( } template -Expr MakeMaxPool(Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, +Expr MakeMaxPool(Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, std::string op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); @@ -74,14 +71,9 @@ Expr MakeMaxPool(Expr data, } template -Expr MakeAvgPool(Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad, - std::string op_name) { +Expr MakeAvgPool(Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, + bool count_include_pad, std::string op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -94,9 +86,7 @@ Expr MakeAvgPool(Expr data, } template -bool Pool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Pool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -112,8 +102,7 @@ bool Pool2DRel(const Array& types, Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; + << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split"; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); @@ -140,8 +129,9 @@ bool Pool2DRel(const Array& types, oshape[hidx] = dshape[hidx]; } else { if (param->ceil_mode) { - oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + - param->strides[0] - 1) / param->strides[0]) + 1; + oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + param->strides[0] - 1) / + param->strides[0]) + + 1; } else { oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1; } @@ -150,8 +140,9 @@ bool Pool2DRel(const Array& types, oshape[widx] = dshape[widx]; } else { if (param->ceil_mode) { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + - param->strides[1] - 1) / param->strides[1]) + 1; + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + param->strides[1] - 1) / + param->strides[1]) + + 1; } else { oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1; } @@ -162,9 +153,8 @@ bool Pool2DRel(const Array& types, return true; } -template -Array Pool2DCompute(const Attrs& attrs, - const Array& inputs, +template +Array Pool2DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); @@ -182,9 +172,7 @@ Array Pool2DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool2d does not support input split on width"; - CHECK(inputs[0].ndim() == 4U || - inputs[0].ndim() == 5U || - inputs[0].ndim() == 6U) + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) << "Pool2D only support 4-D input (e.g., NCHW)" << " or 5-D input (e.g. NCHWc on for vector instructions)" << " or 6-D input (e.g. NCHWnc for tensor accelerators)"; @@ -199,30 +187,23 @@ Array Pool2DCompute(const Attrs& attrs, } if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{ - topi::nn::pool(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode, + layout.name(), count_include_pad)}; } else { return Array{ - topi::nn::pool(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool2d"); -}); - + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool2d"); + }); RELAY_REGISTER_OP("nn.max_pool2d") -.describe(R"code(Max pooling operation for two dimensional data. + .describe(R"code(Max pooling operation for two dimensional data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -242,30 +223,25 @@ RELAY_REGISTER_OP("nn.max_pool2d") equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool2D", Pool2DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool2DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool2D", Pool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool2DCompute); // AvgPool2D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool2d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool2d"); + }); RELAY_REGISTER_OP("nn.avg_pool2d") -.describe(R"code( + .describe(R"code( Average pooling operation for one dimensional data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape @@ -286,24 +262,24 @@ Average pooling operation for one dimensional data. equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AvgPool2D", Pool2DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool2DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("AvgPool2D", Pool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool2DCompute); // relay.nn.global_pool_2d & relay.nn.max_pool_2d TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs); -bool GlobalPool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GlobalPool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - if (data == nullptr) { return false; } + if (data == nullptr) { + return false; + } const auto dshape = data->shape; CHECK_GE(dshape.size(), 2U) << "Pool2D only support input >= 2-D: input must have height and width"; @@ -313,8 +289,7 @@ bool GlobalPool2DRel(const Array& types, Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; + << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split"; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); @@ -327,44 +302,38 @@ bool GlobalPool2DRel(const Array& types, return true; } - -template -Array GlobalPool2DCompute(const Attrs& attrs, - const Array& inputs, +template +Array GlobalPool2DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) - << "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; + << "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "global_avg_pool2d does not support input split on height"; + << "global_avg_pool2d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "global_avg_pool2d does not support input split on width"; + << "global_avg_pool2d does not support input split on width"; CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) - << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)"; - return Array{ - topi::nn::global_pool(inputs[0], mode, layout.name()) }; + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + return Array{topi::nn::global_pool(inputs[0], mode, layout.name())}; } -Expr MakeGlobalAvgPool2D(Expr data, - std::string layout) { +Expr MakeGlobalAvgPool2D(Expr data, std::string layout) { auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_avg_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.global_avg_pool2d") -.set_body_typed(MakeGlobalAvgPool2D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.global_avg_pool2d").set_body_typed(MakeGlobalAvgPool2D); // GlobalAvgPool RELAY_REGISTER_OP("nn.global_avg_pool2d") -.describe(R"code(Global average pooling operation for 2D data. + .describe(R"code(Global average pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -372,30 +341,26 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") (batch_size, channels, 1, 1) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", GlobalPool2DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", GlobalPool2DCompute); // GlobalMaxPool -Expr MakeGlobalMaxPool2D(Expr data, - std::string layout) { +Expr MakeGlobalMaxPool2D(Expr data, std::string layout) { auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_max_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d") -.set_body_typed(MakeGlobalMaxPool2D); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d").set_body_typed(MakeGlobalMaxPool2D); RELAY_REGISTER_OP("nn.global_max_pool2d") -.describe(R"code(Global max pooling operation for 2D data. + .describe(R"code(Global max pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -403,44 +368,40 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") (batch_size, channels, 1, 1) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", GlobalPool2DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", GlobalPool2DCompute); // relay.nn.adaptive_pool_2d TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); -bool AdaptivePool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool AdaptivePool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - if (data == nullptr) { return false; } + if (data == nullptr) { + return false; + } const auto dshape = data->shape; CHECK_GE(dshape.size(), 2U) - << "Pool2D only support input >= 2-D: input must have height and width"; + << "Pool2D only support input >= 2-D: input must have height and width"; const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; + << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split"; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); Array oshape(dshape); auto output_size = param->output_size; - CHECK_LE(output_size.size(), 2U) - << "output_size can have up to 2 elements."; + CHECK_LE(output_size.size(), 2U) << "output_size can have up to 2 elements."; IndexExpr output_height, output_width; if (output_size.empty()) { output_height = dshape[hidx]; @@ -461,24 +422,23 @@ bool AdaptivePool2DRel(const Array& types, return true; } -template -Array AdaptivePool2DCompute(const Attrs& attrs, - const Array& inputs, +template +Array AdaptivePool2DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) - << "Adaptive pool2d currently only supports layouts that are convertible from NCHW"; + << "Adaptive pool2d currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "Adaptive pool2d does not support input split on height"; + << "Adaptive pool2d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "Adaptive pool2d does not support input split on width"; + << "Adaptive pool2d does not support input split on width"; CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) - << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)"; + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; auto output_size = param->output_size; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); @@ -494,15 +454,12 @@ Array AdaptivePool2DCompute(const Attrs& attrs, output_height = output_size[0]; output_width = output_size[1]; } - return Array{ - topi::nn::adaptive_pool(inputs[0], Array{ output_height, output_width }, - mode, layout.name()) }; + return Array{topi::nn::adaptive_pool( + inputs[0], Array{output_height, output_width}, mode, layout.name())}; } // relay.nn.adaptive_avg_pool2d -Expr MakeAdaptiveAvgPool2D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveAvgPool2D(Expr data, Array output_size, std::string layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -510,11 +467,10 @@ Expr MakeAdaptiveAvgPool2D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d") -.set_body_typed(MakeAdaptiveAvgPool2D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d").set_body_typed(MakeAdaptiveAvgPool2D); RELAY_REGISTER_OP("nn.adaptive_avg_pool2d") - .describe(R"code(Adaptive average pooling operation for 2D data. + .describe(R"code(Adaptive average pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -528,19 +484,17 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool2d") (batch_size, channels, output_height, output_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool2DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool2DCompute); // relay.nn.adaptive_max_pool2d -Expr MakeAdaptiveMaxPool2D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveMaxPool2D(Expr data, Array output_size, std::string layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -548,11 +502,10 @@ Expr MakeAdaptiveMaxPool2D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d") -.set_body_typed(MakeAdaptiveMaxPool2D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d").set_body_typed(MakeAdaptiveMaxPool2D); RELAY_REGISTER_OP("nn.adaptive_max_pool2d") - .describe(R"code(Adaptive max pooling operation for 2D data. + .describe(R"code(Adaptive max pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -566,45 +519,43 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool2d") (batch_size, channels, output_height, output_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool2DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool2DCompute); TVM_REGISTER_NODE_TYPE(AdaptivePool3DAttrs); -bool AdaptivePool3DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool AdaptivePool3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - if (data == nullptr) { return false; } + if (data == nullptr) { + return false; + } const auto dshape = data->shape; CHECK_GE(dshape.size(), 3U) - << "Pool3D only support input >= 3-D: input must have depth, height and width"; + << "Pool3D only support input >= 3-D: input must have depth, height and width"; const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && - !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool3D layout must have D, H and W, which cannot be split"; + !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) + << "Invalid layout " << layout + << ". Pool3D layout must have D, H and W, which cannot be split"; const auto didx = layout.IndexOf(LayoutAxis::Get('D')); const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); Array oshape(dshape); auto output_size = param->output_size; - CHECK_LE(output_size.size(), 3U) - << "output_size can have up to 3 elements."; + CHECK_LE(output_size.size(), 3U) << "output_size can have up to 3 elements."; IndexExpr output_depth, output_height, output_width; if (output_size.empty()) { output_depth = dshape[didx]; @@ -629,26 +580,25 @@ bool AdaptivePool3DRel(const Array& types, return true; } -template -Array AdaptivePool3DCompute(const Attrs& attrs, - const Array& inputs, +template +Array AdaptivePool3DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCDHW("NCDHW"); const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(tir::BijectiveLayout(layout, kNCDHW).defined()) - << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW"; + << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1) - << "Adaptive pool3d does not support input split on depth"; + << "Adaptive pool3d does not support input split on depth"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "Adaptive pool3d does not support input split on height"; + << "Adaptive pool3d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "Adaptive pool3d does not support input split on width"; + << "Adaptive pool3d does not support input split on width"; CHECK(inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) - << "Pool3D only support 5-D input (e.g., NCDHW)" - << " or 6-D input (last dimension is a split of channel)"; + << "Pool3D only support 5-D input (e.g., NCDHW)" + << " or 6-D input (last dimension is a split of channel)"; auto output_size = param->output_size; const auto didx = layout.IndexOf(LayoutAxis::Get('D')); @@ -669,16 +619,12 @@ Array AdaptivePool3DCompute(const Attrs& attrs, output_width = output_size[2]; } - auto osize = Array{ output_depth, output_height, output_width }; - return Array { - topi::nn::adaptive_pool3d(inputs[0], osize, mode, layout.name()) - }; + auto osize = Array{output_depth, output_height, output_width}; + return Array{topi::nn::adaptive_pool3d(inputs[0], osize, mode, layout.name())}; } // relay.nn.adaptive_max_pool3d -Expr MakeAdaptiveMaxPool3D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveMaxPool3D(Expr data, Array output_size, std::string layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -686,11 +632,10 @@ Expr MakeAdaptiveMaxPool3D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d") -.set_body_typed(MakeAdaptiveMaxPool3D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d").set_body_typed(MakeAdaptiveMaxPool3D); RELAY_REGISTER_OP("nn.adaptive_max_pool3d") - .describe(R"code(Adaptive max pooling operation for 3D data. + .describe(R"code(Adaptive max pooling operation for 3D data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. @@ -704,19 +649,17 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool3d") (batch_size, channels, output_depth, output_height, output_width) if `layout` is `NCDHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveMaxPool3D", AdaptivePool3DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool3DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveMaxPool3D", AdaptivePool3DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool3DCompute); // relay.nn.adaptive_max_pool3d -Expr MakeAdaptiveAvgPool3D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveAvgPool3D(Expr data, Array output_size, std::string layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -724,11 +667,10 @@ Expr MakeAdaptiveAvgPool3D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d") -.set_body_typed(MakeAdaptiveAvgPool3D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d").set_body_typed(MakeAdaptiveAvgPool3D); RELAY_REGISTER_OP("nn.adaptive_avg_pool3d") - .describe(R"code(Adaptive avg pooling operation for 3D data. + .describe(R"code(Adaptive avg pooling operation for 3D data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. - **output_size**: If this argument is not provided, input depth, height and width will be used @@ -740,15 +682,14 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool3d") - **out**: This depends on the `layout` parameter. Output is 5D array of shape (batch_size, channels, output_depth, output_height, output_width) if `layout` is `NCDHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveAvgPool3D", AdaptivePool3DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool3DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveAvgPool3D", AdaptivePool3DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool3DCompute); bool Pool2DGradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -763,8 +704,7 @@ bool Pool2DGradRel(const Array& types, int num_inputs, const Attrs& attrs, } template -Array Pool2DGradCompute(const Attrs& attrs, - const Array& inputs, +Array Pool2DGradCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); @@ -802,17 +742,18 @@ Array Pool2DGradCompute(const Attrs& attrs, if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; return Array{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + mode, ceil_mode, layout.name(), + count_include_pad)}; } else { return Array{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + mode, ceil_mode, layout.name())}; } } - // MaxPool2DGrad Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, - Array strides, Array padding, std::string layout, bool ceil_mode) { + Array strides, Array padding, std::string layout, + bool ceil_mode) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -825,7 +766,6 @@ Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad); - RELAY_REGISTER_OP("nn.max_pool2d_grad") .describe(R"code(Gradient of max pooling operation for two dimensional data. @@ -849,18 +789,17 @@ RELAY_REGISTER_OP("nn.max_pool2d_grad") (batch_size, channels, height, width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool2DGrad", Pool2DGradRel) -.set_attr("FTVMCompute", Pool2DGradCompute); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool2DGrad", Pool2DGradRel) + .set_attr("FTVMCompute", Pool2DGradCompute); // AvgPool2DGrad Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, - Array strides, Array padding, std::string layout, bool ceil_mode, - bool count_include_pad) { + Array strides, Array padding, std::string layout, + bool ceil_mode, bool count_include_pad) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -874,7 +813,6 @@ Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad); - RELAY_REGISTER_OP("nn.avg_pool2d_grad") .describe(R"code(Gradient of average pooling operation for two dimensional data. @@ -898,22 +836,19 @@ RELAY_REGISTER_OP("nn.avg_pool2d_grad") (batch_size, channels, height, width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool2DGrad", Pool2DGradRel) -.set_attr("FTVMCompute", Pool2DGradCompute); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool2DGrad", Pool2DGradRel) + .set_attr("FTVMCompute", Pool2DGradCompute); // relay.nn.max_pool1d & relay.nn.avg_pool1d TVM_REGISTER_NODE_TYPE(MaxPool1DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool1DAttrs); template -bool Pool1DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Pool1DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -921,15 +856,13 @@ bool Pool1DRel(const Array& types, if (data == nullptr) return false; const auto dshape = data->shape; - CHECK_GE(dshape.size(), 1U) - << "Pool1D only support input >= 1-D: input must have width"; + CHECK_GE(dshape.size(), 1U) << "Pool1D only support input >= 1-D: input must have width"; const auto param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool1D layout must have W, which cannot be split"; + << "Invalid layout " << layout << ". Pool1D layout must have W, which cannot be split"; const auto widx = layout.IndexOf(LayoutAxis::Get('W')); @@ -949,8 +882,9 @@ bool Pool1DRel(const Array& types, oshape[widx] = dshape[widx]; } else { if (param->ceil_mode) { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] + - param->strides[0] - 1) / param->strides[0]) + 1; + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] + param->strides[0] - 1) / + param->strides[0]) + + 1; } else { oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0]) / param->strides[0]) + 1; } @@ -961,10 +895,8 @@ bool Pool1DRel(const Array& types, return true; } - -template -Array Pool1DCompute(const Attrs& attrs, - const Array& inputs, +template +Array Pool1DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCW("NCW"); const auto* param = attrs.as(); @@ -980,9 +912,7 @@ Array Pool1DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool1d does not support input split on width"; - CHECK(inputs[0].ndim() == 3U || - inputs[0].ndim() == 4U || - inputs[0].ndim() == 5U) + CHECK(inputs[0].ndim() == 3U || inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) << "Pool1D only support 3-D input (e.g., NCW)" << " or 4-D input (e.g. NCWc on for vector instructions)" << " or 5-D input (e.g. NCWnc for tensor accelerators)"; @@ -993,29 +923,23 @@ Array Pool1DCompute(const Attrs& attrs, if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{ - topi::nn::pool1d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode, + ceil_mode, layout.name(), count_include_pad)}; } else { return Array{ - topi::nn::pool1d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool1d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool1d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool1d"); + }); RELAY_REGISTER_OP("nn.max_pool1d") -.describe(R"code(Max pooling operation for one dimensional data. + .describe(R"code(Max pooling operation for one dimensional data. - **data**: This depends on the `layout` parameter. Input is 3D array of shape (batch_size, channels, width) if `layout` is `NCW`. @@ -1033,30 +957,25 @@ RELAY_REGISTER_OP("nn.max_pool1d") equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool1D", Pool1DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool1DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool1D", Pool1DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool1DCompute); // AvgPool1D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool1d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool1d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool1d"); + }); RELAY_REGISTER_OP("nn.avg_pool1d") -.describe(R"code( + .describe(R"code( Average pooling operation for one dimensional data. - **data**: This depends on the `layout` parameter. Input is 3D array of shape @@ -1075,23 +994,20 @@ Average pooling operation for one dimensional data. equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AvgPool1D", Pool1DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool1DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("AvgPool1D", Pool1DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool1DCompute); // relay.nn.max_pool3d & relay.nn.avg_pool3d TVM_REGISTER_NODE_TYPE(MaxPool3DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool3DAttrs); template -bool Pool3DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Pool3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -1108,8 +1024,8 @@ bool Pool3DRel(const Array& types, CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool3D layout must have D, H and W, which cannot be split"; + << "Invalid layout " << layout + << ". Pool3D layout must have D, H and W, which cannot be split"; const auto didx = layout.IndexOf(LayoutAxis::Get('D')); const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); @@ -1143,8 +1059,9 @@ bool Pool3DRel(const Array& types, oshape[ii] = dshape[ii]; } else { if (param->ceil_mode) { - oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i] + - param->strides[i] - 1) / param->strides[i]) + 1; + oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i] + param->strides[i] - 1) / + param->strides[i]) + + 1; } else { oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i]) / param->strides[i]) + 1; } @@ -1156,10 +1073,8 @@ bool Pool3DRel(const Array& types, return true; } - -template -Array Pool3DCompute(const Attrs& attrs, - const Array& inputs, +template +Array Pool3DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCDHW("NCDHW"); const auto* param = attrs.as(); @@ -1179,9 +1094,7 @@ Array Pool3DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool3d does not support input split on width"; - CHECK(inputs[0].ndim() == 4U || - inputs[0].ndim() == 5U || - inputs[0].ndim() == 6U) + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) << "Pool3D only support 5-D input (e.g., NCDHW)" << " or 6-D input (e.g. NCDHWc on for vector instructions)" << " or 7-D input (e.g. NCDHWnc for tensor accelerators)"; @@ -1197,29 +1110,23 @@ Array Pool3DCompute(const Attrs& attrs, } if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{ - topi::nn::pool3d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode, + ceil_mode, layout.name(), count_include_pad)}; } else { return Array{ - topi::nn::pool3d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool3d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool3d"); + }); RELAY_REGISTER_OP("nn.max_pool3d") -.describe(R"code(Max pooling operation for three dimensional data. + .describe(R"code(Max pooling operation for three dimensional data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. @@ -1240,30 +1147,25 @@ RELAY_REGISTER_OP("nn.max_pool3d") equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool3D", Pool3DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool3DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool3D", Pool3DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool3DCompute); // AvgPool3D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool3d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool3d"); + }); RELAY_REGISTER_OP("nn.avg_pool3d") -.describe(R"code( + .describe(R"code( Average pooling operation for three dimensional data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape @@ -1285,13 +1187,13 @@ Average pooling operation for three dimensional data. equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AvgPool3D", Pool3DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool3DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("AvgPool3D", Pool3DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool3DCompute); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index c761c3f..0aca00c 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -22,9 +22,10 @@ * \brief Property def of nn.sparse_dense operator. */ -#include -#include #include +#include +#include + #include #include "../../transforms/infer_layout_util.h" @@ -53,9 +54,8 @@ bool SparseDenseRel(const Array& types, int num_inputs, const Attrs& attrs if (weight_data->shape.size() == 3) { // BSR case. - Array oshape({ - data->shape[0], - (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); + Array oshape( + {data->shape[0], (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); reporter->Assign(types[4], TensorType(oshape, data->dtype)); return true; } @@ -71,32 +71,32 @@ Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weig } TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSparseDense, args, rv); -}); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeSparseDense, args, rv); + }); RELAY_REGISTER_OP("nn.sparse_dense") -.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse. + .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(4) -.add_argument("data", "nD Tensor", "Input data.") -.add_argument("weight_data", "1D Tensor", "Weight data matrix.") -.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") -.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") -.set_support_level(1) -.add_type_rel("SparseDense", SparseDenseRel); + .set_attrs_type() + .set_num_inputs(4) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("weight_data", "1D Tensor", "Weight data matrix.") + .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") + .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") + .set_support_level(1) + .add_type_rel("SparseDense", SparseDenseRel); // relay.nn.sparse_transpose TVM_REGISTER_NODE_TYPE(SparseTransposeAttrs); bool SparseTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* sparse_data = types[0].as(); CHECK_EQ(sparse_data->shape.size(), 1); @@ -119,24 +119,22 @@ Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indp return Call(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose") -.set_body_typed(MakeSparseTranspose); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose").set_body_typed(MakeSparseTranspose); RELAY_REGISTER_OP("nn.sparse_transpose") -.describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix + .describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix - **input**: `(N, N)` - **out**: `(N, N)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("sparse_data", "1D Tensor", "Sparse data matrix.") -.add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") -.add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.") -.set_support_level(1) -.add_type_rel("SparseTranspose", SparseTransposeRel); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("sparse_data", "1D Tensor", "Sparse data matrix.") + .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") + .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.") + .set_support_level(1) + .add_type_rel("SparseTranspose", SparseTransposeRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 63bd42d..7f5e683 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -21,11 +21,13 @@ * \file upsampling.cc * \brief upsampling operator */ -#include -#include #include +#include #include +#include + #include + #include "../op_common.h" namespace tvm { @@ -35,13 +37,12 @@ TVM_REGISTER_NODE_TYPE(UpSamplingAttrs); TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs); template -Array > UpsamplingInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +Array > UpsamplingInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. - T *params = const_cast(attrs.as()); + T* params = const_cast(attrs.as()); if (new_in_layouts.defined()) { CHECK_EQ(new_in_layouts.size(), 1); @@ -49,12 +50,12 @@ Array > UpsamplingInferCorrectLayout( Layout raw_layout(params->layout); Layout input = new_in_layouts[0]; if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && - input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && - !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))&& + input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && + !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) && (input.IndexOf(LayoutAxis::Get('D')) == -1 || - (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && - !input.Contains(LayoutAxis::Get('d'))))) { - params->layout = input.name(); // modify self to follow the input layout + (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && + !input.Contains(LayoutAxis::Get('d'))))) { + params->layout = input.name(); // modify self to follow the input layout } } @@ -62,9 +63,7 @@ Array > UpsamplingInferCorrectLayout( return Array >{{inferred_layout}, {inferred_layout}}; } -bool UpSamplingRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -78,29 +77,22 @@ bool UpSamplingRel(const Array& types, auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) - << "UpSampling only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "UpSampling only support input layouts that are convertible from NCHW." + << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); // assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } - // Positional relay function to create upsampling operator // used by frontend FFI. -Expr MakeUpSampling(Expr data, - double scale_h, - double scale_w, - std::string layout, - std::string method, - bool align_corners) { +Expr MakeUpSampling(Expr data, double scale_h, double scale_w, std::string layout, + std::string method, bool align_corners) { auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); @@ -111,12 +103,11 @@ Expr MakeUpSampling(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling") -.set_body_typed(MakeUpSampling); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling").set_body_typed(MakeUpSampling); RELAY_REGISTER_OP("nn.upsampling") -.describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. + .describe( + R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW @@ -130,20 +121,17 @@ RELAY_REGISTER_OP("nn.upsampling") (batch_size, in_height*scale, in_width*scale, channels) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("UpSampling", UpSamplingRel) -.set_attr("FInferCorrectLayout", - UpsamplingInferCorrectLayout) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("UpSampling", UpSamplingRel) + .set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) + .set_attr("TOpPattern", kInjective); // UpSampling3D -bool UpSampling3DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool UpSampling3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -157,8 +145,8 @@ bool UpSampling3DRel(const Array& types, auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); CHECK(layout_converter.defined()) - << "UpSampling3D only support input layouts that are convertible from NCDHW." - << " But got " << in_layout; + << "UpSampling3D only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); @@ -166,21 +154,14 @@ bool UpSampling3DRel(const Array& types, oshape.Set(4, tir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); // assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } // Positional relay function to create upsampling3d operator // used by frontend FFI. -Expr MakeUpSampling3D(Expr data, - double scale_d, - double scale_h, - double scale_w, - std::string layout, - std::string method, - std::string coordinate_transformation_mode) { +Expr MakeUpSampling3D(Expr data, double scale_d, double scale_h, double scale_w, std::string layout, + std::string method, std::string coordinate_transformation_mode) { auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); @@ -192,12 +173,10 @@ Expr MakeUpSampling3D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d") -.set_body_typed(MakeUpSampling3D); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d").set_body_typed(MakeUpSampling3D); RELAY_REGISTER_OP("nn.upsampling3d") -.describe(R"code(Perform upsampling on input array with nearest neighbour or + .describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. - **data**: data is 5D array of shape @@ -212,14 +191,14 @@ bilinear interpolation. (batch_size, in_depth*scale, in_height*scale, in_width*scale, channels) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("UpSampling3D", UpSampling3DRel) -.set_attr("FInferCorrectLayout", - UpsamplingInferCorrectLayout) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("UpSampling3D", UpSampling3DRel) + .set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 2d89d77..b560aa3 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -28,11 +28,13 @@ #include #include #include -#include + #include #include -#include "type_relations.h" +#include + #include "../transforms/infer_layout_util.h" +#include "type_relations.h" namespace tvm { namespace relay { @@ -47,21 +49,18 @@ namespace relay { * \param OpName the name of registry. */ -#define RELAY_REGISTER_UNARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr data) { \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {data}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") \ - .add_type_rel("Identity", IdentityRel) \ - .set_attr("TOpPattern", kElemWise) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - ElemwiseArbitraryLayout) \ - +#define RELAY_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") \ + .add_type_rel("Identity", IdentityRel) \ + .set_attr("TOpPattern", kElemWise) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) /*! Quick helper macro * - Expose a positional make function to construct the node. @@ -73,42 +72,37 @@ namespace relay { * * \param OpName the name of registry. */ -#define RELAY_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("Broadcast", BroadcastRel) \ - .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - BinaryBroadcastLayout) +#define RELAY_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("Broadcast", BroadcastRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", BinaryBroadcastLayout) // Comparisons -#define RELAY_REGISTER_CMP_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("BroadcastComp", BroadcastCompRel) \ - .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - BinaryBroadcastLayout) - +#define RELAY_REGISTER_CMP_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("BroadcastComp", BroadcastCompRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", BinaryBroadcastLayout) /*! \brief A helper class for matching and rewriting operators. */ -template +template class OpMatch { public: using MatchFunc = @@ -157,8 +151,7 @@ inline void GetPaddingWidth(const Array& padding, IndexExpr* pad_w) { } else if (padding.size() == 2) { *pad_w = padding[0] + padding[1]; } else { - CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " - << padding.size(); + CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " << padding.size(); } } @@ -175,8 +168,7 @@ inline void GetPaddingHeightWidth(const Array& padding, IndexExpr* pa *pad_h = padding[0] + padding[2]; *pad_w = padding[1] + padding[3]; } else { - CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " - << padding.size(); + CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " << padding.size(); } } @@ -196,8 +188,7 @@ inline void GetPaddingDepthHeightWidth(const Array& padding, IndexExp *pad_h = padding[1] + padding[4]; *pad_w = padding[2] + padding[5]; } else { - CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " - << padding.size(); + CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " << padding.size(); } } diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 0f47c9a..026dfc2 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -21,166 +21,145 @@ * \file binary.cc * \brief binary broadcast operators. */ +#include #include #include -#include -#include "../type_relations.h" + #include "../op_common.h" +#include "../type_relations.h" namespace tvm { namespace relay { -#define RELAY_BINARY_COMPUTE(FTOPI) \ - [] (const Attrs& attrs, \ - const Array& inputs, \ - const Type& out_type) -> Array { \ - CHECK_EQ(inputs.size(), 2U); \ - return {FTOPI(inputs[0], inputs[1])}; \ - } \ +#define RELAY_BINARY_COMPUTE(FTOPI) \ + [](const Attrs& attrs, const Array& inputs, \ + const Type& out_type) -> Array { \ + CHECK_EQ(inputs.size(), 2U); \ + return {FTOPI(inputs[0], inputs[1])}; \ + } // Addition RELAY_REGISTER_BINARY_OP("add") -.describe("Elementwise add with with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); + .describe("Elementwise add with with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); // Subtraction RELAY_REGISTER_BINARY_OP("subtract") -.describe("Elementwise substract with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); + .describe("Elementwise substract with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); // Right shift RELAY_REGISTER_BINARY_OP("right_shift") -.describe("Elementwise right shift with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); - + .describe("Elementwise right shift with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); RELAY_REGISTER_BINARY_OP("left_shift") -.describe("Elementwise left shift with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); - + .describe("Elementwise left shift with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); RELAY_REGISTER_BINARY_OP("maximum") -.describe("Elementwise maximum of two tensors with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); - + .describe("Elementwise maximum of two tensors with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); RELAY_REGISTER_BINARY_OP("minimum") -.describe("Elementwise minimum of two tensors with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); - + .describe("Elementwise minimum of two tensors with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); RELAY_REGISTER_BINARY_OP("divide") -.describe("Elementwise divide with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); - + .describe("Elementwise divide with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); RELAY_REGISTER_BINARY_OP("floor_divide") -.describe("Elementwise floor divide with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_divide)); - + .describe("Elementwise floor divide with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_divide)); RELAY_REGISTER_BINARY_OP("multiply") -.describe("Elementwise multiply with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); - + .describe("Elementwise multiply with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); RELAY_REGISTER_BINARY_OP("power") -.describe("Elementwise power with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); - + .describe("Elementwise power with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); RELAY_REGISTER_BINARY_OP("mod") -.describe("Elementwise mod with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); - + .describe("Elementwise mod with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); RELAY_REGISTER_BINARY_OP("floor_mod") - .describe("Elementwise floor mod with broadcasting") - .set_support_level(1) - .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_mod)); - + .describe("Elementwise floor mod with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_mod)); RELAY_REGISTER_BINARY_OP("logical_and") -.describe("Elementwise logical AND with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and)); - + .describe("Elementwise logical AND with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and)); RELAY_REGISTER_BINARY_OP("logical_or") -.describe("Elementwise logical OR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); - + .describe("Elementwise logical OR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); RELAY_REGISTER_BINARY_OP("logical_xor") -.describe("Elementwise logical XOR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor)); - + .describe("Elementwise logical XOR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor)); RELAY_REGISTER_BINARY_OP("bitwise_and") -.describe("Elementwise bitwise AND with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and)); - + .describe("Elementwise bitwise AND with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and)); RELAY_REGISTER_BINARY_OP("bitwise_or") -.describe("Elementwise bitwise OR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or)); - + .describe("Elementwise bitwise OR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or)); RELAY_REGISTER_BINARY_OP("bitwise_xor") -.describe("Elementwise bitwise XOR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor)); - + .describe("Elementwise bitwise XOR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor)); RELAY_REGISTER_CMP_OP("equal") -.describe("Elementwise equal compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal)); - + .describe("Elementwise equal compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal)); RELAY_REGISTER_CMP_OP("not_equal") -.describe("Elementwise not equal with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal)); - + .describe("Elementwise not equal with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal)); RELAY_REGISTER_CMP_OP("less") -.describe("Elementwise less than with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less)); - + .describe("Elementwise less than with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less)); RELAY_REGISTER_CMP_OP("less_equal") -.describe("Elementwise less than or equal compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal)); - + .describe("Elementwise less than or equal compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal)); RELAY_REGISTER_CMP_OP("greater") -.describe("Elementwise greater than compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater)); - + .describe("Elementwise greater than compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater)); RELAY_REGISTER_CMP_OP("greater_equal") -.describe("Elementwise greater than or equal compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal)); + .describe("Elementwise greater than or equal compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal)); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 3f220fb..d526cef 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -21,13 +21,15 @@ * \file reduce.cc * \brief Reduction operators. */ -#include -#include -#include #include #include -#include +#include +#include +#include + #include +#include + #include "../op_common.h" #include "../type_relations.h" @@ -37,14 +39,13 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ReduceAttrs); /*! -* \brief GetReduceAxes, get the new axis from indim and other arguments -* \param indim Number of dimensions of input data. -* \param axis The input axis vector. -* \param exclude Whether 'axis' input given is the excluded axis. -* \return r_axes The new reduced axes of the output. -*/ -inline std::vector GetReduceAxes(const uint32_t indim, - const Array& inaxis, + * \brief GetReduceAxes, get the new axis from indim and other arguments + * \param indim Number of dimensions of input data. + * \param axis The input axis vector. + * \param exclude Whether 'axis' input given is the excluded axis. + * \return r_axes The new reduced axes of the output. + */ +inline std::vector GetReduceAxes(const uint32_t indim, const Array& inaxis, bool exclude) { if (!inaxis.defined()) { std::vector r_axes(indim); @@ -60,16 +61,13 @@ inline std::vector GetReduceAxes(const uint32_t indim, } // Check out of bounds error - CHECK(axis >= 0) - << "Axis out of bounds in reduce operator."; - CHECK(axis < indim) - << "Axis out of bounds in reduce operator."; + CHECK(axis >= 0) << "Axis out of bounds in reduce operator."; + CHECK(axis < indim) << "Axis out of bounds in reduce operator."; in_axes.push_back(axis); } CHECK(in_axes[in_axes.size() - 1] < indim) - << "Reduction axis " << in_axes[in_axes.size() - 1] - << " exceeds input dimensions " << indim; + << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input dimensions " << indim; std::sort(in_axes.begin(), in_axes.end()); @@ -81,18 +79,16 @@ inline std::vector GetReduceAxes(const uint32_t indim, std::vector r_axes(r_size); for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) { if (j < in_axes.size() && in_axes[j] == i) { - ++j; - continue; + ++j; + continue; } r_axes[k++] = i; } return r_axes; } - // Get axis under exclude condition. -Array GetExcludeAxes(size_t indim, - const Array& inaxis) { +Array GetExcludeAxes(size_t indim, const Array& inaxis) { CHECK(inaxis.defined()) << "Cannot set exclude when axis=None"; std::vector axis_flag(indim, true); for (auto i : inaxis) { @@ -101,10 +97,8 @@ Array GetExcludeAxes(size_t indim, axis = axis + static_cast(indim); } // Check out of bounds error - CHECK_GE(axis, 0) - << "Axis out of bounds in reduce operator."; - CHECK_LT(axis, static_cast(indim)) - << "Axis out of bounds in reduce operator."; + CHECK_GE(axis, 0) << "Axis out of bounds in reduce operator."; + CHECK_LT(axis, static_cast(indim)) << "Axis out of bounds in reduce operator."; axis_flag[axis] = false; } @@ -177,34 +171,32 @@ Array> ReduceInferCorrectLayout(const Attrs& attrs, return Array>{{ret}, {ret}}; } -template -Array ReduceCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type, - F f) { +template +Array ReduceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type, F f) { const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); if (inputs[0]->shape.size() == 0) { - return { topi::identity(inputs[0]) }; + return {topi::identity(inputs[0])}; } auto axes = param->axis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); if (axes.size() == 0) { - return { topi::identity(inputs[0]) }; + return {topi::identity(inputs[0])}; } } - return { f(inputs[0], axes, param->keepdims, false) }; + return {f(inputs[0], axes, param->keepdims, false)}; } /*! -* \brief ReduceShapeImpl get the outshape for the reduction operator -* \param in_shape Shape of input data. -* \param param ReduceAttrs details. -* \param reporter The reporter to report solution to. -* \return oshape Output shape inferred. -*/ -inline std::vector ReduceShapeImpl(const std::vector &in_shape, + * \brief ReduceShapeImpl get the outshape for the reduction operator + * \param in_shape Shape of input data. + * \param param ReduceAttrs details. + * \param reporter The reporter to report solution to. + * \return oshape Output shape inferred. + */ +inline std::vector ReduceShapeImpl(const std::vector& in_shape, const ReduceAttrs* param, const TypeReporter& reporter) { uint32_t indim = in_shape.size(); @@ -225,9 +217,9 @@ inline std::vector ReduceShapeImpl(const std::vector &in_s } if (is_dynamic_input) { - CHECK(reporter->Assert(max_shape < tir::make_const( - DataType::Int(64), std::numeric_limits::max()))) - << "The maximum possible index of reduced shape cannot be more than int32 max."; + CHECK(reporter->Assert(max_shape < + tir::make_const(DataType::Int(64), std::numeric_limits::max()))) + << "The maximum possible index of reduced shape cannot be more than int32 max."; } if (param->keepdims) { @@ -255,16 +247,14 @@ inline std::vector ReduceShapeImpl(const std::vector &in_s } /*! -* \brief ArgReduceRel Output type and shape relation evaluation function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. -*/ -bool ArgReduceRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { + * \brief ArgReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; @@ -281,15 +271,13 @@ bool ArgReduceRel(const Array& types, } /*! -* \brief ReduceRel Output type and shape relation evaluation function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. -*/ -bool ReduceRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief ReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ReduceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -305,70 +293,57 @@ bool ReduceRel(const Array& types, return true; } -#define RELAY_REGISTER_REDUCE_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([]( \ - Expr data, \ - Array axis, \ - bool keepdims, \ - bool exclude) { \ - auto attrs = make_object(); \ - attrs->axis = std::move(axis); \ - attrs->keepdims = keepdims; \ - attrs->exclude = exclude; \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {data}, Attrs(attrs), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") - - -Array ArgMaxCompute(const Attrs& attrs, - const Array& inputs, +#define RELAY_REGISTER_REDUCE_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ + .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude) { \ + auto attrs = make_object(); \ + attrs->axis = std::move(axis); \ + attrs->keepdims = keepdims; \ + attrs->exclude = exclude; \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {data}, Attrs(attrs), {}); \ + }); \ + RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") + +Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::argmax); } - RELAY_REGISTER_REDUCE_OP("argmax") -.describe(R"code(Creates an operation that finds the indices of the maximum + .describe(R"code(Creates an operation that finds the indices of the maximum values over a given axis. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("ArgReduce", ArgReduceRel) -.set_attr("FTVMCompute", ArgMaxCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("ArgReduce", ArgReduceRel) + .set_attr("FTVMCompute", ArgMaxCompute) + .set_attr("TOpPattern", kCommReduce); -Array ArgMinCompute(const Attrs& attrs, - const Array& inputs, +Array ArgMinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::argmin); } RELAY_REGISTER_REDUCE_OP("argmin") -.describe(R"code(Creates an operation that finds the indices of the minimum + .describe(R"code(Creates an operation that finds the indices of the minimum values over a given axis. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("ArgReduce", ArgReduceRel) -.set_attr("FTVMCompute", ArgMinCompute) -.set_attr("TOpPattern", kCommReduce); - -Array SumCompute(const Attrs& attrs, - const Array& inputs, + .set_attrs_type() + .set_support_level(4) + .add_type_rel("ArgReduce", ArgReduceRel) + .set_attr("FTVMCompute", ArgMinCompute) + .set_attr("TOpPattern", kCommReduce); + +Array SumCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::sum); } - RELAY_REGISTER_REDUCE_OP("sum") -.describe(R"code(Computes the sum of array elements over given axes. + .describe(R"code(Computes the sum of array elements over given axes. Example:: @@ -385,23 +360,20 @@ Example:: [ 12. 19. 27.] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) -.set_attr("FTVMCompute", SumCompute) -.set_attr("TOpPattern", kCommReduce); - - -Array AllCompute(const Attrs& attrs, - const Array& inputs, + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) + .set_attr("FTVMCompute", SumCompute) + .set_attr("TOpPattern", kCommReduce); + +Array AllCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::all); } - RELAY_REGISTER_REDUCE_OP("all") -.describe(R"code(Computes the logical AND of boolean array elements over given axes. + .describe(R"code(Computes the logical AND of boolean array elements over given axes. Example:: @@ -422,22 +394,19 @@ Example:: [False, True, False]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", AllCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", AllCompute) + .set_attr("TOpPattern", kCommReduce); -Array AnyCompute(const Attrs& attrs, - const Array& inputs, +Array AnyCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::any); } - RELAY_REGISTER_REDUCE_OP("any") -.describe(R"code(Computes the logical OR of boolean array elements over given axes. + .describe(R"code(Computes the logical OR of boolean array elements over given axes. Example:: @@ -458,56 +427,49 @@ Example:: [False, True, True]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", AnyCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", AnyCompute) + .set_attr("TOpPattern", kCommReduce); -Array MaxCompute(const Attrs& attrs, - const Array& inputs, +Array MaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::max); } RELAY_REGISTER_REDUCE_OP("max") -.describe(R"code(Computes the max of array elements over given axes. + .describe(R"code(Computes the max of array elements over given axes. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", MaxCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", MaxCompute) + .set_attr("TOpPattern", kCommReduce); -Array MinCompute(const Attrs& attrs, - const Array& inputs, +Array MinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::min); } - RELAY_REGISTER_REDUCE_OP("min") -.describe(R"code(Computes the min of array elements over given axes. + .describe(R"code(Computes the min of array elements over given axes. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", MinCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", MinCompute) + .set_attr("TOpPattern", kCommReduce); -Array ProdCompute(const Attrs& attrs, - const Array& inputs, +Array ProdCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::prod); } RELAY_REGISTER_REDUCE_OP("prod") -.describe(R"code(Computes the products of array elements over given axes. + .describe(R"code(Computes the products of array elements over given axes. Example:: @@ -522,32 +484,27 @@ Example:: [ 36 480 2058] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", ProdCompute) -.set_attr("TOpPattern", kCommReduce); + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", ProdCompute) + .set_attr("TOpPattern", kCommReduce); - -Array MeanCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type) { +Array MeanCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { IndexExpr count = tir::make_const(inputs[0]->dtype, 1); const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); auto axes = param->axis; - for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), - param->axis, - param->exclude)) { + for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) { count *= inputs[0]->shape[i]; } auto res = ReduceCompute(attrs, inputs, out_type, topi::sum); return {topi::divide(res[0], count)}; } - RELAY_REGISTER_REDUCE_OP("mean") -.describe(R"code(Computes the mean of array elements over given axes. + .describe(R"code(Computes the mean of array elements over given axes. Example:: @@ -562,16 +519,13 @@ Example:: [ 2. 3.16666667 4.5] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", MeanCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", MeanCompute) + .set_attr("TOpPattern", kCommReduce); -bool VarianceRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool VarianceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -593,8 +547,7 @@ bool VarianceRel(const Array& types, return true; } -Array VarianceCompute(const Attrs& attrs, - const Array& inputs, +Array VarianceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { IndexExpr count = tir::make_const(inputs[0]->dtype, 1); const ReduceAttrs* param = attrs.as(); @@ -602,9 +555,7 @@ Array VarianceCompute(const Attrs& attrs, auto axes = param->axis; auto data = inputs[0]; auto mean = inputs[1]; - for (int64_t i : GetReduceAxes(data->shape.size(), - param->axis, - param->exclude)) { + for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, param->exclude)) { count *= data->shape[i]; } std::vector expand_shape; @@ -614,11 +565,7 @@ Array VarianceCompute(const Attrs& attrs, return {var}; } -Expr MakeVariance(Expr data, - Expr mean, - Array axis, - bool keepdims, - bool exclude) { +Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; @@ -627,23 +574,22 @@ Expr MakeVariance(Expr data, return Call(op, {data, mean}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make._variance") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call(MakeVariance, args, rv); }); RELAY_REGISTER_OP("variance") -.describe(R"code(Computes the variance of array elements over given axes. + .describe(R"code(Computes the variance of array elements over given axes. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("mean", "Tensor", "The mean tensor.") -.add_type_rel("Variance", VarianceRel) -.set_attr("FTVMCompute", VarianceCompute) -.set_attr("TOpPattern", kCommReduce); + .set_attrs_type() + .set_support_level(4) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("mean", "Tensor", "The mean tensor.") + .add_type_rel("Variance", VarianceRel) + .set_attr("FTVMCompute", VarianceCompute) + .set_attr("TOpPattern", kCommReduce); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 6d78ba8..15761f6 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -21,24 +21,27 @@ * \file transform.cc * \brief Transform operators. */ -#include +#include "transform.h" + +#include +#include +#include +#include +#include #include #include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include +#include +#include +#include + #include -#include "../op_common.h" + #include "../../../arith/compute_expr.h" #include "../../transforms/infer_layout_util.h" #include "../../transforms/pattern_util.h" -#include "transform.h" +#include "../op_common.h" namespace tvm { namespace relay { @@ -47,115 +50,95 @@ using tir::IntImmNode; // relay.cast TVM_REGISTER_NODE_TYPE(CastAttrs); -bool CastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "cast: expect input type to be TensorType but get " - << types[0]; + << "cast: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); - reporter->Assign(types[1], TensorType( - data->shape, param->dtype)); + reporter->Assign(types[1], TensorType(data->shape, param->dtype)); return true; } -Array CastCompute(const Attrs& attrs, - const Array& inputs, +Array CastCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const CastAttrs *param = attrs.as(); + const CastAttrs* param = attrs.as(); CHECK(param != nullptr); DataType dtype = param->dtype; - return { topi::cast(inputs[0], dtype) }; + return {topi::cast(inputs[0], dtype)}; } -Expr MakeCast(Expr data, - DataType dtype) { +Expr MakeCast(Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("cast"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.ir.cast") -.set_body_typed(MakeCast); +TVM_REGISTER_GLOBAL("relay.ir.cast").set_body_typed(MakeCast); RELAY_REGISTER_OP("cast") -.describe(R"code(Cast the data into a new data type. + .describe(R"code(Cast the data into a new data type. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Cast", CastRel) -.set_attr("FTVMCompute", CastCompute) -.set_attr("TOpPattern", kElemWise) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Cast", CastRel) + .set_attr("FTVMCompute", CastCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // relay.cast_like -bool CastLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CastLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "cast: expect input type to be TensorType but get " - << types[0]; + << "cast: expect input type to be TensorType but get " << types[0]; return false; } const auto* dtype_like = types[1].as(); if (dtype_like == nullptr) { CHECK(types[1].as()) - << "cast: expect input type to be TensorType but get " - << types[1]; + << "cast: expect input type to be TensorType but get " << types[1]; return false; } reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype)); return true; } - -Array CastLikeCompute(const Attrs& attrs, - const Array& inputs, +Array CastLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::cast(inputs[0], inputs[1]->dtype) }; + return {topi::cast(inputs[0], inputs[1]->dtype)}; } - -Expr MakeCastLike(Expr data, - Expr dtype_like) { +Expr MakeCastLike(Expr data, Expr dtype_like) { static const Op& op = Op::Get("cast_like"); return Call(op, {data, dtype_like}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.ir.cast_like") -.set_body_typed(MakeCastLike); +TVM_REGISTER_GLOBAL("relay.ir.cast_like").set_body_typed(MakeCastLike); RELAY_REGISTER_OP("cast_like") -.describe(R"code(Cast the data into the type of another tensor. + .describe(R"code(Cast the data into the type of another tensor. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("dtype_like", "Tensor", "The tensor to cast to.") -.set_support_level(3) -.add_type_rel("CastLike", CastLikeRel) -.set_attr("FTVMCompute", CastLikeCompute) -.set_attr("TOpPattern", kElemWise) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); - - -Array ReinterpretCompute(const Attrs& attrs, - const Array& inputs, + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("dtype_like", "Tensor", "The tensor to cast to.") + .set_support_level(3) + .add_type_rel("CastLike", CastLikeRel) + .set_attr("FTVMCompute", CastLikeCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + +Array ReinterpretCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const CastAttrs* param = attrs.as(); CHECK(param != nullptr); @@ -175,44 +158,39 @@ TVM_REGISTER_GLOBAL("relay._make.reinterpret").set_body([](const TVMArgs& args, }); RELAY_REGISTER_OP("reinterpret") -.describe(R"code(Reinterpret the data into a new data type. + .describe(R"code(Reinterpret the data into a new data type. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Reinterpret", CastRel) -.set_attr("FTVMCompute", ReinterpretCompute) -.set_attr("TOpPattern", kElemWise) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Reinterpret", CastRel) + .set_attr("FTVMCompute", ReinterpretCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // relay.expand_dims TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); -bool ExpandDimsRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ExpandDimsRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "expand_dims: expect input type to be TensorType but get " - << types[0]; + << "expand_dims: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int axis = param->axis; const int num_newaxis = param->num_newaxis; - CHECK(num_newaxis >= 0) - << "expand_dims only accepts `num_newaxis >= 0`" - << ", but got num_newaxis = " << num_newaxis; + CHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" + << ", but got num_newaxis = " << num_newaxis; CHECK(-ndim - 1 <= axis && axis <= ndim) - << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; const int pivot = axis < 0 ? ndim + axis + 1 : axis; std::vector oshape; oshape.reserve(ndim + num_newaxis); @@ -229,17 +207,14 @@ bool ExpandDimsRel(const Array& types, return true; } -Array ExpandDimsCompute(const Attrs& attrs, - const Array& inputs, +Array ExpandDimsCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const ExpandDimsAttrs *param = attrs.as(); + const ExpandDimsAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::expand_dims(inputs[0], param->axis, param->num_newaxis) }; + return {topi::expand_dims(inputs[0], param->axis, param->num_newaxis)}; } -Expr MakeExpandDims(Expr data, - int axis, - int num_newaxis) { +Expr MakeExpandDims(Expr data, int axis, int num_newaxis) { auto attrs = make_object(); attrs->axis = axis; attrs->num_newaxis = num_newaxis; @@ -247,75 +222,68 @@ Expr MakeExpandDims(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.expand_dims") -.set_body_typed(MakeExpandDims); +TVM_REGISTER_GLOBAL("relay.op._make.expand_dims").set_body_typed(MakeExpandDims); RELAY_REGISTER_OP("expand_dims") -.describe(R"code(Insert `num_newaxis` axises at the position given by `axis` + .describe(R"code(Insert `num_newaxis` axises at the position given by `axis` - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("ExpandDims", ExpandDimsRel) -.set_attr("FTVMCompute", ExpandDimsCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("ExpandDims", ExpandDimsRel) + .set_attr("FTVMCompute", ExpandDimsCompute) + .set_attr("TOpPattern", kBroadcast); // relay.concatenate TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); -Array ConcatenateCompute(const Attrs& attrs, - const Array& inputs, +Array ConcatenateCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const ConcatenateAttrs *param = attrs.as(); + const ConcatenateAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::concatenate(inputs, param->axis) }; + return {topi::concatenate(inputs, param->axis)}; } -Expr MakeConcatenate(Expr data, - int axis) { +Expr MakeConcatenate(Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("concatenate"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.concatenate") -.set_body_typed(MakeConcatenate); +TVM_REGISTER_GLOBAL("relay.op._make.concatenate").set_body_typed(MakeConcatenate); RELAY_REGISTER_OP("concatenate") -.describe(R"code(Concatenate the input tensors along the given axis. + .describe(R"code(Concatenate the input tensors along the given axis. - **data** : A list of tensors. - **axis** : The axis along which the tensors are concatenated. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input list of tensors.") -.set_support_level(1) -.add_type_rel("Concatenate", ConcatenateRel) -.set_attr("FInferCorrectLayout", ConcatenateLayout) -.set_attr("FTVMCompute", ConcatenateCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input list of tensors.") + .set_support_level(1) + .add_type_rel("Concatenate", ConcatenateRel) + .set_attr("FInferCorrectLayout", ConcatenateLayout) + .set_attr("FTVMCompute", ConcatenateCompute) + .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(StackAttrs); -bool StackRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool StackRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { CHECK(types[0].as()) - << "cast: expect input type to be TupleType but get " - << types[0]; + << "cast: expect input type to be TupleType but get " << types[0]; return false; } const auto* param = attrs.as(); @@ -324,11 +292,9 @@ bool StackRel(const Array& types, // Sanity check: axis int axis = param->axis; - CHECK(-ndim <= axis && axis < ndim) - << "stack only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; - axis = axis < 0 ? ndim + axis + 1: axis; + CHECK(-ndim <= axis && axis < ndim) << "stack only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; + axis = axis < 0 ? ndim + axis + 1 : axis; // Sanity check: ndim and dtype. const DataType dtype = first->dtype; @@ -341,8 +307,9 @@ bool StackRel(const Array& types, for (size_t j = 0; j < first->shape.size(); ++j) { if (j == static_cast(axis)) continue; if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue; - throw Error("relay.stack requires all tensors have the same shape " - "on non-stacking axes"); + throw Error( + "relay.stack requires all tensors have the same shape " + "on non-stacking axes"); } } @@ -361,55 +328,49 @@ bool StackRel(const Array& types, return true; } -Array StackCompute(const Attrs& attrs, - const Array& inputs, +Array StackCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const StackAttrs *param = attrs.as(); + const StackAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::stack(inputs, param->axis) }; + return {topi::stack(inputs, param->axis)}; } -Expr MakeStack(Expr data, - int axis) { +Expr MakeStack(Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("stack"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.stack") -.set_body_typed(MakeStack); +TVM_REGISTER_GLOBAL("relay.op._make.stack").set_body_typed(MakeStack); RELAY_REGISTER_OP("stack") -.describe(R"code(Stack the input tensors along the given axis. + .describe(R"code(Stack the input tensors along the given axis. - **data** : A list of tensors. - **axis** : The axis along which the tensors are stacked. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input list of tensors.") -.set_support_level(3) -.add_type_rel("Stack", StackRel) -.set_attr("FTVMCompute", StackCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input list of tensors.") + .set_support_level(3) + .add_type_rel("Stack", StackRel) + .set_attr("FTVMCompute", StackCompute) + .set_attr("TOpPattern", kInjective); /* relay.transpose */ TVM_REGISTER_NODE_TYPE(TransposeAttrs); -bool TransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "transpose: expect input type to be TensorType but get " - << types[0]; + << "transpose: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); @@ -417,8 +378,8 @@ bool TransposeRel(const Array& types, const Array& axes = param->axes; // check dimension match CHECK(!axes.defined() || static_cast(axes.size()) == ndim) - << "Dimension mismatch: axes has " << axes.size() << " elements" - << ", but data.ndim = " << ndim; + << "Dimension mismatch: axes has " << axes.size() << " elements" + << ", but data.ndim = " << ndim; // construct int_axes std::vector int_axes; int_axes.reserve(ndim); @@ -433,9 +394,8 @@ bool TransposeRel(const Array& types, int64_t axis = e; // sanity check for axis and ndim CHECK(-ndim <= axis && axis < ndim) - << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; axis = axis < 0 ? axis + ndim : axis; // sanity check for duplication CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; @@ -452,55 +412,49 @@ bool TransposeRel(const Array& types, return true; } -Array TransposeCompute(const Attrs& attrs, - const Array& inputs, +Array TransposeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array{ topi::transpose(inputs[0], param->axes) }; + return Array{topi::transpose(inputs[0], param->axes)}; } -Expr MakeTranspose(Expr data, - Array axes) { +Expr MakeTranspose(Expr data, Array axes) { auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("transpose"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.transpose") -.set_body_typed(MakeTranspose); +TVM_REGISTER_GLOBAL("relay.op._make.transpose").set_body_typed(MakeTranspose); RELAY_REGISTER_OP("transpose") -.describe(R"code(Permutes the dimensions of an array. + .describe(R"code(Permutes the dimensions of an array. - **data**: The input data to the operator. - **axes**: The target axes order, reverse order if not specified. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Transpose", TransposeRel) -.set_attr("FTVMCompute", TransposeCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Transpose", TransposeRel) + .set_attr("FTVMCompute", TransposeCompute) + .set_attr("TOpPattern", kInjective); /* relay.reshape */ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); -bool ReshapeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "reshape: expect input type to be TensorType but get " - << types[0]; + << "reshape: expect input type to be TensorType but get " << types[0]; return false; } @@ -534,8 +488,7 @@ bool ReshapeRel(const Array& types, oshape.push_back(data_shape[src_idx++]); } else if (svalue == -1) { // inference based on rest - CHECK_LT(infer_idx, 0) - << "One and only one dim can be inferred"; + CHECK_LT(infer_idx, 0) << "One and only one dim can be inferred"; infer_idx = i; oshape.push_back(1); ++src_idx; @@ -569,8 +522,7 @@ bool ReshapeRel(const Array& types, Integer d1 = newshape[++i]; Integer d2 = newshape[++i]; if (d1->value == -1) { - CHECK(d2->value != -1) - << "Split dims cannot both be -1."; + CHECK(d2->value != -1) << "Split dims cannot both be -1."; used_output_dims.insert(oshape.size()); if (d0.as()) { oshape.push_back(Any::make()); @@ -626,16 +578,15 @@ bool ReshapeRel(const Array& types, } if (param->reverse) { - reporter->Assign(types[1], TensorType( - Array(oshape.rbegin(), oshape.rend()), data->dtype)); + reporter->Assign(types[1], + TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); } else { reporter->Assign(types[1], TensorType(oshape, data->dtype)); } return true; } -Array ReshapeCompute(const Attrs& attrs, - const Array& inputs, +Array ReshapeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); @@ -647,11 +598,10 @@ Array ReshapeCompute(const Attrs& attrs, newshape.push_back(val); } } - return { topi::reshape(inputs[0], newshape) }; + return {topi::reshape(inputs[0], newshape)}; } -Expr MakeReshape(Expr data, - Array newshape) { +Expr MakeReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = false; @@ -659,11 +609,10 @@ Expr MakeReshape(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.reshape") -.set_body_typed(MakeReshape); +TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape); RELAY_REGISTER_OP("reshape") -.describe(R"code(Reshapes the input array. + .describe(R"code(Reshapes the input array. Example:: @@ -713,26 +662,23 @@ Example:: - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Reshape", ReshapeRel) -.set_attr("FTVMCompute", ReshapeCompute) -.set_attr("TOpPattern", kInjective); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Reshape", ReshapeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); /*! -* \brief ReshapeLikeRel User defined type constraint function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return False if the relation has not been resolved, it might be resolved later. -* True if this relation has been resolved. -*/ -bool ReshapeLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief ReshapeLikeRel User defined type constraint function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return False if the relation has not been resolved, it might be resolved later. + * True if this relation has been resolved. + */ +bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -753,43 +699,36 @@ bool ReshapeLikeRel(const Array& types, } if (is_static_shape) { CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) - << "Reshape inputs size should be compatible."; + << "Reshape inputs size should be compatible."; } reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype)); return true; } - -Expr MakeReshapeLike(Expr data, - Expr shape_like) { +Expr MakeReshapeLike(Expr data, Expr shape_like) { static const Op& op = Op::Get("reshape_like"); return Call(op, {data, shape_like}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op._make.reshape_like") -.set_body_typed(MakeReshapeLike); - +TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike); RELAY_REGISTER_OP("reshape_like") -.describe(R"code(Reshapes the input array by the size of another array. + .describe(R"code(Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes the input array into an output array with the same shape as the second input array. .. note:: Sizes for both array should be compatible. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("shape_like", "Tensor", "Shape tensor.") -.set_support_level(3) -.add_type_rel("ReshapeLike", ReshapeLikeRel) -.set_attr("FTVMCompute", ReshapeCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape_like", "Tensor", "Shape tensor.") + .set_support_level(3) + .add_type_rel("ReshapeLike", ReshapeLikeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); // ArgWhere -bool ArgWhereRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ArgWhereRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); @@ -803,28 +742,25 @@ bool ArgWhereRel(const Array& types, return true; } -TVM_REGISTER_GLOBAL("relay.op._make.argwhere") -.set_body_typed([](Expr data) { +TVM_REGISTER_GLOBAL("relay.op._make.argwhere").set_body_typed([](Expr data) { static const Op& op = Op::Get("argwhere"); return Call(op, {data}, Attrs(), {}); }); RELAY_REGISTER_OP("argwhere") -.describe(R"doc(Find the indices of elements of a tensor that are + .describe(R"doc(Find the indices of elements of a tensor that are non-zero)doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("condition", "Tensor", "The input condition tensor.") -.add_type_rel("ArgWhere", ArgWhereRel) -.set_attr("TOpIsStateful", false) -.set_attr("TOpPattern", kOpaque) -.set_support_level(10); + .set_num_inputs(1) + .add_argument("condition", "Tensor", "The input condition tensor.") + .add_type_rel("ArgWhere", ArgWhereRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kOpaque) + .set_support_level(10); // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); -bool TakeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TakeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); @@ -851,9 +787,8 @@ bool TakeRel(const Array& types, const auto ndim_indices = static_cast(indices->shape.size()); int axis = static_cast(param->axis->value); if (axis < 0) axis += ndim_data; - CHECK_LE(axis, ndim_data) - << "axis should be with in data shape" - << ", but got = " << axis; + CHECK_LE(axis, ndim_data) << "axis should be with in data shape" + << ", but got = " << axis; oshape.reserve(ndim_data - 1 + ndim_indices); for (int i = 0; i < axis; ++i) { @@ -862,7 +797,7 @@ bool TakeRel(const Array& types, for (int i = 0; i < ndim_indices; ++i) { oshape.emplace_back(indices->shape[i]); } - for (int i = axis+1; i < ndim_data; ++i) { + for (int i = axis + 1; i < ndim_data; ++i) { oshape.emplace_back(data->shape[i]); } @@ -870,22 +805,18 @@ bool TakeRel(const Array& types, return true; } -Array TakeCompute(const Attrs& attrs, - const Array& inputs, +Array TakeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); if (!param->axis.defined()) { - return Array{ topi::take(inputs[0], inputs[1], param->mode) }; + return Array{topi::take(inputs[0], inputs[1], param->mode)}; } else { - return Array{ topi::take(inputs[0], inputs[1], param->axis, param->mode) }; + return Array{topi::take(inputs[0], inputs[1], param->axis, param->mode)}; } } -Expr MakeTake(Expr data, - Expr indices, - Integer axis, - std::string mode) { +Expr MakeTake(Expr data, Expr indices, Integer axis, std::string mode) { auto attrs = make_object(); attrs->axis = std::move(axis); attrs->mode = std::move(mode); @@ -893,11 +824,10 @@ Expr MakeTake(Expr data, return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.take") -.set_body_typed(MakeTake); +TVM_REGISTER_GLOBAL("relay.op._make.take").set_body_typed(MakeTake); RELAY_REGISTER_OP("take") -.describe(R"code(Take elements from an array along an axis. + .describe(R"code(Take elements from an array along an axis. When axis is not None, this function does the same thing as 'fancy' indexing (indexing arrays using arrays); however, it can be easier to use if you need @@ -919,22 +849,19 @@ Examples:: [ 4., 3.]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("indices", "Tensor", "The indices tensor.") -.set_support_level(3) -.add_type_rel("Take", TakeRel) -.set_attr("FTVMCompute", TakeCompute) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .set_support_level(3) + .add_type_rel("Take", TakeRel) + .set_attr("FTVMCompute", TakeCompute) + .set_attr("TOpPattern", kInjective); // Init ops TVM_REGISTER_NODE_TYPE(InitOpAttrs); -bool FullRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const InitOpAttrs* param = attrs.as(); @@ -949,23 +876,19 @@ bool FullRel(const Array& types, } CHECK_EQ(fill_value->shape.size(), 0) - << "Fill value should be a scalar but has dimension " - << fill_value->shape.size() << "."; + << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << "."; reporter->Assign(types[1], TensorType(param->shape, out_dtype)); return true; } -Array FullCompute(const Attrs& attrs, - const Array& inputs, +Array FullCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); - return { topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]()) }; + return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())}; } -Expr MakeFull(Expr fill_value, - Array shape, - DataType dtype) { +Expr MakeFull(Expr fill_value, Array shape, DataType dtype) { auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); @@ -973,24 +896,21 @@ Expr MakeFull(Expr fill_value, return Call(op, {fill_value}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.full") -.set_body_typed(MakeFull); +TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull); RELAY_REGISTER_OP("full") -.describe(R"code(Fill array with scalar value. + .describe(R"code(Fill array with scalar value. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("fill_value", "double", "The value to fill.") -.set_support_level(3) -.add_type_rel("Full", FullRel) -.set_attr("FTVMCompute", FullCompute) -.set_attr("TOpPattern", kElemWise); - -bool InitOpRel(const Array& types, - int num_inputs, - const Attrs& attrs, + .set_attrs_type() + .set_num_inputs(1) + .add_argument("fill_value", "double", "The value to fill.") + .set_support_level(3) + .add_type_rel("Full", FullRel) + .set_attr("FTVMCompute", FullCompute) + .set_attr("TOpPattern", kElemWise); + +bool InitOpRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 1); const InitOpAttrs* param = attrs.as(); @@ -999,8 +919,7 @@ bool InitOpRel(const Array& types, return true; } -Expr MakeZeros(Array shape, - DataType dtype) { +Expr MakeZeros(Array shape, DataType dtype) { auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); @@ -1008,20 +927,18 @@ Expr MakeZeros(Array shape, return Call(op, {}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.zeros") -.set_body_typed(MakeZeros); +TVM_REGISTER_GLOBAL("relay.op._make.zeros").set_body_typed(MakeZeros); RELAY_REGISTER_OP("zeros") -.describe(R"code(Fill array with zeros. + .describe(R"code(Fill array with zeros. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(0) -.set_support_level(3) -.add_type_rel("InitOp", InitOpRel); + .set_attrs_type() + .set_num_inputs(0) + .set_support_level(3) + .add_type_rel("InitOp", InitOpRel); -Expr MakeOnes(Array shape, - DataType dtype) { +Expr MakeOnes(Array shape, DataType dtype) { auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); @@ -1029,21 +946,18 @@ Expr MakeOnes(Array shape, return Call(op, {}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.ones") -.set_body_typed(MakeOnes); +TVM_REGISTER_GLOBAL("relay.op._make.ones").set_body_typed(MakeOnes); RELAY_REGISTER_OP("ones") -.describe(R"code(Fill array with ones. + .describe(R"code(Fill array with ones. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(0) -.set_support_level(3) -.add_type_rel("InitOp", InitOpRel); - -bool FullLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, + .set_attrs_type() + .set_num_inputs(0) + .set_support_level(3) + .add_type_rel("InitOp", InitOpRel); + +bool FullLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -1056,40 +970,37 @@ bool FullLikeRel(const Array& types, } CHECK_EQ(fill_value->shape.size(), 0) - << "The fill value should be a scalar but here it has dimension " - << fill_value->shape.size() << "."; + << "The fill value should be a scalar but here it has dimension " << fill_value->shape.size() + << "."; reporter->Assign(types[2], TensorType(data->shape, data->dtype)); return true; } -Array FullLikeCompute(const Attrs& attrs, - const Array& inputs, +Array FullLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::full_like(inputs[0], inputs[1]()) }; + return {topi::full_like(inputs[0], inputs[1]())}; } -Expr MakeFullLike(Expr data, - Expr fill_value) { +Expr MakeFullLike(Expr data, Expr fill_value) { static const Op& op = Op::Get("full_like"); return Call(op, {data, fill_value}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.full_like") -.set_body_typed(MakeFullLike); +TVM_REGISTER_GLOBAL("relay.op._make.full_like").set_body_typed(MakeFullLike); RELAY_REGISTER_OP("full_like") -.describe(R"code(Return an scalar value array with the same shape + .describe(R"code(Return an scalar value array with the same shape and type as the input array. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("fill_value", "double", "Scalar value to fill.") -.set_support_level(3) -.add_type_rel("FullLike", FullLikeRel) -.set_attr("FTVMCompute", FullLikeCompute) -.set_attr("TOpPattern", kElemWise); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("fill_value", "double", "Scalar value to fill.") + .set_support_level(3) + .add_type_rel("FullLike", FullLikeRel) + .set_attr("FTVMCompute", FullLikeCompute) + .set_attr("TOpPattern", kElemWise); // arange operator TVM_REGISTER_NODE_TYPE(ArangeAttrs); @@ -1132,9 +1043,7 @@ double ToScalar(const runtime::NDArray& array) { return -std::numeric_limits::infinity(); } -bool ArangeRel(const Array& types, - int num_inputs, - const Attrs& raw_attrs, +bool ArangeRel(const Array& types, int num_inputs, const Attrs& raw_attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const ArangeAttrs* attrs = raw_attrs.as(); @@ -1144,16 +1053,14 @@ bool ArangeRel(const Array& types, reporter->Assign(types[1], types[2]); reporter->Assign(types[2], TensorType({}, attrs->dtype)); - if ((cstart = attrs->start.as()) && - (cstop = attrs->stop.as()) && + if ((cstart = attrs->start.as()) && (cstop = attrs->stop.as()) && (cstep = attrs->step.as())) { double start = ToScalar(cstart->data); double stop = ToScalar(cstop->data); double step = ToScalar(cstep->data); int32_t num_elem = static_cast(std::ceil((stop - start) / step)); - CHECK_GT(num_elem, 0) - << "Invalid arange attributes (start, stop, step): " << attrs->start - << ", " << attrs->stop << ", " << attrs->step; + CHECK_GT(num_elem, 0) << "Invalid arange attributes (start, stop, step): " << attrs->start + << ", " << attrs->stop << ", " << attrs->step; reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype)); return true; } else { @@ -1162,32 +1069,28 @@ bool ArangeRel(const Array& types, } } -inline te::Tensor DynamicArange(const te::Tensor& start, - const te::Tensor& stop, - const te::Tensor& step, - tvm::DataType dtype, - std::string name = "tensor", - std::string tag = topi::kInjective) { +inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop, + const te::Tensor& step, tvm::DataType dtype, + std::string name = "tensor", std::string tag = topi::kInjective) { tvm::PrimExpr num_elem = tvm::tir::Var("num_elem"); - return te::compute({num_elem}, [&](const Array& indices) { - return tvm::cast(dtype, start[0] + step[0] * indices[0]); - }, name, tag); + return te::compute( + {num_elem}, + [&](const Array& indices) { + return tvm::cast(dtype, start[0] + step[0] * indices[0]); + }, + name, tag); } -Array ArangeCompute(const Attrs& attrs, - const Array& inputs, +Array ArangeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const ArangeAttrs* param = attrs.as(); te::Tensor start = inputs[0]; - te::Tensor stop = inputs[1]; + te::Tensor stop = inputs[1]; te::Tensor step = inputs[2]; - return { DynamicArange(start, stop, step, param->dtype) }; + return {DynamicArange(start, stop, step, param->dtype)}; } -Expr MakeArange(Expr start, - Expr stop, - Expr step, - DataType dtype) { +Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype) { auto attrs = make_object(); attrs->start = start; attrs->stop = stop; @@ -1197,8 +1100,7 @@ Expr MakeArange(Expr start, return Call(op, {start, stop, step}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.arange") -.set_body_typed(MakeArange); +TVM_REGISTER_GLOBAL("relay.op._make.arange").set_body_typed(MakeArange); // An issue with the existing design is that we require dependency // to type the operator precisely. @@ -1214,45 +1116,40 @@ TVM_REGISTER_GLOBAL("relay.op._make.arange") // In general I think we should avoid this pattern, and introduce // a secondary shape analysis to recover more precise information. RELAY_REGISTER_OP("arange") -.describe(R"code(Returns evenly spaced values within a given interval. + .describe(R"code(Returns evenly spaced values within a given interval. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.set_support_level(3) -.add_type_rel("Arange", ArangeRel) -.set_attr("FTVMCompute", ArangeCompute) -// TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape -.set_attr("TOpPattern", kOpaque) -.set_attr("AnyCodegenStrategy", kVariableDimensions); + .set_attrs_type() + .set_num_inputs(3) + .set_support_level(3) + .add_type_rel("Arange", ArangeRel) + .set_attr("FTVMCompute", ArangeCompute) + // TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape + .set_attr("TOpPattern", kOpaque) + .set_attr("AnyCodegenStrategy", kVariableDimensions); // repeat operator TVM_REGISTER_NODE_TYPE(RepeatAttrs); -bool RepeatRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool RepeatRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "repeat: expect input type to be TensorType but get " - << types[0]; + << "repeat: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int repeats = param->repeats; const int axis = param->axis; - CHECK(repeats >= 1) - << "repeat only accepts `repeats >= 1`" - << ", but got repeats = " << repeats; + CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; CHECK(-ndim - 1 <= axis && axis <= ndim) - << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; const int pivot = axis < 0 ? ndim + axis : axis; std::vector oshape; oshape.reserve(ndim + repeats); @@ -1267,17 +1164,14 @@ bool RepeatRel(const Array& types, return true; } -Array RepeatCompute(const Attrs& attrs, - const Array& inputs, +Array RepeatCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const RepeatAttrs *param = attrs.as(); + const RepeatAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::repeat(inputs[0], param->repeats, param->axis) }; + return {topi::repeat(inputs[0], param->repeats, param->axis)}; } -Expr MakeRepeat(Expr data, - int repeats, - int axis) { +Expr MakeRepeat(Expr data, int repeats, int axis) { auto attrs = make_object(); attrs->repeats = repeats; attrs->axis = axis; @@ -1285,50 +1179,45 @@ Expr MakeRepeat(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.repeat") -.set_body_typed(MakeRepeat); +TVM_REGISTER_GLOBAL("relay.op._make.repeat").set_body_typed(MakeRepeat); RELAY_REGISTER_OP("repeat") -.describe(R"code(Repeat elements of an array `repeats` times along axis `axis` + .describe(R"code(Repeat elements of an array `repeats` times along axis `axis` - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Repeat", RepeatRel) -.set_attr("FTVMCompute", RepeatCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Repeat", RepeatRel) + .set_attr("FTVMCompute", RepeatCompute) + .set_attr("TOpPattern", kBroadcast); // tile operator TVM_REGISTER_NODE_TYPE(TileAttrs); -bool TileRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TileRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "tile: expect input type to be TensorType but get " - << types[0]; + << "tile: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const size_t ndim = data->shape.size(); const Array& reps = param->reps; // check dimension match - CHECK(reps.defined()) - << "repetition array is not defined. data.ndim = " << ndim; + CHECK(reps.defined()) << "repetition array is not defined. data.ndim = " << ndim; const size_t rndim = reps.size(); for (size_t i = 0; i < rndim; ++i) { if (const tvm::tir::IntImmNode* val = reps[i].as()) { - CHECK_GT(val->value, 0) - << "Tile reps value should always be larger than 0, but get: " << val->value; + CHECK_GT(val->value, 0) << "Tile reps value should always be larger than 0, but get: " + << val->value; } } size_t tndim = (ndim > rndim) ? ndim : rndim; @@ -1377,103 +1266,91 @@ bool TileRel(const Array& types, return true; } -Array TileCompute(const Attrs& attrs, - const Array& inputs, +Array TileCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const TileAttrs *param = attrs.as(); + const TileAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::tile(inputs[0], param->reps) }; + return {topi::tile(inputs[0], param->reps)}; } -Expr MakeTile(Expr data, - Array reps) { +Expr MakeTile(Expr data, Array reps) { auto attrs = make_object(); attrs->reps = reps; static const Op& op = Op::Get("tile"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.tile") -.set_body_typed(MakeTile); +TVM_REGISTER_GLOBAL("relay.op._make.tile").set_body_typed(MakeTile); RELAY_REGISTER_OP("tile") -.describe(R"code(Repeat the whole array multiple times. + .describe(R"code(Repeat the whole array multiple times. - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Tile", TileRel) -.set_attr("FTVMCompute", TileCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Tile", TileRel) + .set_attr("FTVMCompute", TileCompute) + .set_attr("TOpPattern", kBroadcast); // reverse operator TVM_REGISTER_NODE_TYPE(ReverseAttrs); -bool ReverseRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { +bool ReverseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "reverse: expect input type to be TensorType but get " - << types[0]; + << "reverse: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int axis = param->axis; CHECK(-ndim <= axis && axis < ndim) - << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; reporter->Assign(types[1], types[0]); return true; } -Array ReverseCompute(const Attrs& attrs, - const Array& inputs, +Array ReverseCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const ReverseAttrs *param = attrs.as(); + const ReverseAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::flip(inputs[0], param->axis) }; + return {topi::flip(inputs[0], param->axis)}; } -Expr MakeReverse(Expr data, - int axis) { +Expr MakeReverse(Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("reverse"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.reverse") -.set_body_typed(MakeReverse); +TVM_REGISTER_GLOBAL("relay.op._make.reverse").set_body_typed(MakeReverse); RELAY_REGISTER_OP("reverse") -.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. + .describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Reverse", ReverseRel) -.set_attr("FTVMCompute", ReverseCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Reverse", ReverseRel) + .set_attr("FTVMCompute", ReverseCompute) + .set_attr("TOpPattern", kInjective); // where operator -bool WhereRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4U); const auto* condition = types[0].as(); @@ -1487,17 +1364,16 @@ bool WhereRel(const Array& types, CHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size"; if (cond_shape.size() != x_shape.size()) { - CHECK_EQ(cond_shape.size(), 1) - << "Shape of condition " << condition->shape - << " must be either equal to x or has dimension of 1."; + CHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape + << " must be either equal to x or has dimension of 1."; } for (size_t i = 0; i < x_shape.size(); i++) { CHECK(reporter->AssertEQ(x_shape[i], y_shape[i])) << "x and y must have the same shape: " << x_shape << " vs " << y_shape; if (i < cond_shape.size()) { - CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) - << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; + CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) + << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; } } reporter->Assign(types[3], TensorType(x_shape, x->dtype)); @@ -1510,17 +1386,15 @@ Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) { return Call(op, {condition, x, y}); } -Array WhereCompute(const Attrs& attrs, - const Array& inputs, +Array WhereCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::where(inputs[0], inputs[1], inputs[2]) }; + return {topi::where(inputs[0], inputs[1], inputs[2])}; } -TVM_REGISTER_GLOBAL("relay.op._make.where") -.set_body_typed(MakeWhere); +TVM_REGISTER_GLOBAL("relay.op._make.where").set_body_typed(MakeWhere); RELAY_REGISTER_OP("where") -.describe(R"code( + .describe(R"code( Return the elements, either from x or y, depending on the condition. Given three ndarrays, condition, x, and y, return an ndarray with the elements @@ -1548,34 +1422,28 @@ Examples:: where(cond, x, y) = [[1, 2], [7, 8]] )code" TVM_ADD_FILELINE) -.add_argument("condition", "Tensor", "Condition array") -.add_argument("x", "Tensor", "First array to be selected") -.add_argument("y", "Tensor", "Second array to be selected") -.set_num_inputs(3) -.set_support_level(4) -.add_type_rel("Where", WhereRel) -.set_attr("FTVMCompute", WhereCompute) -.set_attr("TOpPattern", kBroadcast); - + .add_argument("condition", "Tensor", "Condition array") + .add_argument("x", "Tensor", "First array to be selected") + .add_argument("y", "Tensor", "Second array to be selected") + .set_num_inputs(3) + .set_support_level(4) + .add_type_rel("Where", WhereRel) + .set_attr("FTVMCompute", WhereCompute) + .set_attr("TOpPattern", kBroadcast); // Squeeze TVM_REGISTER_NODE_TYPE(SqueezeAttrs); -Expr MakeSqueeze(Expr data, - Array axis) { +Expr MakeSqueeze(Expr data, Array axis) { auto attrs = make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("squeeze"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.squeeze") -.set_body_typed(MakeSqueeze); +TVM_REGISTER_GLOBAL("relay.op._make.squeeze").set_body_typed(MakeSqueeze); - -bool SqueezeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SqueezeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -1599,7 +1467,7 @@ bool SqueezeRel(const Array& types, } } else { // pair up original shape with a boolean which control whether it will be in the final shape. - std::vector > original_shape; + std::vector> original_shape; for (const auto& e : data->shape) { original_shape.push_back(std::pair(e, true)); } @@ -1626,78 +1494,70 @@ bool SqueezeRel(const Array& types, return true; } -Array SqueezeCompute(const Attrs& attrs, - const Array& inputs, +Array SqueezeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const SqueezeAttrs *param = attrs.as(); + const SqueezeAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::squeeze(inputs[0], param->axis) }; + return {topi::squeeze(inputs[0], param->axis)}; } - RELAY_REGISTER_OP("squeeze") -.describe(R"code(Squeeze the input tensor at the dimensions given by axes + .describe(R"code(Squeeze the input tensor at the dimensions given by axes - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Squeeze", SqueezeRel) -.set_attr("FTVMCompute", SqueezeCompute) -.set_attr("TOpPattern", kInjective); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Squeeze", SqueezeRel) + .set_attr("FTVMCompute", SqueezeCompute) + .set_attr("TOpPattern", kInjective); // CollapseSumLike: -> B where BroadCast(A, B) = A -bool CollapseSumLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CollapseSumLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); reporter->Assign(types[2], types[1]); return BroadcastRel({types[0], types[1], types[0]}, 2, Attrs(), reporter); } -Expr MakeCollapseSumLike(Expr data, - Expr collapse_type) { +Expr MakeCollapseSumLike(Expr data, Expr collapse_type) { static const Op& op = Op::Get("collapse_sum_like"); return Call(op, {data, collapse_type}, Attrs(), {}); } -Array CollapseSumLikeCompute(const Attrs& attrs, - const Array& inputs, +Array CollapseSumLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); - return { topi::collapse_sum(inputs[0], out_ttype->shape) }; + return {topi::collapse_sum(inputs[0], out_ttype->shape)}; } -TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like") -.set_body_typed(MakeCollapseSumLike); +TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like").set_body_typed(MakeCollapseSumLike); RELAY_REGISTER_OP("collapse_sum_like") -.describe(R"code(Collapse the first input to match the shape of the second input. + .describe(R"code(Collapse the first input to match the shape of the second input. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.") -.set_support_level(10) -.add_type_rel("CollapseSumLike", CollapseSumLikeRel) -.set_attr("FTVMCompute", CollapseSumLikeCompute) -.set_attr("TOpPattern", kCommReduce); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("collapse_type", "Tensor", "Provide the type to collapse to.") + .set_support_level(10) + .add_type_rel("CollapseSumLike", CollapseSumLikeRel) + .set_attr("FTVMCompute", CollapseSumLikeCompute) + .set_attr("TOpPattern", kCommReduce); // BroadCastTo: -> B where BroadCast(A, B) = B -bool BroadCastToRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadCastToRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); auto ioattrs = attrs.as(); CHECK(ioattrs); auto intt = types[0].as(); - if (intt == nullptr) { return false; } + if (intt == nullptr) { + return false; + } auto type = TensorType(ioattrs->shape, intt->dtype); reporter->Assign(types[1], type); return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); @@ -1710,87 +1570,75 @@ Expr MakeBroadCastTo(Expr data, Array shape) { return Call(op, {data}, Attrs(attrs), {}); } -Array BroadCastToCompute(const Attrs& attrs, - const Array& inputs, +Array BroadCastToCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { auto ioattrs = attrs.as(); CHECK(ioattrs != nullptr); - return { topi::broadcast_to(inputs[0], ioattrs->shape) }; + return {topi::broadcast_to(inputs[0], ioattrs->shape)}; } -TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to") -.set_body_typed(MakeBroadCastTo); +TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastTo); RELAY_REGISTER_OP("broadcast_to") -.describe(R"code(Broadcast the first input to match the shape argument. + .describe(R"code(Broadcast the first input to match the shape argument. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(4) -.add_type_rel("BroadCastTo", BroadCastToRel) -.set_attr("FTVMCompute", BroadCastToCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(4) + .add_type_rel("BroadCastTo", BroadCastToRel) + .set_attr("FTVMCompute", BroadCastToCompute) + .set_attr("TOpPattern", kBroadcast); // BroadCastToLike: -> B where BroadCast(A, B) = B -bool BroadCastToLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadCastToLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); reporter->Assign(types[2], types[1]); return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); } -Expr MakeBroadCastToLike(Expr data, - Expr broadcast_type) { +Expr MakeBroadCastToLike(Expr data, Expr broadcast_type) { static const Op& op = Op::Get("broadcast_to_like"); return Call(op, {data, broadcast_type}, Attrs(), {}); } -Array BroadCastToLikeCompute(const Attrs& attrs, - const Array& inputs, +Array BroadCastToLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); - return { topi::broadcast_to(inputs[0], out_ttype->shape) }; + return {topi::broadcast_to(inputs[0], out_ttype->shape)}; } -TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like") -.set_body_typed(MakeBroadCastToLike); +TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like").set_body_typed(MakeBroadCastToLike); RELAY_REGISTER_OP("broadcast_to_like") -.describe(R"code(Broadcast the first input to match the shape of the second input. + .describe(R"code(Broadcast the first input to match the shape of the second input. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") -.set_support_level(10) -.add_type_rel("BroadCastToLike", BroadCastToLikeRel) -.set_attr("FTVMCompute", BroadCastToLikeCompute) -.set_attr("TOpPattern", kBroadcast); - + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") + .set_support_level(10) + .add_type_rel("BroadCastToLike", BroadCastToLikeRel) + .set_attr("FTVMCompute", BroadCastToLikeCompute) + .set_attr("TOpPattern", kBroadcast); // Adapter function to make int array. Array GetIntArray(Array arr) { for (size_t i = 0; i < arr.size(); ++i) { - CHECK(!arr[i].defined() || arr[i].as()) - << "Expect an int array"; + CHECK(!arr[i].defined() || arr[i].as()) << "Expect an int array"; } - return Downcast >(arr); + return Downcast>(arr); } - // strided_slice TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); -bool StridedSliceRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; - const StridedSliceAttrs *param = attrs.as(); + const StridedSliceAttrs* param = attrs.as(); CHECK(param != nullptr); auto dshape = data->shape; @@ -1838,12 +1686,8 @@ bool StridedSliceRel(const Array& types, int64_t begin_v = begin_vec[i]; int64_t end_v = end_vec[i]; - if ((stride_v == 1 && - begin_v == 0 && - end_v == max_range) || - (stride_v == -1 && - begin_v == max_range && - end_v == 0)) { + if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || + (stride_v == -1 && begin_v == max_range && end_v == 0)) { // Quick path, do not slice this dimension. oshape[i] = dshape[i]; continue; @@ -1852,8 +1696,7 @@ bool StridedSliceRel(const Array& types, // Require concrete integer as symbolic inference of min/max // can get complicated and not very helpful. const int64_t* p_dim_size = tir::as_const_int(dshape[i]); - CHECK(p_dim_size) - << "strided_slice requires sliced dimension to be concrete int"; + CHECK(p_dim_size) << "strided_slice requires sliced dimension to be concrete int"; int64_t dim_size = p_dim_size[0]; begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; end_v = (end_v < 0) ? dim_size + end_v : end_v; @@ -1861,16 +1704,14 @@ bool StridedSliceRel(const Array& types, int64_t slice_range, step; if (stride_v < 0) { if (end_v < -1) end_v = -1; - CHECK_LT(end_v, begin_v) - << "strided_slice get empty slice at axis " << i; + CHECK_LT(end_v, begin_v) << "strided_slice get empty slice at axis " << i; begin_v = std::min(dim_size - 1, begin_v); slice_range = begin_v - end_v; step = -stride_v; } else { if (begin_v < 0) begin_v = 0; CHECK_GE(stride_v, 0); - CHECK_LT(begin_v, end_v) - << "strided_slice get empty slice at axis " << i; + CHECK_LT(begin_v, end_v) << "strided_slice get empty slice at axis " << i; end_v = std::min(dim_size, end_v); slice_range = end_v - begin_v; step = stride_v; @@ -1881,13 +1722,10 @@ bool StridedSliceRel(const Array& types, return true; } - -Array > StridedSliceInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { - +Array> StridedSliceInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { Array> old_in_shapes; for (auto old_in_t : old_in_types) { CHECK(old_in_t.as()); @@ -1906,7 +1744,7 @@ Array > StridedSliceInferCorrectLayout( auto shape = old_in_shapes[0]; // NOTE: Discard "const" qualifier here. - auto *params = const_cast(attrs.as()); + auto* params = const_cast(attrs.as()); Array new_begin, new_end; @@ -1929,8 +1767,8 @@ Array > StridedSliceInferCorrectLayout( } } int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0; - int64_t end = params->end[i].defined() ? params->end[i]->value : - shape[i].as()->value; + int64_t end = + params->end[i].defined() ? params->end[i]->value : shape[i].as()->value; if (begin % factor || end % factor) { // transform to original layout return {{Layout::Undef()}, {Layout::Undef()}}; @@ -1946,12 +1784,8 @@ Array > StridedSliceInferCorrectLayout( return {{layout}, {layout}}; } - // Positional relay function to create StridedSlice operator used by frontend FFI. -Expr MakeStridedSlice(Expr data, - Array begin, - Array end, - Array strides) { +Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides) { auto attrs = make_object(); attrs->begin = std::move(begin); attrs->end = std::move(end); @@ -1960,20 +1794,15 @@ Expr MakeStridedSlice(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -Array StridedSliceCompute(const Attrs& attrs, - const Array& inputs, +Array StridedSliceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const StridedSliceAttrs *param = attrs.as(); + const StridedSliceAttrs* param = attrs.as(); CHECK(param != nullptr); return Array{ - topi::strided_slice(inputs[0], param->begin, param->end, param->strides) - }; + topi::strided_slice(inputs[0], param->begin, param->end, param->strides)}; } - -TVM_REGISTER_GLOBAL("relay.op._make.strided_slice") -.set_body_typed(MakeStridedSlice); - +TVM_REGISTER_GLOBAL("relay.op._make.strided_slice").set_body_typed(MakeStridedSlice); RELAY_REGISTER_OP("strided_slice") .describe(R"code(Strided slice of an array. @@ -1999,40 +1828,32 @@ Examples:: [[ 5., 6.], [ 7., 8.]]] )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(4) -.set_attrs_type() -.add_type_rel("StridedSlice", StridedSliceRel) -.set_attr("FTVMCompute", StridedSliceCompute) -.set_attr("TOpPattern", kInjective) -.set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(4) + .set_attrs_type() + .add_type_rel("StridedSlice", StridedSliceRel) + .set_attr("FTVMCompute", StridedSliceCompute) + .set_attr("TOpPattern", kInjective) + .set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); // strided_set -bool StridedSetRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool StridedSetRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 6); reporter->Assign(types[5], types[0]); return true; } -Expr MakeStridedSet(Expr data, - Expr v, - Expr begin, - Expr end, - Expr strides) { +Expr MakeStridedSet(Expr data, Expr v, Expr begin, Expr end, Expr strides) { static const Op& op = Op::Get("strided_set"); return Call(op, {data, v, begin, end, strides}, {}); } -TVM_REGISTER_GLOBAL("relay.op._make.strided_set") -.set_body_typed(MakeStridedSet); - +TVM_REGISTER_GLOBAL("relay.op._make.strided_set").set_body_typed(MakeStridedSet); RELAY_REGISTER_OP("strided_set") - .describe(R"code(Strided set of an array. + .describe(R"code(Strided set of an array. Example:: x = [[ 1., 4., 7., 10.], @@ -2047,22 +1868,20 @@ Example:: [ 2., 44., 55., 66.], [ 3., 6., 9., 12.]] )code" TVM_ADD_FILELINE) -.set_num_inputs(5) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("v", "Tensor", "The data to set.") -.add_argument("begin", "Tensor", "Indices for the start of the slice.") -.add_argument("end", "Tensor", "Indices indicating the end of the slice.") -.add_argument("strides", "Tensor", "The strides values.") -.set_support_level(4) -.set_attr("TOpPattern", kInjective) -.add_type_rel("StridedSet", StridedSetRel); + .set_num_inputs(5) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("v", "Tensor", "The data to set.") + .add_argument("begin", "Tensor", "Indices for the start of the slice.") + .add_argument("end", "Tensor", "Indices indicating the end of the slice.") + .add_argument("strides", "Tensor", "The strides values.") + .set_support_level(4) + .set_attr("TOpPattern", kInjective) + .add_type_rel("StridedSet", StridedSetRel); // relay.split TVM_REGISTER_NODE_TYPE(SplitAttrs); -bool SplitRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); @@ -2075,21 +1894,19 @@ bool SplitRel(const Array& types, if (axis < 0) { axis += data->shape.size(); } - CHECK_LT(axis, data->shape.size()) - << "axis should be within the input dimension range."; - CHECK_GE(axis, 0) - << "axis should be within the input dimension range."; + CHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range."; + CHECK_GE(axis, 0) << "axis should be within the input dimension range."; if (const IntImmNode* sections = param->indices_or_sections.as()) { - CHECK(reporter->Assert(indexmod(data->shape[axis], - sections->value) == tir::make_zero(DataType::Int(64)))) + CHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == + tir::make_zero(DataType::Int(64)))) << "indices_or_sections need to be able to divide input.shape[axis]"; std::vector fields; for (int i = 0; i < sections->value; ++i) { - std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] = indexdiv(oshape[axis], sections->value); - auto vec_type = TensorType(oshape, data->dtype); - fields.push_back(vec_type); + std::vector oshape(data->shape.begin(), data->shape.end()); + oshape[axis] = indexdiv(oshape[axis], sections->value); + auto vec_type = TensorType(oshape, data->dtype); + fields.push_back(vec_type); } reporter->Assign(types[1], TupleType(Array(fields))); } else { @@ -2116,25 +1933,21 @@ bool SplitRel(const Array& types, return true; } -Array SplitCompute(const Attrs& attrs, - const Array& inputs, +Array SplitCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto param = attrs.as(); CHECK(param != nullptr); if (const IntImmNode* sections = param->indices_or_sections.as()) { int64_t num_sections = sections->value; - return Array{ - topi::split_sections(inputs[0], num_sections, param->axis) }; + return Array{topi::split_sections(inputs[0], num_sections, param->axis)}; } else { - auto indices = Downcast >(param->indices_or_sections); - return Array{ topi::split(inputs[0], indices, param->axis) }; + auto indices = Downcast>(param->indices_or_sections); + return Array{topi::split(inputs[0], indices, param->axis)}; } } -Expr MakeSplit(Expr data, - ObjectRef indices_or_sections, - int axis) { +Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { auto attrs = make_object(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); @@ -2142,22 +1955,20 @@ Expr MakeSplit(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.split") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - if (args.type_codes[1] == kDLInt) { - // Note: we change it from Int(64) to Int(32) for now as - // combine_parallel_dense will transform the graph with Int(32). - // More invetigation is needs to check which one we should use. - *rv = MakeSplit(args[0], - tir::make_const(DataType::Int(32), static_cast(args[1])), - args[2]); - } else { - *rv = MakeSplit(args[0], args[1], args[2]); - } +TVM_REGISTER_GLOBAL("relay.op._make.split").set_body([](const TVMArgs& args, TVMRetValue* rv) { + if (args.type_codes[1] == kDLInt) { + // Note: we change it from Int(64) to Int(32) for now as + // combine_parallel_dense will transform the graph with Int(32). + // More invetigation is needs to check which one we should use. + *rv = + MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast(args[1])), args[2]); + } else { + *rv = MakeSplit(args[0], args[1], args[2]); + } }); RELAY_REGISTER_OP("split") -.describe(R"code(Splits an array along a particular axis into multiple sub-arrays. + .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. Indices or sections to split into. Accepts an int or a tuple If indices_or_sections is an integer, the input will be divided equally @@ -2167,29 +1978,26 @@ If indices_or_sections is a tuple of sorted integers, the entries indicate where along axis the array is split. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Split", SplitRel) -.set_attr("FTVMCompute", SplitCompute) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Split", SplitRel) + .set_attr("FTVMCompute", SplitCompute) + .set_attr("TOpPattern", kInjective); // relay.slice_like TVM_REGISTER_NODE_TYPE(SliceLikeAttrs); /*! -* \brief SliceLikeRel User defined type constraint function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return False if the relation has not been resolved, it might be resolved later. -* True if this relation has been resolved. -*/ -bool SliceLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief SliceLikeRel User defined type constraint function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return False if the relation has not been resolved, it might be resolved later. + * True if this relation has been resolved. + */ +bool SliceLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -2214,8 +2022,8 @@ bool SliceLikeRel(const Array& types, if (i < target_shape.size()) { oshape[i] = target_shape[i]; CHECK(reporter->Assert(oshape[i] <= dshape[i])) - << "End index of axis " << i << " exceeds input shape: " - << oshape[i] << " vs " << dshape[i]; + << "End index of axis " << i << " exceeds input shape: " << oshape[i] << " vs " + << dshape[i]; } } } else { @@ -2226,12 +2034,11 @@ bool SliceLikeRel(const Array& types, axis += dshape.size(); } CHECK(axis < static_cast(target_shape.size())) - << "Axis " << axis << " exceeds dimension " - << target_shape.size() << " of target_shape."; + << "Axis " << axis << " exceeds dimension " << target_shape.size() << " of target_shape."; oshape[axis] = target_shape[axis]; CHECK(reporter->Assert(oshape[axis] <= dshape[axis])) - << "End index of axis " << axis << " exceeds input shape: " - << oshape[axis] << " vs " << dshape[axis]; + << "End index of axis " << axis << " exceeds input shape: " << oshape[axis] << " vs " + << dshape[axis]; } } @@ -2239,18 +2046,14 @@ bool SliceLikeRel(const Array& types, return true; } - -Expr MakeSliceLike(Expr data, - Expr shape_like, - Array axes) { +Expr MakeSliceLike(Expr data, Expr shape_like, Array axes) { auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("slice_like"); return Call(op, {data, shape_like}, Attrs(attrs), {}); } -Array SliceLikeCompute(const Attrs& attrs, - const Array& inputs, +Array SliceLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); @@ -2266,11 +2069,10 @@ Array SliceLikeCompute(const Attrs& attrs, for (size_t i = 0; i < src_shape.size(); ++i) { if (i < target_shape.size()) { end_idx.Set(i, target_shape[i]); - CHECK_LE(topi::GetConstInt(end_idx[i]), - topi::GetConstInt(src_shape[i])) - << "End index of axis " << i << " exceeds input shape: " - << topi::GetConstInt(end_idx[i]) << " vs " - << topi::GetConstInt(src_shape[i]); + CHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i])) + << "End index of axis " << i + << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs " + << topi::GetConstInt(src_shape[i]); } } } else { @@ -2279,60 +2081,46 @@ Array SliceLikeCompute(const Attrs& attrs, axis = static_cast(src_shape.size()) + axis; } end_idx.Set(axis, target_shape[axis]); - CHECK_LE(topi::GetConstInt(end_idx[axis]), - topi::GetConstInt(src_shape[axis])) - << "End index of axis " << axis << " exceeds input shape: " - << topi::GetConstInt(end_idx[axis]) << " vs " - << topi::GetConstInt(src_shape[axis]); + CHECK_LE(topi::GetConstInt(end_idx[axis]), topi::GetConstInt(src_shape[axis])) + << "End index of axis " << axis + << " exceeds input shape: " << topi::GetConstInt(end_idx[axis]) << " vs " + << topi::GetConstInt(src_shape[axis]); } } - return Array{ - topi::strided_slice(inputs[0], - GetIntArray(begin_idx), - GetIntArray(end_idx), - GetIntArray(strides)) - }; + return Array{topi::strided_slice(inputs[0], GetIntArray(begin_idx), + GetIntArray(end_idx), GetIntArray(strides))}; } - -TVM_REGISTER_GLOBAL("relay.op._make.slice_like") -.set_body_typed(MakeSliceLike); - +TVM_REGISTER_GLOBAL("relay.op._make.slice_like").set_body_typed(MakeSliceLike); RELAY_REGISTER_OP("slice_like") -.describe(R"code(Slice the first input respect to the second input. + .describe(R"code(Slice the first input respect to the second input. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("shape_like", "Tensor", "Shape tensor.") -.set_support_level(10) -.add_type_rel("SliceLike", SliceLikeRel) -.set_attr("FTVMCompute", SliceLikeCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape_like", "Tensor", "Shape tensor.") + .set_support_level(10) + .add_type_rel("SliceLike", SliceLikeRel) + .set_attr("FTVMCompute", SliceLikeCompute) + .set_attr("TOpPattern", kInjective); // relay.layout_transform TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); -Array LayoutTransformCompute(const Attrs& attrs, - const Array& inputs, +Array LayoutTransformCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array{ - topi::layout_transform(inputs[0], param->src_layout, param->dst_layout) - }; + return Array{topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)}; } -bool LayoutTransformRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool LayoutTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "LayoutTransform: expect input data type to be TensorType but get " - << types[0]; + << "LayoutTransform: expect input data type to be TensorType but get " << types[0]; return false; } const LayoutTransformAttrs* params = attrs.as(); @@ -2340,20 +2128,17 @@ bool LayoutTransformRel(const Array& types, Layout src_layout(params->src_layout); Layout dst_layout(params->dst_layout); - CHECK(src_layout.defined() && dst_layout.defined()) - << "cannot convert from/to undefined layout"; + CHECK(src_layout.defined() && dst_layout.defined()) << "cannot convert from/to undefined layout"; auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout); CHECK(layout_converter.defined()) - << "cannot convert from " << params->src_layout << " to " << params->dst_layout; + << "cannot convert from " << params->src_layout << " to " << params->dst_layout; const auto& out_shape = layout_converter.ForwardShape(data->shape); reporter->Assign(types[1], TensorType(out_shape, data->dtype)); return true; } -Expr MakeLayoutTransform(Expr data, - std::string src_layout, - std::string dst_layout) { +Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout) { auto attrs = make_object(); attrs->src_layout = std::move(src_layout); attrs->dst_layout = std::move(dst_layout); @@ -2361,27 +2146,24 @@ Expr MakeLayoutTransform(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.layout_transform") -.set_body_typed(MakeLayoutTransform); +TVM_REGISTER_GLOBAL("relay.op._make.layout_transform").set_body_typed(MakeLayoutTransform); RELAY_REGISTER_OP("layout_transform") -.describe(R"code(Transform the input data layout. + .describe(R"code(Transform the input data layout. For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("layout_transform", LayoutTransformRel) -.set_support_level(5) -.set_attr("FTVMCompute", LayoutTransformCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("layout_transform", LayoutTransformRel) + .set_support_level(5) + .set_attr("FTVMCompute", LayoutTransformCompute); /* relay._contrib_reverse_reshape */ -Expr MakeReverseReshape(Expr data, - Array newshape) { +Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = true; @@ -2389,11 +2171,10 @@ Expr MakeReverseReshape(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape") -.set_body_typed(MakeReverseReshape); +TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape").set_body_typed(MakeReverseReshape); RELAY_REGISTER_OP("_contrib_reverse_reshape") -.describe(R"code(Reshapes the input array where the special values are inferred from + .describe(R"code(Reshapes the input array where the special values are inferred from right to left. Example:: @@ -2406,18 +2187,16 @@ example below:: - data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5) )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("Reshape", ReshapeRel) -.set_attr("FTVMCompute", ReshapeCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("Reshape", ReshapeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); // gather_nd operator -bool GatherNDRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); @@ -2425,48 +2204,40 @@ bool GatherNDRel(const Array& types, const auto* indices = types[1].as(); if (data == nullptr) { CHECK(types[0].as()) - << "GatherND: expect input data type to be TensorType but get " - << types[0]; + << "GatherND: expect input data type to be TensorType but get " << types[0]; return false; } if (indices == nullptr) { CHECK(types[1].as()) - << "GatherND: expect indices type to be TensorType but get " - << types[1]; + << "GatherND: expect indices type to be TensorType but get " << types[1]; return false; } const size_t ndim = data->shape.size(); const IntImmNode* mdim = indices->shape[0].as(); const size_t kdim = indices->shape.size() - 1; - CHECK(size_t(mdim->value) <= ndim) - << "GatherND: indices shape does satisfy."; + CHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy."; Array oshape; - for (size_t i = 1; i < kdim + 1; ++i) - oshape.push_back(indices->shape[i]); - for (size_t i = mdim->value; i < ndim; ++i) - oshape.push_back(data->shape[i]); + for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]); + for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]); reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } -Array GatherNDCompute(const Attrs& attrs, - const Array& inputs, +Array GatherNDCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::gather_nd(inputs[0], inputs[1]) }; + return {topi::gather_nd(inputs[0], inputs[1])}; } -Expr MakeGatherND(Expr data, - Expr indices) { +Expr MakeGatherND(Expr data, Expr indices) { static const Op& op = Op::Get("gather_nd"); return Call(op, {data, indices}, {}); } -TVM_REGISTER_GLOBAL("relay.op._make.gather_nd") -.set_body_typed(MakeGatherND); +TVM_REGISTER_GLOBAL("relay.op._make.gather_nd").set_body_typed(MakeGatherND); RELAY_REGISTER_OP("gather_nd") -.describe(R"code(Gather elements or slices from data and store to + .describe(R"code(Gather elements or slices from data and store to a tensor whose shape is defined by indices. Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with @@ -2474,19 +2245,17 @@ shape (M, Y_0, ..., Y_{K-1}), the output will have shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N, output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("GatherND", GatherNDRel) -.set_attr("FTVMCompute", GatherNDCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("GatherND", GatherNDRel) + .set_attr("FTVMCompute", GatherNDCompute) + .set_attr("TOpPattern", kInjective); // relay.sequence_mask TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs); -bool SequenceMaskRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SequenceMaskRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, valid_length, result] CHECK_EQ(types.size(), 3); @@ -2503,19 +2272,15 @@ bool SequenceMaskRel(const Array& types, return true; } -Array SequenceMaskCompute(const Attrs& attrs, - const Array& inputs, +Array SequenceMaskCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); return Array{ - topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) }; + topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis)}; } -Expr MakeSequenceMask(Expr data, - Expr valid_length, - double mask_value, - int axis) { +Expr MakeSequenceMask(Expr data, Expr valid_length, double mask_value, int axis) { auto attrs = make_object(); attrs->mask_value = std::move(mask_value); attrs->axis = std::move(axis); @@ -2523,11 +2288,11 @@ Expr MakeSequenceMask(Expr data, return Call(op, {data, valid_length}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask") -.set_body_typed(MakeSequenceMask); +TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask").set_body_typed(MakeSequenceMask); RELAY_REGISTER_OP("sequence_mask") -.describe(R"code(Sets all elements outside the expected length of the sequence to a constant value. + .describe( + R"code(Sets all elements outside the expected length of the sequence to a constant value. This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or [batch_size, MAX_LENGTH, ...] and returns an array of the same shape. @@ -2575,21 +2340,19 @@ Examples:: [[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.") -.set_support_level(10) -.add_type_rel("SequenceMask", SequenceMaskRel) -.set_attr("FTVMCompute", SequenceMaskCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.") + .set_support_level(10) + .add_type_rel("SequenceMask", SequenceMaskRel) + .set_attr("FTVMCompute", SequenceMaskCompute) + .set_attr("TOpPattern", kInjective); // relay.one_hot TVM_REGISTER_NODE_TYPE(OneHotAttrs); -bool OneHotRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool OneHotRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [indices, on_value, off_value, result] CHECK_EQ(types.size(), 4); @@ -2615,27 +2378,15 @@ bool OneHotRel(const Array& types, return true; } -Array OneHotCompute(const Attrs& attrs, - const Array& inputs, +Array OneHotCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array { - topi::one_hot(inputs[0], - inputs[1](), - inputs[2](), - param->depth, - param->axis, - param->dtype) - }; -} - -Expr MakeOneHot(Expr indices, - Expr on_value, - Expr off_value, - int depth, - int axis, - DataType dtype) { + return Array{ + topi::one_hot(inputs[0], inputs[1](), inputs[2](), param->depth, param->axis, param->dtype)}; +} + +Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype) { auto attrs = make_object(); attrs->depth = std::move(depth); attrs->axis = axis; @@ -2644,11 +2395,10 @@ Expr MakeOneHot(Expr indices, return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.one_hot") -.set_body_typed(MakeOneHot); +TVM_REGISTER_GLOBAL("relay.op._make.one_hot").set_body_typed(MakeOneHot); RELAY_REGISTER_OP("one_hot") -.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, + .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, other locations take value 0. Final dimension is x depth. **indices** Locations to set to 1. @@ -2662,42 +2412,36 @@ RELAY_REGISTER_OP("one_hot") **axis** Axis to fill. **dtype**)code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("indices", "Tensor", "Locations to set to on_value.") -.add_argument("on_value", "Expr", "Value to fill at indices.") -.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") -.set_support_level(10) -.add_type_rel("OneHot", OneHotRel) -.set_attr("FTVMCompute", OneHotCompute) -.set_attr("TOpPattern", kOutEWiseFusable); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("indices", "Tensor", "Locations to set to on_value.") + .add_argument("on_value", "Expr", "Value to fill at indices.") + .add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") + .set_support_level(10) + .add_type_rel("OneHot", OneHotRel) + .set_attr("FTVMCompute", OneHotCompute) + .set_attr("TOpPattern", kOutEWiseFusable); /* relay.unravel_index */ -bool UnRavelIndexRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool UnRavelIndexRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* indices = types[0].as(); if (indices == nullptr) { CHECK(types[0].as()) - << "unravel_index: expect input type to be TensorType but get " - << types[0]; + << "unravel_index: expect input type to be TensorType but get " << types[0]; return false; } - CHECK(indices->dtype.is_int()) - << "indices of unravel_index must be tensor of integer"; + CHECK(indices->dtype.is_int()) << "indices of unravel_index must be tensor of integer"; const auto* shape = types[1].as(); if (shape == nullptr) { CHECK(types[1].as()) - << "unravel_index: expect input type to be TensorType but get " - << types[1]; + << "unravel_index: expect input type to be TensorType but get " << types[1]; return false; } - CHECK(indices->dtype.is_int()) - << "shape of unravel_index must be tensor of integer"; + CHECK(indices->dtype.is_int()) << "shape of unravel_index must be tensor of integer"; Array indices_shape; Array shape_shape; @@ -2713,32 +2457,30 @@ bool UnRavelIndexRel(const Array& types, return true; } -Array UnRavelIndexCompute(const Attrs& attrs, - const Array& inputs, +Array UnRavelIndexCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return Array{topi::unravel_index(inputs[0], inputs[1])}; } -Expr MakeUnRavelIndex(Expr data, - Expr shape) { +Expr MakeUnRavelIndex(Expr data, Expr shape) { static const Op& op = Op::Get("unravel_index"); return Call(op, {data, shape}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.unravel_index") -.set_body_typed(MakeUnRavelIndex); +TVM_REGISTER_GLOBAL("relay.op._make.unravel_index").set_body_typed(MakeUnRavelIndex); RELAY_REGISTER_OP("unravel_index") -.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays. + .describe( + R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays. Example:: - unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]] )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.set_support_level(3) -.add_type_rel("UnRavelIndexRel", UnRavelIndexRel) -.set_attr("FTVMCompute", UnRavelIndexCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(2) + .set_support_level(3) + .add_type_rel("UnRavelIndexRel", UnRavelIndexRel) + .set_attr("FTVMCompute", UnRavelIndexCompute) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index a64dcd5..62433c2 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -26,33 +26,32 @@ #include #include -#include +#include + #include #include #include #include #include +#include namespace tvm { namespace relay { template -bool ConcatenateRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); /* If we receive a tuple we can continue, if we receive * anything but an incomplete type we should signal an * error. - */ + */ const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { throw Error( - ErrorBuilder() - << "concatenate requires a tuple of tensors as the first argument, found " - << PrettyPrint(types[0])); + ErrorBuilder() << "concatenate requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0])); } else if (types[0].as() != nullptr) { return false; } @@ -69,10 +68,8 @@ bool ConcatenateRel(const Array& types, // Sanity check: axis int axis = param->axis; if (!(-ndim <= axis && axis < ndim)) { - throw Error(ErrorBuilder() << - "concatenate only accepts `axis` in [-ndim, ndim)" << - ", but got axis = " << axis << - ", and ndim = " << ndim); + throw Error(ErrorBuilder() << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim); } axis = axis < 0 ? ndim + axis : axis; @@ -94,14 +91,15 @@ bool ConcatenateRel(const Array& types, for (size_t j = 0; j < first->shape.size(); ++j) { if (j == static_cast(axis)) continue; if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue; - throw Error("relay.concatenate requires all tensors have the same shape " - "on non-concatenating axes"); + throw Error( + "relay.concatenate requires all tensors have the same shape " + "on non-concatenating axes"); } } // Calculate shape std::vector oshape(first->shape.begin(), first->shape.end()); - IndexExpr &concat_dim = oshape[axis]; + IndexExpr& concat_dim = oshape[axis]; bool has_any = false; if (concat_dim.as()) { has_any = true; @@ -125,11 +123,10 @@ bool ConcatenateRel(const Array& types, return true; } -static inline Array> ConcatenateLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +static inline Array> ConcatenateLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { ConcatenateAttrs* param = const_cast(attrs.as()); Array> old_in_shapes; @@ -141,8 +138,8 @@ static inline Array> ConcatenateLayout( } } - size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : - static_cast(param->axis); + size_t axis = + param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast(param->axis); Layout ret; bool is_new_layout_selected = false; @@ -175,11 +172,11 @@ static inline Array> ConcatenateLayout( } if (ret.ndim() <= axis || !ret[axis].IsPrimal()) { - return Array > {{Layout::Undef()}, {Layout::Undef()}}; + return Array>{{Layout::Undef()}, {Layout::Undef()}}; } } - return Array > {Array(old_in_layouts.size(), ret), {ret}}; + return Array>{Array(old_in_layouts.size(), ret), {ret}}; } } // namespace relay diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 5714379..ccf6dd2 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -21,112 +21,101 @@ * \file unary.cc * \brief Unary operators. */ -#include -#include -#include #include #include -#include "../type_relations.h" +#include +#include +#include + #include "../op_common.h" +#include "../type_relations.h" namespace tvm { namespace relay { -#define RELAY_UNARY_COMPUTE(FTOPI) \ - [] (const Attrs& attrs, \ - const Array& inputs, \ - const Type& out_type) -> Array { \ - return {FTOPI(inputs[0])}; \ - } \ - +#define RELAY_UNARY_COMPUTE(FTOPI) \ + [](const Attrs& attrs, const Array& inputs, \ + const Type& out_type) -> Array { return {FTOPI(inputs[0])}; } RELAY_REGISTER_UNARY_OP("log") -.describe(R"code(Returns the log input array, computed element-wise. + .describe(R"code(Returns the log input array, computed element-wise. .. math:: log(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); RELAY_REGISTER_UNARY_OP("log2") -.describe(R"code(Returns the log to base 2 of input array, computed element-wise. + .describe(R"code(Returns the log to base 2 of input array, computed element-wise. .. math:: log2(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2)); RELAY_REGISTER_UNARY_OP("log10") -.describe(R"code(Returns the log to base 10 of input array, computed element-wise. + .describe(R"code(Returns the log to base 10 of input array, computed element-wise. .. math:: log10(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10)); RELAY_REGISTER_UNARY_OP("tan") -.describe(R"code(Returns the tan of input array, computed element-wise. + .describe(R"code(Returns the tan of input array, computed element-wise. .. math:: Y = tan(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan)); RELAY_REGISTER_UNARY_OP("cos") -.describe(R"code(Returns the cos of input array, computed element-wise. + .describe(R"code(Returns the cos of input array, computed element-wise. .. math:: Y = cos(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos)); RELAY_REGISTER_UNARY_OP("cosh") -.describe(R"code(Returns the cosh of input array, computed element-wise. + .describe(R"code(Returns the cosh of input array, computed element-wise. .. math:: Y = cosh(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh)); RELAY_REGISTER_UNARY_OP("sin") -.describe(R"code(Returns the sin of input array, computed element-wise. + .describe(R"code(Returns the sin of input array, computed element-wise. .. math:: Y = sin(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin)); RELAY_REGISTER_UNARY_OP("sinh") -.describe(R"code(Returns the sinh of input array, computed element-wise. + .describe(R"code(Returns the sinh of input array, computed element-wise. .. math:: Y = sinh(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh)); RELAY_REGISTER_UNARY_OP("acos") .describe(R"code(Returns the acos of input array, computed element-wise. @@ -173,15 +162,14 @@ RELAY_REGISTER_UNARY_OP("asinh") RELAY_REGISTER_UNARY_OP("atan") -.describe(R"code(Returns the atan of input array, computed element-wise. + .describe(R"code(Returns the atan of input array, computed element-wise. .. math:: Y = atan(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atan)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atan)); RELAY_REGISTER_UNARY_OP("atanh") .describe(R"code(Returns the atanh of input array, computed element-wise. @@ -195,243 +183,225 @@ RELAY_REGISTER_UNARY_OP("atanh") RELAY_REGISTER_UNARY_OP("exp") -.describe(R"code(Returns the exp input array, computed element-wise. + .describe(R"code(Returns the exp input array, computed element-wise. .. math:: \exp(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); RELAY_REGISTER_UNARY_OP("fast_exp") -.describe(R"code(Returns the fast_exp input array, computed element-wise. + .describe(R"code(Returns the fast_exp input array, computed element-wise. .. math:: \fast_exp(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp)); RELAY_REGISTER_UNARY_OP("erf") -.describe(R"code(Returns the error function value for input array, computed element-wise. + .describe(R"code(Returns the error function value for input array, computed element-wise. .. math:: \erf(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf)); RELAY_REGISTER_UNARY_OP("fast_erf") -.describe(R"code(Returns the error function value for input array, computed element-wise. + .describe(R"code(Returns the error function value for input array, computed element-wise. .. math:: \fast_erf(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf)); RELAY_REGISTER_UNARY_OP("sqrt") -.describe(R"code(Returns the sqrt input array, computed element-wise. + .describe(R"code(Returns the sqrt input array, computed element-wise. .. math:: sqrt(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); RELAY_REGISTER_UNARY_OP("rsqrt") -.describe(R"code(Returns the rsqrt input array, computed element-wise. + .describe(R"code(Returns the rsqrt input array, computed element-wise. .. math:: 1/sqrt(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt)); RELAY_REGISTER_UNARY_OP("zeros_like") -.describe(R"code(Returns an array of zeros, with same type and shape as the input. + .describe(R"code(Returns an array of zeros, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(4); + .set_support_level(4); RELAY_REGISTER_UNARY_OP("ones_like") -.describe(R"code(Returns an array of ones, with same type and shape as the input. + .describe(R"code(Returns an array of ones, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(4); + .set_support_level(4); RELAY_REGISTER_UNARY_OP("sigmoid") -.describe(R"code(Returns the sigmoid input array, computed element-wise. + .describe(R"code(Returns the sigmoid input array, computed element-wise. .. math:: sigmoid(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); RELAY_REGISTER_UNARY_OP("copy") -.describe(R"code(Copy a tensor. + .describe(R"code(Copy a tensor. )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); // relay.clip TVM_REGISTER_NODE_TYPE(ClipAttrs); -TVM_REGISTER_GLOBAL("relay.op._make.clip") -.set_body_typed([](Expr a, double a_min, double a_max) { - auto attrs = make_object(); - attrs->a_min = a_min; - attrs->a_max = a_max; - static const Op& op = Op::Get("clip"); +TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed([](Expr a, double a_min, double a_max) { + auto attrs = make_object(); + attrs->a_min = a_min; + attrs->a_max = a_max; + static const Op& op = Op::Get("clip"); return Call(op, {a}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("clip") -.describe(R"code(Clip tensor values. + .describe(R"code(Clip tensor values. This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kElemWise) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attrs_type() -.set_support_level(3); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kElemWise) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attrs_type() + .set_support_level(3); RELAY_REGISTER_UNARY_OP("floor") -.describe(R"code(Returns the floor of input array, computed element-wise. + .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); RELAY_REGISTER_UNARY_OP("ceil") -.describe(R"code(Returns the ceil of input array, computed element-wise. + .describe(R"code(Returns the ceil of input array, computed element-wise. .. math:: ceil(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); RELAY_REGISTER_UNARY_OP("trunc") -.describe(R"code(Returns the trunc of input array, computed element-wise. + .describe(R"code(Returns the trunc of input array, computed element-wise. .. math:: trunc(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); RELAY_REGISTER_UNARY_OP("round") -.describe(R"code(Returns the round of input array, computed element-wise. + .describe(R"code(Returns the round of input array, computed element-wise. .. math:: round(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); RELAY_REGISTER_UNARY_OP("sign") -.describe(R"code(Returns the sign of input array, computed element-wise. + .describe(R"code(Returns the sign of input array, computed element-wise. .. numpy:: sign(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign)); RELAY_REGISTER_UNARY_OP("abs") -.describe(R"code(Returns the abs of input array, computed element-wise. + .describe(R"code(Returns the abs of input array, computed element-wise. .. math:: abs(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); RELAY_REGISTER_UNARY_OP("tanh") -.describe(R"code(Returns the tanh of input array, computed element-wise. + .describe(R"code(Returns the tanh of input array, computed element-wise. .. math:: Y = sinh(X) / cosh(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); RELAY_REGISTER_UNARY_OP("fast_tanh") -.describe(R"code(Returns the fast_tanh of input array, computed element-wise. + .describe(R"code(Returns the fast_tanh of input array, computed element-wise. .. math:: Y = sinh(X) / cosh(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh)); RELAY_REGISTER_UNARY_OP("negative") -.describe(R"code(Returns the numeric negative of input array, computed element-wise. + .describe(R"code(Returns the numeric negative of input array, computed element-wise. .. math:: -(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); RELAY_REGISTER_UNARY_OP("logical_not") -.describe(R"code(Returns the logical inverse of input array, computed element-wise. + .describe(R"code(Returns the logical inverse of input array, computed element-wise. .. math:: !(x) )code" TVM_ADD_FILELINE) -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); - + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); RELAY_REGISTER_UNARY_OP("bitwise_not") -.describe(R"code(Returns the bitwise inverse of input array, computed element-wise. + .describe(R"code(Returns the bitwise inverse of input array, computed element-wise. .. math:: ~(x) )code" TVM_ADD_FILELINE) -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not)); - + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not)); // shape_of TVM_REGISTER_NODE_TYPE(ShapeOfAttrs); -bool ShapeOfRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ShapeOfRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); @@ -445,8 +415,7 @@ bool ShapeOfRel(const Array& types, return true; } -Array ShapeOfCompute(const Attrs& attrs, - const Array& inputs, +Array ShapeOfCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { CHECK_EQ(inputs.size(), 1); const auto* param = attrs.as(); @@ -454,8 +423,7 @@ Array ShapeOfCompute(const Attrs& attrs, return {topi::shape(inputs[0], param->dtype)}; } -TVM_REGISTER_GLOBAL("relay.op._make.shape_of") -.set_body_typed([](Expr data, DataType dtype) { +TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("shape_of"); @@ -463,29 +431,25 @@ TVM_REGISTER_GLOBAL("relay.op._make.shape_of") }); RELAY_REGISTER_OP("shape_of") -.describe(R"code(Returns a tensor representing the shape of a tensor. + .describe(R"code(Returns a tensor representing the shape of a tensor. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("ShapeOf", ShapeOfRel) -.set_attr("TOpIsStateful", false) -// Use kOpaque for shape_of op for now since it won't be performance critic, -// and it makes things easier for dynamic shape func -.set_attr("TOpPattern", kOpaque) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_support_level(10) -.set_attr("FTVMCompute", ShapeOfCompute); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("ShapeOf", ShapeOfRel) + .set_attr("TOpIsStateful", false) + // Use kOpaque for shape_of op for now since it won't be performance critic, + // and it makes things easier for dynamic shape func + .set_attr("TOpPattern", kOpaque) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_support_level(10) + .set_attr("FTVMCompute", ShapeOfCompute); TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs); -bool NdarraySizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { +bool NdarraySizeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); CHECK(tt != nullptr); @@ -495,8 +459,7 @@ bool NdarraySizeRel(const Array& types, return true; } -Array NdarraySizeCompute(const Attrs& attrs, - const Array& inputs, +Array NdarraySizeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { CHECK_EQ(inputs.size(), 1); const auto* param = attrs.as(); @@ -504,8 +467,7 @@ Array NdarraySizeCompute(const Attrs& attrs, return Array{topi::ndarray_size(inputs[0], param->dtype)}; } -TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size") -.set_body_typed([](Expr data, DataType dtype) { +TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size").set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("ndarray_size"); @@ -513,46 +475,45 @@ TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size") }); RELAY_REGISTER_OP("ndarray_size") -.describe(R"code(Returns a tensor representing the number of elements of input tensor. + .describe(R"code(Returns a tensor representing the number of elements of input tensor. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("NdarraySize", NdarraySizeRel) -.set_attr("TOpIsStateful", false) -.set_attr("TOpPattern", kInjective) -.set_attr("FInferCorrectLayout", -ElemwiseArbitraryLayout) -.set_support_level(10) -.set_attr("FTVMCompute", NdarraySizeCompute); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("NdarraySize", NdarraySizeRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kInjective) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_support_level(10) + .set_attr("FTVMCompute", NdarraySizeCompute); RELAY_REGISTER_UNARY_OP("isnan") -.describe(R"code(Returns whether the input contains any NaN, computed element-wise. + .describe(R"code(Returns whether the input contains any NaN, computed element-wise. .. math:: isnan(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.add_type_rel("IdentityCompRel", IdentityCompRel) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan)); + .set_support_level(3) + .add_type_rel("IdentityCompRel", IdentityCompRel) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan)); RELAY_REGISTER_UNARY_OP("isfinite") -.describe(R"code(Returns the finiteness of input, computed element-wise. + .describe(R"code(Returns the finiteness of input, computed element-wise. .. math:: isfinite(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.add_type_rel("IdentityCompRel", IdentityCompRel) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite)); + .set_support_level(3) + .add_type_rel("IdentityCompRel", IdentityCompRel) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite)); RELAY_REGISTER_UNARY_OP("isinf") -.describe(R"code(Returns the infiniteness of input, computed element-wise. + .describe(R"code(Returns the infiniteness of input, computed element-wise. .. math:: isinf(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.add_type_rel("IdentityCompRel", IdentityCompRel) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf)); + .set_support_level(3) + .add_type_rel("IdentityCompRel", IdentityCompRel) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf)); } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index e2e7f49..677683c 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -22,19 +22,19 @@ * \brief A set of utilities and common functionality * for type relations. */ +#include "./type_relations.h" + #include -#include #include #include +#include + #include -#include "./type_relations.h" namespace tvm { namespace relay { -bool IdentityRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { for (size_t i = 1; i < types.size(); ++i) { reporter->Assign(types[i], types[0]); @@ -42,8 +42,7 @@ bool IdentityRel(const Array& types, return true; } -bool EqualCheck(const IndexExpr& lhs, - const IndexExpr& rhs) { +bool EqualCheck(const IndexExpr& lhs, const IndexExpr& rhs) { IndexExpr diff = lhs - rhs; if (const int64_t* pdiff = tir::as_const_int(diff)) { return pdiff[0] == 0; @@ -64,9 +63,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) { return false; } -Type ConcreteBroadcast(const TensorType& t1, - const TensorType& t2, - DataType output_dtype) { +Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { std::vector oshape; size_t ndim1 = t1->shape.size(); size_t ndim2 = t2->shape.size(); @@ -87,9 +84,7 @@ Type ConcreteBroadcast(const TensorType& t1, } else if (EqualCheck(s1, s2)) { oshape.push_back(s1); } else { - throw Error(ErrorBuilder() - << "Incompatible broadcast type " - << t1 << " and " << t2); + throw Error(ErrorBuilder() << "Incompatible broadcast type " << t1 << " and " << t2); } } @@ -98,13 +93,10 @@ Type ConcreteBroadcast(const TensorType& t1, for (; i <= max_ndim; ++i) { oshape.push_back(rshape[max_ndim - i]); } - return TensorType(Array( - oshape.rbegin(), oshape.rend()), output_dtype); + return TensorType(Array(oshape.rbegin(), oshape.rend()), output_dtype); } -bool BroadcastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] @@ -112,17 +104,15 @@ bool BroadcastRel(const Array& types, if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], - ConcreteBroadcast(GetRef(t0), GetRef(t1), t0->dtype)); + reporter->Assign( + types[2], ConcreteBroadcast(GetRef(t0), GetRef(t1), t0->dtype)); return true; } } return false; } -bool BroadcastCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] @@ -130,17 +120,15 @@ bool BroadcastCompRel(const Array& types, if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], - ConcreteBroadcast(GetRef(t0), GetRef(t1), DataType::Bool())); + reporter->Assign(types[2], ConcreteBroadcast(GetRef(t0), GetRef(t1), + DataType::Bool())); return true; } } return false; } -bool IdentityCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool IdentityCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { if (auto* t0 = types[0].as()) { Type out_type = TensorType(GetRef(t0)->shape, DataType::Bool()); @@ -154,7 +142,7 @@ Array RankShape(const Array& shape) { if (shape.size() == 0) { return {}; } else { - return { tvm::Integer(shape.size()) }; + return {tvm::Integer(shape.size())}; } } diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 48a545b..acd4b2d 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -27,6 +27,7 @@ #include #include + #include namespace tvm { @@ -40,9 +41,7 @@ namespace relay { * \param reporter The reporter. * \return true whether relation has been resolved. */ -bool IdentityRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); /*! @@ -55,9 +54,7 @@ bool IdentityRel(const Array& types, * \param reporter The reporter. * \return true whether relation has been resolved. */ -bool BroadcastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); /*! @@ -74,15 +71,11 @@ bool BroadcastRel(const Array& types, * \param reporter The reporter. * \return true whether relation has been resolved. */ -bool BroadcastCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); -bool IdentityCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter); +bool IdentityCompRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter); Array RankShape(const Array& shape); diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index cafe9b6..18a2edb 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -21,46 +21,38 @@ * \file multibox_op.cc * \brief Multibox related operators */ -#include -#include #include +#include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(MultiBoxPriorAttrs); -bool MultiboxPriorRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool MultiboxPriorRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); const MultiBoxPriorAttrs* param = attrs.as(); const auto& dshape = data->shape; CHECK_EQ(dshape.size(), 4) << "Input data should be 4D: " - "[batch, channel, height, width]"; + "[batch, channel, height, width]"; IndexExpr in_height = dshape[2]; IndexExpr in_width = dshape[3]; int num_sizes = static_cast(param->sizes.size()); int num_ratios = static_cast(param->ratios.size()); // since input sizes are same in each batch, we could share MultiBoxPrior - std::vector oshape( - {1, in_height * in_width * (num_sizes + num_ratios - 1), 4}); + std::vector oshape({1, in_height * in_width * (num_sizes + num_ratios - 1), 4}); // assign output type reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } - -Expr MakeMultiBoxPrior(Expr data, - Array sizes, - Array ratios, - Array steps, - Array offsets, - bool clip) { +Expr MakeMultiBoxPrior(Expr data, Array sizes, Array ratios, + Array steps, Array offsets, bool clip) { auto attrs = make_object(); attrs->sizes = std::move(sizes); attrs->ratios = std::move(ratios); @@ -71,25 +63,20 @@ Expr MakeMultiBoxPrior(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior") -.set_body_typed(MakeMultiBoxPrior); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior").set_body_typed(MakeMultiBoxPrior); RELAY_REGISTER_OP("vision.multibox_prior") -.describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." + .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." )doc" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(5) -.add_type_rel("MultiBoxPrior", MultiboxPriorRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("MultiBoxPrior", MultiboxPriorRel); TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs); -bool MultiBoxTransformLocRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool MultiBoxTransformLocRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); @@ -102,20 +89,15 @@ bool MultiBoxTransformLocRel(const Array& types, const auto& loc_shape = loc_pred->shape; const auto& anchor_shape = anchor->shape; - CHECK_EQ(cls_shape.size(), 3U) - << "The dimension of class probability should be 3, but received " - << cls_shape.size(); + CHECK_EQ(cls_shape.size(), 3U) << "The dimension of class probability should be 3, but received " + << cls_shape.size(); CHECK_EQ(loc_shape.size(), 2U) - << "The dimension of location prediction should be 2, but received " - << loc_shape.size(); + << "The dimension of location prediction should be 2, but received " << loc_shape.size(); CHECK_EQ(anchor_shape.size(), 3U) - << "The dimension of anchor should be 3, but received " - << anchor_shape.size(); + << "The dimension of anchor should be 3, but received " << anchor_shape.size(); - CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) - << "Number of anchors mismatch found"; - CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) - << "# anchors mismatch with # loc."; + CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) << "Number of anchors mismatch found"; + CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) << "# anchors mismatch with # loc."; CHECK(reporter->Assert(anchor_shape[1] > 0)) << "Number of anchors must > 0."; CHECK(reporter->AssertEQ(anchor_shape[2], 4)); @@ -130,12 +112,8 @@ bool MultiBoxTransformLocRel(const Array& types, return true; } -Expr MakeMultiBoxTransformLoc(Expr cls_prob, - Expr loc_pred, - Expr anchor, - bool clip, - double threshold, - Array variances) { +Expr MakeMultiBoxTransformLoc(Expr cls_prob, Expr loc_pred, Expr anchor, bool clip, + double threshold, Array variances) { auto attrs = make_object(); attrs->clip = std::move(clip); attrs->threshold = std::move(threshold); @@ -145,18 +123,18 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob, } TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_transform_loc") -.set_body_typed(MakeMultiBoxTransformLoc); + .set_body_typed(MakeMultiBoxTransformLoc); RELAY_REGISTER_OP("vision.multibox_transform_loc") -.describe(R"doc("Location transformation for multibox detection." + .describe(R"doc("Location transformation for multibox detection." )doc" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("cls_prob", "Tensor", "Class probabilities.") -.add_argument("loc_pred", "Tensor", "Location regression predictions.") -.add_argument("anchor", "Tensor", "Multibox prior anchor boxes") -.add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel) -.set_support_level(5); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("cls_prob", "Tensor", "Class probabilities.") + .add_argument("loc_pred", "Tensor", "Location regression predictions.") + .add_argument("anchor", "Tensor", "Multibox prior anchor boxes") + .add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel) + .set_support_level(5); } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 25743f9..b1aaaf0 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -21,17 +21,15 @@ * \file nms.cc * \brief Non-maximum suppression operators */ -#include #include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs); -bool GetValidCountRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GetValidCountRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -48,10 +46,7 @@ bool GetValidCountRel(const Array& types, return true; } -Expr MakeGetValidCounts(Expr data, - double score_threshold, - int id_index, - int score_index) { +Expr MakeGetValidCounts(Expr data, double score_threshold, int id_index, int score_index) { auto attrs = make_object(); attrs->score_threshold = score_threshold; attrs->id_index = id_index; @@ -60,33 +55,26 @@ Expr MakeGetValidCounts(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts") -.set_body_typed(MakeGetValidCounts); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts").set_body_typed(MakeGetValidCounts); RELAY_REGISTER_OP("vision.get_valid_counts") -.describe(R"doc(Get valid count of bounding boxes given + .describe(R"doc(Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. )doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "Input data.") -.set_support_level(5) -.add_type_rel("GetValidCount", GetValidCountRel); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input data.") + .set_support_level(5) + .add_type_rel("GetValidCount", GetValidCountRel); TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs); -bool NMSRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* valid_count = types[1].as(); - const NonMaximumSuppressionAttrs* param = - attrs.as(); + const NonMaximumSuppressionAttrs* param = attrs.as(); const auto& dshape = data->shape; const auto& vshape = valid_count->shape; CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; @@ -102,18 +90,9 @@ bool NMSRel(const Array& types, return true; } - -Expr MakeNMS(Expr data, - Expr valid_count, - int max_output_size, - double iou_threshold, - bool force_suppress, - int top_k, - int coord_start, - int score_index, - int id_index, - bool return_indices, - bool invalid_to_bottom) { +Expr MakeNMS(Expr data, Expr valid_count, int max_output_size, double iou_threshold, + bool force_suppress, int top_k, int coord_start, int score_index, int id_index, + bool return_indices, bool invalid_to_bottom) { auto attrs = make_object(); attrs->max_output_size = max_output_size; attrs->iou_threshold = iou_threshold; @@ -128,21 +107,18 @@ Expr MakeNMS(Expr data, return Call(op, {data, valid_count}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression") -.set_body_typed(MakeNMS); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS); RELAY_REGISTER_OP("vision.non_max_suppression") -.describe(R"doc(Non-maximum suppression. The input boxes should + .describe(R"doc(Non-maximum suppression. The input boxes should be in the format of [class_id, score, left, top, right, bottom]. Set id_index to be -1 to ignore class_id axis. )doc" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "Input data.") -.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") -.set_support_level(5) -.add_type_rel("NMS", NMSRel); + .set_num_inputs(2) + .add_argument("data", "Tensor", "Input data.") + .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") + .set_support_level(5) + .add_type_rel("NMS", NMSRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 5661ebb..efedb5e 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -21,9 +21,9 @@ * \file rcnn_op.cc * \brief Faster RCNN and Mask RCNN operators */ +#include #include #include -#include namespace tvm { namespace relay { @@ -62,8 +62,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spa return Call(op, {data, rois}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align") -.set_body_typed(MakeROIAlign); +TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align").set_body_typed(MakeROIAlign); RELAY_REGISTER_OP("vision.roi_align") .describe(R"doc(ROI Align operator. @@ -75,16 +74,16 @@ RELAY_REGISTER_OP("vision.roi_align") - **out**: This depends on the `layout` parameter. Output is 4D array of shape (num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`. )doc" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("rois", "Tensor", "The input rois") -.set_support_level(5) -.add_type_rel("ROIAlign", ROIAlignRel); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", "The input rois") + .set_support_level(5) + .add_type_rel("ROIAlign", ROIAlignRel); TVM_REGISTER_NODE_TYPE(ROIPoolAttrs); bool ROIPoolRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { auto roi_pool_attrs = attrs.as(); CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -112,8 +111,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spat return Call(op, {data, rois}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool") -.set_body_typed(MakeROIPool); +TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool").set_body_typed(MakeROIPool); RELAY_REGISTER_OP("vision.roi_pool") .describe(R"doc(ROI Pool operator. @@ -125,11 +123,11 @@ RELAY_REGISTER_OP("vision.roi_pool") - **out**: This depends on the `layout` parameter. Output is 4D array of shape (num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`. )doc" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("rois", "Tensor", "The input rois") -.set_support_level(5) -.add_type_rel("ROIPool", ROIPoolRel); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", "The input rois") + .set_support_level(5) + .add_type_rel("ROIPool", ROIPoolRel); TVM_REGISTER_NODE_TYPE(ProposalAttrs); @@ -155,16 +153,14 @@ bool ProposalRel(const Array& types, int num_inputs, const Attrs& attrs, auto batch = cls_prob->shape[0]; - std::vector oshape( - {batch * proposal_attrs->rpn_post_nms_top_n, 5}); + std::vector oshape({batch * proposal_attrs->rpn_post_nms_top_n, 5}); reporter->Assign(types[3], TensorType(oshape, cls_prob->dtype)); return true; } Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array scales, Array ratios, int feature_stride, double threshold, - int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, - bool iou_loss) { + int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, bool iou_loss) { auto attrs = make_object(); attrs->scales = scales; attrs->ratios = ratios; @@ -178,8 +174,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array return Call(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal") -.set_body_typed(MakeProposal); +TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal").set_body_typed(MakeProposal); RELAY_REGISTER_OP("vision.proposal") .describe(R"code(Generate region proposals via RPN. @@ -189,12 +184,12 @@ RELAY_REGISTER_OP("vision.proposal") - **im_info**: 2-D with shape [batch, 3]. - **out**: 2-D with shape [batch * rpn_post_nms_top_n, 5]. )code" TVM_ADD_FILELINE) -.set_num_inputs(3) -.add_argument("cls_prob", "Tensor", "Score of how likely proposal is object") -.add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals") -.add_argument("im_info", "Tensor", "Image size and scale") -.set_support_level(5) -.add_type_rel("Proposal", ProposalRel); + .set_num_inputs(3) + .add_argument("cls_prob", "Tensor", "Score of how likely proposal is object") + .add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals") + .add_argument("im_info", "Tensor", "Image size and scale") + .set_support_level(5) + .add_type_rel("Proposal", ProposalRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index 5859677..e54473f 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -21,10 +21,12 @@ * \file yolo.cc * \brief Yolo related operators */ -#include -#include #include +#include +#include + #include + #include "../op_common.h" #include "../type_relations.h" @@ -34,15 +36,13 @@ namespace relay { TVM_REGISTER_NODE_TYPE(YoloReorgAttrs); /*! -* \brief YoloReorgRel Output type and shape relation evaluation function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. -*/ -bool YoloReorgRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief YoloReorgRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool YoloReorgRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -60,34 +60,29 @@ bool YoloReorgRel(const Array& types, return true; } -Expr MakeYoloReorg(Expr data, - Integer stride) { +Expr MakeYoloReorg(Expr data, Integer stride) { auto attrs = make_object(); attrs->stride = stride; static const Op& op = Op::Get("vision.yolo_reorg"); return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg") -.set_body_typed(MakeYoloReorg); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg").set_body_typed(MakeYoloReorg); RELAY_REGISTER_OP("vision.yolo_reorg") -.describe(R"doc("Yolo reorg operation. This layer reorganize the output. + .describe(R"doc("Yolo reorg operation. This layer reorganize the output. Its function is mostly shape transform.")doc" TVM_ADD_FILELINE) -.add_argument("data", "Tensor", "The input tensor.") -.set_num_inputs(1) -.set_support_level(5) -.set_attrs_type() -.add_type_rel("YoloReorg", YoloReorgRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* params = attrs.as(); - CHECK(params != nullptr); - return Array{ topi::vision::reorg(inputs[0], params->stride) }; -}); + .add_argument("data", "Tensor", "The input tensor.") + .set_num_inputs(1) + .set_support_level(5) + .set_attrs_type() + .add_type_rel("YoloReorg", YoloReorgRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* params = attrs.as(); + CHECK(params != nullptr); + return Array{topi::vision::reorg(inputs[0], params->stride)}; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc index d8752d8..b0dc3e4 100644 --- a/src/relay/qnn/op/add.cc +++ b/src/relay/qnn/op/add.cc @@ -23,6 +23,7 @@ */ #include #include + #include "op_common.h" namespace tvm { @@ -44,7 +45,6 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Get the input dtype and shape. QnnBinaryOpTensorType input_type(arg_types, 0); - // FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in // the start, we can insert requantize at the end if both input tensors have same qnn params. In // that case, we can first add the tensors, subtract the zero point, and requantize at the end. @@ -65,18 +65,14 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Q_c = Q_a' + Q_b' - zp_c // The add op is done in int32 precision. - - // Requantize LHS if necessary. Computes Q_a' - auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale, - args.lhs_zero_point, - args.output_scale, args.output_zero_point, - input_type.shape); + auto requantized_lhs = + RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Requantize RHS if necessary. Computes Q_b' - auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale, - args.rhs_zero_point, - args.output_scale, args.output_zero_point, - input_type.shape); + auto requantized_rhs = + RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Computes Q_a' + Q_b' auto output = Add(requantized_lhs, requantized_rhs); @@ -92,9 +88,9 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // QNN Addition operator. QNN_REGISTER_BINARY_OP("add") -.describe("Elementwise add with with broadcasting for quantized tensors.") -.set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize); + .describe("Elementwise add with with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 338e7a1..bda8cf8 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -22,13 +22,14 @@ * \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis. */ -#include #include #include #include +#include + #include "../../op/tensor/transform.h" -#include "../../transforms/pattern_util.h" #include "../../transforms/infer_layout_util.h" +#include "../../transforms/pattern_util.h" #include "../util.h" namespace tvm { @@ -42,10 +43,9 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at // Check the scale and zero point types const auto* input_scales_tuple = types[1].as(); if (input_scales_tuple == nullptr) { - throw Error( - ErrorBuilder() - << "qnn concatenate requires a tuple of scales as the second argument, found " - << PrettyPrint(types[1])); + throw Error(ErrorBuilder() + << "qnn concatenate requires a tuple of scales as the second argument, found " + << PrettyPrint(types[1])); } for (const auto& input_scale : input_scales_tuple->fields) { CHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx] @@ -53,10 +53,9 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at const auto* input_zero_points_tuple = types[2].as(); if (input_zero_points_tuple == nullptr) { - throw Error( - ErrorBuilder() - << "qnn concatenate requires a tuple of zero_points as the third argument, found " - << PrettyPrint(types[2])); + throw Error(ErrorBuilder() + << "qnn concatenate requires a tuple of zero_points as the third argument, found " + << PrettyPrint(types[2])); } for (const auto& input_zero_point : input_zero_points_tuple->fields) { CHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx] @@ -113,9 +112,8 @@ Expr MakeQnnConcatenate(Expr data, Expr input_scales, Expr input_zero_points, Ex auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("qnn.concatenate"); - return Call(op, - {data, input_scales, input_zero_points, output_scale, output_zero_point}, - Attrs(attrs), {}); + return Call(op, {data, input_scales, input_zero_points, output_scale, output_zero_point}, + Attrs(attrs), {}); } /* @@ -196,22 +194,23 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.concatenate") -.describe(R"code(Concatenate the quantized input tensors along the given axis. + .describe(R"code(Concatenate the quantized input tensors along the given axis. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(5) -.add_argument("data", "Tensor", "The tensor to concatenate.") -.add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.") -.add_argument("input_zero_points", "Tensor", "The quantization zero_points of the input tensors.") -.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") -.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") -.set_support_level(11) -.add_type_rel("QnnConcatenate", QnnConcatenateRel) -.set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize) -.set_attr("FInferCorrectLayout", QnnConcatenateLayout); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate") -.set_body_typed(MakeQnnConcatenate); + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "The tensor to concatenate.") + .add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.") + .add_argument("input_zero_points", "Tensor", + "The quantization zero_points of the input tensors.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("QnnConcatenate", QnnConcatenateRel) + .set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize) + .set_attr("FInferCorrectLayout", QnnConcatenateLayout); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate").set_body_typed(MakeQnnConcatenate); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 3718628..ae52a42 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -21,15 +21,16 @@ * \file src/relay/qnn/op/convolution.cc * \brief Property def of qnn convolution operator. */ -#include +#include "../../op/nn/convolution.h" + #include #include #include #include #include #include +#include -#include "../../op/nn/convolution.h" #include "../../transforms/pattern_util.h" #include "../util.h" @@ -88,9 +89,8 @@ Array> QnnConvInferCorrectLayout(const Attrs& attrs, } bool is_depthwise(const Conv2DAttrs* param) { - return param->channels.defined() && - tvm::tir::ExprDeepEqual()(param->channels, param->groups) && - param->groups != 1; + return param->channels.defined() && tvm::tir::ExprDeepEqual()(param->channels, param->groups) && + param->groups != 1; } // Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier @@ -201,8 +201,8 @@ Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2D auto pad_left_value = get_const_int(param->padding[1]); auto pad_bottom_value = get_const_int(param->padding[2]); auto pad_right_value = get_const_int(param->padding[3]); - bool do_pad = pad_top_value != 0 || pad_left_value != 0 || - pad_bottom_value != 0 || pad_right_value != 0; + bool do_pad = + pad_top_value != 0 || pad_left_value != 0 || pad_bottom_value != 0 || pad_right_value != 0; if (do_pad) { Array pad_n({0, 0}); Array pad_c({0, 0}); @@ -676,13 +676,12 @@ Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_ze attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("qnn.conv2d"); - return Call( - op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, - Attrs(attrs), {}); + return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, + Attrs(attrs), {}); } RELAY_REGISTER_OP("qnn.conv2d") -.describe(R"code(2D quantized convolution layer. + .describe(R"code(2D quantized convolution layer. This operator convolves quantized weight with quantized data. The scale of the output quantized tensor is the product of the weight_scale and input_scale of the input quantized tensors. The zero point of the output quantized tensor is @@ -694,18 +693,19 @@ operator to understand how to scale back the int32 output to (u)int8. - **out**: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(6) -.add_argument("data", "Tensor", "The quantized input data tensor.") -.add_argument("weight", "Tensor", "The quantized weight tensor.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") -.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.") -.set_support_level(11) -.add_type_rel("QnnConv2D", QnnConv2DRel) -.set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize) -.set_attr("FInferCorrectLayout", QnnConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(6) + .add_argument("data", "Tensor", "The quantized input data tensor.") + .add_argument("weight", "Tensor", "The quantized weight tensor.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") + .add_argument("weight_zero_point", "Tensor", + "The quantization zero_point of the weight tensor.") + .set_support_level(11) + .add_type_rel("QnnConv2D", QnnConv2DRel) + .set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize) + .set_attr("FInferCorrectLayout", QnnConvInferCorrectLayout); TVM_REGISTER_GLOBAL("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D); diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 7b9733c..464b3f9 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -26,6 +26,7 @@ #include #include #include + #include "../../op/nn/nn.h" #include "../../transforms/pattern_util.h" #include "../util.h" @@ -72,9 +73,8 @@ Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kern attrs->units = std::move(units); attrs->out_dtype = out_dtype; static const Op& op = Op::Get("qnn.dense"); - return Call( - op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, - Attrs(attrs), {}); + return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, + Attrs(attrs), {}); } Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, @@ -173,25 +173,25 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.dense") -.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. - **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)` - **weight**: quantized(int8, unit8) `(units, input_dim)` - **out**: quantized(int32) `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(6) -.add_argument("data", "quantized nD Tensor", "Input data.") -.add_argument("weight", "quantized 2D Tensor", "Weight matrix.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") -.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.") -.set_support_level(11) -.add_type_rel("QDense", QnnDenseRel) -.set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense") -.set_body_typed(MakeQuantizedDense); + .set_attrs_type() + .set_num_inputs(6) + .add_argument("data", "quantized nD Tensor", "Input data.") + .add_argument("weight", "quantized 2D Tensor", "Weight matrix.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") + .add_argument("weight_zero_point", "Tensor", + "The quantization zero_point of the weight tensor.") + .set_support_level(11) + .add_type_rel("QDense", QnnDenseRel) + .set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 69389a7..7c014d7 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -26,6 +26,7 @@ #include #include #include + #include "../../transforms/pattern_util.h" #include "../util.h" @@ -33,19 +34,16 @@ namespace tvm { namespace relay { namespace qnn { -bool DequantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); CHECK(data != nullptr); const auto input_dtype = data->dtype; - CHECK(input_dtype == DataType::Int(8) || - input_dtype == DataType::UInt(8) || + CHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) || input_dtype == DataType::Int(32)) - << "Input type should be one of the quantized types [unit8, int8, int32] but was " - << input_dtype; + << "Input type should be one of the quantized types [unit8, int8, int32] but was " + << input_dtype; // Check the types of scale and zero points. CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale @@ -83,20 +81,19 @@ Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.dequantize") -.describe(R"code(Dequantizes the input and produces float32 output. + .describe(R"code(Dequantizes the input and produces float32 output. The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point. - **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point )code" TVM_ADD_FILELINE) -.set_num_inputs(3) -.add_argument("data", "Tensor", "The tensor to dequantize.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.set_support_level(11) -.add_type_rel("Dequantize", DequantizeRel) -.set_attr("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); + .set_num_inputs(3) + .add_argument("data", "Tensor", "The tensor to dequantize.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .set_support_level(11) + .add_type_rel("Dequantize", DequantizeRel) + .set_attr("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); -TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize") -.set_body_typed(MakeDequantize); +TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize").set_body_typed(MakeDequantize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/mul.cc b/src/relay/qnn/op/mul.cc index 5f9251b..ec74b79 100644 --- a/src/relay/qnn/op/mul.cc +++ b/src/relay/qnn/op/mul.cc @@ -24,6 +24,7 @@ #include #include #include + #include "../../transforms/pattern_util.h" #include "../util.h" #include "op_common.h" @@ -85,21 +86,17 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array& new_args, auto new_input_zero_point = zero_scalar; // Requantize to get Q_c - output = Requantize(output, input_type.shape, - new_input_scale, - new_input_zero_point, - args.output_scale, - args.output_zero_point, - input_type.dtype); + output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point, + args.output_scale, args.output_zero_point, input_type.dtype); return output; } // QNN Multiplication operator. QNN_REGISTER_BINARY_OP("mul") -.describe("Elementwise mul with with broadcasting for quantized tensors.") -.set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnMulCanonicalize); + .describe("Elementwise mul with with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr("FTVMQnnCanonicalize", QnnMulCanonicalize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index f780f70..50fc0cd 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -28,7 +28,9 @@ #include #include #include + #include + #include "../../op/type_relations.h" #include "../../transforms/infer_layout_util.h" #include "../util.h" @@ -87,10 +89,9 @@ struct QnnBinaryOpArguments { */ struct QnnBinaryOpTensorType { DataType dtype; - Array shape; + Array shape; - explicit QnnBinaryOpTensorType(const Array& arg_types, - const int32_t arg_idx) { + explicit QnnBinaryOpTensorType(const Array& arg_types, const int32_t arg_idx) { CHECK_EQ(arg_types.size(), kNumQnnBinaryOpArgTypes); auto tensor_type = arg_types[arg_idx].as(); CHECK(tensor_type != nullptr); @@ -109,8 +110,7 @@ struct QnnBinaryOpTensorType { * \return New expression with target dtype and possibly lower * precision. */ -inline Expr ConvertDtype(const Expr& expr, - const DataType& target_dtype) { +inline Expr ConvertDtype(const Expr& expr, const DataType& target_dtype) { auto q_min = GetQmin(target_dtype); auto q_max = GetQmax(target_dtype); auto output = Clip(expr, q_min, q_max); @@ -134,18 +134,15 @@ inline Expr ConvertDtype(const Expr& expr, * it simply casts the given expression to Int32 as no requantization is * needed in this case. */ -inline Expr RequantizeOrUpcast(const Expr& expr, - const Expr& expr_scale, - const Expr& expr_zero_point, - const Expr& target_scale, - const Expr& target_zero_point, - const Array& expr_shape, +inline Expr RequantizeOrUpcast(const Expr& expr, const Expr& expr_scale, + const Expr& expr_zero_point, const Expr& target_scale, + const Expr& target_zero_point, const Array& expr_shape, const DataType& target_dtype = DataType::Int(32)) { auto result = expr; if (!IsEqualScalar(expr_scale, target_scale) || !IsEqualScalar(expr_zero_point, target_zero_point)) { - result = Requantize(expr, expr_shape, expr_scale, expr_zero_point, - target_scale, target_zero_point, target_dtype); + result = Requantize(expr, expr_shape, expr_scale, expr_zero_point, target_scale, + target_zero_point, target_dtype); } else { result = Cast(result, target_dtype); } @@ -153,27 +150,23 @@ inline Expr RequantizeOrUpcast(const Expr& expr, } /*! \brief Infer layout for QNN binary broadcast operators */ -inline Array > QnnBinaryBroadcastLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { +inline Array > QnnBinaryBroadcastLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // Use Relay Binary Broadcast Infer correct layout. auto layouts = BinaryBroadcastLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these // tensors can be treated as C. Layout channel_layout = Layout("C"); - Array input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout, + Array input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout, channel_layout, channel_layout, channel_layout, channel_layout}; Array output_layouts = layouts[1]; return {input_layouts, output_layouts}; } - -static inline bool QnnBroadcastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +static inline bool QnnBroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes); @@ -201,28 +194,28 @@ static inline bool QnnBroadcastRel(const Array& types, * * \param OpName the name of registry. */ -#define QNN_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ - Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ - static const Op& op = Op::Get("qnn." OpName); \ - return Call(op, {lhs, rhs, \ - lhs_scale, lhs_zero_point, \ - rhs_scale, rhs_zero_point, \ - output_scale, output_zero_point}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP("qnn." OpName) \ - .set_num_inputs(kNumQnnBinaryOpInputs) \ - .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ - .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ - .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \ - .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \ - .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ - .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ - .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ - .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ - .set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout) +#define QNN_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ + .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ + Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ + static const Op& op = Op::Get("qnn." OpName); \ + return Call(op, \ + {lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, \ + output_zero_point}, \ + Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP("qnn." OpName) \ + .set_num_inputs(kNumQnnBinaryOpInputs) \ + .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ + .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ + .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \ + .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \ + .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ + .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ + .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ + .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ + .set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout) } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 43ba4b6..28f0b89 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -26,6 +26,7 @@ #include #include #include + #include "../../transforms/pattern_util.h" #include "../util.h" @@ -35,24 +36,21 @@ namespace qnn { TVM_REGISTER_NODE_TYPE(QuantizeAttrs); -bool QuantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); CHECK(data != nullptr); const auto input_dtype = data->dtype; CHECK(input_dtype == DataType::Float(32)) - << "Input type should be one of float32 but was " << input_dtype; + << "Input type should be one of float32 but was " << input_dtype; const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1: axis; + axis = (axis == -1) ? data->shape.size() - 1 : axis; CHECK_LT(axis, static_cast(data->shape.size())) << "axis " << quantize_attrs->axis << " is out of range"; - CHECK_GE(axis, 0) - << "axis " << quantize_attrs->axis << " is out of range"; + CHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; // Check and assign types for scale and zero points. AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale @@ -130,7 +128,7 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.quantize") -.describe(R"code(Quantizes the input and produces quantized output. + .describe(R"code(Quantizes the input and produces quantized output. The input can be either float or quantized(int8, unit8). If the input is float, this op takes scale and zero point and quantize the float value to quantized output, in int8 or uint8 format. If the input is quantized value, @@ -140,17 +138,17 @@ scale and zero point. - **data**: Tensor of any shape to quantize. The input data can be of floating point or quantized. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "The tensor to quantize.") -.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") -.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") -.set_support_level(11) -.add_type_rel("Quantize", QuantizeRel) -.set_attr("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize") -.set_body_typed(MakeQuantize); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The tensor to quantize.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("Quantize", QuantizeRel) + .set_attr("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize").set_body_typed(MakeQuantize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index a2a4649..79cb08d 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -25,8 +25,9 @@ #include #include #include -#include "../../transforms/pattern_util.h" + #include "../../transforms/infer_layout_util.h" +#include "../../transforms/pattern_util.h" #include "../util.h" namespace tvm { @@ -68,7 +69,7 @@ Array> RequantizeInferCorrectLayout(const Attrs& attrs, for (auto iter_var : new_in_layouts[0]->axes) { const auto& layout_axis = LayoutAxis::Get(iter_var); const std::string& layout_dim = layout_axis.name(); - if (old_dim == layout_dim) { + if (old_dim == layout_dim) { new_axis = tvm::Integer(axis_index); } // Collect only the primal axis. @@ -249,18 +250,16 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* data = types[0].as(); CHECK(data != nullptr); const auto in_dtype = data->dtype; - CHECK(in_dtype == DataType::Int(8) || - in_dtype == DataType::UInt(8) || + CHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) || in_dtype == DataType::Int(32)) << "Input type should be one of [int8, uint8, int32] but was " << in_dtype; const RequantizeAttrs* requantize_attrs = attrs.as(); int axis = requantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1: axis; + axis = (axis == -1) ? data->shape.size() - 1 : axis; CHECK_LT(axis, static_cast(data->shape.size())) << "axis " << requantize_attrs->axis << " is out of range"; - CHECK_GE(axis, 0) - << "axis " << requantize_attrs->axis << " is out of range"; + CHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range"; // Check and assign types for scale and zero points. AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale @@ -272,8 +271,7 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const Array oshape = data->shape; // assign output type auto out_dtype = requantize_attrs->out_dtype; - CHECK(out_dtype == DataType::Int(8) || - out_dtype == DataType::UInt(8) || + CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || out_dtype == DataType::Int(32)) << "Output type should be one of [int8, uint8, int32] but was " << out_dtype; reporter->Assign(types[5], TensorType(oshape, out_dtype)); @@ -290,11 +288,11 @@ Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr out attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("qnn.requantize"); return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point}, - Attrs(attrs), {}); + Attrs(attrs), {}); } RELAY_REGISTER_OP("qnn.requantize") -.describe(R"code(Requantize operator. + .describe(R"code(Requantize operator. The requantize operator converts one quantized tensor to another quantized tensor. For the output tensor, we are provided with output scale and zero point. The computation looks like this @@ -302,20 +300,20 @@ point. The computation looks like this Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(5) -.add_argument("data", "Tensor", "The quantized input tensor.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") -.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") -.set_support_level(11) -.add_type_rel("Requantize", RequantizeRel) -.set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize) -.set_attr("FInferCorrectLayout", RequantizeInferCorrectLayout); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize") -.set_body_typed(MakeRequantize); + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "The quantized input tensor.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("Requantize", RequantizeRel) + .set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize) + .set_attr("FInferCorrectLayout", RequantizeInferCorrectLayout); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize").set_body_typed(MakeRequantize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/subtract.cc b/src/relay/qnn/op/subtract.cc index c6ce3e3..b928bd5 100644 --- a/src/relay/qnn/op/subtract.cc +++ b/src/relay/qnn/op/subtract.cc @@ -23,6 +23,7 @@ */ #include #include + #include "op_common.h" namespace tvm { @@ -36,8 +37,7 @@ namespace qnn { * \param arg_types The types of input and output. * \return The sequence of Relay ops for add op. */ -Expr QnnSubtractCanonicalize(const Attrs& attrs, - const Array& new_args, +Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array& new_args, const Array& arg_types) { // Get the args. QnnBinaryOpArguments args(new_args); @@ -66,17 +66,13 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, // The subtract op is done in int32 precision. // Requantize LHS if necessary. Computes Q_a' - auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale, - args.lhs_zero_point, - args.output_scale, - args.output_zero_point, - input_type.shape); + auto requantized_lhs = + RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Requantize RHS if necessary. Computes Q_b' - auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale, - args.rhs_zero_point, - args.output_scale, - args.output_zero_point, - input_type.shape); + auto requantized_rhs = + RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Computes Q_a' - Q_b' auto output = Subtract(requantized_lhs, requantized_rhs); @@ -93,10 +89,9 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, // QNN Subtraction operator. QNN_REGISTER_BINARY_OP("subtract") -.describe("Elementwise subtract with with broadcasting for quantized tensors.") -.set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnSubtractCanonicalize); - + .describe("Elementwise subtract with with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr("FTVMQnnCanonicalize", QnnSubtractCanonicalize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 91fe3ca..7171ded 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -23,6 +23,7 @@ */ #include "util.h" + #include "../transforms/pattern_util.h" namespace tvm { @@ -48,8 +49,7 @@ namespace qnn { * * Credit to TFLite reference implementation. */ -std::pair GetFixedPointMultiplierShift( - double double_multiplier) { +std::pair GetFixedPointMultiplierShift(double double_multiplier) { int32_t significand, exponent; if (double_multiplier == 0.) { significand = 0; @@ -84,8 +84,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& // 1) Calculating the integer multiplier and integer shift int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = - GetFixedPointMultiplierShift(multiplier); + std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(multiplier); int left_shift = shift > 0 ? shift : 0; int right_shift = shift > 0 ? 0 : -shift; @@ -119,8 +118,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); auto zero_t = Zeros(input_shape, hp_dtype); - round_scalar = - Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); } else { LOG(FATAL) << "Rounding mode " << rounding << " not supported."; } @@ -128,8 +126,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& tensor = Add(tensor, round_scalar); // 5) Simply right shift the result to get the final output. - tensor = - RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + tensor = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. return Cast(tensor, DataType::Int(32)); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index d4046ae..736b736 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -25,14 +25,15 @@ #ifndef TVM_RELAY_QNN_UTIL_H_ #define TVM_RELAY_QNN_UTIL_H_ -#include -#include #include #include +#include +#include + #include #include -#include #include +#include namespace tvm { namespace relay { @@ -46,8 +47,7 @@ static inline Array get_shape(const Type& type) { } static inline int32_t GetQmin(const DataType& dtype) { - CHECK_LE(dtype.bits(), 32) - << "QNN ops support int32 or lower precision"; + CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { auto* min_value = tir::as_const_int(tvm::min_value(dtype)); CHECK(min_value != nullptr); @@ -59,8 +59,7 @@ static inline int32_t GetQmin(const DataType& dtype) { } static inline int32_t GetQmax(const DataType& dtype) { - CHECK_LE(dtype.bits(), 32) - << "QNN ops support int32 or lower precision"; + CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { auto* max_value = tir::as_const_int(tvm::max_value(dtype)); CHECK(max_value != nullptr); @@ -171,8 +170,7 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons const TypeReporter& reporter) { // Scale/Zero_points can be either const scalar or a vector with C axis num elems. const auto* tensor_type = expr_type.as(); - CHECK(tensor_type) << "Can assign type to Tensor type only. But got " - << AsText(expr_type, false); + CHECK(tensor_type) << "Can assign type to Tensor type only. But got " << AsText(expr_type, false); const auto tensor_dtype = tensor_type->dtype; CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype; if (tensor_type->shape.size() != 0) { diff --git a/src/relay/quantize/annotate.cc b/src/relay/quantize/annotate.cc index 4492ed5..8ae7df9 100644 --- a/src/relay/quantize/annotate.cc +++ b/src/relay/quantize/annotate.cc @@ -24,8 +24,9 @@ * \brief Annotating the graph with simulated quantize operators. */ -#include #include +#include + #include "./quantize.h" namespace tvm { @@ -63,10 +64,7 @@ class QAnnotateExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode); }; - -Expr QAnnotateExprNode::Realize() const { - return expr; -} +Expr QAnnotateExprNode::Realize() const { return expr; } QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) { auto rnode = make_object(); @@ -75,12 +73,10 @@ QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) { data_ = std::move(rnode); } -TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr") -.set_body_typed([](Expr expr, int kind) { +TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr").set_body_typed([](Expr expr, int kind) { return QAnnotateExpr(expr, static_cast(kind)); }); - Pass QuantizeAnnotate() { // TODO(tvm-teams): since partition has added cast_hint in different // branches, try to remove this in the future. @@ -88,8 +84,7 @@ Pass QuantizeAnnotate() { if (e->IsInstance()) { const auto* n = e.as(); CHECK(n); - const PackedFunc* f = - runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); + const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); Expr ret = (*f)(n->expr, static_cast(kQInput)); return static_cast(QAnnotateExpr(ret, kQInput)); } @@ -97,23 +92,18 @@ Pass QuantizeAnnotate() { }; runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); - auto new_params = func->params; - for (const auto& x : FreeVars(func)) { - new_params.push_back(x); - } - return Function(new_params, - func->body, - func->ret_type, - func->type_params, - func->attrs); - }; + [=](Function f, IRModule m, PassContext pc) { + auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); + auto new_params = func->params; + for (const auto& x : FreeVars(func)) { + new_params.push_back(x); + } + return Function(new_params, func->body, func->ret_type, func->type_params, func->attrs); + }; return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); } -TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate") -.set_body_typed(QuantizeAnnotate); +TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate").set_body_typed(QuantizeAnnotate); TVM_REGISTER_NODE_TYPE(QAnnotateExprNode); diff --git a/src/relay/quantize/calibrate.cc b/src/relay/quantize/calibrate.cc index 7b1e909..ea42a19 100644 --- a/src/relay/quantize/calibrate.cc +++ b/src/relay/quantize/calibrate.cc @@ -26,7 +26,9 @@ #include #include #include + #include + #include "./quantize.h" namespace tvm { @@ -65,8 +67,8 @@ static std::vector SmoothDistribution(const std::vector& p, } static float ComputeEntropy(float* p, float* q, size_t size) { - float p_sum = std::accumulate(p, p+size, 0.f); - float q_sum = std::accumulate(q, q+size, 0.f); + float p_sum = std::accumulate(p, p + size, 0.f); + float q_sum = std::accumulate(q, q + size, 0.f); float ret = 0; for (size_t i = 0; i < size; i++) { CHECK(p[i] > 0 && q[i] > 0); @@ -77,9 +79,8 @@ static float ComputeEntropy(float* p, float* q, size_t size) { return ret; } -float MinimizeKL(const std::vector& hist, - const std::vector& hist_edges, - int num_bins, int num_quantized_bins) { +float MinimizeKL(const std::vector& hist, const std::vector& hist_edges, int num_bins, + int num_quantized_bins) { const int zero_bin_idx = num_bins / 2; const int num_half_quantized_bins = num_quantized_bins / 2; std::vector thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f); @@ -137,9 +138,9 @@ float MinimizeKL(const std::vector& hist, divergence[i - num_half_quantized_bins] = ComputeEntropy(p.data(), q.data(), p.size()); } } - auto min_divergence_idx = std::distance(divergence.begin(), - std::min_element(divergence.begin(), divergence.end())); - return thresholds[min_divergence_idx];; + auto min_divergence_idx = + std::distance(divergence.begin(), std::min_element(divergence.begin(), divergence.end())); + return thresholds[min_divergence_idx]; } class StatsCollector : private ExprMutator { @@ -152,7 +153,7 @@ class StatsCollector : private ExprMutator { CHECK(func) << "Input shoule be Function"; Expr new_body = Tuple(std::move(profile_data_)); return Function(FreeVars(new_body), new_body, NullValue(), func->type_params, - func->attrs); + func->attrs); } private: @@ -167,7 +168,7 @@ class StatsCollector : private ExprMutator { auto attrs = new_call->attrs.as(); // rewrite the annotation auto new_attrs = make_object(); - const Expr& quantize_input = new_call->args[0]; // expression being quantized + const Expr& quantize_input = new_call->args[0]; // expression being quantized auto placeholder = MakeConstantScalar(DataType::Float(32), 0.); // unused argument Array new_args{quantize_input, placeholder, placeholder, placeholder}; new_attrs->kind = QAnnotateKind::kQIdentity; @@ -198,24 +199,20 @@ class StatsCollector : private ExprMutator { * \param expr The simulation graph after annotation. * \return The profile graph. */ -Expr CreateStatsCollector(const Expr& expr) { - return StatsCollector().Collect(expr); -} - -TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector") -.set_body_typed(CreateStatsCollector); +Expr CreateStatsCollector(const Expr& expr) { return StatsCollector().Collect(expr); } +TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector").set_body_typed(CreateStatsCollector); TVM_REGISTER_GLOBAL("relay._quantize.FindScaleByKLMinimization") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int* hist_ptr = static_cast(static_cast(args[0])); - float* hist_edges_ptr = static_cast(static_cast(args[1])); - int num_bins = args[2]; - int num_quantized_bins = args[3]; - std::vector hist(hist_ptr, hist_ptr + num_bins); - std::vector hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1); - ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + int* hist_ptr = static_cast(static_cast(args[0])); + float* hist_edges_ptr = static_cast(static_cast(args[1])); + int num_bins = args[2]; + int num_quantized_bins = args[3]; + std::vector hist(hist_ptr, hist_ptr + num_bins); + std::vector hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1); + ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins); + }); } // namespace quantize } // namespace relay diff --git a/src/relay/quantize/partition.cc b/src/relay/quantize/partition.cc index 39de0bc..14b420d 100644 --- a/src/relay/quantize/partition.cc +++ b/src/relay/quantize/partition.cc @@ -25,6 +25,7 @@ */ #include + #include "../transforms/pattern_util.h" #include "./quantize.h" @@ -34,16 +35,13 @@ namespace quantize { using namespace relay::transform; - class QPartitionExpr; class QPartitionExprNode : public TempExprNode { public: /*! \brief The original expression */ Expr expr; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("expr", &expr); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } Expr Realize() const final; @@ -62,7 +60,6 @@ class QPartitionExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode); }; - Expr QPartitionExprNode::Realize() const { // insert cast hint and stop fusion const QConfig& cfg = QConfig::Current(); @@ -76,23 +73,20 @@ QPartitionExpr::QPartitionExpr(Expr expr) { data_ = std::move(rnode); } -TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr") -.set_body_typed([](Expr expr) { +TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr").set_body_typed([](Expr expr) { return QPartitionExpr(expr); }); Pass QuantizePartition() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - auto ret = Downcast( - ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); - return ret; - }; + [=](Function f, IRModule m, PassContext pc) { + auto ret = Downcast(ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); + return ret; + }; return CreateFunctionPass(pass_func, 1, "QuantizePartition", {}); } -TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition") -.set_body_typed(QuantizePartition); +TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition").set_body_typed(QuantizePartition); TVM_REGISTER_NODE_TYPE(QPartitionExprNode); diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index 431e18b..d197458 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -23,12 +23,13 @@ * \brief transform a graph to a low-bit graph * for compression and acceleration. */ +#include "./quantize.h" + #include #include #include -#include -#include "./quantize.h" +#include namespace tvm { namespace relay { @@ -36,9 +37,7 @@ namespace quantize { TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); -bool SimulatedQuantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SimulatedQuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 5); const auto param = attrs.as(); @@ -48,36 +47,34 @@ bool SimulatedQuantizeRel(const Array& types, CHECK(data != nullptr); CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; - reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale - reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min - reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max - reporter->Assign(types[4], types[0]); // output + reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale + reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min + reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max + reporter->Assign(types[4], types[0]); // output return true; } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") -.describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) -.set_num_inputs(4) -.add_argument("data", "Tensor", "The input data.") -.add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") -.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") -.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") -.set_attrs_type() -.set_support_level(11) -.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); + .describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) + .set_num_inputs(4) + .add_argument("data", "Tensor", "The input data.") + .add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") + .add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") + .add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") + .set_attrs_type() + .set_support_level(11) + .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize") -.set_body_typed( - [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, - int kind, bool sign, std::string rounding) { - auto attrs = make_object(); - attrs->kind = kind; - attrs->sign = sign; - attrs->rounding = rounding; - static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); - return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); - }); - + .set_body_typed([](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, + std::string rounding) { + auto attrs = make_object(); + attrs->kind = kind; + attrs->sign = sign; + attrs->rounding = rounding; + static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); + return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); + }); /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMQConfigThreadLocalEntry { @@ -87,26 +84,24 @@ struct TVMQConfigThreadLocalEntry { /*! \brief The current build config context */ std::stack context_stack; - TVMQConfigThreadLocalEntry() : - default_config(make_object()) { - } + TVMQConfigThreadLocalEntry() : default_config(make_object()) {} }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore TVMQConfigThreadLocalStore; void QConfig::EnterQConfigScope(const QConfig& build_config) { - TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.push(build_config); } void QConfig::ExitQConfigScope() { - TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.pop(); } QConfig& QConfig::Current() { - TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } @@ -117,33 +112,31 @@ QConfig& QConfig::Current() { TVM_REGISTER_NODE_TYPE(QConfigNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* op = static_cast(ref.get()); - p->stream << "qconfig("; - p->stream << "nbit_input=" << op->nbit_input << ", "; - p->stream << "nbit_weight=" << op->nbit_weight << ", "; - p->stream << "nbit_activation=" << op->nbit_activation << ", "; - p->stream << "calibrate_mode=" << op->calibrate_mode << ", "; - p->stream << "global_scale=" << op->global_scale << ", "; - p->stream << "weight_scale=" << op->weight_scale << ", "; - p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; - p->stream << "do_simulation==" << op->do_simulation << ", "; - p->stream << "round_for_shift==" << op->round_for_shift << ", "; - p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", "; - p->stream << "rounding==" << op->rounding; - p->stream << ")"; -}); - -TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig") -.set_body_typed([]() -> QConfig { + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* op = static_cast(ref.get()); + p->stream << "qconfig("; + p->stream << "nbit_input=" << op->nbit_input << ", "; + p->stream << "nbit_weight=" << op->nbit_weight << ", "; + p->stream << "nbit_activation=" << op->nbit_activation << ", "; + p->stream << "calibrate_mode=" << op->calibrate_mode << ", "; + p->stream << "global_scale=" << op->global_scale << ", "; + p->stream << "weight_scale=" << op->weight_scale << ", "; + p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; + p->stream << "do_simulation==" << op->do_simulation << ", "; + p->stream << "round_for_shift==" << op->round_for_shift << ", "; + p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; + p->stream << "rounding==" << op->rounding; + p->stream << ")"; + }); + +TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig").set_body_typed([]() -> QConfig { return QConfig::Current(); }); TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope") -.set_body_typed(QConfig::EnterQConfigScope); + .set_body_typed(QConfig::EnterQConfigScope); -TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope") -.set_body_typed(QConfig::ExitQConfigScope); +TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope").set_body_typed(QConfig::ExitQConfigScope); } // namespace quantize } // namespace relay diff --git a/src/relay/quantize/quantize.h b/src/relay/quantize/quantize.h index 563f47f..a883cb1 100644 --- a/src/relay/quantize/quantize.h +++ b/src/relay/quantize/quantize.h @@ -24,9 +24,11 @@ #ifndef TVM_RELAY_QUANTIZE_QUANTIZE_H_ #define TVM_RELAY_QUANTIZE_QUANTIZE_H_ -#include #include +#include + #include + #include "../transforms/pattern_util.h" namespace tvm { @@ -34,12 +36,7 @@ namespace relay { namespace quantize { /*! \brief Kind of annotate field */ -enum QAnnotateKind : int { - kQIdentity = 0, - kQInput = 1, - kQWeight = 2, - kQActivation = 3 -}; +enum QAnnotateKind : int { kQIdentity = 0, kQInput = 1, kQWeight = 2, kQActivation = 3 }; /*! \brief Attribute for simulated quantize operator */ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { @@ -48,20 +45,17 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { std::string rounding; TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { - TVM_ATTR_FIELD(kind) - .describe("kind of field, hint for nbit/dtype configuration."); - TVM_ATTR_FIELD(sign).set_default(true) - .describe("whether to use signed data type."); - TVM_ATTR_FIELD(rounding).set_default("round") - .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); + TVM_ATTR_FIELD(kind).describe("kind of field, hint for nbit/dtype configuration."); + TVM_ATTR_FIELD(sign).set_default(true).describe("whether to use signed data type."); + TVM_ATTR_FIELD(rounding).set_default("round").describe( + "rounding mode. Can be 'floor', 'ceil', 'round'"); } }; - class QConfig; /*! -* \brief Container for build configuration options -*/ + * \brief Container for build configuration options + */ class QConfigNode : public Object { public: int nbit_input = 8; @@ -103,20 +97,16 @@ class QConfigNode : public Object { }; /*! -* \brief Container for build configuration options -*/ + * \brief Container for build configuration options + */ class QConfig : public ObjectRef { public: QConfig() {} explicit QConfig(ObjectPtr n) : ObjectRef(n) {} - const QConfigNode* operator->() const { - return static_cast(get()); - } + const QConfigNode* operator->() const { return static_cast(get()); } - QConfigNode* operator->() { - return static_cast(get_mutable()); - } + QConfigNode* operator->() { return static_cast(get_mutable()); } /*! * \brief Push a new BuildConfig context onto the thread local stack. @@ -150,14 +140,10 @@ struct QConfigContext { * context. When the BuildConfigContext is destructed, the previous context is restored. * \param build_config The BuildConfig to set as the new current context. */ - explicit QConfigContext(const QConfig& qconfig) { - QConfig::EnterQConfigScope(qconfig); - } + explicit QConfigContext(const QConfig& qconfig) { QConfig::EnterQConfigScope(qconfig); } /*! \brief Destructor. Pops the context off the thread local stack. */ - ~QConfigContext() { - QConfig::ExitQConfigScope(); - } + ~QConfigContext() { QConfig::ExitQConfigScope(); } }; } // namespace quantize diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 6d56e19..49d1e52 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -25,12 +25,13 @@ * graph. */ -#include #include #include -#include "./quantize.h" -#include "../transforms/pattern_util.h" +#include + #include "../qnn/util.h" +#include "../transforms/pattern_util.h" +#include "./quantize.h" namespace tvm { namespace relay { @@ -53,7 +54,6 @@ class QRealizeExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(QRealizeExpr, TempExpr, QRealizeExprNode); }; - class QRealizeIntExprNode : public QRealizeExprNode { public: Expr dom_scale; @@ -67,7 +67,7 @@ class QRealizeIntExprNode : public QRealizeExprNode { Expr Realize() const final; - static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; + static constexpr const char* _type_key = "relay.quantize.QRealizeIntExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(QRealizeIntExprNode, QRealizeExprNode); }; @@ -78,7 +78,6 @@ class QRealizeIntExpr : public QRealizeExpr { TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode); }; - Expr QRealizeIntExprNode::Realize() const { Expr data = this->data; // dequantize @@ -95,15 +94,13 @@ QRealizeIntExpr::QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype) { data_ = std::move(n); } - inline Expr ForwardOp(const Call& ref_call, const Array& args) { return Call(ref_call->op, args, ref_call->attrs, ref_call->type_args); } - /* calculate `data * s1 / s2`, use shift if possible */ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, - const Array &data_shape) { + const Array& data_shape) { const QConfig& cfg = QConfig::Current(); // here we assume the dtype of data is dtype activation if (s1 == s2) return data; @@ -112,8 +109,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, float shift_factor = std::log2(factor); CHECK_GT(shift_factor, 0); if (static_cast(shift_factor) == shift_factor) { - return LeftShift(data, MakeConstantScalar(dtype, - static_cast(shift_factor))); + return LeftShift(data, MakeConstantScalar(dtype, static_cast(shift_factor))); } else if (static_cast(factor) == factor) { return Multiply(data, MakeConstantScalar(dtype, factor)); } else { @@ -122,9 +118,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, } } -Expr QuantizeRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); // do not handle data type cast const auto param = ref_call->attrs.as(); @@ -158,22 +152,20 @@ Expr QuantizeRealize(const Call& ref_call, // use right shift if (cfg->round_for_shift) { float round_bias = std::pow(2.0, shift_nbit - 1); - data = Add(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(round_bias))); + data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias))); } - data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); + data = RightShift(data, + MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); } else { - data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); + data = LeftShift(data, + MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); } data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExpr(data, dom_scale, n->dtype); } else { data = Cast(data, DataType::Int(64)); data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm, - ref_call->type_as()->shape, - cfg->rounding); + ref_call->type_as()->shape, cfg->rounding); data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype); return QRealizeIntExpr(data, dom_scale, n->dtype); } @@ -195,12 +187,9 @@ Expr FoldConstantOpt(const Expr& expr) { } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") -.set_attr("FQRealizeRewrite", QuantizeRealize); - + .set_attr("FQRealizeRewrite", QuantizeRealize); -Expr Conv2dRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr Conv2dRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() && !new_args[1]->IsInstance()) { @@ -223,20 +212,15 @@ Expr Conv2dRealize(const Call& ref_call, DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; - Expr ret = Call(ref_call->op, - {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FQRealizeRewrite", Conv2dRealize); - +RELAY_REGISTER_OP("nn.conv2d").set_attr("FQRealizeRewrite", Conv2dRealize); -Expr DenseRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr DenseRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() || !new_args[1]->IsInstance()) { @@ -257,20 +241,15 @@ Expr DenseRealize(const Call& ref_call, DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; - Expr ret = Call(ref_call->op, - {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } -RELAY_REGISTER_OP("nn.dense") -.set_attr("FQRealizeRewrite", DenseRealize); +RELAY_REGISTER_OP("nn.dense").set_attr("FQRealizeRewrite", DenseRealize); - -Expr MulRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr MulRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { @@ -297,9 +276,7 @@ Expr MulRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("multiply") -.set_attr("FQRealizeRewrite", MulRealize); - +RELAY_REGISTER_OP("multiply").set_attr("FQRealizeRewrite", MulRealize); float ChooseDomScale(const std::vector& nptrs) { if (nptrs.size() == 2) { @@ -316,7 +293,6 @@ float ChooseDomScale(const std::vector& nptrs) { } } - /* \brief Unify the dom scale of arguments */ Array UnifyDTypeScale(const Array& ref_args, const Array& args, DataType* dtype_ptr, Expr* scale_ptr) { @@ -366,9 +342,7 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args return ret; } -Expr AddRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr AddRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { DataType dtype; @@ -382,12 +356,9 @@ Expr AddRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("add") -.set_attr("FQRealizeRewrite", AddRealize); +RELAY_REGISTER_OP("add").set_attr("FQRealizeRewrite", AddRealize); -Expr ClipRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr ClipRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { const auto ref_attrs = ref_call->attrs.as(); @@ -396,21 +367,16 @@ Expr ClipRealize(const Call& ref_call, attrs->a_min = ref_attrs->a_min / dom_scale; attrs->a_max = ref_attrs->a_max / dom_scale; - Expr ret = Call(ref_call->op, - {n->data}, Attrs(attrs), ref_call->type_args); + Expr ret = Call(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args); return QRealizeIntExpr(ret, n->dom_scale, n->dtype); } CHECK(!new_args[0]->IsInstance()); return Expr(nullptr); } -RELAY_REGISTER_OP("clip") -.set_attr("FQRealizeRewrite", ClipRealize); - +RELAY_REGISTER_OP("clip").set_attr("FQRealizeRewrite", ClipRealize); -Expr ConcatenateRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr ConcatenateRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); CHECK_EQ(ref_call->args.size(), 1); @@ -435,14 +401,10 @@ Expr ConcatenateRealize(const Call& ref_call, } } -RELAY_REGISTER_OP("concatenate") -.set_attr("FQRealizeRewrite", ConcatenateRealize); - +RELAY_REGISTER_OP("concatenate").set_attr("FQRealizeRewrite", ConcatenateRealize); /* \brief forward the original operator */ -Expr IdentityRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr IdentityRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = ForwardOp(ref_call, {n->data}); @@ -452,18 +414,15 @@ Expr IdentityRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("nn.relu") -.set_attr("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("nn.relu").set_attr("FQRealizeRewrite", IdentityRealize); -RELAY_REGISTER_OP("strided_slice") -.set_attr("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("strided_slice").set_attr("FQRealizeRewrite", IdentityRealize); RELAY_REGISTER_OP("annotation.stop_fusion") -.set_attr("FQRealizeRewrite", IdentityRealize); + .set_attr("FQRealizeRewrite", IdentityRealize); /* \brief for unary operators which requantize its input to dtype_nbit */ -Expr CastDtypeInputRealize(const Call& ref_call, - const Array& new_args, +Expr CastDtypeInputRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); @@ -477,12 +436,9 @@ Expr CastDtypeInputRealize(const Call& ref_call, } RELAY_REGISTER_OP("nn.max_pool2d") -.set_attr("FQRealizeRewrite", CastDtypeInputRealize); - + .set_attr("FQRealizeRewrite", CastDtypeInputRealize); -Expr AvgPoolRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr AvgPoolRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -497,15 +453,12 @@ Expr AvgPoolRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("nn.avg_pool2d") -.set_attr("FQRealizeRewrite", AvgPoolRealize); +RELAY_REGISTER_OP("nn.avg_pool2d").set_attr("FQRealizeRewrite", AvgPoolRealize); RELAY_REGISTER_OP("nn.global_avg_pool2d") -.set_attr("FQRealizeRewrite", AvgPoolRealize); + .set_attr("FQRealizeRewrite", AvgPoolRealize); -Expr CastHintRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr CastHintRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const auto param = ref_call->attrs.as(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -517,19 +470,17 @@ Expr CastHintRealize(const Call& ref_call, } RELAY_REGISTER_OP("annotation.cast_hint") -.set_attr("FQRealizeRewrite", CastHintRealize); + .set_attr("FQRealizeRewrite", CastHintRealize); Pass QuantizeRealizePass() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast( - ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); + }; return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {}); } -TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize") -.set_body_typed(QuantizeRealizePass); +TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize").set_body_typed(QuantizeRealizePass); } // namespace quantize } // namespace relay diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index aab0b3a..7b91e8c 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -24,20 +24,20 @@ custom layouts or other general weight pre-transformation. */ #include -#include -#include #include +#include #include #include -#include -#include + #include #include -#include +#include #include +#include +#include -#include "transform_layout.h" #include "pattern_util.h" +#include "transform_layout.h" namespace tvm { namespace relay { @@ -85,8 +85,8 @@ class AlterTransformMemorizer : public TransformMemorizer { } // TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes. // Probably we need to disable the AlterOpLayout when compiling dynamic models. - Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos, - ref_call->checked_type()); + Expr altered_value = + falter_layout[op](ref_call->attrs, new_args, tinfos, ref_call->checked_type()); if (altered_value.defined()) { new_e = altered_value; modified = true; @@ -122,14 +122,13 @@ namespace transform { Pass AlterOpLayout() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::alter_op_layout::AlterOpLayout(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::alter_op_layout::AlterOpLayout(f)); + }; return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout") -.set_body_typed(AlterOpLayout); +TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout").set_body_typed(AlterOpLayout); } // namespace transform diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 4caac04..3635947 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -299,8 +299,7 @@ Pass AnnotateTarget(const Array& targets) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::annotate_target::AnnotateTarget(f, targets)); }; - auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", - {"InferType"}); + auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); } diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index ebcbd57..f478107 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -22,9 +22,10 @@ * \brief Canonicalize cast expressions to make operator fusion more efficient. */ #include -#include #include +#include #include + #include "pass_util.h" #include "pattern_util.h" @@ -112,8 +113,7 @@ class CastCanonicalizer : public ExprMutator { const CallNode* new_call = new_expr.as(); CHECK(new_call); CHECK(new_call->op == cast_op_); - return Call(new_call->op, new_call->args, new_call->attrs, - new_call->type_args); + return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args); } } } @@ -122,22 +122,19 @@ class CastCanonicalizer : public ExprMutator { } }; -Expr CanonicalizeCast(const Expr& e) { - return CastCanonicalizer().Mutate(e); -} +Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); } namespace transform { Pass CanonicalizeCast() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CanonicalizeCast(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeCast(f)); + }; return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast") -.set_body_typed(CanonicalizeCast); +TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast").set_body_typed(CanonicalizeCast); } // namespace transform diff --git a/src/relay/transforms/canonicalize_ops.cc b/src/relay/transforms/canonicalize_ops.cc index 1d3111b..fec757e 100644 --- a/src/relay/transforms/canonicalize_ops.cc +++ b/src/relay/transforms/canonicalize_ops.cc @@ -23,10 +23,11 @@ This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) */ #include +#include #include #include -#include #include + #include "pattern_util.h" namespace tvm { @@ -71,14 +72,13 @@ namespace transform { Pass CanonicalizeOps() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CanonicalizeOps(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeOps(f)); + }; return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps") -.set_body_typed(CanonicalizeOps); +TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps").set_body_typed(CanonicalizeOps); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index af6b135..1990414 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -33,15 +33,17 @@ */ #include -#include #include #include +#include #include #include + #include #include -#include "./expr_subst.h" + #include "./combine_parallel_op.h" +#include "./expr_subst.h" #include "pattern_util.h" namespace tvm { @@ -50,13 +52,10 @@ namespace relay { class ParallelConv2DCombiner : public ParallelOpCombiner { public: explicit ParallelConv2DCombiner(uint64_t min_num_branches) - : ParallelOpCombiner("nn.conv2d", min_num_branches) { - } + : ParallelOpCombiner("nn.conv2d", min_num_branches) {} protected: - bool IsSupportedOp(const CallNode* n) { - return n->attrs.as()->groups == 1; - } + bool IsSupportedOp(const CallNode* n) { return n->attrs.as()->groups == 1; } bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { StructuralEqual eq; @@ -67,10 +66,10 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { CHECK(attrs_b); const auto* tweight_a = a->args[1]->type_as(); const auto* tweight_b = b->args[1]->type_as(); - const auto shape_a = tir::BijectiveLayout( - Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape); - const auto shape_b = tir::BijectiveLayout( - Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape); + const auto shape_a = + tir::BijectiveLayout(Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape); + const auto shape_b = + tir::BijectiveLayout(Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape); return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && @@ -118,8 +117,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { auto toutput_a = a->type_as(); auto toutput_b = b->type_as(); - if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) - return false; + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; // Position of the 'C' dimension in the argument size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size(); @@ -132,15 +130,12 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { for (size_t i = 0; i < ta->shape.size(); i++) { if (i == arg_channel_pos) continue; - if (!eq(ta->shape[i], tb->shape[i])) - return false; + if (!eq(ta->shape[i], tb->shape[i])) return false; } return true; } - Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, + Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; @@ -166,9 +161,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { return Call(call->op, new_args, call->attrs, {}); } - void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int64_t index = 0; for (const auto& branch : branches) { @@ -217,14 +210,13 @@ namespace transform { Pass CombineParallelConv2D(uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelConv2D(f, min_num_branches)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CombineParallelConv2D(f, min_num_branches)); + }; return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D") -.set_body_typed(CombineParallelConv2D); +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D").set_body_typed(CombineParallelConv2D); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 1278020..8613dbe 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -32,16 +32,18 @@ */ #include -#include #include #include +#include #include #include + #include #include + +#include "./combine_parallel_op_batch.h" #include "./expr_subst.h" #include "pattern_util.h" -#include "./combine_parallel_op_batch.h" namespace tvm { namespace relay { @@ -49,8 +51,7 @@ namespace relay { class ParallelDenseCombiner : public ParallelOpBatchCombiner { public: explicit ParallelDenseCombiner(uint64_t min_num_branches) - : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) { - } + : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {} protected: virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { @@ -63,8 +64,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner { const auto* weight_b = b->args[1]->type_as(); return eq(attrs_a->out_dtype, attrs_b->out_dtype) && - eq(weight_a->shape[0], weight_b->shape[0]) && - eq(weight_a->shape[1], weight_b->shape[1]); + eq(weight_a->shape[0], weight_b->shape[0]) && eq(weight_a->shape[1], weight_b->shape[1]); } }; @@ -77,14 +77,13 @@ namespace transform { Pass CombineParallelDense(uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelDense(f, min_num_branches)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CombineParallelDense(f, min_num_branches)); + }; return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense") -.set_body_typed(CombineParallelDense); +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense").set_body_typed(CombineParallelDense); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_op.cc b/src/relay/transforms/combine_parallel_op.cc index a7f7af2..854a1ae 100644 --- a/src/relay/transforms/combine_parallel_op.cc +++ b/src/relay/transforms/combine_parallel_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -23,33 +23,33 @@ * \brief Abstract class to combine parallel ops and their successive element-wise ops. */ +#include "combine_parallel_op.h" + #include #include -#include #include #include +#include #include #include #include + #include -#include #include #include +#include + #include "expr_subst.h" #include "pattern_util.h" -#include "combine_parallel_op.h" - namespace tvm { namespace relay { -BranchGroupFinder::BranchGroupFinder(const Op& op, - FIsSupportedOp fis_supported_op, +BranchGroupFinder::BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops) - : cached_op_(op), - fis_supported_op_(fis_supported_op), - fare_compatible_ops_(fare_compatible_ops) { -} + : cached_op_(op), + fis_supported_op_(fis_supported_op), + fare_compatible_ops_(fare_compatible_ops) {} std::vector BranchGroupFinder::Find(const Expr& expr) { this->VisitExpr(expr); @@ -111,18 +111,13 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) { } ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) - : cached_op_(Op::Get(op_name)), - min_num_branches_(min_num_branches) { -} + : cached_op_(Op::Get(op_name)), min_num_branches_(min_num_branches) {} Expr ParallelOpCombiner::Combine(const Expr& expr) { - auto groups = BranchGroupFinder(cached_op_, - [&](const CallNode* n) { - return IsSupportedOp(n); - }, - [&](const CallNode* a, const CallNode* b) { - return CanOpsBeCombined(a, b); - }).Find(expr); + auto groups = BranchGroupFinder( + cached_op_, [&](const CallNode* n) { return IsSupportedOp(n); }, + [&](const CallNode* a, const CallNode* b) { return CanOpsBeCombined(a, b); }) + .Find(expr); for (const Group& group : groups) { if (group.size() < min_num_branches_) { continue; @@ -135,10 +130,9 @@ Expr ParallelOpCombiner::Combine(const Expr& expr) { void ParallelOpCombiner::CombineBranches(const Group& branches) { Call combined = MakeCombinedOp(branches); auto it = std::min_element(branches.begin(), branches.end(), - [](const Branch& branch_a, - const Branch& branch_b) { - return branch_a.size() < branch_b.size(); - }); + [](const Branch& branch_a, const Branch& branch_b) { + return branch_a.size() < branch_b.size(); + }); size_t depth = it->size(); size_t i; // starting from 1 to skip the op @@ -155,32 +149,30 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) { } bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { - const CallNode* call = branches[0][depth]; - tvm::StructuralEqual attrs_equal; - // check if all branches in current depth can be combined - for (auto it = branches.begin() + 1; it != branches.end(); it++) { - const Branch& branch = *it; - if (!branch[depth]->op.same_as(call->op) || - !attrs_equal(branch[depth]->attrs, call->attrs) || - branch[depth]->args.size() != call->args.size()) { - return false; - } + const CallNode* call = branches[0][depth]; + tvm::StructuralEqual attrs_equal; + // check if all branches in current depth can be combined + for (auto it = branches.begin() + 1; it != branches.end(); it++) { + const Branch& branch = *it; + if (!branch[depth]->op.same_as(call->op) || !attrs_equal(branch[depth]->attrs, call->attrs) || + branch[depth]->args.size() != call->args.size()) { + return false; + } - if (branch[depth]->args[parent_index].get() != branch[depth - 1]) - return false; + if (branch[depth]->args[parent_index].get() != branch[depth - 1]) return false; - // Check args - for (size_t i = 0; i < call->args.size(); i++) { - if (i == parent_index) continue; + // Check args + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) continue; - if (!IsArgCompatible(call, branch[depth], i) || - !attrs_equal(call->attrs, branch[depth]->attrs)) { - return false; - } + if (!IsArgCompatible(call, branch[depth], i) || + !attrs_equal(call->attrs, branch[depth]->attrs)) { + return false; } } - return true; } + return true; +} } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/combine_parallel_op.h b/src/relay/transforms/combine_parallel_op.h index 0097e29..23fe347 100644 --- a/src/relay/transforms/combine_parallel_op.h +++ b/src/relay/transforms/combine_parallel_op.h @@ -26,26 +26,27 @@ #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_ #include -#include #include #include +#include #include #include + +#include #include #include #include -#include + #include "./expr_subst.h" #include "pattern_util.h" - namespace tvm { namespace relay { using Branch = std::vector; using Group = std::vector; -using FIsSupportedOp = std::function; -using FAreCompatibleOps = std::function; +using FIsSupportedOp = std::function; +using FAreCompatibleOps = std::function; using ExprSubstMap = std::unordered_map; /* @@ -74,8 +75,7 @@ class BranchGroupFinder : private ExprVisitor { * \param fare_compatible_ops function that returns true if * two ops are compatible for combining */ - BranchGroupFinder(const Op& op, - FIsSupportedOp fis_supported_op, + BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops); /* @@ -188,10 +188,8 @@ class ParallelOpCombiner { * all combined ops * \return new combined call */ - virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, - size_t parent_index) = 0; + virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, + size_t depth, size_t parent_index) = 0; /* * \brief Updates map of expr to substitute with combined expr. This usually involves @@ -201,9 +199,7 @@ class ParallelOpCombiner { * \param depth depth at which to substitute * \param subst_map map of Expr to replace with Expr to replace it with */ - virtual void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, + virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) = 0; private: diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index 361565e..5cd287c 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -44,17 +44,20 @@ * */ +#include "./combine_parallel_op_batch.h" + #include -#include #include #include +#include #include #include + #include #include -#include "./expr_subst.h" + #include "./combine_parallel_op.h" -#include "./combine_parallel_op_batch.h" +#include "./expr_subst.h" #include "pattern_util.h" namespace tvm { @@ -63,13 +66,9 @@ namespace relay { ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches) - : ParallelOpCombiner(op_name, min_num_branches), - batch_op_name_(batch_op_name) { -} + : ParallelOpCombiner(op_name, min_num_branches), batch_op_name_(batch_op_name) {} -bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { - return true; -} +bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { return true; } bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) { if (a->args.size() != b->args.size()) { @@ -116,19 +115,16 @@ bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* auto ta = a->args[index]->type_as(); auto tb = b->args[index]->type_as(); - if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) - return false; + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; for (size_t i = 0; i < ta->shape.size(); i++) { - if (!eq(ta->shape[i], tb->shape[i])) - return false; + if (!eq(ta->shape[i], tb->shape[i])) return false; } return true; } Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, + const Group& branches, size_t depth, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; @@ -160,10 +156,8 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, return Call(call->op, new_args, call->attrs, {}); } -void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, - ExprSubstMap* subst_map) { +void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches, + size_t depth, ExprSubstMap* subst_map) { int index = 0; auto split = MakeSplit(data, Integer(branches.size()), 0); for (const auto& branch : branches) { @@ -174,30 +168,25 @@ void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, } /*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */ -Expr CombineParallelOpBatch(const Expr& expr, - const std::string& op_name, - const std::string& batch_op_name, - uint64_t min_num_branches) { +Expr CombineParallelOpBatch(const Expr& expr, const std::string& op_name, + const std::string& batch_op_name, uint64_t min_num_branches) { return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr); } namespace transform { -Pass CombineParallelOpBatch(const std::string& op_name, - const std::string& batch_op_name, +Pass CombineParallelOpBatch(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelOpBatch(f, - op_name, - batch_op_name, - min_num_branches)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast( + CombineParallelOpBatch(f, op_name, batch_op_name, min_num_branches)); + }; return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch") -.set_body_typed(CombineParallelOpBatch); + .set_body_typed(CombineParallelOpBatch); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_op_batch.h b/src/relay/transforms/combine_parallel_op_batch.h index 6876604..9f87d9d 100644 --- a/src/relay/transforms/combine_parallel_op_batch.h +++ b/src/relay/transforms/combine_parallel_op_batch.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,16 +25,18 @@ #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_ #include -#include #include #include +#include #include #include + +#include #include #include -#include -#include "./expr_subst.h" + #include "./combine_parallel_op.h" +#include "./expr_subst.h" #include "pattern_util.h" namespace tvm { @@ -68,8 +70,7 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * \param min_num_branches min number of parallel branches beginning with op * to start combining */ - ParallelOpBatchCombiner(const std::string& op_name, - const std::string& batch_op_name, + ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches); protected: @@ -116,9 +117,7 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * all combined ops * \return new combined call as batch op by stacking args */ - Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, + Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) final; /* @@ -129,15 +128,13 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * \param depth depth at which to substitute * \param subst_map map of Expr to replace with Expr to replace it with */ - void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) final; private: /* \brief name of op to replace combined ops with. for example, * for combining parallel dense, this will will be set to - * nn.batch_matmul + * nn.batch_matmul */ std::string batch_op_name_; }; diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index dbb2c38..f43c8f6 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -24,20 +24,20 @@ custom layouts or other general weight pre-transformation. */ #include -#include -#include #include +#include #include #include -#include -#include + #include #include -#include +#include #include +#include +#include -#include "transform_layout.h" #include "pattern_util.h" +#include "transform_layout.h" namespace tvm { namespace relay { @@ -132,8 +132,7 @@ Pass ConvertLayout(const std::string& desired_layout) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::convert_op_layout::ConvertLayout(f, desired_layout)); }; - return CreateFunctionPass( - pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"}); + return CreateFunctionPass(pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"}); } TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); diff --git a/src/relay/transforms/convert_sparse_dense.cc b/src/relay/transforms/convert_sparse_dense.cc index 1b83e71..36aaa47 100644 --- a/src/relay/transforms/convert_sparse_dense.cc +++ b/src/relay/transforms/convert_sparse_dense.cc @@ -133,20 +133,12 @@ Pass DenseToSparse(const Array& weight_name, // Remove FreeVar warnings auto f0 = Downcast(DenseToSparse(f, weight_name, weight_shape)); Array sparse_params = FreeVars(f0); - auto f1 = Function(sparse_params, - f0->body, - f0->ret_type, - f0->type_params, - f0->attrs); + auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); for (const auto& var : sparse_params) { params.push_back(var); } - return Function(params, - f1->body, - f1->ret_type, - f1->type_params, - f1->attrs); + return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); }; return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 48b8666..1c250f1 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -23,17 +23,15 @@ * \brief Use a fresh Id for every Var to make the result well-formed. */ #include -#include #include +#include #include namespace tvm { namespace relay { Expr DeDup(const Expr& e) { - class DeDupMutator : public TypeMutator, - public ExprMutator, - public PatternMutator { + class DeDupMutator : public TypeMutator, public ExprMutator, public PatternMutator { public: TypeVar Fresh(const TypeVar& tv) { TypeVar ret = TypeVar(tv->name_hint, tv->kind); @@ -65,9 +63,7 @@ Expr DeDup(const Expr& e) { return Let(v, VisitExpr(op->value), VisitExpr(op->body)); } - Type VisitType(const Type& t) final { - return t.defined() ? TypeMutator::VisitType(t) : t; - } + Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } Expr VisitExpr_(const FunctionNode* op) final { tvm::Array type_params; @@ -78,29 +74,19 @@ Expr DeDup(const Expr& e) { for (const Var& param : op->params) { params.push_back(Fresh(param)); } - return Function(params, - VisitExpr(op->body), - VisitType(op->ret_type), - type_params, - op->attrs); + return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, op->attrs); } - Pattern VisitPattern(const Pattern& p) final { - return PatternFunctor::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } - Pattern VisitPattern_(const PatternVarNode* op) final { - return PatternVar(Fresh(op->var)); - } + Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(Fresh(op->var)); } Type VisitType_(const TypeVarNode* op) final { TypeVar v = GetRef(op); return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; } - Var VisitVar(const Var& v) final { - return Fresh(v); - } + Var VisitVar(const Var& v) final { return Fresh(v); } private: std::unordered_map rename_; @@ -113,8 +99,7 @@ Expr DeDup(const Expr& e) { return ret; } -TVM_REGISTER_GLOBAL("relay._transform.dedup") -.set_body_typed(DeDup); +TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index a0d093f..9aa0f49 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -30,12 +30,13 @@ #include #include #include + #include "let_list.h" namespace tvm { namespace relay { -template +template using VarMap = std::unordered_map; using VarSet = std::unordered_set; @@ -59,20 +60,18 @@ class Eliminator : private ExprMutator { VarMap expr_map_; VarMap use_map_; bool inline_once_; - explicit Eliminator(const VarMap& expr_map, - const VarMap& use_map, - bool inline_once) : - expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { } + explicit Eliminator(const VarMap& expr_map, const VarMap& use_map, bool inline_once) + : expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) {} friend CalcDep; bool HasLet(const Var& v) { switch (use_map_[v]) { - case 0: - return false; - case 1: - return !inline_once_; - default: - return true; + case 0: + return false; + case 1: + return !inline_once_; + default: + return true; } } @@ -104,8 +103,7 @@ class CalcDep : protected MixedModeVisitor { } private: - explicit CalcDep(const VarMap& expr_map) - : MixedModeVisitor(2), expr_map_(expr_map) {} + explicit CalcDep(const VarMap& expr_map) : MixedModeVisitor(2), expr_map_(expr_map) {} VarMap expr_map_; VarMap use_map_; @@ -123,9 +121,7 @@ class CalcDep : protected MixedModeVisitor { } } - void VisitExpr_(const LetNode* l) final { - VisitExpr(l->body); - } + void VisitExpr_(const LetNode* l) final { VisitExpr(l->body); } void VisitExpr_(const VarNode* v) final { Var var = GetRef(v); @@ -144,14 +140,13 @@ namespace transform { Pass DeadCodeElimination(bool inline_once) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(DeadCodeElimination(f, inline_once)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(DeadCodeElimination(f, inline_once)); + }; return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); } -TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination") -.set_body_typed(DeadCodeElimination); +TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); } // namespace transform diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index d5e1d2e..39cf563 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -28,12 +28,12 @@ * 3. Collect the device allocation of each expression. */ -#include -#include #include +#include #include #include #include +#include #include #include @@ -103,8 +103,7 @@ class ValidateAnnotation : private ExprVisitor { * \return The device type. */ int GetDeviceId(const CallNode* call_node) { - CHECK(IsOnDeviceNode(call_node)) - << "The input call node must be on_device node."; + CHECK(IsOnDeviceNode(call_node)) << "The input call node must be on_device node."; const OnDeviceAttrs* on_device_attr = call_node->attrs.as(); return on_device_attr->device_type; } @@ -160,8 +159,7 @@ class RewriteAnnotation : public ExprMutator { Expr VisitExpr_(const TupleGetItemNode* op) final { Expr tuple = op->tuple; if (NeedDeviceCopy(tuple.operator->(), op)) { - Expr new_expr = - TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); + Expr new_expr = TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); UpdateAnnotationMap(op, new_expr.operator->()); return this->VisitExpr(new_expr); } else { @@ -201,8 +199,7 @@ class RewriteAnnotation : public ExprMutator { } if (annotated) { - Call new_call = Call(call_node->op, new_args, call_node->attrs, - call_node->type_args); + Call new_call = Call(call_node->op, new_args, call_node->attrs, call_node->type_args); UpdateAnnotationMap(call_node, new_call.operator->()); return this->VisitExpr(new_call); @@ -235,8 +232,7 @@ class RewriteAnnotation : public ExprMutator { return CreateDeviceCopy(src, fallback_device_, dit->second); } else { const auto dit = annotation_map_.find(dst); - int dst_dev_type = - dit == annotation_map_.end() ? fallback_device_ : dit->second; + int dst_dev_type = dit == annotation_map_.end() ? fallback_device_ : dit->second; return CreateDeviceCopy(src, sit->second, dst_dev_type); } } @@ -301,6 +297,7 @@ class AnnotatationVisitor : private ExprVisitor { visitor(expr); return visitor.annotations_; } + private: void VisitExpr_(const CallNode* call_node) { if (IsOnDeviceNode(call_node)) { @@ -414,9 +411,7 @@ class DeviceInfo { // TODO(zhiics) Skip annotation of tuple node for now. } - void VisitExpr_(const TupleGetItemNode* op) final { - ExprVisitor::VisitExpr_(op); - } + void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); } void VisitExpr_(const VarNode* vn) final { post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); @@ -432,7 +427,6 @@ class DeviceInfo { post_dfs_order_.push_back(std::make_pair(in, has_copy_)); } - int num_device_copy_ops_{0}; bool has_copy_ = false; std::vector> post_dfs_order_; @@ -479,25 +473,23 @@ class DeviceInfo { const auto* attrs = last_copy_node->attrs.as(); cur_dev_type = attrs->src_dev_type; if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type; - if (it->second) device_map_.Set(GetRef(it->first), - attrs->dst_dev_type); + if (it->second) device_map_.Set(GetRef(it->first), attrs->dst_dev_type); } else if (last_copy_node) { Expr expr = GetRef(it->first); CHECK_EQ(device_map_.count(expr), 0U); if (it->second) device_map_.Set(expr, cur_dev_type); } } - return out_dev_type; + return out_dev_type; } void FillPropagation(int out_dev_type) { for (const auto& it : post_visitor_.post_dfs_order_) { - Expr expr = GetRef(it.first); - if (!it.second) device_map_.Set(expr, out_dev_type); + Expr expr = GetRef(it.first); + if (!it.second) device_map_.Set(expr, out_dev_type); } } - PostDfsOrderVisitor post_visitor_; Map device_map_; }; @@ -521,14 +513,12 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { } CHECK_GT(new_body.size(), 0U); if (new_body.size() == 1) { - return Function(params, new_body[0], Type(nullptr), - fn->type_params, fn->attrs); + return Function(params, new_body[0], Type(nullptr), fn->type_params, fn->attrs); } else if (tuple->fields.size() == new_body.size()) { - return new_expr; + return new_expr; } else { Tuple tuple_body = Tuple(new_body); - return Function(params, tuple_body, Type(nullptr), - fn->type_params, fn->attrs); + return Function(params, tuple_body, Type(nullptr), fn->type_params, fn->attrs); } } else { return new_expr; @@ -544,40 +534,35 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { if (tuple->fields.size() == new_fields.size()) { return new_fields.size() == 1 ? new_fields[0] : new_expr; } else { - return new_fields.size() == 1 ? new_fields[0] - : Tuple(new_fields); + return new_fields.size() == 1 ? new_fields[0] : Tuple(new_fields); } } else { return new_expr; } } -Map CollectDeviceInfo(const Expr& expr) { - return DeviceInfo::GetDeviceMap(expr); -} +Map CollectDeviceInfo(const Expr& expr) { return DeviceInfo::GetDeviceMap(expr); } Map CollectDeviceAnnotationOps(const Expr& expr) { return AnnotatationVisitor::GetAnnotations(expr); } -TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo") -.set_body_typed(CollectDeviceInfo); +TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo").set_body_typed(CollectDeviceInfo); TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps") -.set_body_typed(CollectDeviceAnnotationOps); + .set_body_typed(CollectDeviceAnnotationOps); namespace transform { Pass RewriteAnnotatedOps(int fallback_device) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); + }; return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation") -.set_body_typed(RewriteAnnotatedOps); +TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation").set_body_typed(RewriteAnnotatedOps); } // namespace transform diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index 68c59f5..2861f32 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -29,7 +29,9 @@ #include #include #include + #include + #include "pattern_util.h" namespace tvm { @@ -37,7 +39,7 @@ namespace relay { class CommonSubexprEliminator : public ExprMutator { public: - explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip): fskip_(fskip) {} + explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip) : fskip_(fskip) {} Expr VisitExpr_(const CallNode* call) final { static auto op_stateful = Op::GetAttr("TOpIsStateful"); @@ -88,14 +90,14 @@ namespace transform { Pass EliminateCommonSubexpr(PackedFunc fskip) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(EliminateCommonSubexpr(f, fskip)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(f, fskip)); + }; return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr") -.set_body_typed(EliminateCommonSubexpr); + .set_body_typed(EliminateCommonSubexpr); } // namespace transform diff --git a/src/relay/transforms/eta_expand.cc b/src/relay/transforms/eta_expand.cc index c720bdf..5b43d07 100644 --- a/src/relay/transforms/eta_expand.cc +++ b/src/relay/transforms/eta_expand.cc @@ -24,9 +24,9 @@ * */ #include +#include #include #include -#include namespace tvm { namespace relay { @@ -62,16 +62,14 @@ class EtaExpander : public ExprMutator { type_var_replacer_(TypeVarReplacer()), expand_constructor_(expand_constructor), expand_global_var_(expand_global_var) { - CHECK(expand_constructor || expand_global_var) - << "must expand at least one language feature"; + CHECK(expand_constructor || expand_global_var) << "must expand at least one language feature"; } IRModule Expand() { for (GlobalVar global_var : mod_->GetGlobalVars()) { const BaseFunc base_func = mod_->Lookup(global_var); if (auto* n = base_func.as()) { - const Function new_func = Downcast( - VisitExpr(GetRef(n))); + const Function new_func = Downcast(VisitExpr(GetRef(n))); mod_->Update(global_var, new_func); } } @@ -111,11 +109,8 @@ class EtaExpander : public ExprMutator { Expr body = Call(cons, params, Attrs()); Type ret_type = TypeCall(cons->belong_to, type_params); - return Function( - Downcast>(params), - body, - ret_type, - Downcast>(type_params)); + return Function(Downcast>(params), body, ret_type, + Downcast>(type_params)); } Expr VisitExpr_(const GlobalVarNode* gvar_node) final { @@ -124,7 +119,7 @@ class EtaExpander : public ExprMutator { return std::move(gvar); } const auto base_func = mod_->Lookup(gvar); - if (auto *ptr = base_func.as()) { + if (auto* ptr = base_func.as()) { // handle relay function, skip external functions. auto func = GetRef(ptr); tvm::Array params; @@ -135,11 +130,7 @@ class EtaExpander : public ExprMutator { args.push_back(var); } - return Function( - args, - Call(gvar, params), - func->ret_type, - func->type_params); + return Function(args, Call(gvar, params), func->ret_type, func->type_params); } else { return std::move(gvar); } @@ -161,15 +152,14 @@ class EtaExpander : public ExprMutator { namespace transform { Pass EtaExpand(bool expand_constructor, bool expand_global_var) { - runtime::TypedPackedFunc pass_func = - [=](IRModule mod, PassContext pc) { + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext pc) { return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand(); }; return CreateModulePass(pass_func, 1, "EtaExpand", {}); } -TVM_REGISTER_GLOBAL("relay._transform.EtaExpand") -.set_body_typed(EtaExpand); +TVM_REGISTER_GLOBAL("relay._transform.EtaExpand").set_body_typed(EtaExpand); } // namespace transform diff --git a/src/relay/transforms/expr_subst.cc b/src/relay/transforms/expr_subst.cc index d3e6aa8..54731ed 100644 --- a/src/relay/transforms/expr_subst.cc +++ b/src/relay/transforms/expr_subst.cc @@ -22,9 +22,10 @@ * \brief Utility functions for substituting expressions. */ -#include #include "./expr_subst.h" +#include + namespace tvm { namespace relay { diff --git a/src/relay/transforms/expr_subst.h b/src/relay/transforms/expr_subst.h index 849ffc2..e82e3e6 100644 --- a/src/relay/transforms/expr_subst.h +++ b/src/relay/transforms/expr_subst.h @@ -24,13 +24,13 @@ #ifndef TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_ #define TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_ #include + #include namespace tvm { namespace relay { -Expr ExprSubst(const Expr& expr, - std::unordered_map subst_map); +Expr ExprSubst(const Expr& expr, std::unordered_map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index 8234dea..3c8d8db 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -22,10 +22,11 @@ * \brief Replaces non linear activation functions with their fast but approximate counterparts. */ #include -#include #include -#include +#include #include +#include + #include "pattern_util.h" namespace tvm { @@ -33,10 +34,7 @@ namespace relay { class FastMathMutator : public ExprRewriter { public: - FastMathMutator() - : exp_op_(Op::Get("exp")), - erf_op_(Op::Get("erf")), - tanh_op_(Op::Get("tanh")) {} + FastMathMutator() : exp_op_(Op::Get("exp")), erf_op_(Op::Get("erf")), tanh_op_(Op::Get("tanh")) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (pre->op == exp_op_) { @@ -67,14 +65,11 @@ namespace transform { Pass FastMath() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(FastMath(f)); - }; + [=](Function f, IRModule m, PassContext pc) { return Downcast(FastMath(f)); }; return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.FastMath") -.set_body_typed(FastMath); +TVM_REGISTER_GLOBAL("relay._transform.FastMath").set_body_typed(FastMath); } // namespace transform diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index fab184c..70df0ed 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -21,15 +21,16 @@ * \file constant_folding.cc */ #include +#include #include +#include #include #include -#include -#include #include -#include -#include #include +#include +#include + #include "pattern_util.h" namespace tvm { @@ -48,8 +49,7 @@ class ConstantChecker : private ExprVisitor { return true; } const auto it = memo_.find(expr); - if (it != memo_.end()) - return it->second; + if (it != memo_.end()) return it->second; VisitExpr(expr); return memo_[expr]; // return memoized result or the default value false } @@ -69,12 +69,9 @@ class ConstantChecker : private ExprVisitor { } }; -bool ConstantCheck(const Expr& e) { - return ConstantChecker().Check(e); -} +bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); } -TVM_REGISTER_GLOBAL("relay.analysis.check_constant") -.set_body_typed(ConstantCheck); +TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantCheck); // TODO(tvm-team) consider combine dead-code with constant folder. // or make a more powerful partial evaluator. @@ -98,9 +95,7 @@ class ConstantFolder : public ExprMutator { } else { Var var = Downcast(this->Mutate(op->var)); Expr body = this->Mutate(op->body); - if (var.same_as(op->var) && - value.same_as(op->value) && - body.same_as(op->body)) { + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return Let(var, value, body); @@ -123,7 +118,7 @@ class ConstantFolder : public ExprMutator { const OpNode* op = call->op.as(); if (op == nullptr) return res; if (skip_list.count(op->name)) { - return res; + return res; } // skip stateful ops. if (op_stateful.get(GetRef(op), false)) return res; @@ -133,9 +128,7 @@ class ConstantFolder : public ExprMutator { } // We should think about potentially constant evaluation over these ops too. - if (call->op == invoke_tvm_op_ || - call->op == shape_func_op_ || - call->op == alloc_tensor_op_ || + if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || call->op == alloc_storage_op_) { return GetRef(call); } @@ -184,8 +177,7 @@ class ConstantFolder : public ExprMutator { if (value->IsInstance()) { auto nd_array = Downcast(value); for (auto dim : nd_array.Shape()) { - CHECK_GT(dim, 0) - << "invalid dimension after constant eval"; + CHECK_GT(dim, 0) << "invalid dimension after constant eval"; } return Constant(nd_array); } else if (const auto* val = value.as()) { @@ -202,8 +194,7 @@ class ConstantFolder : public ExprMutator { } // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { - std::vector passes = {transform::FuseOps(0), - transform::ToANormalForm(), + std::vector passes = {transform::FuseOps(0), transform::ToANormalForm(), transform::InferType()}; Function func; if (expr.as()) { @@ -212,10 +203,7 @@ class ConstantFolder : public ExprMutator { // TODO(@jroesch): fix this func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); } - auto mod = IRModule( - {}, - module_->type_definitions, - module_->Imports()); + auto mod = IRModule({}, module_->type_definitions, module_->Imports()); auto global = GlobalVar("main"); mod->Add(global, func); auto seq = transform::Sequential(passes); @@ -251,7 +239,7 @@ class ConstantFolder : public ExprMutator { value = runtime::NDArray::Empty({}, cdtype, ctx); } else { CHECK_NE(ishape.size(), 0); - std::vector cshape = { static_cast(ishape.size()) }; + std::vector cshape = {static_cast(ishape.size())}; value = runtime::NDArray::Empty(cshape, cdtype, ctx); int32_t* dims = static_cast(value->data); using ::tvm::tir::IntImmNode; @@ -274,12 +262,11 @@ class ConstantFolder : public ExprMutator { // Cast the constant into correct dtype auto cast_attrs = make_object(); cast_attrs->dtype = param->dtype; - Expr ret = Call(cast_op_, { shape }, Attrs(cast_attrs), {}); + Expr ret = Call(cast_op_, {shape}, Attrs(cast_attrs), {}); return ConstEvaluate(ret); } }; - Expr FoldConstant(const Expr& expr, const IRModule& mod) { DLContext ctx; ctx.device_type = kDLCPU; @@ -296,14 +283,13 @@ namespace transform { Pass FoldConstant() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(FoldConstant(f, m)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FoldConstant(f, m)); + }; return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } -TVM_REGISTER_GLOBAL("relay._transform.FoldConstant") -.set_body_typed(FoldConstant); +TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant); } // namespace transform diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index cfe74bf..57e3d69 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -23,14 +23,14 @@ * \brief Fold axis scaling into weights of * conv/dense operators. */ -#include #include #include #include #include -#include "pattern_util.h" -#include "pass_util.h" +#include +#include "pass_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -43,7 +43,6 @@ namespace fold_scale_axis { using runtime::TypedPackedFunc; - // FoldScaleAxis algorithm: // // The general idea is to transform Expr to tuple of @@ -109,7 +108,7 @@ class Message : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode); }; -Message::Message(const AxesSet& axes, bool require_positive) { +Message::Message(const AxesSet& axes, bool require_positive) { auto n = make_object(); n->axes = axes; n->require_positive = require_positive; @@ -139,7 +138,8 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { ++j; } else { ret.push_back(lhs[i]); - ++i; ++j; + ++i; + ++j; } } return ret; @@ -166,8 +166,8 @@ Message Intersect(const Message& lhs, const Message& rhs) { * positive scale is required. * \return The message containing the result scaling on axes of the input. */ -using FForwardPrep = runtime::TypedPackedFunc< - Array (const Call& call, const Message& out_message)>; +using FForwardPrep = + runtime::TypedPackedFunc(const Call& call, const Message& out_message)>; /*! \brief Axis scale tuple. */ class ScaledExprNode : public TempExprNode { @@ -180,8 +180,7 @@ class ScaledExprNode : public TempExprNode { Expr scale = NullValue(); Expr Realize() const final { - CHECK(!axes.defined()) - << "outstanding scale"; + CHECK(!axes.defined()) << "outstanding scale"; return value; } @@ -195,18 +194,15 @@ class ScaledExprNode : public TempExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode); }; -using FForwardRewrite = TypedPackedFunc< - Expr(const Call& ref_call, - const Array& new_args, - const Message& message)>; +using FForwardRewrite = TypedPackedFunc& new_args, + const Message& message)>; //---------------------------------------------- // Generic Visitors for FScaleAxisForward //---------------------------------------------- class ForwardPrep : private ExprVisitor { public: - std::unordered_map - Prepare(const Expr& body) { + std::unordered_map Prepare(const Expr& body) { this->Update(body, NullValue()); this->VisitExpr(body); // flist is added in the Post-DFS order @@ -222,7 +218,7 @@ class ForwardPrep : private ExprVisitor { private: // The invoke list - std::vector > flist_; + std::vector> flist_; // The message on each node. std::unordered_map message_; // Update the message stored at node. @@ -245,15 +241,11 @@ class ForwardPrep : private ExprVisitor { } } // Visitor pattern override. - void VisitExpr_(const LetNode* call) { - LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; - } + void VisitExpr_(const LetNode* call) { LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; } void VisitExpr_(const FunctionNode* op) { ExprVisitor::VisitExpr_(op); - auto flazy = [this, op] { - this->Update(op->body, NullValue()); - }; + auto flazy = [this, op] { this->Update(op->body, NullValue()); }; flist_.push_back(flazy); } @@ -261,8 +253,7 @@ class ForwardPrep : private ExprVisitor { ExprVisitor::VisitExpr_(call); // function to be lazily invoked auto flazy = [this, call]() { - static const auto& fprep = - Op::GetAttr("FScaleAxisForwardPrep"); + static const auto& fprep = Op::GetAttr("FScaleAxisForwardPrep"); // find the message send to this node. auto it = message_.find(call); Message out_message; @@ -326,31 +317,26 @@ Array ReluForwardPrep(const Call& call, const Message& out_message) { return {out_message}; } -Expr ReluForwardRewrite(const Call& ref_call, - const Array& new_args, - const Message& message) { +Expr ReluForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { const auto* input = new_args[0].as(); if (input == nullptr) return Expr(nullptr); // return transformed conv2d auto rnode = make_object(); - rnode->value = Call( - ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); + rnode->value = Call(ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); rnode->scale = input->scale; rnode->axes = input->axes; return Expr(rnode); } -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisForwardPrep", ReluForwardPrep); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisForwardPrep", ReluForwardPrep); -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisForwardRewrite", + ReluForwardRewrite); -RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisForwardPrep", ReluForwardPrep); +RELAY_REGISTER_OP("nn.leaky_relu").set_attr("FScaleAxisForwardPrep", ReluForwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); + .set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); // AddSub Array AddSubForwardPrep(const Call& call, const Message& out_message) { @@ -367,8 +353,7 @@ Array AddSubForwardPrep(const Call& call, const Message& out_message) { return {none, none}; } -Expr AddSubForwardRewrite(const Call& ref_call, - const Array& new_args, +Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { const auto* slhs = new_args[0].as(); const auto* srhs = new_args[1].as(); @@ -380,43 +365,36 @@ Expr AddSubForwardRewrite(const Call& ref_call, if (slhs != nullptr) { CHECK(srhs == nullptr); CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); - Expr scale = ExpandBiasToMatchAxis( - slhs->scale, tlhs->shape.size(), slhs->axes); + Expr scale = ExpandBiasToMatchAxis(slhs->scale, tlhs->shape.size(), slhs->axes); Expr rhs = Divide(new_args[1], scale); - rnode->value = Call(ref_call->op, {slhs->value, rhs}, - ref_call->attrs, ref_call->type_args); + rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; rnode->axes = slhs->axes; } else { CHECK(srhs != nullptr); CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); - Expr scale = ExpandBiasToMatchAxis( - srhs->scale, trhs->shape.size(), srhs->axes); + Expr scale = ExpandBiasToMatchAxis(srhs->scale, trhs->shape.size(), srhs->axes); Expr lhs = Divide(new_args[0], scale); - rnode->value = Call(ref_call->op, {lhs, srhs->value}, - ref_call->attrs, ref_call->type_args); + rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; rnode->axes = srhs->axes; } return Expr(rnode); } -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisForwardRewrite", + AddSubForwardRewrite); -RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); +RELAY_REGISTER_OP("subtract").set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); + .set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); // Producer operators // Multiply produces the scale-axis pair. -Expr MultiplyForwardRewrite(const Call& ref_call, - const Array& new_args, +Expr MultiplyForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { if (!message.defined()) return Expr(); const auto& expected_out_axes = message->axes; @@ -451,7 +429,7 @@ Expr MultiplyForwardRewrite(const Call& ref_call, } RELAY_REGISTER_OP("multiply") -.set_attr("FScaleAxisForwardRewrite", MultiplyForwardRewrite); + .set_attr("FScaleAxisForwardRewrite", MultiplyForwardRewrite); // Consumer operators // Conv2D send out requirement of axis folding. @@ -476,8 +454,7 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); - if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && - c_small_axis < 0 && + if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { data_axes = {c_big_axis}; } @@ -488,8 +465,7 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { } // Conv2D consumes the scale axis during transformation. -Expr Conv2DForwardRewrite(const Call& ref_call, - const Array& new_args, +Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { // if data do not have scale, normal transform path. const auto* sdata = new_args[0].as(); @@ -505,8 +481,7 @@ Expr Conv2DForwardRewrite(const Call& ref_call, // For now, we only support simple pattern (no folded weight/data) // TODO(tvm-team) support general data layout CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1); - CHECK(sdata->axes.size() == 1 && - c_big_axis == sdata->axes[0]->value); + CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value); int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); @@ -518,29 +493,24 @@ Expr Conv2DForwardRewrite(const Call& ref_call, // match the ic_axis if (is_depthwise_conv2d) { - Expr scale = ExpandBiasToMatchAxis( - sdata->scale, kernel_layout.ndim(), {big_oc_axis}); + Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_oc_axis}); weight = Multiply(weight, scale); } else { - Expr scale = ExpandBiasToMatchAxis( - sdata->scale, kernel_layout.ndim(), {big_ic_axis}); + Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ic_axis}); weight = Multiply(weight, scale); } // return transformed conv2d - return Call( - ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); + return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisForwardPrep", Conv2DForwardPrep); +RELAY_REGISTER_OP("nn.conv2d").set_attr("FScaleAxisForwardPrep", Conv2DForwardPrep); RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); - + .set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); - auto fcontext = [&](const Call& call) -> ObjectRef{ + auto fcontext = [&](const Call& call) -> ObjectRef { auto it = message.find(call.get()); if (it != message.end()) { return it->second; @@ -548,8 +518,7 @@ Expr ForwardFoldScaleAxis(const Expr& data) { return ObjectRef(nullptr); } }; - return ForwardRewrite( - data, "FScaleAxisForwardRewrite", fcontext); + return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); } //---------------------------------------- @@ -564,14 +533,11 @@ class BackwardTransformer; * positive scale is required. * \return Message containing the result scaling on axes of the input. */ -using FBackwardPrep = TypedPackedFunc< - Message(const Call& call, const Array& in_messages)>; +using FBackwardPrep = TypedPackedFunc& in_messages)>; -using FBackwardTransform = TypedPackedFunc< - Expr(const Call& call, - const Message& message, - const Expr& scale, - const BackwardTransformer& transformer)>; +using FBackwardTransform = + TypedPackedFunc; //---------------------------------------------- // Generic Visitors for FScaleAxisBackward @@ -580,8 +546,7 @@ using FBackwardTransform = TypedPackedFunc< class BackwardPrep : private ExprVisitor { public: // The message on each node. - std::unordered_map - Prepare(const Expr& body) { + std::unordered_map Prepare(const Expr& body) { ref_counter_ = GetExprRefCount(body); this->VisitExpr(body); return std::move(message_); @@ -595,8 +560,7 @@ class BackwardPrep : private ExprVisitor { // Visit the expression. void VisitExpr_(const CallNode* call) { ExprVisitor::VisitExpr_(call); - static const auto& fprep = - Op::GetAttr("FScaleAxisBackwardPrep"); + static const auto& fprep = Op::GetAttr("FScaleAxisBackwardPrep"); auto f = fprep.get(call->op, nullptr); if (f == nullptr) return; auto rit = ref_counter_.find(call); @@ -620,9 +584,7 @@ class BackwardPrep : private ExprVisitor { } }; -class BackwardTransformerNode : - public Object, - private ExprMutator { +class BackwardTransformerNode : public Object, private ExprMutator { public: // Run forward transform. Expr Fold(Expr expr) { @@ -692,19 +654,15 @@ class BackwardTransformerNode : class BackwardTransformer : public ObjectRef { public: BackwardTransformer() {} - explicit BackwardTransformer( - ::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { - } + explicit BackwardTransformer(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} BackwardTransformerNode* operator->() const { return static_cast(get_mutable()); } using ContainerType = BackwardTransformerNode; }; -Expr BackwardTransformerNode::Transform( - const CallNode* call_node, Message message, Expr scale) { - static const auto& ftransform = - Op::GetAttr("FScaleAxisBackwardTransform"); +Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) { + static const auto& ftransform = Op::GetAttr("FScaleAxisBackwardTransform"); auto f = ftransform.get(call_node->op, nullptr); if (f != nullptr) { const Call call = GetRef(call_node); @@ -712,10 +670,7 @@ Expr BackwardTransformerNode::Transform( if (it != memo_.end()) { return it->second; } - Expr new_expr = f(GetRef(call_node), - message, - scale, - GetRef(this)); + Expr new_expr = f(GetRef(call_node), message, scale, GetRef(this)); memo_[call] = new_expr; return new_expr; } else { @@ -724,7 +679,6 @@ Expr BackwardTransformerNode::Transform( } } - //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- @@ -737,45 +691,38 @@ Message ReluBackwardPrep(const Call& call, const Array& in_messages) { return in_messages[0]; } -Expr ReluBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr ReluBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } - Expr input = transformer->Transform( - call->args[0], message, scale); + Expr input = transformer->Transform(call->args[0], message, scale); return Call(call->op, {input}, call->attrs, call->type_args); } -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisBackwardTransform", + ReluBackwardTransform); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); + .set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); // AddSub Message AddSubBackwardPrep(const Call& call, const Array& in_messages) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); StructuralEqual equal; - if (in_messages[0].defined() && - MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { + if (in_messages[0].defined() && MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { return in_messages[0]; } else if (in_messages[1].defined() && MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) { return in_messages[1]; - } else if (in_messages[0].defined() && - in_messages[1].defined() && - equal(in_messages[0]->axes, in_messages[1]->axes) && - equal(tlhs->shape, trhs->shape)) { + } else if (in_messages[0].defined() && in_messages[1].defined() && + equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) { // add of two elements. return in_messages[0]; } else { @@ -784,9 +731,7 @@ Message AddSubBackwardPrep(const Call& call, const Array& in_messages) } } -Expr AddSubBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr AddSubBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); @@ -806,19 +751,15 @@ Expr AddSubBackwardTransform(const Call& call, } else if (lhs_message.defined()) { CHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); - Expr rhs = transformer->Transform( - call->args[1], NullValue(), NullValue()); - Expr rhs_scale = ExpandBiasToMatchAxis( - scale, tlhs->shape.size(), message->axes); + Expr rhs = transformer->Transform(call->args[1], NullValue(), NullValue()); + Expr rhs_scale = ExpandBiasToMatchAxis(scale, tlhs->shape.size(), message->axes); rhs = Multiply(rhs, rhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (rhs_message.defined()) { CHECK(equal(message->axes, rhs_message->axes)); - Expr lhs = transformer->Transform( - call->args[0], NullValue(), NullValue()); + Expr lhs = transformer->Transform(call->args[0], NullValue(), NullValue()); Expr rhs = transformer->Transform(call->args[1], message, scale); - Expr lhs_scale = ExpandBiasToMatchAxis( - scale, trhs->shape.size(), message->axes); + Expr lhs_scale = ExpandBiasToMatchAxis(scale, trhs->shape.size(), message->axes); lhs = Multiply(lhs, lhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { @@ -827,23 +768,19 @@ Expr AddSubBackwardTransform(const Call& call, } } -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisBackwardTransform", + AddSubBackwardTransform); -RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); +RELAY_REGISTER_OP("subtract").set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); // Producer operators // Multiply produces the scale-axis pair. -Expr MultiplyBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr MultiplyBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { CHECK(!message.defined()) << "outstanding scale"; const auto* tlhs = call->args[0]->type_as(); @@ -871,7 +808,7 @@ Expr MultiplyBackwardTransform(const Call& call, } RELAY_REGISTER_OP("multiply") -.set_attr("FScaleAxisBackwardTransform", MultiplyBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", MultiplyBackwardTransform); // Consumer operators // Conv2D send out requirement of axis folding. @@ -893,8 +830,7 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 && - kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && - c_small_axis < 0 && + kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { return Message({c_big_axis}, false); } else { @@ -903,9 +839,7 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) } // Conv2D consumes the scale axis during transformation. -Expr Conv2DBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); @@ -920,31 +854,26 @@ Expr Conv2DBackwardTransform(const Call& call, // TODO(tvm-team) support general data layout CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1); CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1); - CHECK(message->axes.size() == 1 && - c_big_axis == message->axes[0]->value); + CHECK(message->axes.size() == 1 && c_big_axis == message->axes[0]->value); int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); - Expr data = transformer->Transform( - call->args[0], NullValue(), NullValue()); - Expr weight = transformer->Transform( - call->args[1], NullValue(), NullValue()); + Expr data = transformer->Transform(call->args[0], NullValue(), NullValue()); + Expr weight = transformer->Transform(call->args[1], NullValue(), NullValue()); // scale on input for deptwise. - Expr wscale = ExpandBiasToMatchAxis( - scale, kernel_layout.ndim(), {big_oc_axis}); + Expr wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_oc_axis}); weight = Multiply(weight, wscale); - return Call( - call->op, {data, weight}, call->attrs, call->type_args); + return Call(call->op, {data, weight}, call->attrs, call->type_args); } RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisBackwardPrep", Conv2DBackwardPrep); + .set_attr("FScaleAxisBackwardPrep", Conv2DBackwardPrep); RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); Expr BackwardFoldScaleAxis(const Expr& data) { return make_object()->Fold(data); @@ -956,39 +885,33 @@ namespace transform { Pass ForwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast( - relay::fold_scale_axis::ForwardFoldScaleAxis(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::fold_scale_axis::ForwardFoldScaleAxis(f)); + }; return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") -.set_body_typed(ForwardFoldScaleAxis); +TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis").set_body_typed(ForwardFoldScaleAxis); Pass BackwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast( - relay::fold_scale_axis::BackwardFoldScaleAxis(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::fold_scale_axis::BackwardFoldScaleAxis(f)); + }; return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis") -.set_body_typed(BackwardFoldScaleAxis); +TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis").set_body_typed(BackwardFoldScaleAxis); Pass FoldScaleAxis() { // FoldScaleAxis pass contains the following three passes. Therefore, we can // register it as a sequential pass. - Pass pass = Sequential( - {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, - "FoldScaleAxis"); + Pass pass = Sequential({BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, + "FoldScaleAxis"); return pass; } -TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis") -.set_body_typed(FoldScaleAxis); +TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis").set_body_typed(FoldScaleAxis); } // namespace transform diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc index f01c4fa..226b338 100644 --- a/src/relay/transforms/forward_rewrite.cc +++ b/src/relay/transforms/forward_rewrite.cc @@ -26,6 +26,7 @@ #include #include #include + #include "pass_util.h" namespace tvm { @@ -36,9 +37,7 @@ namespace relay { // so that calling realize repeatively won't hurt perf. class TempRealizer : private MixedModeMutator { public: - Expr Realize(Expr expr) { - return Mutate(expr); - } + Expr Realize(Expr expr) { return Mutate(expr); } private: Expr DispatchVisitExpr(const Expr& expr) final { @@ -57,17 +56,12 @@ class ForwardRewriter : private MixedModeMutator { ForwardRewriter(const OpMap* rewrite_map, std::function fcontext, std::function fmulti_ref_trigger) - : rewrite_map_(rewrite_map), - fcontext_(fcontext), - fmulti_ref_trigger_(fmulti_ref_trigger) {} + : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} ForwardRewriter(const FForwardRewrite* rewrite_func, std::function fcontext, std::function fmulti_ref_trigger) - : rewrite_func_(rewrite_func), - fcontext_(fcontext), - fmulti_ref_trigger_(fmulti_ref_trigger) {} - + : rewrite_func_(rewrite_func), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} // Transform expression. Expr Rewrite(const Expr& expr) { @@ -91,7 +85,7 @@ class ForwardRewriter : private MixedModeMutator { TempRealizer realizer_; // Visit and allow non-realized version. - Expr GetTempExpr(const Expr& expr, const Expr& post) { + Expr GetTempExpr(const Expr& expr, const Expr& post) { if (fmulti_ref_trigger_ != nullptr) { Expr ret = post; auto it = ref_counter_.find(expr.get()); @@ -160,9 +154,8 @@ class ForwardRewriter : private MixedModeMutator { } // try to rewrite. if (frewrite != nullptr) { - Expr res = frewrite( - ref_call, call_args, - fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); + Expr res = frewrite(ref_call, call_args, + fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); if (res.defined()) return res; // abort, use old rule for (size_t i = 0; i < call_args.size(); ++i) { @@ -175,21 +168,18 @@ class ForwardRewriter : private MixedModeMutator { } } if (unchanged) return ref_call; - return Call( - new_op, call_args, call_node->attrs, call_node->type_args); + return Call(new_op, call_args, call_node->attrs, call_node->type_args); } }; -Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_name, +Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_name, std::function fcontext, std::function fmulti_ref_trigger) { auto rewrite_map = Op::GetAttr(rewrite_map_name); return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); } -Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, +Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, std::function fcontext, std::function fmulti_ref_trigger) { return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index e37b44c..0ca8d7c 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -24,13 +24,14 @@ * \brief This is a backend-aware optimization pass. * Fuse necessary ops into a single one. */ -#include #include #include #include #include -#include "pattern_util.h" +#include + #include "../../support/arena.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -53,8 +54,9 @@ namespace relay { However, at the point of conv2d we do not necessarily know that all the future paths will merge at the elemwise add. The fusion algorithm applies post-dominator analysis. - The immediate post-dominator of a node defined by the closest node where all the future path goes into. - In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm is as follows: + The immediate post-dominator of a node defined by the closest node where all the future path goes + into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm + is as follows: - Construct a DAG of dataflow graph for dominator analysis - Construct a post-dominator tree which gives immediate post dominator of each node. @@ -73,8 +75,8 @@ namespace relay { - CommitFuse: mark all the nodes between source and post-dominator as the same group. - We use an Union-Find data structure to manage the groups. */ -using support::LinkNode; using support::LinkedList; +using support::LinkNode; constexpr uint32_t kMaxFusedOps = 256; @@ -123,9 +125,7 @@ class IndexedForwardGraph { std::ostringstream os; for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; - os << "node[" << i << "], " - << GetRef(node->ref) - << " outputs=["; + os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; for (auto* link = node->outputs.head; link != nullptr; link = link->next) { os << link->value.node->index << ", "; } @@ -147,8 +147,7 @@ class IndexedForwardGraph { // Creator of post dominator tree of the dataflow class IndexedForwardGraph::Creator : private ExprVisitor { public: - explicit Creator(support::Arena* arena) - : arena_(arena) {} + explicit Creator(support::Arena* arena) : arena_(arena) {} IndexedForwardGraph Prepare(const Expr& body) { this->Update(body, nullptr, kOpaque); @@ -164,9 +163,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // attribute equal comparator StructuralEqual attr_equal_; // Update the message stored at the node. - void Update(const Expr& node, - IndexedForwardGraph::Node* parent, - OpPatternKind pattern) { + void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) { const tvm::Object* key = node.get(); IndexedForwardGraph::Node* current; auto it = graph_.node_map.find(key); @@ -188,8 +185,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void AddNode(const tvm::Object* key) { auto it = graph_.node_map.find(key); - CHECK(it != graph_.node_map.end()) - << "Cannot find node " << GetRef(key); + CHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef(key); IndexedForwardGraph::Node* node = it->second; CHECK(node->ref == nullptr); node->ref = key; @@ -214,12 +210,9 @@ class IndexedForwardGraph::Creator : private ExprVisitor { Node* node = graph_.node_map.at(op); DataType dtype = DataType(op->data->dtype); // This rule must be consistent with code generator. - bool is_simple_const = ( - dtype == DataType::Int(32) || - dtype == DataType::Int(64) || - dtype == DataType::Float(32) || - dtype == DataType::Float(64) || - dtype == DataType::Bool()); + bool is_simple_const = + (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) || + dtype == DataType::Float(64) || dtype == DataType::Bool()); if (op->is_scalar() && is_simple_const) { node->pattern = kElemWise; } else { @@ -232,8 +225,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const CallNode* call) final { CHECK(graph_.node_map.count(call)); Node* node = graph_.node_map.at(call); - static auto fpattern = - Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttr("TOpPattern"); // Now we set the pattern of this call. // // If we see a call mentioning an operator we should mark it with its @@ -255,13 +247,10 @@ class IndexedForwardGraph::Creator : private ExprVisitor { const auto* rtype = call->checked_type().as(); // pass the analysis back to all the children it references. for (size_t i = 0; i < call->args.size(); ++i) { - const auto* arg_type = - call->args[i]->checked_type().as(); + const auto* arg_type = call->args[i]->checked_type().as(); // specifically check if result type is the same as arguments type OpPatternKind edge_pattern = op_pattern; - if (edge_pattern == kBroadcast && - arg_type != nullptr && - rtype != nullptr && + if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr && attr_equal_(rtype->shape, arg_type->shape)) { edge_pattern = kElemWise; } @@ -313,9 +302,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const VarNode* op) final { - this->AddNode(op); - } + void VisitExpr_(const VarNode* op) final { this->AddNode(op); } void VisitExpr_(const LetNode* op) final { // do not fuse through let. @@ -364,8 +351,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } }; -IndexedForwardGraph IndexedForwardGraph::Create( - support::Arena* arena, const Expr& body) { +IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) { return Creator(arena).Prepare(body); } @@ -398,13 +384,11 @@ class DominatorTree { * \note This algorithm makes use of the fact that graph is DAG, * and runs a single pass algorithm via LCA (Least Common Ancestor) */ - static DominatorTree PostDom(support::Arena* arena, - const IndexedForwardGraph& graph); + static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); private: // Combine pattern together. - static OpPatternKind CombinePattern( - OpPatternKind lhs, OpPatternKind rhs) { + static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > rhs) return lhs; return rhs; } @@ -416,26 +400,19 @@ class DominatorTree { * The combined edge pattern across all the parents. * \return The least common ancestor of the two. */ - static Node* LeastCommonAncestor( - Node* lhs, - Node* rhs, - OpPatternKind* edge_pattern) { + static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) { while (lhs != rhs) { if (lhs == nullptr) return nullptr; if (rhs == nullptr) return nullptr; if (lhs->depth < rhs->depth) { - edge_pattern[0] = CombinePattern( - edge_pattern[0], rhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); rhs = rhs->parent; } else if (rhs->depth < lhs->depth) { - edge_pattern[0] = CombinePattern( - edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); lhs = lhs->parent; } else { - edge_pattern[0] = CombinePattern( - edge_pattern[0], lhs->pattern); - edge_pattern[0] = CombinePattern( - edge_pattern[0], rhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); lhs = lhs->parent; rhs = rhs->parent; } @@ -496,9 +473,7 @@ class DominatorTree { } }; - -DominatorTree DominatorTree::PostDom(support::Arena* arena, - const IndexedForwardGraph& graph) { +DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { DominatorTree tree; tree.nodes.resize(graph.post_dfs_order.size(), nullptr); // reverse topo order @@ -572,13 +547,11 @@ class GraphPartitioner { /*! \brief internal field used for deduplication */ std::unordered_set visited_; // Internal implelementation of CheckPath - template - bool CheckPath_(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink, - F fcond) { + template + bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { if (visited_.count(src)) return true; visited_.insert(src); - Group* gnode = groups_[src->index]; + Group* gnode = groups_[src->index]; CHECK(gnode != nullptr); gnode = gnode->FindRoot(); if (!fcond(gnode->pattern, src == sink)) return false; @@ -600,10 +573,8 @@ class GraphPartitioner { * \tparam F the condition function, with signature * \note sink must be a post-dominator of src. */ - template - bool CheckPath(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink, - F fcond) { + template + bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { CHECK(!src->extern_ref); visited_.clear(); CHECK(src != sink); @@ -613,8 +584,7 @@ class GraphPartitioner { return true; } // Combine two patterns together. - static OpPatternKind CombinePattern( - OpPatternKind lhs, OpPatternKind rhs) { + static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > kBroadcast && rhs > kBroadcast) { LOG(FATAL) << "Cannot merge two complex group together"; } @@ -637,14 +607,11 @@ class GraphPartitioner { if (child->master_ref != nullptr) { CHECK(parent->master_ref == nullptr); parent->master_ref = child->master_ref; - parent->pattern = CombinePattern( - child->pattern, parent->pattern); + parent->pattern = CombinePattern(child->pattern, parent->pattern); } } // Internal implelementation of CommitFuse - void CommitFuse_(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink, - Group* target) { + void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) { if (src == sink) return; if (visited_.count(src)) return; visited_.insert(src); @@ -662,8 +629,7 @@ class GraphPartitioner { * \param sink The termination node. * \note sink must be a post-dominator of src. */ - void CommitFuse(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink) { + void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { Group* target = groups_[sink->index]; visited_.clear(); CHECK(src != sink); @@ -687,9 +653,7 @@ class GraphPartitioner { } // execute the fusion algorithm. - void RunFuse(const IndexedForwardGraph& graph, - const DominatorTree& post_dom_tree, - int phase) { + void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) { for (size_t nid = 0; nid < groups_.size(); ++nid) { // the group of current node has been specified already. auto* graph_node = graph.post_dfs_order[nid]; @@ -704,8 +668,7 @@ class GraphPartitioner { size_t dom_parent_gindex = dom_node->parent->gnode->index; // refuse the fusion if too many ops are going to be fused together - if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) - continue; + if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) continue; if (phase == 2) { // Fuse injective ops into intermediate tuples, if any @@ -716,9 +679,7 @@ class GraphPartitioner { if (dom_root_group->pattern == kTuple) continue; if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { // Now we know the tuple has been fused into subsequent injective ops - auto fcond = [](OpPatternKind kind, bool is_sink) { - return kind <= kInjective; - }; + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; // dom_root_group can also be tuple, as in inception layers // CheckPath is needed to avoid fusing two intermediate tuples if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { @@ -743,9 +704,7 @@ class GraphPartitioner { if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { CHECK(dom_node->parent->gnode != nullptr); // The fuse can be executed if all the intermediate ops are still broadcast. - auto fcond = [](OpPatternKind kind, bool is_sink) { - return kind <= kBroadcast; - }; + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); } @@ -753,8 +712,7 @@ class GraphPartitioner { } else if (group_node->pattern <= kBroadcast) { // Pre-condition: can only be fused to parent which is injective or reduction. if (dom_node->parent != nullptr && - (dom_node->pattern <= kInjective || - dom_node->pattern == kCommReduce)) { + (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { // Check if all the intermediate ops are still broadcast. // The final terminal node can already be fused to a OutEWiseFusable group. auto fcond = [](OpPatternKind kind, bool is_sink) { @@ -763,9 +721,7 @@ class GraphPartitioner { // are allowed be fused to the elemwise/broadcast master. return kind <= kInjective; } else { - return (kind <= kBroadcast || - kind == kCommReduce || - kind == kInjective || + return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || kind == kOutEWiseFusable); } }; @@ -778,9 +734,7 @@ class GraphPartitioner { // so conv2d always finishes fusing. if (phase != 1) continue; // Check if all path are injective. - auto fcond = [](OpPatternKind kind, bool is_sink) { - return kind <= kInjective; - }; + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); } @@ -792,8 +746,8 @@ class GraphPartitioner { } }; -std::vector -GraphPartitioner::Partition(const IndexedForwardGraph& graph) { +std::vector GraphPartitioner::Partition( + const IndexedForwardGraph& graph) { this->InitGroups(graph); if (opt_level_ == 0) return std::move(groups_); // get post dominator tree @@ -811,8 +765,7 @@ class FuseMutator : private ExprMutator { Expr Transform(const Expr& body, int fuse_opt_level) { // setup the group map. auto graph = IndexedForwardGraph::Create(&arena_, body); - auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition( - graph); + auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { CHECK(graph.post_dfs_order[nid]->ref != nullptr); gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; @@ -822,7 +775,6 @@ class FuseMutator : private ExprMutator { return this->Mutate(body); } - private: /*! \brief Temporary information from each group. */ struct GroupInfo { @@ -865,8 +817,7 @@ class FuseMutator : private ExprMutator { // Transform calls. Expr VisitExpr_(const CallNode* call) { if (call->op.as()) { - static auto fnoncomputational = - Op::GetAttr("TNonComputational"); + static auto fnoncomputational = Op::GetAttr("TNonComputational"); if (fnoncomputational.get(Downcast(call->op), false)) { return ExprMutator::VisitExpr_(call); @@ -881,8 +832,7 @@ class FuseMutator : private ExprMutator { auto* ret_group = gmap_.at(call)->FindRoot(); Array new_args = GetNewArguments(call->args, ret_group); - auto new_call = Call( - call->op, new_args, call->attrs, call->type_args); + auto new_call = Call(call->op, new_args, call->attrs, call->type_args); if (ret_group->root_ref == call) { // This is the root of the group @@ -929,9 +879,7 @@ class FuseMutator : private ExprMutator { // If the function has no call, it is not a primitive function. struct HasCallVisitor : ExprVisitor { bool has_call = false; - void VisitExpr_(const CallNode* op) final { - has_call = true; - } + void VisitExpr_(const CallNode* op) final { has_call = true; } } visitor; visitor(body); const GroupInfo& ginfo = ginfo_[group]; @@ -960,13 +908,13 @@ class FuseMutator : private ExprMutator { // Debug function, dump the group assignment in text. void DebugDumpGroup(const Expr& body) { std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string { - auto it = gmap_.find(expr.get()); - if (it == gmap_.end()) return ""; - std::ostringstream os; - auto *group = it->second->FindRoot(); - os << " /* group=" << group << " */"; - return os.str(); - }); + auto it = gmap_.find(expr.get()); + if (it == gmap_.end()) return ""; + std::ostringstream os; + auto* group = it->second->FindRoot(); + os << " /* group=" << group << " */"; + return os.str(); + }); LOG(INFO) << "Dump of group info:\n" << text; } }; @@ -979,15 +927,14 @@ namespace transform { Pass FuseOps(int fuse_opt_level) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; - return Downcast(FuseOps(f, opt_level, m)); - }; + [=](Function f, IRModule m, PassContext pc) { + int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; + return Downcast(FuseOps(f, opt_level, m)); + }; return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.FuseOps") -.set_body_typed(FuseOps); +TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps); } // namespace transform diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index d0ff169..67c62f3 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -22,13 +22,14 @@ * \brief API for Automatic Differentiation for the Relay IR. */ #include -#include -#include #include +#include #include -#include "pattern_util.h" -#include "pass_util.h" +#include + #include "let_list.h" +#include "pass_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -42,12 +43,14 @@ using namespace tvm::runtime; * Formally speaking, such requirement mean that the input function is a closed expression - * that is, it only refer to local variable that is it's parameter, or defined inside it. * Every top level definition satisfy this criteria. - * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> (Float[] -> Float[]). - * In relay we currently only support compile-time AD, but it should be enough for a lot of use case. + * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> + * (Float[] -> Float[]). In relay we currently only support compile-time AD, but it should be enough + * for a lot of use case. * - * In deep learning, the most common way to train a deep neural network is by gradient descent or some of it's variant. - * Such optimization method require us to input the gradient of neural network, which can be obtained easily using AD. - * In fact, back propagation is essentially reverse-mode automatic differentiation, a kind of AD! + * In deep learning, the most common way to train a deep neural network is by gradient descent or + * some of it's variant. Such optimization method require us to input the gradient of neural + * network, which can be obtained easily using AD. In fact, back propagation is essentially + * reverse-mode automatic differentiation, a kind of AD! */ /*! In relay, automatic differentiation(AD) is a macro, @@ -55,10 +58,10 @@ using namespace tvm::runtime; * (x0, x1, x2, ...) -> Float[] to * (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)), * When x0, x1, x2... are Float of different shape. - * the return value is a pair, with left hand side as the original value, and right hand side as gradient of the input. - * WithGradientType will take the type of input, and produce the type of output. - * There are multiple implementation of AD in relay, with different characteristic. - * However, they all transform the input expr according to WithGradientType. + * the return value is a pair, with left hand side as the original value, and right hand side as + * gradient of the input. WithGradientType will take the type of input, and produce the type of + * output. There are multiple implementation of AD in relay, with different characteristic. However, + * they all transform the input expr according to WithGradientType. */ Type WithGradientType(const Type&); @@ -71,10 +74,7 @@ Type WithGradientType(const Type& t) { // TODO(M.K.): stricter checking auto ty = t.as(); CHECK(ty) << "input should be a function"; - return FuncType(ty->arg_types, - TupleType({ - ty->ret_type, - TupleType(ty->arg_types)}), {}, {}); + return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {}); } //! \brief if the expression is a GlobalVar, transform to it's expression. @@ -95,7 +95,7 @@ Expr DeGlobal(const IRModule& mod, const Expr& e) { * pass. */ struct ADValueNode { - virtual ~ADValueNode() { } + virtual ~ADValueNode() {} template T& get() { auto ret = dynamic_cast(this); @@ -110,8 +110,8 @@ using ADValue = std::shared_ptr; struct ADTensor : ADValueNode { Expr forward; mutable Expr reverse; // must be a variable to avoid duplication - ADTensor(LetList* ll, const Expr& forward) : - forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { + ADTensor(LetList* ll, const Expr& forward) + : forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { this->forward->checked_type_ = forward->checked_type(); } }; @@ -121,51 +121,46 @@ struct ADTensor : ADValueNode { * can compute away this function to obtain a reverse mode program. */ struct ADFunction : ADValueNode { - std::function&, - const Attrs&, - const tvm::Array&)> func; - explicit ADFunction(const std::function&, - const Attrs&, - const tvm::Array&)>& func) : - func(func) { } + std::function&, const Attrs&, + const tvm::Array&)> + func; + explicit ADFunction(const std::function&, + const Attrs&, const tvm::Array&)>& func) + : func(func) {} }; -struct FirstOrderReverseAD : ExprFunctor { +struct FirstOrderReverseAD : ExprFunctor { const OpMap rev_map = Op::GetAttr("FPrimalGradient"); std::vector> backprop_actions; // we assume no closure so no need for lexical scoping std::unordered_map env; LetList* ll; - FirstOrderReverseAD(LetList* ll) : ll(ll) { } + FirstOrderReverseAD(LetList* ll) : ll(ll) {} ADValue VisitExpr_(const OpNode* op) final { Op op_ref = GetRef(op); - CHECK(rev_map.count(op_ref)) - << op->name << " does not have reverse mode defined"; - return std::make_shared([this, op_ref](const Type& orig_type, - const std::vector& args, - const Attrs& attrs, - const tvm::Array& type_args) { - std::vector call_args; - for (const ADValue& adval : args) { - call_args.push_back(adval->get().forward); - } - auto orig = Call(op_ref, call_args, attrs, type_args); - orig->checked_type_ = orig_type; - auto ret = std::make_shared(ll, orig); - backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, ret->reverse); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - args[i]->get().reverse = - ll->Push(Add(args[i]->get().reverse, rev[i])); - } - }); - return ret; - }); + CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined"; + return std::make_shared( + [this, op_ref](const Type& orig_type, const std::vector& args, const Attrs& attrs, + const tvm::Array& type_args) { + std::vector call_args; + for (const ADValue& adval : args) { + call_args.push_back(adval->get().forward); + } + auto orig = Call(op_ref, call_args, attrs, type_args); + orig->checked_type_ = orig_type; + auto ret = std::make_shared(ll, orig); + backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { + tvm::Array rev = rev_map[op_ref](orig, ret->reverse); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + args[i]->get().reverse = + ll->Push(Add(args[i]->get().reverse, rev[i])); + } + }); + return ret; + }); } ADValue VisitExpr_(const ConstantNode* op) final { @@ -185,16 +180,15 @@ struct FirstOrderReverseAD : ExprFunctor { ADValue VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); // todo: assert no closure - return std::make_shared([this, f](const Type& orig_type, - const std::vector& args, - const Attrs& attrs, - const tvm::Array& type_args) { - CHECK_EQ(f->params.size(), args.size()); - for (size_t i = 0; i < f->params.size(); ++i) { - env[f->params[i]] = args[i]; - } - return VisitExpr(f->body); - }); + return std::make_shared( + [this, f](const Type& orig_type, const std::vector& args, const Attrs& attrs, + const tvm::Array& type_args) { + CHECK_EQ(f->params.size(), args.size()); + for (size_t i = 0; i < f->params.size(); ++i) { + env[f->params[i]] = args[i]; + } + return VisitExpr(f->body); + }); } ADValue VisitExpr_(const VarNode* op) final { @@ -240,8 +234,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { const auto& res = c->get(); Expr grad = LetList::With([&](LetList* ll) { res.reverse = OnesLike(res.forward); - for (auto it = reverse_ad.backprop_actions.rbegin(); - it != reverse_ad.backprop_actions.rend(); + for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend(); ++it) { (*it)(ll); } @@ -257,8 +250,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { return Function(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient") -.set_body_typed(FirstOrderGradient); +TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient); struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { @@ -267,17 +259,13 @@ struct ReverseADType : TypeMutator { } }; -Type ReverseType(const Type& t) { - return ReverseADType()(t); -} +Type ReverseType(const Type& t) { return ReverseADType()(t); } /*! \brief Lift a function that transform Tensor to a function that also transform more type * by doing a structure preserving map. */ Expr LiftTensor(const std::function& f, - const std::function& tf, - const Type& forward_type, - const Expr& e, + const std::function& tf, const Type& forward_type, const Expr& e, LetList* ll) { CHECK(IsAtomic(e)) << e; if (forward_type.as()) { @@ -288,11 +276,7 @@ Expr LiftTensor(const std::function& f, tvm::Array fields; tvm::Array types; for (size_t i = 0; i < tt->fields.size(); ++i) { - auto field = LiftTensor(f, - tf, - tt->fields[i], - ll->Push(GetField(e, i)), - ll); + auto field = LiftTensor(f, tf, tt->fields[i], ll->Push(GetField(e, i)), ll); fields.push_back(field); types.push_back(field->checked_type_); } @@ -308,10 +292,7 @@ Expr LiftTensor(const std::function& f, /*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr, * by stitching the references in the AD values. */ -void TransferGrads(const Type& forward_type, - const Expr& from, - const Expr& to, - LetList* ll) { +void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, LetList* ll) { CHECK(IsAtomic(from)) << from; CHECK(IsAtomic(to)) << to; if (forward_type.as()) { @@ -320,9 +301,7 @@ void TransferGrads(const Type& forward_type, ll->Push(RefWrite(to_ref, RefRead(from_ref))); } else if (auto* tt = forward_type.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { - TransferGrads(tt->fields[i], - ll->Push(TupleGetItem(from, i)), - ll->Push(TupleGetItem(to, i)), + TransferGrads(tt->fields[i], ll->Push(TupleGetItem(from, i)), ll->Push(TupleGetItem(to, i)), ll); } } else { @@ -333,48 +312,31 @@ void TransferGrads(const Type& forward_type, /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { - auto rev = [&](const Expr& e) { - return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); - }; - auto rev_type = [&](const Type& forward_type) { - return ReverseType(forward_type); - }; + auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); }; + auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); }; return LiftTensor(rev, rev_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the original value. */ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { - auto val = [&](const Expr& e) { - return GetField(e, 0); - }; - auto val_type = [&](const Type& forward_type) { - return forward_type; - }; + auto val = [&](const Expr& e) { return GetField(e, 0); }; + auto val_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(val, val_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the gradient. */ Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { - auto grad = [&](const Expr& e) { - return ll->Push(RefRead(GetField(e, 1))); - }; - auto grad_type = [&](const Type& forward_type) { - return forward_type; - }; + auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); }; + auto grad_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(grad, grad_type, forward_type, e, ll); } void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { if (t.as()) { - ll->Push(RefWrite(GetField(arg, 1), - Add(ll->Push(RefRead(GetField(arg, 1))), - grad))); + ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad))); } else if (auto* tt = t.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { - UpdateGrad(tt->fields[i], - ll->Push(GetField(arg, i)), - ll->Push(GetField(grad, i)), - ll); + UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll); } } else { LOG(FATAL) << "unsupported arg type of operator: " << t; @@ -394,15 +356,14 @@ struct ReverseAD : ExprMutator { std::shared_ptr ad_vars; const OpMap rev_map = Op::GetAttr("FPrimalGradient"); - explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) - : bp(bp), ad_vars(ad_vars) { } + explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) : bp(bp), ad_vars(ad_vars) {} Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; throw; } - Expr VisitCheckpoint(const CallNode *call) { + Expr VisitCheckpoint(const CallNode* call) { const OpNode* op_node = call->op.as(); CHECK(op_node) << "expected op in call"; Op op_ref = GetRef(op_node); @@ -412,20 +373,17 @@ struct ReverseAD : ExprMutator { auto x_var = ll->Push(x); auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); auto bpv = ll->Push(RefRead(bp)); - Expr nbp = Function( - {}, - LetList::With([&](LetList* ll) { - // we need a new ReverseAD visitor to avoid clobbering the bp local var - auto dup_bp = ll->Push(BPEmpty()); - ReverseAD dup_diff(dup_bp, ad_vars); - auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); - - TransferGrads(call->checked_type(), ret, dup_ad, ll); - ll->Push(Call(RefRead(dup_bp), {})); - return Call(bpv, {}); - }), - TupleType::Empty(), - {}); + Expr nbp = Function({}, LetList::With([&](LetList* ll) { + // we need a new ReverseAD visitor to avoid clobbering the bp local var + auto dup_bp = ll->Push(BPEmpty()); + ReverseAD dup_diff(dup_bp, ad_vars); + auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); + + TransferGrads(call->checked_type(), ret, dup_ad, ll); + ll->Push(Call(RefRead(dup_bp), {})); + return Call(bpv, {}); + }), + TupleType::Empty(), {}); ll->Push(RefWrite(bp, nbp)); return ret; }); @@ -439,8 +397,7 @@ struct ReverseAD : ExprMutator { return VisitCheckpoint(call); } - CHECK(rev_map.count(op_ref)) - << op_node->name << " does not have reverse mode defined"; + CHECK(rev_map.count(op_ref)) << op_node->name << " does not have reverse mode defined"; return LetList::With([&](LetList* ll) { std::vector args; for (const auto& arg : call->args) { @@ -456,18 +413,16 @@ struct ReverseAD : ExprMutator { orig_var->checked_type_ = call->checked_type(); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefRead(bp)); - Expr nbp = Function( - {}, - LetList::With([&](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); - } - return Call(bpv, {}); - }), - TupleType::Empty(), - {}); + Expr nbp = Function({}, LetList::With([&](LetList* ll) { + tvm::Array rev = + rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); + } + return Call(bpv, {}); + }), + TupleType::Empty(), {}); ll->Push(RefWrite(bp, nbp)); return ret; }); @@ -481,9 +436,8 @@ struct ReverseAD : ExprMutator { } Expr VisitExpr_(const IfNode* op) final { - return If(TupleGetItem(VisitExpr(op->cond), 0), - VisitExpr(op->true_branch), - VisitExpr(op->false_branch)); + return If(TupleGetItem(VisitExpr(op->cond), 0), VisitExpr(op->true_branch), + VisitExpr(op->false_branch)); } Expr VisitExpr_(const VarNode* var) final { @@ -497,9 +451,7 @@ struct ReverseAD : ExprMutator { return ad_vars->at(var_ref); } - Type VisitType(const Type& t) final { - return t.defined() ? ReverseType(t) : t; - } + Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } }; bool MissingGrad(const Expr& e) { @@ -585,8 +537,7 @@ Expr Gradient(const Expr& re, const IRModule& mod) { return Function(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_GLOBAL("relay._transform.gradient") -.set_body_typed(Gradient); +TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/infer_layout_util.h b/src/relay/transforms/infer_layout_util.h index ca73003..e4df647 100644 --- a/src/relay/transforms/infer_layout_util.h +++ b/src/relay/transforms/infer_layout_util.h @@ -27,11 +27,13 @@ #ifndef TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_ #define TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_ -#include #include #include +#include + #include #include + #include "pattern_util.h" namespace tvm { @@ -94,17 +96,15 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o * \return infered_layout An array of two elements that are inferred input layouts and * inferred output layouts. */ -using FInferCorrectLayout = runtime::TypedPackedFunc< - Array>(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types)>; +using FInferCorrectLayout = runtime::TypedPackedFunc>( + const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, + const Array& old_in_types)>; /*! \brief take arbitrary input layout and copy to output */ -inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +inline Array> ElemwiseArbitraryLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { Layout ret; if (new_in_layouts.defined()) { @@ -119,14 +119,14 @@ inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, } } - return Array >{Array(old_in_layouts.size(), ret), {ret}}; + return Array>{Array(old_in_layouts.size(), ret), {ret}}; } /*! \brief Infer layout for binary broadcast operators */ -inline Array > BinaryBroadcastLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +inline Array> BinaryBroadcastLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { Array layouts; Array> old_in_shapes; for (auto old_in_t : old_in_types) { @@ -142,28 +142,27 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, if (!layouts[0].defined() && !layouts[1].defined()) { // both undefined, infer fails - return Array > {{Layout::Undef()}, {Layout::Undef()}}; + return Array>{{Layout::Undef()}, {Layout::Undef()}}; } else if (!layouts[0].defined() || !layouts[1].defined()) { // only one is defined, use shape information to help infer int defined_idx = layouts[0].defined() ? 0 : 1; int undef_idx = 1 - defined_idx; if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { - layouts.Set(undef_idx, - layouts[defined_idx].SubLayout( - old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(), - old_in_shapes[undef_idx].size())); - return Array >{layouts, {layouts[defined_idx]}}; + layouts.Set(undef_idx, layouts[defined_idx].SubLayout(old_in_shapes[defined_idx].size() - + old_in_shapes[undef_idx].size(), + old_in_shapes[undef_idx].size())); + return Array>{layouts, {layouts[defined_idx]}}; } else { // only know the tensor with smaller dimensions, // so we cannot infer the final broadcasted output. // fails in this case. - return Array >{{Layout::Undef()}, {Layout::Undef()}}; + return Array>{{Layout::Undef()}, {Layout::Undef()}}; } } else if (layouts[0].defined() && layouts[1].defined() && - (layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) { + (layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) { int scalar = layouts[0].ndim() == 0 ? 0 : 1; - return Array >{layouts, {layouts[1-scalar]}}; + return Array>{layouts, {layouts[1 - scalar]}}; } else { // Set the layout of the larger dimension. If one dimension size is lower, we call expand dims // while transforming layout. @@ -217,8 +216,7 @@ static inline std::tuple, Array, bool> InferCorrectLayouts Op op = Downcast(call->op); if (finfer_layout.count(op)) { Array> inferred_layouts; - inferred_layouts = - finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types); + inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types); CHECK_EQ(inferred_layouts.size(), 2) << "FInferCorrectLayout should return an array with size of 2"; for (auto x : inferred_layouts) { diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index ba0f568..c9a0de4 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -35,8 +35,9 @@ #include #include -#include #include +#include + #include #include @@ -83,11 +84,8 @@ class Inliner : ExprMutator { } Function Inline(const Function& func) { - return Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, - func->attrs); + return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, + func->attrs); } private: @@ -115,20 +113,13 @@ class Inliner : ExprMutator { } // Make a new Relay expression to replace the callee. - Expr MakeNewExpr(const GlobalVar& global, - const Array& args, - const Expr& callee) { - CHECK(callee->IsInstance() || - callee->IsInstance()); + Expr MakeNewExpr(const GlobalVar& global, const Array& args, const Expr& callee) { + CHECK(callee->IsInstance() || callee->IsInstance()); auto base_func = call_graph_->GetGlobalFunction(global); const auto* fn = base_func.as(); CHECK(fn) << "Expected to work on a Relay function."; - auto func = Function(fn->params, - fn->body, - fn->ret_type, - fn->type_params, - fn->attrs); + auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. if (!func->GetAttr(attr::kCompiler).defined()) { @@ -144,14 +135,13 @@ class Inliner : ExprMutator { // Cannot replace TensorType/TensorTupleType with FuncType. Therefore, // we simply inline the function as a closure instead of directly using // its body when the global var returns FuncType. - return ret_type->IsInstance() ? std::move(func) - : func->body; + return ret_type->IsInstance() ? std::move(func) : func->body; } else { CHECK(callee->IsInstance()); return Bind(func->body, bind_map); } } else if (const auto* call_node = callee.as()) { - return Call(func, args, call_node->attrs, call_node->type_args); + return Call(func, args, call_node->attrs, call_node->type_args); } else { return std::move(func); } @@ -214,14 +204,11 @@ namespace transform { Pass Inline() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::Inline(m); - }; + [=](IRModule m, PassContext pc) { return relay::Inline(m); }; return CreateModulePass(pass_func, 1, "InlineGlobals", {}); } -TVM_REGISTER_GLOBAL("relay._transform.Inline") -.set_body_typed(Inline); +TVM_REGISTER_GLOBAL("relay._transform.Inline").set_body_typed(Inline); } // namespace transform diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index e6248f1..3cd29d6 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -24,21 +24,21 @@ * \brief Lazily instantiate 0-filled or 1-filled tensors. * This pass should be used after reverse-mode ad so that gradient tensors * are not instantiated until after the forward pass. - * - * This pass delays or removes memory allocation by converting tensors into + * + * This pass delays or removes memory allocation by converting tensors into * GradCell, an algebraic data type defined in gradient.rly. - * + * * This will delay or decrease memory usage. All calls to * ones, ones_like, zeros, zeros_like will call the One or Zero constructor * of GradCell, which will not instantiate in memory until needed. All other cases result * in using the Raw constructor which means the tensor is instantiated in memory. - * + * * It also overloads + and * operation which can increase performance when doing * operations involving tensors with values of only 0 or 1. - * + * * Note: this pass can only be used with functions where the input/output types are * a combination of TupleTypes and TensorTypes - * + * * This pass optimizes 6 ops: * - add * - multiply @@ -46,39 +46,40 @@ * - ones_like * - zeros * - zeros_like - * + * * This pass makes use of three visitor. The most important one visits the entire function, * one is used for wrap inputs and one to unwrap outputs. - * + * * For example: * fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32] - * + * * After this pass * fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]] - * + * * Thus, it is necessary to wrap this outer function so that the input/output types remain the same */ +#include #include #include #include -#include #include + #include "let_list.h" namespace tvm { namespace relay { /*! -* \brief Visitor appropriately wraps tensors with Raw constructor -* -* Recursively looks at the type of the expression (TensorType or TupleType are only supported for now) -* and either call the GradCell constructor if TensorType -* or unfold and recursively visit if TupleType -*/ -class InputVisitor: public ExprFunctor { + * \brief Visitor appropriately wraps tensors with Raw constructor + * + * Recursively looks at the type of the expression (TensorType or TupleType are only supported for + * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if + * TupleType + */ +class InputVisitor : public ExprFunctor { public: - explicit InputVisitor(IRModule module): module_(module) {} + explicit InputVisitor(IRModule module) : module_(module) {} Expr VisitExpr_(const VarNode* op, const Type& t) final { std::cout << op->type_annotation << std::endl; @@ -88,13 +89,13 @@ class InputVisitor: public ExprFunctor { Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { return WrapExpr(GetRef(op), t); } + private: IRModule module_; Expr WrapExpr(const Expr expr, const Type& type) { if (type.as()) { - return Call(module_->GetConstructor("GradCell", "Raw"), - {expr}, Attrs(), {type}); + return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { @@ -110,15 +111,15 @@ class InputVisitor: public ExprFunctor { }; /*! -* \brief Visitor appropriately unwraps expressions with GradCell type into Tensors -* -* Recursively looks at the type of the expression -* and either use the FromGradCell function if TypeCall to GradCell -* or unfold and recursively visit if TupleType -*/ -class OutputVisitor: public ExprFunctor { + * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors + * + * Recursively looks at the type of the expression + * and either use the FromGradCell function if TypeCall to GradCell + * or unfold and recursively visit if TupleType + */ +class OutputVisitor : public ExprFunctor { public: - explicit OutputVisitor(IRModule module): module_(module) {} + explicit OutputVisitor(IRModule module) : module_(module) {} Expr VisitExpr_(const CallNode* op, const Type& t) final { return UnwrapExpr(GetRef(op), t); @@ -127,6 +128,7 @@ class OutputVisitor: public ExprFunctor { Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { return UnwrapExpr(GetRef(op), t); } + private: IRModule module_; @@ -150,19 +152,18 @@ class OutputVisitor: public ExprFunctor { } }; -class LazyGradientInitializer: public ExprMutator, public TypeMutator { +class LazyGradientInitializer : public ExprMutator, public TypeMutator { public: - explicit LazyGradientInitializer(IRModule module): - module_(module) { - module_->ImportFromStd("gradient.rly"); - } + explicit LazyGradientInitializer(IRModule module) : module_(module) { + module_->ImportFromStd("gradient.rly"); + } /*! - * \brief apply LazyGradientInit transformation and wrap function - * so that function type stays the same - * - * input/output types should only be a combination of TupleTypes and TensorTypes - */ + * \brief apply LazyGradientInit transformation and wrap function + * so that function type stays the same + * + * input/output types should only be a combination of TupleTypes and TensorTypes + */ Expr Transform(const Expr& e) { auto* f = (e).as(); auto* transformed = this->Mutate(e).as(); @@ -185,8 +186,8 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { } Expr VisitExpr_(const ConstantNode* op) final { - return Call(module_->GetConstructor("GradCell", "Raw"), - {GetRef(op)}, Attrs(), {op->checked_type()}); + return Call(module_->GetConstructor("GradCell", "Raw"), {GetRef(op)}, Attrs(), + {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { @@ -203,12 +204,12 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) { // fn() -> T, function returns result of the operation - Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}); + Expr func = + Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); // call appropriate GradCell constructor std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero"; - return Call(module_->GetConstructor("GradCell", constructor_name), - {func}, Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", constructor_name), {func}, Attrs(), + {call_node->checked_type()}); } if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) { @@ -218,23 +219,21 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { Expr func = Function({}, result, {call_node->checked_type()}, Array()); // call appropriate GradCell constructor std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero"; - return Call(module_->GetConstructor("GradCell", "One"), - {func}, Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", "One"), {func}, Attrs(), + {call_node->checked_type()}); } // handle all other ops Expr result = CallPrimitiveOp(call_node); // wrap result with Raw constructor - return Call(module_->GetConstructor("GradCell", "Raw"), {result}, - Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), + {call_node->checked_type()}); } // not an op return ExprMutator::VisitExpr_(call_node); } - Type VisitType(const Type& t) final { - return TypeMutator::VisitType(t); - } + Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } Type VisitType_(const TensorTypeNode* op) { GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); @@ -248,23 +247,22 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { IRModule module_; /*! - * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type - */ + * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type + */ Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) { // can only use overloaded functions if 2 arguments of same type if (call_node->args.size() != 2 || !tvm::StructuralEqual()(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { Expr result = CallPrimitiveOp(call_node); - return Call(module_->GetConstructor("GradCell", "Raw"), {result}, - Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), + {call_node->checked_type()}); } tvm::Array args; // create "fallback" function for overloaded function Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {Var("lhs", paramType), - Var("rhs", paramType)}; + tvm::Array params = {Var("lhs", paramType), Var("rhs", paramType)}; // use primitive op in this case Expr callOp = Call(call_node->op, {params[0], params[1]}); Expr func = Function(params, callOp, paramType, Array()); @@ -279,16 +277,15 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { } /*! - * \brief Convert calls to other ops by converting args into TensorType - * \return call expr returning result of op - */ + * \brief Convert calls to other ops by converting args into TensorType + * \return call expr returning result of op + */ Expr CallPrimitiveOp(const CallNode* call_node) { const auto fromFunc = module_->GetGlobalVar("FromGradCell"); tvm::Array args; // use FromGradCell to convert args to Tensor for (Expr expr : call_node->args) { - args.push_back(Call(fromFunc, - {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } // result of operation return Call(call_node->op, args); @@ -302,14 +299,13 @@ Expr LazyGradientInit(const Expr& e, IRModule mod) { namespace transform { Pass LazyGradientInit() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(LazyGradientInit(f, m)); - }; - return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); + [=](Function f, IRModule m, PassContext pc) { + return Downcast(LazyGradientInit(f, m)); + }; + return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); } -TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit") -.set_body_typed(LazyGradientInit); +TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit").set_body_typed(LazyGradientInit); } // namespace transform diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc index 0b5c671..25919b4 100644 --- a/src/relay/transforms/legalize.cc +++ b/src/relay/transforms/legalize.cc @@ -23,10 +23,10 @@ * shape, dtype or layout to another op or a sequence of ops. */ -#include #include #include #include +#include namespace tvm { namespace relay { diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index f195c30..c0e0b3a 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -29,12 +29,14 @@ #ifndef TVM_RELAY_TRANSFORMS_LET_LIST_H_ #define TVM_RELAY_TRANSFORMS_LET_LIST_H_ -#include #include +#include + +#include +#include #include #include -#include -#include + #include "tvm/relay/type.h" namespace tvm { @@ -77,9 +79,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr, Type ty) { - return Push(Var("x", ty), expr); - } + Var Push(Expr expr, Type ty) { return Push(Var("x", ty), expr); } /*! * \brief insert a binding. @@ -88,9 +88,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr) { - return Push(expr, Type()); - } + Var Push(Expr expr) { return Push(expr, Type()); } /*! * \brief wrap an expr around the LetList. @@ -130,16 +128,14 @@ class LetList { * * \return the wrapped Expr. */ - template + template static Expr With(F&& f) { LetList ll; return ll.Get(f(&ll)); } static Expr LetBind(const Expr& e, const std::function& f) { - return With([&](LetList* ll) { - return f(ll->Push(e)); - }); + return With([&](LetList* ll) { return f(ll->Push(e)); }); } private: diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 75d95f0..46fdae0 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -234,8 +234,7 @@ Pass MergeComposite(const tvm::Array& pattern_names, return func_pass; } -TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) { tvm::Array pattern_names = args[0]; tvm::Array patterns = args[1]; std::vector checks; diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index cd1f40c..a27cb79 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -91,12 +91,13 @@ */ #include #include -#include #include -#include #include -#include "pass_util.h" +#include +#include + #include "let_list.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -109,9 +110,7 @@ using namespace runtime; * Use VarHash to hash Var by id. */ struct VarHash { - size_t operator()(const Var& v) const { - return ObjectHash()(v->vid); - } + size_t operator()(const Var& v) const { return ObjectHash()(v->vid); } }; /*! \brief Compare Var by it's id. @@ -119,9 +118,7 @@ struct VarHash { * Use VarEqual to compare Var by id. */ struct VarEqual { - bool operator()(const Var& l, const Var& r) const { - return l->vid.get() == r->vid.get(); - } + bool operator()(const Var& l, const Var& r) const { return l->vid.get() == r->vid.get(); } }; Expr PostProcess(const Expr&); @@ -137,9 +134,7 @@ class Static : public ObjectRef { public: Static() {} explicit Static(ObjectPtr n) : ObjectRef(n) {} - const StaticNode* operator->() const { - return static_cast(get()); - } + const StaticNode* operator->() const { return static_cast(get()); } using ContainerType = StaticNode; }; @@ -156,9 +151,9 @@ struct PStaticNode : Object { Static pstatic; // may be null Expr dynamic; Time created_time; - PStaticNode(const Static& pstatic, const Expr& dynamic) : - pstatic(pstatic), dynamic(dynamic), created_time(time()) { } - explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } + PStaticNode(const Static& pstatic, const Expr& dynamic) + : pstatic(pstatic), dynamic(dynamic), created_time(time()) {} + explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) {} static constexpr const char* _type_key = "relay.PStatic"; TVM_DECLARE_FINAL_OBJECT_INFO(PStaticNode, Object); }; @@ -170,7 +165,7 @@ class PStatic : public ObjectRef { struct STupleNode : StaticNode { std::vector fields; - explicit STupleNode(const std::vector& fields) : fields(fields) { } + explicit STupleNode(const std::vector& fields) : fields(fields) {} static constexpr const char* _type_key = "relay.STuple"; TVM_DECLARE_FINAL_OBJECT_INFO(STupleNode, StaticNode); }; @@ -186,7 +181,7 @@ Static MkSTuple(const std::vector& fields) { struct STensorNode : StaticNode { runtime::NDArray data; - explicit STensorNode(const NDArray& data) : data(data) { } + explicit STensorNode(const NDArray& data) : data(data) {} static constexpr const char* _type_key = "relay.STensor"; TVM_DECLARE_FINAL_OBJECT_INFO(STensorNode, StaticNode); }; @@ -196,15 +191,13 @@ class STensor : public Static { TVM_DEFINE_OBJECT_REF_METHODS(STensor, Static, STensorNode); }; -Static MkSTensor(const NDArray& data) { - return Static(make_object(data)); -} +Static MkSTensor(const NDArray& data) { return Static(make_object(data)); } struct SConstructorNode : StaticNode { Constructor constructor; std::vector fields; - SConstructorNode(const Constructor& constructor, const std::vector& fields) : - constructor(constructor), fields(fields) { } + SConstructorNode(const Constructor& constructor, const std::vector& fields) + : constructor(constructor), fields(fields) {} static constexpr const char* _type_key = "relay.SConstructor"; TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode); }; @@ -229,19 +222,14 @@ class SRef : public Static { TVM_DEFINE_OBJECT_REF_METHODS(SRef, Static, SRefNode); }; -Static MkSRef() { - return Static(make_object()); -} +Static MkSRef() { return Static(make_object()); } -using Func = std::function&, - const Attrs&, - const Array&, - LetList*)>; +using Func = std::function&, const Attrs&, + const Array&, LetList*)>; struct SFuncNode : StaticNode { Func func; - explicit SFuncNode(const Func& func) : func(func) { } + explicit SFuncNode(const Func& func) : func(func) {} static constexpr const char* _type_key = "relay.SFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(SFuncNode, StaticNode); }; @@ -251,15 +239,13 @@ class SFunc : public Static { TVM_DEFINE_OBJECT_REF_METHODS(SFunc, Static, SFuncNode); }; -Static MkSFunc(const Func& func) { - return Static(make_object(func)); -} - +Static MkSFunc(const Func& func) { return Static(make_object(func)); } class FuelNode; /*! \brief A meet-semilattice with finite descending chain. * It means that we can meet two element to get an element, - * and for every element, there is only a finite amount of meet before getting back the same element. + * and for every element, there is only a finite amount of meet before getting back the same + * element. * * Every time we recurse, we do a meet and require that progress must be made. * This ensures we do not recurse infinitely in the Partial Evaluator. @@ -301,9 +287,7 @@ class FuelNode : public RelayNode { TVM_DECLARE_BASE_OBJECT_INFO(FuelNode, RelayNode); }; -const FuelNode* Fuel::operator->() const { - return static_cast(get()); -} +const FuelNode* Fuel::operator->() const { return static_cast(get()); } Fuel MkFSeq(const std::vector& fuels); struct FSeqNode : FuelNode { @@ -318,7 +302,7 @@ struct FSeqNode : FuelNode { } return MkFSeq(new_fuels); } - explicit FSeqNode(const std::vector& fuels) : fuels(fuels) { } + explicit FSeqNode(const std::vector& fuels) : fuels(fuels) {} static constexpr const char* _type_key = "relay.FSeq"; TVM_DECLARE_FINAL_OBJECT_INFO(FSeqNode, FuelNode); }; @@ -328,9 +312,7 @@ class FSeq : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FSeq, Fuel, FSeqNode); }; -Fuel MkFSeq(const std::vector& fuels) { - return Fuel(make_object(fuels)); -} +Fuel MkFSeq(const std::vector& fuels) { return Fuel(make_object(fuels)); } Fuel MkFTime(Time time); struct FTimeNode : FuelNode { @@ -341,7 +323,7 @@ struct FTimeNode : FuelNode { Time new_time = std::min(time, x->time); return std::make_tuple(MkFTime(new_time), new_time < time); } - explicit FTimeNode(Time time) : time(time) { } + explicit FTimeNode(Time time) : time(time) {} static constexpr const char* _type_key = "relay.FTime"; TVM_DECLARE_FINAL_OBJECT_INFO(FTimeNode, FuelNode); }; @@ -351,9 +333,7 @@ class FTime : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FTime, Fuel, FTimeNode); }; -Fuel MkFTime(Time time) { - return Fuel(make_object(time)); -} +Fuel MkFTime(Time time) { return Fuel(make_object(time)); } Fuel MkFTValue(size_t tvalue); /*! \brief If the pstatic is hold a positive integer scalar, that number, else 0. */ @@ -365,7 +345,7 @@ struct FTValueNode : FuelNode { size_t new_tvalue = std::min(tvalue, x->tvalue); return std::make_tuple(MkFTValue(new_tvalue), new_tvalue < tvalue); } - explicit FTValueNode(size_t tvalue) : tvalue(tvalue) { } + explicit FTValueNode(size_t tvalue) : tvalue(tvalue) {} static constexpr const char* _type_key = "relay.FTValue"; TVM_DECLARE_FINAL_OBJECT_INFO(FTValueNode, FuelNode); }; @@ -375,9 +355,7 @@ class FTValue : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FTValue, Fuel, FTValueNode); }; -Fuel MkFTValue(size_t tvalue) { - return Fuel(make_object(tvalue)); -} +Fuel MkFTValue(size_t tvalue) { return Fuel(make_object(tvalue)); } /*! \brief Initially every element has Fuel of FTop. It is the largest element. * @@ -397,9 +375,7 @@ class FTop : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FTop, Fuel, FTopNode); }; -Fuel MkFTop() { - return Fuel(make_object()); -} +Fuel MkFTop() { return Fuel(make_object()); } /*! * \brief A stack frame in the Relay interpreter. @@ -414,10 +390,10 @@ struct Frame { class Environment { public: - Environment() : env_({Frame()}) { } + Environment() : env_({Frame()}) {} Environment(const Environment&) = delete; - template + template T Extend(const std::function& body) { FrameContext fc(this); return body(); @@ -447,12 +423,8 @@ class Environment { struct FrameContext { Environment* env_; - explicit FrameContext(Environment* env) : env_(env) { - env_->env_.push_back(Frame()); - } - ~FrameContext() { - env_->env_.pop_back(); - } + explicit FrameContext(Environment* env) : env_(env) { env_->env_.push_back(Frame()); } + ~FrameContext() { env_->env_.pop_back(); } }; }; @@ -470,16 +442,16 @@ struct StoreFrame { * It only outdate the frame above it, but not the current frame. */ bool history_valid = true; - explicit StoreFrame(const std::unordered_map& store) : store(store) { } + explicit StoreFrame(const std::unordered_map& store) : store(store) {} StoreFrame() = default; }; class Store { public: - Store() : store_({StoreFrame()}) { } + Store() : store_({StoreFrame()}) {} Store(const Store&) = delete; - template + template T Extend(const std::function& body) { StoreFrameContext sfc(this); return body(); @@ -534,13 +506,9 @@ PStatic HasStatic(const Static& stat, const Expr& dynamic) { return PStatic(make_object(stat, dynamic)); } -PStatic NoStatic(const Expr& dynamic) { - return PStatic(make_object(dynamic)); -} +PStatic NoStatic(const Expr& dynamic) { return PStatic(make_object(dynamic)); } -enum struct MatchStatus { - Match, NoMatch, Unknown -}; +enum struct MatchStatus { Match, NoMatch, Unknown }; bool StatefulOp(const Expr& e) { static auto op_stateful = Op::GetAttr("TOpIsStateful"); @@ -582,20 +550,16 @@ struct WithFuncIdAttrs : public tvm::AttrsNode { FuncId fid; TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") { - TVM_ATTR_FIELD(fid) - .describe("The FuncId that an function is annotated with.") - .set_default(-1); + TVM_ATTR_FIELD(fid).describe("The FuncId that an function is annotated with.").set_default(-1); } }; TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); - RELAY_REGISTER_OP("annotation.with_funcid") -.describe(R"code(Annotate a function with a funcid.)code" -TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("func", "Function", "The input data."); + .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("func", "Function", "The input data."); // Cache with_funcid op to reduce lookup overhead during traversal. static const Op& with_funcid_op = Op::Get("annotation.with_funcid"); @@ -624,7 +588,7 @@ Function AsFunc(const Expr& e) { class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const IRModule& mod) : mod_(mod) { } + PartialEvaluator(const IRModule& mod) : mod_(mod) {} PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic ret = ExprFunctor::VisitExpr(e, ll); @@ -639,9 +603,8 @@ class PartialEvaluator : public ExprFunctor return VisitExpr(c->args[0], ll, name); } } - PStatic ret = e.as() ? - VisitFunc(Downcast(e), ll, name) : - VisitExpr(e, ll); + PStatic ret = + e.as() ? VisitFunc(Downcast(e), ll, name) : VisitExpr(e, ll); CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; return ret; } @@ -670,9 +633,7 @@ class PartialEvaluator : public ExprFunctor } } - PStatic VisitExpr_(const VarNode* op, LetList* ll) final { - return env_.Lookup(GetRef(op)); - } + PStatic VisitExpr_(const VarNode* op, LetList* ll) final { return env_.Lookup(GetRef(op)); } PStatic VisitGlobalVar(const GlobalVar& gv) { CHECK(mod_.defined()); @@ -714,15 +675,11 @@ class PartialEvaluator : public ExprFunctor } } else { Expr t = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - return VisitExpr(op->true_branch, ll)->dynamic; - }); - }); + return LetList::With([&](LetList* ll) { return VisitExpr(op->true_branch, ll)->dynamic; }); + }); Expr f = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - return VisitExpr(op->false_branch, ll)->dynamic; - }); - }); + return LetList::With([&](LetList* ll) { return VisitExpr(op->false_branch, ll)->dynamic; }); + }); store_.Invalidate(); return NoStatic(ll->Push(If(c->dynamic, t, f))); } @@ -782,16 +739,12 @@ class PartialEvaluator : public ExprFunctor PartialEvaluator* pe_; FuncId fid_; Fuel old_fuel; - FuelFrame(PartialEvaluator* pe, - FuncId fid, - const Fuel& new_fuel) : pe_(pe), fid_(fid) { + FuelFrame(PartialEvaluator* pe, FuncId fid, const Fuel& new_fuel) : pe_(pe), fid_(fid) { CHECK_GT(pe_->fuel_map_.count(fid_), 0); old_fuel = pe_->fuel_map_[fid_]; pe_->fuel_map_[fid_] = new_fuel; } - ~FuelFrame() { - pe_->fuel_map_[fid_] = old_fuel; - } + ~FuelFrame() { pe_->fuel_map_[fid_] = old_fuel; } }; size_t GetFTValue(const PStatic& ps) { @@ -829,82 +782,76 @@ class PartialEvaluator : public ExprFunctor free_vars.push_back(std::pair(v, env_.Lookup(v))); } } - return [=](const PStatic& self, - const std::vector& pv, - const Attrs& attrs, - const tvm::Array& type_args, - LetList* ll) { + return [=](const PStatic& self, const std::vector& pv, const Attrs& attrs, + const tvm::Array& type_args, LetList* ll) { return env_.Extend([&]() { - CHECK_EQ(pv.size(), func->params.size()); - CHECK_GT(func_map_.count(func), 0); - FuncId fid = func_map_.at(func); - if (fuel_map_.count(fid) == 0) { - fuel_map_.insert({fid, MkFTop()}); + CHECK_EQ(pv.size(), func->params.size()); + CHECK_GT(func_map_.count(func), 0); + FuncId fid = func_map_.at(func); + if (fuel_map_.count(fid) == 0) { + fuel_map_.insert({fid, MkFTop()}); + } + std::vector args_fuel; + for (const auto& v : pv) { + args_fuel.push_back(GetFuel(v)); + } + auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); + if (std::get<1>(meet_res)) { + FuelFrame tf(this, fid, std::get<0>(meet_res)); + Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func))); + Function func = AsFunc(dedup_func); + if (var.as()) { + env_.Insert(Downcast(var), self); } - std::vector args_fuel; - for (const auto& v : pv) { - args_fuel.push_back(GetFuel(v)); + for (size_t i = 0; i < pv.size(); ++i) { + env_.Insert(func->params[i], pv[i]); + } + for (const auto& p : free_vars) { + env_.Insert(p.first, p.second); + } + tvm::Map subst; + for (size_t i = 0; i < type_args.size(); ++i) { + subst.Set(func->type_params[i], type_args[i]); } - auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); - if (std::get<1>(meet_res)) { - FuelFrame tf(this, fid, std::get<0>(meet_res)); - Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func))); - Function func = AsFunc(dedup_func); - if (var.as()) { - env_.Insert(Downcast(var), self); - } - for (size_t i = 0; i < pv.size(); ++i) { - env_.Insert(func->params[i], pv[i]); - } - for (const auto& p : free_vars) { - env_.Insert(p.first, p.second); - } - tvm::Map subst; - for (size_t i = 0; i < type_args.size(); ++i) { - subst.Set(func->type_params[i], type_args[i]); - } - for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { - subst.Set(func->type_params[i], IncompleteType(kType)); - } - return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); - } else { - std::vector dyn; - for (const auto& v : pv) { - dyn.push_back(v->dynamic); - } - return NoStatic(ll->Push(Call(var, dyn, attrs, type_args))); + for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { + subst.Set(func->type_params[i], IncompleteType(kType)); } - }); + return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); + } else { + std::vector dyn; + for (const auto& v : pv) { + dyn.push_back(v->dynamic); + } + return NoStatic(ll->Push(Call(var, dyn, attrs, type_args))); + } + }); }; } Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { return store_.Extend([&]() { store_.Invalidate(); - return Function(func->params, - LetList::With([&](LetList* ll) { - std::vector pv; - for (const auto& v : func->params) { - pv.push_back(NoStatic(v)); - } - tvm::Array type_args; - for (const auto& tp : func->type_params) { - type_args.push_back(tp); - } - return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; - }), func->ret_type, func->type_params, func->attrs); + return Function(func->params, LetList::With([&](LetList* ll) { + std::vector pv; + for (const auto& v : func->params) { + pv.push_back(NoStatic(v)); + } + tvm::Array type_args; + for (const auto& tp : func->type_params) { + type_args.push_back(tp); + } + return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; + }), + func->ret_type, func->type_params, func->attrs); }); } - PStatic VisitFunc(const Function& func, - LetList* ll, - const Var& name = Var("x", Type())) { + PStatic VisitFunc(const Function& func, LetList* ll, const Var& name = Var("x", Type())) { Func f = VisitFuncStatic(func, name); Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); // TODO(@M.K.): we seems to reduce landin knot into letrec. // restore letrec support across whole relay. - return HasStatic(MkSFunc(f), - ll->Push(name, VisitFuncDynamic(u_func, f, name))); + return HasStatic(MkSFunc(f), ll->Push(name, VisitFuncDynamic(u_func, f, name))); } PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { @@ -912,7 +859,7 @@ class PartialEvaluator : public ExprFunctor } struct ReflectError : dmlc::Error { - ReflectError() : dmlc::Error("static value not found") { } + ReflectError() : dmlc::Error("static value not found") {} }; Expr Reflect(const PStatic& st) { @@ -954,31 +901,24 @@ class PartialEvaluator : public ExprFunctor // Constant evaluate a expression. PStatic ConstEvaluate(const Expr& expr, LetList* ll) { - std::vector passes = {transform::FuseOps(0), - transform::InferType()}; + std::vector passes = {transform::FuseOps(0), transform::InferType()}; auto mod = IRModule::FromExpr(expr); auto seq = transform::Sequential(passes); mod = seq(mod); auto entry_func = Downcast(mod->Lookup("main")); - auto fused_infered = - expr.as() == nullptr ? entry_func->body : entry_func; + auto fused_infered = expr.as() == nullptr ? entry_func->body : entry_func; return Reify(executor_(fused_infered), ll); } Func ConstEvaluateFunc(const Expr& expr) { CHECK_EQ(FreeVars(expr).size(), 0); - return [=](const PStatic& self, - const std::vector& pv, - const Attrs& attrs, - const tvm::Array& type_args, - LetList* ll) { + return [=](const PStatic& self, const std::vector& pv, const Attrs& attrs, + const tvm::Array& type_args, LetList* ll) { tvm::Array ns_args; for (const PStatic& ps : pv) { ns_args.push_back(ps->dynamic); } - auto ns = [&]() { - return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); - }; + auto ns = [&]() { return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); }; if (StatefulOp(expr)) { return ns(); } @@ -988,8 +928,7 @@ class PartialEvaluator : public ExprFunctor args.push_back(Reflect(ps)); } return ConstEvaluate(Call(expr, args, attrs, type_args), ll); - } - catch (const ReflectError&) { + } catch (const ReflectError&) { return ns(); } }; @@ -1001,11 +940,8 @@ class PartialEvaluator : public ExprFunctor PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final { Constructor c = GetRef(op); - Func f = [=](const PStatic& self, - const std::vector& pv, - const Attrs& attrs, - const tvm::Array& type_args, - LetList* ll) { + Func f = [=](const PStatic& self, const std::vector& pv, const Attrs& attrs, + const tvm::Array& type_args, LetList* ll) { tvm::Array dyn; for (const PStatic& ps : pv) { dyn.push_back(ps->dynamic); @@ -1020,30 +956,30 @@ class PartialEvaluator : public ExprFunctor return env_.Extend([&]() { for (const Clause& c : op->clauses) { switch (VisitPattern(c->lhs, ps)) { - case MatchStatus::Match: - return VisitExpr(c->rhs, ll); - case MatchStatus::NoMatch: - continue; - case MatchStatus::Unknown: - return [&]() { - tvm::Array clauses; - for (const Clause& c : op->clauses) { - Expr expr = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - for (const Var& v : BoundVars(c->lhs)) { - env_.Insert(v, NoStatic(v)); - } - return VisitExpr(c->rhs, ll)->dynamic; + case MatchStatus::Match: + return VisitExpr(c->rhs, ll); + case MatchStatus::NoMatch: + continue; + case MatchStatus::Unknown: + return [&]() { + tvm::Array clauses; + for (const Clause& c : op->clauses) { + Expr expr = store_.Extend([&]() { + return LetList::With([&](LetList* ll) { + for (const Var& v : BoundVars(c->lhs)) { + env_.Insert(v, NoStatic(v)); + } + return VisitExpr(c->rhs, ll)->dynamic; + }); }); - }); - clauses.push_back(Clause(c->lhs, expr)); - } - store_.Invalidate(); - return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete))); - }(); - default: - LOG(FATAL) << "Unknown MatchStatus"; - throw; + clauses.push_back(Clause(c->lhs, expr)); + } + store_.Invalidate(); + return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete))); + }(); + default: + LOG(FATAL) << "Unknown MatchStatus"; + throw; } } LOG(FATAL) << "No case Match"; @@ -1071,12 +1007,12 @@ class PartialEvaluator : public ExprFunctor for (size_t i = 0; i < op->patterns.size(); ++i) { MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]); switch (ms) { - case MatchStatus::Match: - continue; - case MatchStatus::NoMatch: - return MatchStatus::NoMatch; - case MatchStatus::Unknown: - current_match_status = MatchStatus::Unknown; + case MatchStatus::Match: + continue; + case MatchStatus::NoMatch: + return MatchStatus::NoMatch; + case MatchStatus::Unknown: + current_match_status = MatchStatus::Unknown; } } return current_match_status; @@ -1095,12 +1031,12 @@ class PartialEvaluator : public ExprFunctor for (size_t i = 0; i < op->patterns.size(); ++i) { MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]); switch (ms) { - case MatchStatus::Match: - continue; - case MatchStatus::NoMatch: - return MatchStatus::NoMatch; - case MatchStatus::Unknown: - current_match_status = MatchStatus::Unknown; + case MatchStatus::Match: + continue; + case MatchStatus::NoMatch: + return MatchStatus::NoMatch; + case MatchStatus::Unknown: + current_match_status = MatchStatus::Unknown; } } return current_match_status; @@ -1112,7 +1048,7 @@ class PartialEvaluator : public ExprFunctor void InitializeFuncId(const Expr& e) { struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor { PartialEvaluator* pe; - explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } + explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {} void VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); @@ -1121,9 +1057,7 @@ class PartialEvaluator : public ExprFunctor VisitExpr(f->body); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } }; InitializeFuncIdVisitor(this).VisitExpr(e); } @@ -1131,7 +1065,7 @@ class PartialEvaluator : public ExprFunctor Expr RegisterFuncId(const Expr& e) { struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor { PartialEvaluator* pe; - explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } + explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {} void VisitExpr_(const CallNode* op) final { if (op->op == with_funcid_op) { @@ -1154,9 +1088,7 @@ class PartialEvaluator : public ExprFunctor ExprVisitor::VisitExpr_(op); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } }; RegisterFuncIdVisitor(this).VisitExpr(e); return e; @@ -1165,7 +1097,7 @@ class PartialEvaluator : public ExprFunctor Expr AnnotateFuncId(const Expr& e) { struct AnnotateFuncIdMutator : ExprMutator, PatternMutator { PartialEvaluator* pe; - explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { } + explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) {} Expr VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); @@ -1173,13 +1105,9 @@ class PartialEvaluator : public ExprFunctor return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f)); } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } - Var VisitVar(const Var& v) final { - return v; - } + Var VisitVar(const Var& v) final { return v; } }; return AnnotateFuncIdMutator(this).VisitExpr(e); } @@ -1199,7 +1127,8 @@ class PartialEvaluator : public ExprFunctor * If no progress is made, we do not inline. * In both case, we remap the mapping to the new Fuel * when we PE inside the Function body. - * Termination is guaranteed because Fuel is finitely descending - there can only be so many meet. + * Termination is guaranteed because Fuel is finitely descending - there can only be so many + * meet. */ std::unordered_map func_map_; std::unordered_map fuel_map_; @@ -1219,9 +1148,7 @@ Expr Remap(const Expr& e) { return remap_.at(v); } - Var VisitVar(const Var& v) final { - return Downcast(VisitExpr(v)); - } + Var VisitVar(const Var& v) final { return Downcast(VisitExpr(v)); } private: std::unordered_map remap_; @@ -1240,20 +1167,14 @@ Expr StripWithFuncId(const Expr& e) { } } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } - Var VisitVar(const Var& v) final { - return v; - } + Var VisitVar(const Var& v) final { return v; } }; return StripWithFuncIdMutator().VisitExpr(e); } -Expr PostProcess(const Expr& e) { - return StripWithFuncId(DeDup(Remap(e))); -} +Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); } } // namespace partial_eval @@ -1273,14 +1194,11 @@ namespace transform { Pass PartialEval() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::PartialEval(m); - }; + [=](IRModule m, PassContext pc) { return relay::PartialEval(m); }; return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } -TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate") -.set_body_typed(PartialEval); +TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval); } // namespace transform diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 56b0645..32ee09f 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -25,9 +25,10 @@ #ifndef TVM_RELAY_TRANSFORMS_PASS_UTIL_H_ #define TVM_RELAY_TRANSFORMS_PASS_UTIL_H_ -#include -#include #include +#include +#include + #include #include @@ -100,41 +101,37 @@ inline bool IsAtomic(const Expr& e) { return e.as() || e.as() || e.as() || e.as(); } -template +template struct TreeNode { typedef std::shared_ptr> pointer; virtual ~TreeNode() {} }; -template +template struct TreeLeafNode : TreeNode { using TreeObjectPtr = typename TreeNode::pointer; Expr body; - explicit TreeLeafNode(Expr body): body(body) {} + explicit TreeLeafNode(Expr body) : body(body) {} - static TreeObjectPtr Make(Expr body) { - return std::make_shared(body); - } + static TreeObjectPtr Make(Expr body) { return std::make_shared(body); } ~TreeLeafNode() {} }; -template +template struct TreeLeafFatalNode : TreeNode { using TreeObjectPtr = typename TreeNode::pointer; TreeLeafFatalNode() = default; - static TreeObjectPtr Make() { - return std::make_shared(); - } + static TreeObjectPtr Make() { return std::make_shared(); } ~TreeLeafFatalNode() {} }; -template +template struct TreeBranchNode : TreeNode { using TreeObjectPtr = typename TreeNode::pointer; @@ -142,15 +139,11 @@ struct TreeBranchNode : TreeNode { TreeObjectPtr then_branch; TreeObjectPtr else_branch; - TreeBranchNode(ConditionObjectPtr cond, - TreeObjectPtr then_branch, - TreeObjectPtr else_branch) - : cond(cond), then_branch(then_branch), else_branch(else_branch) {} - + TreeBranchNode(ConditionObjectPtr cond, TreeObjectPtr then_branch, TreeObjectPtr else_branch) + : cond(cond), then_branch(then_branch), else_branch(else_branch) {} - static TreeObjectPtr Make(ConditionObjectPtr cond, - TreeObjectPtr then_branch, - TreeObjectPtr else_branch) { + static TreeObjectPtr Make(ConditionObjectPtr cond, TreeObjectPtr then_branch, + TreeObjectPtr else_branch) { return std::make_shared(cond, then_branch, else_branch); } diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index cd2af9f..edb6a65 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -28,19 +28,18 @@ #include #include -#include -#include -#include #include #include -#include #include +#include +#include +#include #include +#include #include -#include #include - +#include namespace tvm { namespace relay { @@ -49,42 +48,42 @@ namespace relay { * \brief Dispatch DataType to the C++ data type * during runtime. */ -#define TVM_DTYPE_DISPATCH(type, DType, ...) \ - if (type == DataType::Float(64)) { \ - typedef double DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(32)) { \ - typedef float DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(64)) { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(32)) { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(16)) { \ - typedef int16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(8)) { \ - typedef int8_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(64)) { \ - typedef uint64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(32)) { \ - typedef uint32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(8)) { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } else { \ - LOG(FATAL) << "unknown data type " << type; \ +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == DataType::Float(64)) { \ + typedef double DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(32)) { \ + typedef float DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(64)) { \ + typedef int64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(32)) { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(16)) { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(8)) { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(64)) { \ + typedef uint64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(32)) { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(8)) { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } else { \ + LOG(FATAL) << "unknown data type " << type; \ } /*! @@ -99,10 +98,8 @@ namespace relay { * \param rhs_value A squeezed version of rhs which only contains matched dimension. * \return Whether match is successful. */ -inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, - const TensorTypeNode* trhs, - const Array& lhs_axes, - Expr* rhs_value = nullptr) { +inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const TensorTypeNode* trhs, + const Array& lhs_axes, Expr* rhs_value = nullptr) { if (tlhs->shape.size() < trhs->shape.size()) return false; StructuralEqual equal; size_t base = tlhs->shape.size() - trhs->shape.size(); @@ -145,9 +142,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, * \param target_ndim Target dimension. * \param axes The axis on the output we want to match on. */ -inline Expr ExpandBiasToMatchAxis(Expr bias, - int target_ndim, - const Array& axes) { +inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array& axes) { static const Op& expand_dims = Op::Get("expand_dims"); for (size_t i = axes.size(); i != 0; --i) { if (i == axes.size()) { @@ -179,14 +174,12 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, * \param param The conv2d attributes. * \return Whether it is depthwise_conv2d. */ -inline bool IsDepthwiseConv2D(const Call& call, - const Conv2DAttrs* param, +inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param, const Layout& kernel_layout) { static const Layout kOIHW("OIHW"); const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW); auto wshape = bilayout.ForwardShape(call->args[1]->type_as()->shape); - return tir::is_const_int(wshape[0], param->groups) && - tir::is_const_int(wshape[1], 1); + return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1); } /*! @@ -195,12 +188,12 @@ inline bool IsDepthwiseConv2D(const Call& call, * \return Super-dimension size of output channels of conv2d. */ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { - auto param = call->attrs.as(); - auto tweight = call->args[1]->type_as(); - auto index = param->kernel_layout.find('O'); - CHECK_NE(index, std::string::npos); - auto channels = tir::as_const_int(tweight->shape[index]); - return *channels; + auto param = call->attrs.as(); + auto tweight = call->args[1]->type_as(); + auto index = param->kernel_layout.find('O'); + CHECK_NE(index, std::string::npos); + auto channels = tir::as_const_int(tweight->shape[index]); + return *channels; } /*! @@ -304,13 +297,9 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { return tvm::StructuralEqual()(a, b); } -inline Expr GetField(Expr t, size_t i) { - return TupleGetItem(t, i); -} +inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); } -inline Expr Pair(Expr l, Expr r) { - return Tuple({l, r}); -} +inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); } inline Expr Exp(Expr e) { static const Op& op = Op::Get("exp"); @@ -362,25 +351,21 @@ inline Expr Negative(Expr x) { return Call(op, {x}, Attrs(), {}); } - inline Expr Sqrt(Expr x) { static const Op& op = Op::Get("sqrt"); return Call(op, {x}, Attrs(), {}); } - inline Expr Relu(Expr x) { static const Op& op = Op::Get("nn.relu"); return Call(op, {x}, Attrs(), {}); } - inline Expr Round(Expr x) { static const Op& op = Op::Get("round"); return Call(op, {x}, Attrs(), {}); } - inline Expr Clip(Expr x, double a_min, double a_max) { static const Op& op = Op::Get("clip"); auto attrs = make_object(); @@ -389,25 +374,21 @@ inline Expr Clip(Expr x, double a_min, double a_max) { return Call(op, {x}, Attrs(attrs), {}); } - inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Subtract(Expr lhs, Expr rhs) { static const Op& op = Op::Get("subtract"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Multiply(Expr lhs, Expr rhs) { static const Op& op = Op::Get("multiply"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Divide(Expr lhs, Expr rhs) { static const Op& op = Op::Get("divide"); return Call(op, {lhs, rhs}, Attrs(), {}); @@ -446,31 +427,26 @@ inline Expr Power(Expr lhs, Expr rhs) { return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr RightShift(Expr x, Expr nbit) { static const Op& op = Op::Get("right_shift"); return Call(op, {x, nbit}, Attrs(), {}); } - inline Expr LeftShift(Expr x, Expr nbit) { static const Op& op = Op::Get("left_shift"); return Call(op, {x, nbit}, Attrs(), {}); } - inline Expr ReshapeLike(Expr lhs, Expr rhs) { static const Op& op = Op::Get("reshape_like"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Copy(Expr data) { static const Op& op = Op::Get("copy"); return Call(op, {data}, Attrs(), {}); } - inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { auto attrs = make_object(); attrs->axis = std::move(axis); @@ -489,7 +465,6 @@ inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, b return Call(op, {data, mean}, Attrs(attrs), {}); } - static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { static const Op& op = Op::Get("where"); return Call(op, {condition, x, y}); @@ -500,9 +475,7 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { return Call(op, {lhs, rhs}, Attrs(), {}); } -static inline Expr Full(Expr fill_value, - Array shape, - DataType dtype) { +static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); @@ -529,10 +502,7 @@ static inline Expr Conv2D(Expr data, Expr weight, Array strides, return Call(op, {data, weight}, Attrs(attrs), {}); } -static inline Expr Dense(Expr data, - Expr weight, - IndexExpr units, - DataType out_dtype) { +static inline Expr Dense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; diff --git a/src/relay/transforms/simplify_fc_transpose.cc b/src/relay/transforms/simplify_fc_transpose.cc index 6cd77f4..99ded0b 100644 --- a/src/relay/transforms/simplify_fc_transpose.cc +++ b/src/relay/transforms/simplify_fc_transpose.cc @@ -128,20 +128,12 @@ Pass SimplifyFCTranspose(const Array& target_weights) { // Remove FreeVar warning auto f0 = Downcast(SimplifyFCTranspose(f, target_weights)); Array wt_params = FreeVars(f0); - auto f1 = Function(wt_params, - f0->body, - f0->ret_type, - f0->type_params, - f0->attrs); + auto f1 = Function(wt_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); for (const auto& var : wt_params) { params.push_back(var); } - return Function(params, - f1->body, - f1->ret_type, - f1->type_params, - f1->attrs); + return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); }; return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index a9ceec2..7c33947 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -21,22 +21,18 @@ * \file simplify_inference.cc */ #include -#include #include -#include +#include #include +#include + #include "pattern_util.h" namespace tvm { namespace relay { -Expr BatchNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Expr moving_mean, - Expr moving_var, - Type tdata) { +Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean, + Expr moving_var, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -64,12 +60,7 @@ Expr BatchNormToInferUnpack(const Attrs attrs, return out; } - -Expr GroupNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Type tdata) { +Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -88,20 +79,20 @@ Expr GroupNormToInferUnpack(const Attrs attrs, // new shape = N, num_groups, C/num_groups, H, W // reduce_axes = axis of (C/num_groups, H, W) for (int i = 0; i < ndim; ++i) { - auto val = ttype->shape[i].as()->value; - - // Save the old shape to reshape later - old_shape.push_back(val); - if (i == axis) { - new_shape.push_back(num_groups); - new_shape.push_back(channel / num_groups); - reduced_axes.push_back(i + 1); - continue; - } - if (i >= axis) { - reduced_axes.push_back(i + 1); - } - new_shape.push_back(val); + auto val = ttype->shape[i].as()->value; + + // Save the old shape to reshape later + old_shape.push_back(val); + if (i == axis) { + new_shape.push_back(num_groups); + new_shape.push_back(channel / num_groups); + reduced_axes.push_back(i + 1); + continue; + } + if (i >= axis) { + reduced_axes.push_back(i + 1); + } + new_shape.push_back(val); } data = Reshape(data, new_shape); @@ -124,11 +115,7 @@ Expr GroupNormToInferUnpack(const Attrs attrs, return out; } -Expr LayerNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Type tdata) { +Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -151,11 +138,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs, return out; } -Expr InstanceNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Type tdata) { +Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -165,8 +148,7 @@ Expr InstanceNormToInferUnpack(const Attrs attrs, int axis = (param->axis < 0) ? param->axis + ndim : param->axis; Array reduced_axes; for (int i = 1; i < ndim; ++i) { - if (i != axis) - reduced_axes.push_back(i); + if (i != axis) reduced_axes.push_back(i); } Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast(param->epsilon)); @@ -259,22 +241,19 @@ class InferenceSimplifier : public ExprMutator { std::unordered_map ty_map_; }; -Expr SimplifyInference(const Expr& e) { - return InferenceSimplifier().Mutate(e); -} +Expr SimplifyInference(const Expr& e) { return InferenceSimplifier().Mutate(e); } namespace transform { Pass SimplifyInference() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(SimplifyInference(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(SimplifyInference(f)); + }; return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference") -.set_body_typed(SimplifyInference); +TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference").set_body_typed(SimplifyInference); } // namespace transform diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 21c5162..c0c9286 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -26,12 +26,12 @@ #include #include #include -#include #include -#include "let_list.h" -#include "pass_util.h" + #include "../../support/arena.h" #include "../analysis/dependency_graph.h" +#include "let_list.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -47,13 +47,11 @@ struct ScopeNode { size_t level; Scope parent; std::shared_ptr ll = std::make_shared(); - explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { } - ScopeNode() : level(0) { } + explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {} + ScopeNode() : level(0) {} }; -Scope ChildScope(const Scope& s) { - return std::make_shared(s); -} +Scope ChildScope(const Scope& s) { return std::make_shared(s); } Scope LCA(Scope lhs, Scope rhs) { while (lhs != rhs) { @@ -100,8 +98,7 @@ std::unordered_map CalcScope(const DependencyGrap */ class Fill : ExprFunctor { public: - static Expr ToANormalForm(const Expr& e, - const DependencyGraph& dg, + static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, std::unordered_map* node_scope) { Fill fi(dg, node_scope); return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); @@ -112,14 +109,10 @@ class Fill : ExprFunctor { std::unordered_map* node_scope_; std::unordered_map memo; - Fill(const DependencyGraph& dg, - std::unordered_map* node_scope) : - dg_(dg), - node_scope_(node_scope) { } + Fill(const DependencyGraph& dg, std::unordered_map* node_scope) + : dg_(dg), node_scope_(node_scope) {} - Scope GetScope(const Expr& e) { - return node_scope_->at(dg_.expr_node.at(e)); - } + Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } Scope GetSubScope(const Expr& e, size_t i) { DependencyGraph::Node* n = dg_.expr_node.at(e); @@ -144,18 +137,12 @@ class Fill : ExprFunctor { return ret; } - Expr VisitExpr(const Expr& e) { - return this->VisitExpr(e, Var()); - } + Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } - Expr Atomic(const Expr& e, const Var& v) { - return v.defined() ? GetScope(e)->ll->Push(v, e) : e; - } + Expr Atomic(const Expr& e, const Var& v) { return v.defined() ? GetScope(e)->ll->Push(v, e) : e; } Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Var var = v.defined() ? - v : - Var(std::string("x"), Type()); + Var var = v.defined() ? v : Var(std::string("x"), Type()); return GetScope(orig)->ll->Push(var, now); } @@ -199,9 +186,8 @@ class Fill : ExprFunctor { Expr VisitExpr_(const IfNode* i, const Var& v) final { Expr e = GetRef(i); - Expr ret = If(VisitExpr(i->cond), - GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), - GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); + Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), + GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); return Compound(e, ret, v); } @@ -211,11 +197,8 @@ class Fill : ExprFunctor { if (f->HasNonzeroAttr(attr::kPrimitive)) { ret = e; } else { - ret = Function(f->params, - GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), - f->ret_type, - f->type_params, - f->attrs); + ret = Function(f->params, GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), f->ret_type, + f->type_params, f->attrs); } return Compound(e, ret, v); } @@ -257,9 +240,8 @@ class Fill : ExprFunctor { Expr data = VisitExpr(m->data); std::vector clauses; for (const Clause& c : m->clauses) { - clauses.push_back(Clause( - c->lhs, - GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); + clauses.push_back( + Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); } return Compound(e, Match(data, clauses, m->complete), v); } @@ -301,14 +283,9 @@ IRModule ToANormalForm(const IRModule& m) { if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; } - Expr ret = - TransformF([&](const Expr& e) { - return ToANormalFormAux(e); - }, it.second); + Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second); CHECK_EQ(FreeVars(ret).size(), 0) - << AsText(ret) - << "should not has free vars: " - << FreeVars(ret); + << AsText(ret) << "should not has free vars: " << FreeVars(ret); updates.Set(it.first, Downcast(ret)); } @@ -325,14 +302,11 @@ namespace transform { Pass ToANormalForm() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::ToANormalForm(m); - }; + [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); }; return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm") -.set_body_typed(ToANormalForm); +TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed(ToANormalForm); } // namespace transform diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index e6c8392..81545b6 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -51,9 +51,10 @@ * wheter directly invoking it, or indirectly by recursion. */ #include -#include #include #include +#include + #include "let_list.h" #include "pass_util.h" @@ -62,9 +63,7 @@ namespace relay { // we assume the data type has no closure - no idea how to look into datatype right now. -Type Arrow(const Type& l, const Type& r) { - return FuncType({l}, r, {}, {}); -} +Type Arrow(const Type& l, const Type& r) { return FuncType({l}, r, {}, {}); } Type CPSType(const Type& t, const TypeVar& answer); @@ -79,7 +78,7 @@ FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) { Type CPSType(const Type& t, const TypeVar& answer) { struct CPSTypeMutator : TypeMutator { - explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) { } + explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) {} TypeVar answer; Type VisitType_(const FuncTypeNode* t) final { return CPSFuncType(GetRef(t), answer); @@ -113,22 +112,15 @@ using MCont = std::function; Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm); -Function ToCPS(const Function& f, - const IRModule& m, - CPSMap* cm, - VarMap* vm, +Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, const TypeVar& answer) { - std::function remap = [&](const Var& v) { - return vm->count(v) == 0 ? v : vm->at(v); - }; + std::function remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); }; auto function_type = Downcast(f->checked_type()); // Each MCont can be used at most once. struct CPSFunctor : ExprFunctor, PatternMutator { - CPSFunctor(const std::function& remap, - const TypeVar& answer, - const IRModule& m, - VarMap* vm, - CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { } + CPSFunctor(const std::function& remap, const TypeVar& answer, const IRModule& m, + VarMap* vm, CPSMap* cm) + : remap(remap), answer(answer), m(m), vm(vm), cm(cm) {} const std::function& remap; TypeVar answer; IRModule m; @@ -136,9 +128,8 @@ Function ToCPS(const Function& f, CPSMap* cm; Expr VisitExpr_(const LetNode* op, const MCont& k) final { - return VisitExpr(op->value, [&](const Expr& v) { - return Let(remap(op->var), v, VisitExpr(op->body, k)); - }); + return VisitExpr( + op->value, [&](const Expr& v) { return Let(remap(op->var), v, VisitExpr(op->body, k)); }); } Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { @@ -150,13 +141,9 @@ Function ToCPS(const Function& f, return k(GetRef(op)); } - Expr VisitExpr_(const VarNode* op, const MCont& k) final { - return k(remap(GetRef(op))); - } + Expr VisitExpr_(const VarNode* op, const MCont& k) final { return k(remap(GetRef(op))); } - Pattern VisitPattern_(const PatternVarNode* op) final { - return PatternVar(remap(op->var)); - } + Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(remap(op->var)); } Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { auto gv = GetRef(op); @@ -186,16 +173,14 @@ Function ToCPS(const Function& f, } Expr reify(const MCont& k, const std::function& cont) { - return LetList::LetBind(reify(k), - [&](const Var& f) { + return LetList::LetBind(reify(k), [&](const Var& f) { return cont([&](const Expr& e) { return Call(f, {e}); }); }); } Expr VisitExpr_(const IfNode* op, const MCont& k) final { return reify(k, [&](const MCont& kf) { - return VisitExpr(op->cond, - [&](const Expr& v) { + return VisitExpr(op->cond, [&](const Expr& v) { return If(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf)); }); }); @@ -214,19 +199,13 @@ Function ToCPS(const Function& f, } Expr VisitExpr_(const RefReadNode* op, const MCont& k) final { - return VisitExpr(op->ref, - [&](const Expr& r) { - return LetList::LetBind(RefRead(r), k); - }); + return VisitExpr(op->ref, [&](const Expr& r) { return LetList::LetBind(RefRead(r), k); }); } Expr VisitExpr_(const RefWriteNode* op, const MCont& k) final { - return VisitExpr(op->ref, - [&](const Expr& r) { + return VisitExpr(op->ref, [&](const Expr& r) { return VisitExpr(op->value, - [&](const Expr& v) { - return LetList::LetBind(RefWrite(r, v), k); - }); + [&](const Expr& v) { return LetList::LetBind(RefWrite(r, v), k); }); }); } @@ -234,20 +213,18 @@ Function ToCPS(const Function& f, tvm::Array fields; std::function next; next = [&]() { - return (fields.size() == op->fields.size()) ? - k(Tuple(fields)) : - VisitExpr(op->fields[fields.size()], [&](const Expr& v) { - fields.push_back(v); - return next(); - }); + return (fields.size() == op->fields.size()) + ? k(Tuple(fields)) + : VisitExpr(op->fields[fields.size()], [&](const Expr& v) { + fields.push_back(v); + return next(); + }); }; return next(); } Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final { - return VisitExpr(op->tuple, [&](const Expr& v) { - return k(TupleGetItem(v, op->index)); - }); + return VisitExpr(op->tuple, [&](const Expr& v) { return k(TupleGetItem(v, op->index)); }); } Expr VisitExpr_(const CallNode* op, const MCont& k) final { @@ -259,9 +236,9 @@ Function ToCPS(const Function& f, return LetList::LetBind(Call(op->op, args, op->attrs, op->type_args), k); } else { return VisitExpr(op->args[args.size()], [&](const Expr& v) { - args.push_back(v); - return next(); - }); + args.push_back(v); + return next(); + }); } }; return next(); @@ -279,7 +256,7 @@ Function ToCPS(const Function& f, return next(); }); } - }; + }; return VisitExpr(op->op, [&](const Expr& v) { f = v; return next(); @@ -293,19 +270,15 @@ Function ToCPS(const Function& f, new_params.push_back(remap(v)); } new_params.push_back(k); - return Function(new_params, - mut.VisitExpr(f->body, - [&](const Expr& e) { return Call(k, {e}); }), - answer, - f->type_params, - f->attrs); + return Function(new_params, mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), + answer, f->type_params, f->attrs); } Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { TypeVar answer = TypeVar("answer", kType); VarMap var; struct Remapper : ExprVisitor, PatternVisitor { - Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { } + Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) {} TypeVar answer; VarMap* vm; void VisitExpr_(const VarNode* vn) final { @@ -316,13 +289,9 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { } } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } - void VisitPattern_(const PatternVarNode* op) final { - VisitExpr(op->var); - } + void VisitPattern_(const PatternVarNode* op) final { VisitExpr(op->var); } } remap(answer, &var); remap.VisitExpr(f); Function ret = ToCPS(f, m, cm, &var, answer); @@ -366,43 +335,32 @@ Function UnCPS(const Function& f) { type_args.push_back(tp); } type_args.push_back(new_ret_type); - return Function(new_params, - Call(f, args, {}, type_args), - new_ret_type, - new_type_params, - f->attrs); + return Function(new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params, + f->attrs); } TVM_REGISTER_GLOBAL("relay._transform.to_cps") -.set_body_typed(static_cast(ToCPS)); + .set_body_typed(static_cast(ToCPS)); -TVM_REGISTER_GLOBAL("relay._transform.un_cps") -.set_body_typed(UnCPS); +TVM_REGISTER_GLOBAL("relay._transform.un_cps").set_body_typed(UnCPS); namespace transform { Pass ToCPS() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Function(ToCPS(f, m)); - }; + [=](Function f, IRModule m, PassContext pc) { return Function(ToCPS(f, m)); }; return CreateFunctionPass(pass_func, 1, "ToCPS", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToCPS") -.set_body_typed(ToCPS); - +TVM_REGISTER_GLOBAL("relay._transform.ToCPS").set_body_typed(ToCPS); Pass UnCPS() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Function(UnCPS(f)); - }; + [=](Function f, IRModule m, PassContext pc) { return Function(UnCPS(f)); }; return CreateFunctionPass(pass_func, 1, "UnCPS", {}); } -TVM_REGISTER_GLOBAL("relay._transform.UnCPS") -.set_body_typed(UnCPS); +TVM_REGISTER_GLOBAL("relay._transform.UnCPS").set_body_typed(UnCPS); } // namespace transform diff --git a/src/relay/transforms/to_graph_normal_form.cc b/src/relay/transforms/to_graph_normal_form.cc index 8bf41a4..2e6c545 100644 --- a/src/relay/transforms/to_graph_normal_form.cc +++ b/src/relay/transforms/to_graph_normal_form.cc @@ -26,6 +26,7 @@ #include #include #include + #include "let_list.h" namespace tvm { @@ -33,7 +34,7 @@ namespace relay { class UseVarVisitor : public ExprVisitor { public: - explicit UseVarVisitor(const Var& v) : v(v) { } + explicit UseVarVisitor(const Var& v) : v(v) {} static bool UseVar(const Var& v, const Expr& e) { UseVarVisitor uv(v); @@ -45,9 +46,7 @@ class UseVarVisitor : public ExprVisitor { bool use_var = false; Var v; - void VisitExpr_(const VarNode* vn) override { - use_var = use_var || (v == GetRef(vn)); - } + void VisitExpr_(const VarNode* vn) override { use_var = use_var || (v == GetRef(vn)); } }; class GNF : public ExprMutator { @@ -58,9 +57,7 @@ class GNF : public ExprMutator { return var_map_.count(v) == 0 ? v : var_map_.at(v); } - static bool UseVar(const Var& v, const Expr& e) { - return UseVarVisitor::UseVar(v, e); - } + static bool UseVar(const Var& v, const Expr& e) { return UseVarVisitor::UseVar(v, e); } static Expr WrapRec(const Var& var, const Expr& val) { return UseVar(var, val) ? Let(var, val, var) : val; @@ -72,22 +69,19 @@ class GNF : public ExprMutator { } }; -Expr ToGraphNormalForm(const Expr& e) { - return GNF()(e); -} +Expr ToGraphNormalForm(const Expr& e) { return GNF()(e); } namespace transform { Pass ToGraphNormalForm() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(ToGraphNormalForm(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(ToGraphNormalForm(f)); + }; return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToGraphNormalForm") -.set_body_typed(ToGraphNormalForm); +TVM_REGISTER_GLOBAL("relay._transform.ToGraphNormalForm").set_body_typed(ToGraphNormalForm); } // namespace transform diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index b6e75ae..19632de 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -26,14 +26,16 @@ #ifndef TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_ #define TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_ -#include #include +#include + #include -#include #include +#include #include -#include "pattern_util.h" + #include "infer_layout_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -49,8 +51,8 @@ class TransformMemorizerNode : public Object { struct key_hash : public std::function { std::size_t operator()(const TransformKey& k) const { return dmlc::HashCombine( - dmlc::HashCombine( - std::hash()(std::get<0>(k)), std::get<1>(k)), + dmlc::HashCombine(std::hash()(std::get<0>(k)), + std::get<1>(k)), (std::get<2>(k))); } }; @@ -300,8 +302,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj // new_in2, new_out = op.infer(new_in) if (new_call->op->IsInstance()) { success = false; - std::tie(new_in2, new_out, success) = - InferCorrectLayouts(new_call, new_in, old_in, types); + std::tie(new_in2, new_out, success) = InferCorrectLayouts(new_call, new_in, old_in, types); if (!success) { return Expr(nullptr); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 3a16d8f..0782484 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -37,14 +37,15 @@ * If we can not infer a type or there are conflicting typing * constraints we will trigger an error. */ -#include #include +#include +#include #include #include -#include #include -#include "pass_util.h" + #include "../analysis/type_solver.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -53,21 +54,16 @@ namespace relay { struct TupleGetItemAttrs : public tvm::AttrsNode { int index; - TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") { - TVM_ATTR_FIELD(index); - } + TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") { TVM_ATTR_FIELD(index); } }; -bool TupleGetItemRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TupleGetItemRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); if (types[0].as()) return false; const auto* data = types[0].as(); - CHECK(data != nullptr) - << "TupleGetItem expect input type to be TupleType " - << " get " << types[0] << " instead"; + CHECK(data != nullptr) << "TupleGetItem expect input type to be TupleType " + << " get " << types[0] << " instead"; const auto* param = attrs.as(); CHECK(param != nullptr); CHECK_GE(param->index, 0); @@ -77,9 +73,7 @@ bool TupleGetItemRel(const Array& types, } TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); -TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem") -.set_body_typed( - TupleGetItemRel); +TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem").set_body_typed(TupleGetItemRel); struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array type_args) @@ -105,8 +99,10 @@ class TypeInferencer : private ExprFunctor, // constructors explicit TypeInferencer(IRModule mod, GlobalVar current_func) - : mod_(mod), current_func_(current_func), - err_reporter(), solver_(current_func, mod, &this->err_reporter) { + : mod_(mod), + current_func_(current_func), + err_reporter(), + solver_(current_func, mod, &this->err_reporter) { CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer"; } @@ -140,22 +136,16 @@ class TypeInferencer : private ExprFunctor, Type Unify(const Type& t1, const Type& t2, const ObjectRef& expr) { try { return solver_.Unify(t1, t2, expr); - } catch (const dmlc::Error &e) { + } catch (const dmlc::Error& e) { this->ReportFatalError( - expr, - ErrorBuilder() - << "Error unifying `" - << t1 - << "` and `" - << t2 - << "`: " << e.what()); + expr, ErrorBuilder() << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what()); return Type(); } } // Lazily get type for expr // expression, we will populate it now, and return the result. - Type GetType(const Expr &expr) { + Type GetType(const Expr& expr) { auto it = type_map_.find(expr); if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; @@ -186,19 +176,15 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); if (!mod_.defined()) { - this->ReportFatalError( - GetRef(op), - ErrorBuilder() << - "Cannot do type inference on global variables " \ - "without a module"); + this->ReportFatalError(GetRef(op), + ErrorBuilder() << "Cannot do type inference on global variables " + "without a module"); } Expr e = mod_->Lookup(var); return e->checked_type(); } - Type VisitExpr_(const ConstantNode* op) final { - return op->tensor_type(); - } + Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); } Type VisitExpr_(const TupleNode* op) final { Array types; @@ -209,23 +195,22 @@ class TypeInferencer : private ExprFunctor, } Type VisitExpr_(const TupleGetItemNode* op) final { - if (!tuple_getitem_rel_.defined()) { - tuple_getitem_rel_ = Downcast( - EnvFunc::Get("tvm.relay.type_relation.TupleGetItem")); + if (!tuple_getitem_rel_.defined()) { + tuple_getitem_rel_ = + Downcast(EnvFunc::Get("tvm.relay.type_relation.TupleGetItem")); } Type tuple_type = GetType(op->tuple); Type rtype = IncompleteType(Kind::kType); auto attrs = make_object(); attrs->index = op->index; - solver_.AddConstraint(TypeRelation( - tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef(op)); + solver_.AddConstraint(TypeRelation(tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), + GetRef(op)); return rtype; } void VisitPattern_(const PatternConstructorNode* con, const Type& t) { - CHECK(mod_.defined()) - << "Cannot do type inference without a environment:" - << con->constructor->name_hint; + CHECK(mod_.defined()) << "Cannot do type inference without a environment:" + << con->constructor->name_hint; TypeData td = mod_->type_definitions.at(con->constructor->belong_to); auto pc = GetRef(con); @@ -242,15 +227,14 @@ class TypeInferencer : private ExprFunctor, this->ReportFatalError(pc, ErrorBuilder() << "Expected a type call, got " << unified); } if (td->header != tc->func) { - this->ReportFatalError(pc, - ErrorBuilder() << "ADT headers must match, but we have " - << td->header << " and " << tc->func); + this->ReportFatalError(pc, ErrorBuilder() << "ADT headers must match, but we have " + << td->header << " and " << tc->func); } if (td->type_vars.size() != tc->args.size()) { - this->ReportFatalError(pc, - ErrorBuilder() << "The number of type args must match" - << "the number of type vars in the type data: " - << td->type_vars.size() << " != " << tc->args.size()); + this->ReportFatalError( + pc, ErrorBuilder() << "The number of type args must match" + << "the number of type vars in the type data: " << td->type_vars.size() + << " != " << tc->args.size()); } std::unordered_map type_var_map_; for (size_t i = 0; i < td->type_vars.size(); ++i) { @@ -258,10 +242,9 @@ class TypeInferencer : private ExprFunctor, } CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern"; if (con->constructor->inputs.size() != con->patterns.size()) { - this->ReportFatalError(pc, - ErrorBuilder() << "Not enough inputs for the constructor; " - << "expected " << con->constructor->inputs.size() - << ", got " << con->patterns.size()); + this->ReportFatalError(pc, ErrorBuilder() << "Not enough inputs for the constructor; " + << "expected " << con->constructor->inputs.size() + << ", got " << con->patterns.size()); } for (size_t i = 0; i < con->constructor->inputs.size(); ++i) { VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_)); @@ -294,7 +277,7 @@ class TypeInferencer : private ExprFunctor, Unify(vt, t, pv->span); } - void VisitPattern_(const PatternWildcardNode* wc, const Type& t) { } + void VisitPattern_(const PatternWildcardNode* wc, const Type& t) {} Type VisitExpr_(const MatchNode* op) final { Type dtype = GetType(op->data); @@ -303,9 +286,7 @@ class TypeInferencer : private ExprFunctor, } Type rtype = IncompleteType(Kind::kType); for (const auto& c : op->clauses) { - rtype = this->Unify(rtype, - GetType(c->rhs), - op->span); + rtype = this->Unify(rtype, GetType(c->rhs), op->span); } if (op->complete) { @@ -319,18 +300,14 @@ class TypeInferencer : private ExprFunctor, for (auto cs : unmatched_cases) { ss << "case " << i++ << ": \n" << PrettyPrint(cs); } - this->ReportFatalError( - match, - ss); + this->ReportFatalError(match, ss); } } return rtype; } - Type VisitExpr_(const OpNode* op) final { - return op->op_type; - } + Type VisitExpr_(const OpNode* op) final { return op->op_type; } Type VisitExpr_(const LetNode* let) final { // if the definition is a function literal, permit recursion @@ -342,7 +319,6 @@ class TypeInferencer : private ExprFunctor, type_map_[let->var].checked_type = let_type; } - if (let->var->type_annotation.defined()) { let_type = Unify(let_type, let->var->type_annotation, GetRef(let)); } @@ -360,9 +336,7 @@ class TypeInferencer : private ExprFunctor, // Ensure the type of the guard is of Tensor[Bool, ()], // that is a rank-0 boolean tensor. Type cond_type = this->GetType(ite->cond); - this->Unify(cond_type, - TensorType::Scalar(tvm::DataType::Bool()), - ite->cond); + this->Unify(cond_type, TensorType::Scalar(tvm::DataType::Bool()), ite->cond); Type checked_true = this->GetType(ite->true_branch); Type checked_false = this->GetType(ite->false_branch); return this->Unify(checked_true, checked_false, GetRef(ite)); @@ -372,9 +346,7 @@ class TypeInferencer : private ExprFunctor, // which are registered in the style defined in src/relay/op/*. // // The result will be the return type of the operator. - Type PrimitiveCall(const FuncTypeNode* op, - Array arg_types, - const Attrs& attrs, + Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, const Attrs& attrs, const ObjectRef& loc) { if (op->type_params.size() != arg_types.size() + 1) return Type(); if (op->type_constraints.size() != 1) return Type(); @@ -387,8 +359,7 @@ class TypeInferencer : private ExprFunctor, Type rtype = IncompleteType(Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here - solver_.AddConstraint(TypeRelation( - rel->func, arg_types, arg_types.size() - 1, attrs), loc); + solver_.AddConstraint(TypeRelation(rel->func, arg_types, arg_types.size() - 1, attrs), loc); return rtype; } @@ -417,9 +388,7 @@ class TypeInferencer : private ExprFunctor, ret_type = IncompleteType(Kind::kType); } - Type inst_ty = FuncType(fn_ty->arg_types, - ret_type, {}, - fn_ty->type_constraints); + Type inst_ty = FuncType(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints); inst_ty = Bind(inst_ty, subst_map); return Downcast(inst_ty); } @@ -437,7 +406,6 @@ class TypeInferencer : private ExprFunctor, return InstantiateFuncType(fn_ty, type_args); } - void AddTypeArgs(const Expr& expr, Array type_args) { auto type_info = type_map_.find(expr); if (type_info == type_map_.end()) { @@ -456,10 +424,8 @@ class TypeInferencer : private ExprFunctor, if (fn_ty_node == nullptr && inc_ty_node == nullptr) { this->ReportFatalError( - GetRef(call), - ErrorBuilder() - << "only expressions with function types can be called, found " - << ftype); + GetRef(call), + ErrorBuilder() << "only expressions with function types can be called, found " << ftype); } // incomplete type => it must be a function taking the arg types @@ -474,12 +440,10 @@ class TypeInferencer : private ExprFunctor, Array type_args = call->type_args; if (type_args.size() > fn_ty_node->type_params.size()) { this->ReportFatalError(GetRef(call), - ErrorBuilder() - << "Incorrect number of type args in " - << call->span << ": " - << "Expected " - << fn_ty_node->type_params.size() - << "but got " << type_args.size()); + ErrorBuilder() + << "Incorrect number of type args in " << call->span << ": " + << "Expected " << fn_ty_node->type_params.size() << "but got " + << type_args.size()); } FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); @@ -491,17 +455,15 @@ class TypeInferencer : private ExprFunctor, if (type_arity != number_of_args) { if (type_arity < number_of_args) { - this->ReportFatalError( - GetRef(call), - ErrorBuilder() - << "the function is provided too many arguments " - << "expected " << type_arity << ", found " << number_of_args); + this->ReportFatalError(GetRef(call), + ErrorBuilder() + << "the function is provided too many arguments " + << "expected " << type_arity << ", found " << number_of_args); } else { - this->ReportFatalError( - GetRef(call), - ErrorBuilder() - << "the function is provided too few arguments " - << "expected " << type_arity << ", found " << number_of_args); + this->ReportFatalError(GetRef(call), + ErrorBuilder() + << "the function is provided too few arguments " + << "expected " << type_arity << ", found " << number_of_args); } } @@ -511,9 +473,8 @@ class TypeInferencer : private ExprFunctor, for (auto cs : fn_ty->type_constraints) { if (const auto* tr = cs.as()) { - solver_.AddConstraint( - TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), - GetRef(call)); + solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), + GetRef(call)); } else { solver_.AddConstraint(cs, GetRef(call)); } @@ -529,9 +490,7 @@ class TypeInferencer : private ExprFunctor, } if (const OpNode* opnode = call->op.as()) { - Type rtype = PrimitiveCall(opnode->op_type.as(), - arg_types, - call->attrs, + Type rtype = PrimitiveCall(opnode->op_type.as(), arg_types, call->attrs, GetRef(call)); if (rtype.defined()) { AddTypeArgs(GetRef(call), arg_types); @@ -560,9 +519,7 @@ class TypeInferencer : private ExprFunctor, return solver_.Resolve(ret); } - Type VisitExpr_(const RefCreateNode* op) final { - return RelayRefType(GetType(op->value)); - } + Type VisitExpr_(const RefCreateNode* op) final { return RelayRefType(GetType(op->value)); } Type VisitExpr_(const RefReadNode* op) final { Type it = IncompleteType(Kind::kType); @@ -578,16 +535,13 @@ class TypeInferencer : private ExprFunctor, } Type VisitExpr_(const ConstructorNode* c) final { - CHECK(mod_.defined()) - << "Cannot do type inference without a environment:" - << c->name_hint; + CHECK(mod_.defined()) << "Cannot do type inference without a environment:" << c->name_hint; TypeData td = mod_->LookupTypeDef(c->belong_to); std::vector types; - for (const auto & t : td->type_vars) { + for (const auto& t : td->type_vars) { types.push_back(t); } - return FuncType(c->inputs, TypeCall(c->belong_to, types), - td->type_vars, {}); + return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {}); } void Solve() { @@ -603,72 +557,39 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) - : tmap_(tmap), solver_(solver) { - } + : tmap_(tmap), solver_(solver) {} - Expr VisitExpr_(const VarNode* op) final { - return VisitVar(GetRef(op)); - } + Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef(op)); } - Expr VisitExpr_(const ConstantNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const ConstantNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const GlobalVarNode* op) final { - return GetRef(op); - } + Expr VisitExpr_(const GlobalVarNode* op) final { return GetRef(op); } - Expr VisitExpr_(const OpNode* op) final { - return ExprMutator::VisitExpr_(op); - } + Expr VisitExpr_(const OpNode* op) final { return ExprMutator::VisitExpr_(op); } - Expr VisitExpr_(const TupleNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const TupleNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const TupleGetItemNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const TupleGetItemNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const FunctionNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const CallNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const CallNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const LetNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const IfNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const RefCreateNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const RefCreateNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const RefReadNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const RefReadNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const RefWriteNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const RefWriteNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const ConstructorNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const ConstructorNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const MatchNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const MatchNode* op) final { return AttachCheckedType(op); } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Var VisitVar(const Var& v) final { if (vmap_.count(v) == 0) { @@ -678,7 +599,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } // attach checked type to the mutated node. - template + template Expr AttachCheckedType(const T* op) { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); @@ -687,42 +608,34 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // TODO(@jroesch): it would be nice if we would report resolution // errors directly on the program. CHECK(checked_type.as() == nullptr) - << "Cannot resolve type of " << GetRef(op) - << " at " << op->span; + << "Cannot resolve type of " << GetRef(op) << " at " << op->span; Expr new_e = ExprMutator::VisitExpr_(op); // new_call and new_var's code is only going to be valid for VarNode/CallNode. // Compiler optimization will likely fold these away for other nodes. - CallNode* new_call =( - std::is_base_of::value ? - const_cast(static_cast(new_e.get())) : nullptr); - VarNode* new_var =( - std::is_base_of::value ? - const_cast(static_cast(new_e.get())) : nullptr); - FunctionNode* new_fn =( - std::is_base_of::value ? - const_cast(static_cast(new_e.get())) : nullptr); + CallNode* new_call = (std::is_base_of::value + ? const_cast(static_cast(new_e.get())) + : nullptr); + VarNode* new_var = (std::is_base_of::value + ? const_cast(static_cast(new_e.get())) + : nullptr); + FunctionNode* new_fn = + (std::is_base_of::value + ? const_cast(static_cast(new_e.get())) + : nullptr); // check if we need update the new_e bool need_update_type = !checked_type.same_as(new_e->checked_type_); - bool need_update_call = ( - std::is_base_of::value && - it->second.type_args.defined() && - !it->second.type_args.same_as(new_call->type_args)); - bool need_update_var = ( - std::is_base_of::value && - update_missing_type_annotation_ && - !new_var->type_annotation.defined()); - - bool need_update_fn =( - std::is_base_of::value && - update_missing_type_annotation_ && - !new_fn->ret_type.defined()); - - if (!need_update_type && - !need_update_var && - !need_update_call && - !need_update_fn) { + bool need_update_call = + (std::is_base_of::value && it->second.type_args.defined() && + !it->second.type_args.same_as(new_call->type_args)); + bool need_update_var = (std::is_base_of::value && update_missing_type_annotation_ && + !new_var->type_annotation.defined()); + + bool need_update_fn = (std::is_base_of::value && + update_missing_type_annotation_ && !new_fn->ret_type.defined()); + + if (!need_update_type && !need_update_var && !need_update_call && !need_update_fn) { return new_e; } @@ -732,15 +645,11 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // we make a copy mutating an existing reference. ObjectPtr ptr = make_object(*new_e.as()); new_e = Expr(ptr); - new_call = ( - std::is_base_of::value ? - static_cast(ptr.get()) : nullptr); - new_var = ( - std::is_base_of::value ? - static_cast(ptr.get()) : nullptr); - new_fn = ( - std::is_base_of::value ? - static_cast(ptr.get()) : nullptr); + new_call = + (std::is_base_of::value ? static_cast(ptr.get()) : nullptr); + new_var = (std::is_base_of::value ? static_cast(ptr.get()) : nullptr); + new_fn = (std::is_base_of::value ? static_cast(ptr.get()) + : nullptr); } // attach the information. @@ -765,9 +674,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { return new_e; } - Type VisitType(const Type &t) final { - return solver_->Resolve(t); - } + Type VisitType(const Type& t) final { return solver_->Resolve(t); } private: std::unordered_map vmap_; @@ -793,17 +700,21 @@ Expr TypeInferencer::Infer(Expr expr) { struct AllCheckTypePopulated : ExprVisitor { void VisitExpr(const Expr& e) { - if (e.as()) { return; } - if (e.as()) { return; } - if (e.as()) { return; } + if (e.as()) { + return; + } + if (e.as()) { + return; + } + if (e.as()) { + return; + } CHECK(e->checked_type_.defined()) << "Expression: " << e; return ExprVisitor::VisitExpr(e); } }; -void EnsureCheckedType(const Expr& e) { - AllCheckTypePopulated().VisitExpr(e); -} +void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } Expr InferType(const Expr& expr, const IRModule& mod) { auto main = mod->GetGlobalVar("main"); @@ -811,15 +722,12 @@ Expr InferType(const Expr& expr, const IRModule& mod) { auto e = inferencer.Infer(expr); CHECK(WellFormed(e)); auto free_tvars = FreeTypeVars(e, mod); - CHECK(free_tvars.size() == 0) - << "Found unbound type variables in " << e << ": " << free_tvars; + CHECK(free_tvars.size() == 0) << "Found unbound type variables in " << e << ": " << free_tvars; EnsureCheckedType(e); return e; } -Function InferType(const Function& func, - const IRModule& mod, - const GlobalVar& var) { +Function InferType(const Function& func, const IRModule& mod, const GlobalVar& var) { CHECK(mod.defined()) << "internal error: module must be set for type inference"; Function func_copy = Function(make_object(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); @@ -828,11 +736,9 @@ Function InferType(const Function& func, mod->Remove(var); CHECK(WellFormed(func_ret)); auto free_tvars = FreeTypeVars(func_ret, mod); - CHECK(free_tvars.size() == 0) - << "Found unbound type variables in: " - << std::endl - << AsText(func, true) - << std::endl << free_tvars; + CHECK(free_tvars.size() == 0) << "Found unbound type variables in: " << std::endl + << AsText(func, true) << std::endl + << free_tvars; return Downcast(func_ret); } @@ -840,16 +746,11 @@ namespace transform { Pass InferType() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(InferType(f, m)); - }; + [=](Function f, IRModule m, PassContext pc) { return Downcast(InferType(f, m)); }; return CreateFunctionPass(pass_func, 0, "InferType", {}); } -TVM_REGISTER_GLOBAL("relay._transform.InferType") -.set_body_typed([]() { - return InferType(); -}); +TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); }); } // namespace transform diff --git a/src/runtime/builtin_fp16.cc b/src/runtime/builtin_fp16.cc index 60dc55d..d229491 100644 --- a/src/runtime/builtin_fp16.cc +++ b/src/runtime/builtin_fp16.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,7 +20,7 @@ /*! * \file builtin_fp16.cc * \brief Functions for conversion between fp32 and fp16 -*/ + */ #include #include diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 32b3381..0164b1b 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -22,20 +22,22 @@ * \brief Device specific implementations */ #include -#include #include -#include +#include +#include #include +#include #include -#include -#include -#include + #include -#include -#include +#include #include -#include "runtime_base.h" +#include +#include +#include + #include "object_internal.h" +#include "runtime_base.h" namespace tvm { namespace runtime { @@ -90,9 +92,7 @@ class DeviceAPIManager { public: static const int kMaxDeviceAPI = 32; // Get API - static DeviceAPI* Get(const TVMContext& ctx) { - return Get(ctx.device_type); - } + static DeviceAPI* Get(const TVMContext& ctx) { return Get(ctx.device_type); } static DeviceAPI* Get(int dev_type, bool allow_missing = false) { return Global()->GetAPI(dev_type, allow_missing); } @@ -102,9 +102,7 @@ class DeviceAPIManager { DeviceAPI* rpc_api_{nullptr}; std::mutex mutex_; // constructor - DeviceAPIManager() { - std::fill(api_.begin(), api_.end(), nullptr); - } + DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } // Global static variable. static DeviceAPIManager* Global() { static DeviceAPIManager inst; @@ -130,8 +128,7 @@ class DeviceAPIManager { std::string factory = "device_api." + name; auto* f = Registry::Get(factory); if (f == nullptr) { - CHECK(allow_missing) - << "Device API " << name << " is not enabled."; + CHECK(allow_missing) << "Device API " << name << " is not enabled."; return nullptr; } void* ptr = (*f)(); @@ -140,19 +137,14 @@ class DeviceAPIManager { }; DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) { - return DeviceAPIManager::Get( - static_cast(ctx.device_type), allow_missing); + return DeviceAPIManager::Get(static_cast(ctx.device_type), allow_missing); } -void* DeviceAPI::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { +void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); } -void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { - FreeDataSpace(ctx, ptr); -} +void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { FreeDataSpace(ctx, ptr); } TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) { LOG(FATAL) << "Device does not support stream api."; @@ -163,8 +155,7 @@ void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) { LOG(FATAL) << "Device does not support stream api."; } -void DeviceAPI::SyncStreamFromTo(TVMContext ctx, - TVMStreamHandle event_src, +void DeviceAPI::SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) { LOG(FATAL) << "Device does not support stream api."; } @@ -256,7 +247,8 @@ std::string NormalizeError(std::string err_msg) { // Parse error type. { size_t start_pos = 0, end_pos; - for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {} + for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) { + } for (end_pos = start_pos; end_pos < line.length(); ++end_pos) { char ch = line[end_pos]; if (ch == ':') { @@ -268,8 +260,9 @@ std::string NormalizeError(std::string err_msg) { } if (error_type.length() != 0) { // if we successfully detected error_type: trim the following space. - for (start_pos = end_pos + 1; - start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {} + for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' '; + ++start_pos) { + } line = line.substr(start_pos); } else { // did not detect error_type, use default value. @@ -345,22 +338,16 @@ struct TVMRuntimeEntry { typedef dmlc::ThreadLocalStore TVMAPIRuntimeStore; -const char *TVMGetLastError() { - return TVMAPIRuntimeStore::Get()->last_error.c_str(); -} +const char* TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); } -int TVMAPIHandleException(const std::runtime_error &e) { +int TVMAPIHandleException(const std::runtime_error& e) { TVMAPISetLastError(NormalizeError(e.what()).c_str()); return -1; } -void TVMAPISetLastError(const char* msg) { - TVMAPIRuntimeStore::Get()->last_error = msg; -} +void TVMAPISetLastError(const char* msg) { TVMAPIRuntimeStore::Get()->last_error = msg; } -int TVMModLoadFromFile(const char* file_name, - const char* format, - TVMModuleHandle* out) { +int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) { API_BEGIN(); TVMRetValue ret; ret = Module::LoadFromFile(file_name, format); @@ -371,21 +358,16 @@ int TVMModLoadFromFile(const char* file_name, API_END(); } -int TVMModImport(TVMModuleHandle mod, - TVMModuleHandle dep) { +int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) { API_BEGIN(); - ObjectInternal::GetModuleNode(mod)->Import( - GetRef(ObjectInternal::GetModuleNode(dep))); + ObjectInternal::GetModuleNode(mod)->Import(GetRef(ObjectInternal::GetModuleNode(dep))); API_END(); } -int TVMModGetFunction(TVMModuleHandle mod, - const char* func_name, - int query_imports, - TVMFunctionHandle *func) { +int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, + TVMFunctionHandle* func) { API_BEGIN(); - PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction( - func_name, query_imports != 0); + PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); if (pf != nullptr) { *func = new PackedFunc(pf); } else { @@ -394,23 +376,15 @@ int TVMModGetFunction(TVMModuleHandle mod, API_END(); } -int TVMModFree(TVMModuleHandle mod) { - return TVMObjectFree(mod); -} +int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); } -int TVMBackendGetFuncFromEnv(void* mod_node, - const char* func_name, - TVMFunctionHandle *func) { +int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) { API_BEGIN(); - *func = (TVMFunctionHandle)( - static_cast(mod_node)->GetFuncFromEnv(func_name)); + *func = (TVMFunctionHandle)(static_cast(mod_node)->GetFuncFromEnv(func_name)); API_END(); } -void* TVMBackendAllocWorkspace(int device_type, - int device_id, - uint64_t size, - int dtype_code_hint, +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, int dtype_bits_hint) { TVMContext ctx; ctx.device_type = static_cast(device_type); @@ -421,14 +395,10 @@ void* TVMBackendAllocWorkspace(int device_type, type_hint.bits = static_cast(dtype_bits_hint); type_hint.lanes = 1; - return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, - static_cast(size), - type_hint); + return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast(size), type_hint); } -int TVMBackendFreeWorkspace(int device_type, - int device_id, - void* ptr) { +int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { TVMContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; @@ -436,10 +406,7 @@ int TVMBackendFreeWorkspace(int device_type, return 0; } -int TVMBackendRunOnce(void** handle, - int (*f)(void*), - void* cdata, - int nbytes) { +int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { if (*handle == nullptr) { *handle = reinterpret_cast(1); return (*f)(cdata); @@ -453,21 +420,14 @@ int TVMFuncFree(TVMFunctionHandle func) { API_END(); } -int TVMFuncCall(TVMFunctionHandle func, - TVMValue* args, - int* arg_type_codes, - int num_args, - TVMValue* ret_val, - int* ret_type_code) { +int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args, + TVMValue* ret_val, int* ret_type_code) { API_BEGIN(); TVMRetValue rv; - (*static_cast(func)).CallPacked( - TVMArgs(args, arg_type_codes, num_args), &rv); + (*static_cast(func)).CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); // handle return string. - if (rv.type_code() == kTVMStr || - rv.type_code() == kTVMDataType || - rv.type_code() == kTVMBytes) { + if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); if (rv.type_code() != kTVMDataType) { e->ret_str = *rv.ptr(); @@ -489,10 +449,7 @@ int TVMFuncCall(TVMFunctionHandle func, API_END(); } -int TVMCFuncSetReturn(TVMRetValueHandle ret, - TVMValue* value, - int* type_code, - int num_ret) { +int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) { API_BEGIN(); CHECK_EQ(num_ret, 1); TVMRetValue* rv = static_cast(ret); @@ -500,32 +457,28 @@ int TVMCFuncSetReturn(TVMRetValueHandle ret, API_END(); } -int TVMFuncCreateFromCFunc(TVMPackedCFunc func, - void* resource_handle, - TVMPackedCFuncFinalizer fin, - TVMFunctionHandle *out) { +int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin, + TVMFunctionHandle* out) { API_BEGIN(); if (fin == nullptr) { - *out = new PackedFunc( - [func, resource_handle](TVMArgs args, TVMRetValue* rv) { - int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) - args.num_args, rv, resource_handle); - if (ret != 0) { - throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); - } - }); + *out = new PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) { + int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) + args.num_args, rv, resource_handle); + if (ret != 0) { + throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); + } + }); } else { // wrap it in a shared_ptr, with fin as deleter. // so fin will be called when the lambda went out of scope. std::shared_ptr rpack(resource_handle, fin); - *out = new PackedFunc( - [func, rpack](TVMArgs args, TVMRetValue* rv) { - int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) - args.num_args, rv, rpack.get()); - if (ret != 0) { - throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); - } - }); + *out = new PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) { + int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) + args.num_args, rv, rpack.get()); + if (ret != 0) { + throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); + } + }); } API_END(); } @@ -566,9 +519,7 @@ int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { API_END(); } -int TVMStreamStreamSynchronize(int device_type, - int device_id, - TVMStreamHandle src, +int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, TVMStreamHandle dst) { API_BEGIN(); TVMContext ctx; @@ -586,15 +537,10 @@ int TVMCbArgToReturn(TVMValue* value, int* code) { API_END(); } - -int TVMDeviceAllocDataSpace(DLContext ctx, - size_t nbytes, - size_t alignment, - DLDataType type_hint, +int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint, void** out_data) { API_BEGIN(); - out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace( - ctx, nbytes, alignment, type_hint); + out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace(ctx, nbytes, alignment, type_hint); API_END(); } @@ -604,53 +550,42 @@ int TVMDeviceFreeDataSpace(DLContext ctx, void* ptr) { API_END(); } -int TVMDeviceCopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t num_bytes, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, - TVMStreamHandle stream) { +int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) { API_BEGIN(); TVMContext ctx = ctx_from.device_type != kDLCPU ? ctx_from : ctx_to; - DeviceAPIManager::Get(ctx)->CopyDataFromTo( - from, from_offset, - to, to_offset, - num_bytes, ctx_from, ctx_to, type_hint, stream); + DeviceAPIManager::Get(ctx)->CopyDataFromTo(from, from_offset, to, to_offset, num_bytes, ctx_from, + ctx_to, type_hint, stream); API_END(); } // set device api TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) -.set_body([](TVMArgs args, TVMRetValue *ret) { - TVMContext ctx; - ctx.device_type = static_cast(args[0].operator int()); - ctx.device_id = args[1]; - DeviceAPIManager::Get(ctx)->SetDevice(ctx); - }); + .set_body([](TVMArgs args, TVMRetValue* ret) { + TVMContext ctx; + ctx.device_type = static_cast(args[0].operator int()); + ctx.device_id = args[1]; + DeviceAPIManager::Get(ctx)->SetDevice(ctx); + }); // set device api -TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr") -.set_body([](TVMArgs args, TVMRetValue *ret) { - TVMContext ctx; - ctx.device_type = static_cast(args[0].operator int()); - ctx.device_id = args[1]; - - DeviceAttrKind kind = static_cast(args[2].operator int()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true); - if (api != nullptr) { - api->GetAttr(ctx, kind, ret); - } else { - *ret = 0; - } +TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr").set_body([](TVMArgs args, TVMRetValue* ret) { + TVMContext ctx; + ctx.device_type = static_cast(args[0].operator int()); + ctx.device_id = args[1]; + + DeviceAttrKind kind = static_cast(args[2].operator int()); + if (kind == kExist) { + DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true); + if (api != nullptr) { + api->GetAttr(ctx, kind, ret); } else { - DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret); + *ret = 0; } - }); - + } else { + DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret); + } +}); -TVM_REGISTER_GLOBAL("runtime.TVMSetStream") -.set_body_typed(TVMSetStream); +TVM_REGISTER_GLOBAL("runtime.TVMSetStream").set_body_typed(TVMSetStream); diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 6145926..62220a8 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -24,31 +24,27 @@ #include #include #include -#include #include +#include namespace tvm { namespace runtime { using namespace vm; -TVM_REGISTER_GLOBAL("runtime.GetADTTag") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetADTTag").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.tag()); }); -TVM_REGISTER_GLOBAL("runtime.GetADTSize") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetADTSize").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.size()); }); - -TVM_REGISTER_GLOBAL("runtime.GetADTFields") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetADTFields").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; const auto& adt = Downcast(obj); @@ -56,8 +52,7 @@ TVM_REGISTER_GLOBAL("runtime.GetADTFields") *rv = adt[idx]; }); -TVM_REGISTER_GLOBAL("runtime.Tuple") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.Tuple").set_body([](TVMArgs args, TVMRetValue* rv) { std::vector fields; for (auto i = 0; i < args.size(); ++i) { fields.push_back(args[i]); @@ -65,8 +60,7 @@ TVM_REGISTER_GLOBAL("runtime.Tuple") *rv = ADT::Tuple(fields); }); -TVM_REGISTER_GLOBAL("runtime.ADT") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); std::vector fields; @@ -76,13 +70,11 @@ TVM_REGISTER_GLOBAL("runtime.ADT") *rv = ADT(tag, fields); }); -TVM_REGISTER_GLOBAL("runtime.String") -.set_body_typed([](std::string str) { +TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) { return String(std::move(str)); }); -TVM_REGISTER_GLOBAL("runtime.GetFFIString") -.set_body_typed([](String str) { +TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) { return std::string(str); }); diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index d4959be..0cf4c69 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -21,8 +21,9 @@ * \file Use external cblas library call. */ #include -#include #include +#include + #include "gemm_common.h" extern "C" { @@ -50,8 +51,8 @@ struct CblasSgemmOp { void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, int ldb, float beta, float* C, int ldc) { #if USE_DNNL == 1 - dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B, - ldb, A, lda, beta, C, ldc); + dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B, ldb, A, + lda, beta, C, ldc); #else cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); @@ -159,8 +160,7 @@ struct CblasDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -170,8 +170,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") CallGemm(args, ret, CblasDgemmOp()); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 32)) { @@ -182,14 +181,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") }); TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") -.set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; - CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); - } else { - CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); - } -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); + } + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index b73abab..96d6322 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -23,15 +23,16 @@ */ #pragma once -#include #include +#include + #include namespace tvm { namespace contrib { using namespace runtime; -inline int ColumnStride(DLTensor *tensor) { +inline int ColumnStride(DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -42,7 +43,7 @@ inline int ColumnStride(DLTensor *tensor) { } } -inline int ElementStride(DLTensor *tensor) { +inline int ElementStride(DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[0], tensor->strides[1]); } else { @@ -51,25 +52,21 @@ inline int ElementStride(DLTensor *tensor) { } // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed(DLTensor *tensor) { +inline bool IsInPlaceTransposed(DLTensor* tensor) { return tensor->strides && (tensor->strides[1] > tensor->strides[0]); } -inline int RowCount(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 1 : 0]; -} +inline int RowCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } -inline int ColumnCount(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 0 : 1]; -} +inline int ColumnCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. template -inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8; @@ -92,20 +89,17 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), - ColumnCount(A, transa), static_cast(alpha), - reinterpret_cast( - static_cast(B->data) + B->byte_offset), + op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), + static_cast(alpha), + reinterpret_cast(static_cast(B->data) + B->byte_offset), ColumnStride(B), - reinterpret_cast( - static_cast(A->data) + A->byte_offset), + reinterpret_cast(static_cast(A->data) + A->byte_offset), ColumnStride(A), static_cast(beta), - reinterpret_cast( - static_cast(C->data) + C->byte_offset), + reinterpret_cast(static_cast(C->data) + C->byte_offset), ColumnStride(C)); } -inline int ColumnStride3D(DLTensor *tensor) { +inline int ColumnStride3D(DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -115,7 +109,7 @@ inline int ColumnStride3D(DLTensor *tensor) { return tensor->shape[2]; } } -inline int ElementStride3D(DLTensor *tensor) { +inline int ElementStride3D(DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[1], tensor->strides[2]); } else { @@ -123,22 +117,18 @@ inline int ElementStride3D(DLTensor *tensor) { } } // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed3D(DLTensor *tensor) { +inline bool IsInPlaceTransposed3D(DLTensor* tensor) { return tensor->strides && (tensor->strides[2] > tensor->strides[1]); } -inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; } -inline int RowCount3D(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 2 : 1]; -} -inline int ColumnCount3D(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 1 : 2]; -} +inline int BatchCount3D(DLTensor* tensor) { return tensor->shape[0]; } +inline int RowCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 2 : 1]; } +inline int ColumnCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 2]; } template -inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { +inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { using DType = typename TBatchGemmOp::TDatatype; - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; int bit_depth = sizeof(DType) * 8; @@ -163,16 +153,15 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { const int A_size = A->shape[1] * A->shape[2]; const int B_size = B->shape[1] * B->shape[2]; const int C_size = C->shape[1] * C->shape[2]; - DType *A_data = reinterpret_cast( - static_cast(A->data) + A->byte_offset); - DType *B_data = reinterpret_cast( - static_cast(B->data) + B->byte_offset); - DType *C_data = reinterpret_cast( - static_cast(C->data) + C->byte_offset); - op(batch_size, transb, transa, ColumnCount3D(B, transb), - RowCount3D(A, transa), ColumnCount3D(A, transa), - static_cast(alpha), - B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), + DType* A_data = reinterpret_cast(static_cast(A->data) + + A->byte_offset); + DType* B_data = reinterpret_cast(static_cast(B->data) + + B->byte_offset); + DType* C_data = reinterpret_cast(static_cast(C->data) + + C->byte_offset); + op(batch_size, transb, transa, ColumnCount3D(B, transb), RowCount3D(A, transa), + ColumnCount3D(A, transa), static_cast(alpha), B_data, B_size, + ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), static_cast(beta), C_data, C_size, ColumnStride3D(C)); } diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index fada800..404afa2 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -25,16 +25,16 @@ #ifndef TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_ -#import #import +#import #include #include #include -#include -#include #include +#include +#include namespace tvm { namespace runtime { @@ -53,15 +53,12 @@ class CoreMLRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. */ - const char* type_key() const { - return "CoreMLRuntime"; - } + const char* type_key() const { return "CoreMLRuntime"; } /*! * \brief Invoke the coreml prediction. @@ -74,9 +71,8 @@ class CoreMLRuntime : public ModuleNode { * \param ctx The context where the coreml model will be executed on. * \param output_names The output names of the model. */ - void Init(const std::string& model_path, - TVMContext ctx, - const std::vector& output_names); + void Init(const std::string& model_path, TVMContext ctx, + const std::vector& output_names); /*! * \brief set input to the model. @@ -99,13 +95,13 @@ class CoreMLRuntime : public ModuleNode { int GetNumOutputs() const; // CoreML model - MLModel *model_; + MLModel* model_; // CoreML model input dictionary - NSMutableDictionary *input_dict_; + NSMutableDictionary* input_dict_; // CoreML model output id output_; // List of output names - std::vector output_names_; + std::vector output_names_; // TVM context TVMContext ctx_; }; diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 614842b..1ce84a0 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -27,28 +27,27 @@ namespace tvm { namespace runtime { -MLModel *load_coreml_model(const std::string& model_path) { +MLModel* load_coreml_model(const std::string& model_path) { NSBundle* bundle = [NSBundle mainBundle]; NSString* base = [bundle privateFrameworksPath]; NSString* fname = [NSString stringWithUTF8String:("tvm/" + model_path).c_str()]; - NSString* assetPath = [base stringByAppendingPathComponent: fname]; + NSString* assetPath = [base stringByAppendingPathComponent:fname]; if (![[NSFileManager defaultManager] fileExistsAtPath:assetPath]) { - assetPath = [NSString stringWithCString: model_path.c_str() encoding:NSUTF8StringEncoding]; + assetPath = [NSString stringWithCString:model_path.c_str() encoding:NSUTF8StringEncoding]; } - NSURL *url = [NSURL fileURLWithPath:assetPath]; + NSURL* url = [NSURL fileURLWithPath:assetPath]; - MLModel *model = [MLModel modelWithContentsOfURL:url error:nil]; + MLModel* model = [MLModel modelWithContentsOfURL:url error:nil]; if (model == nil) { NSLog(@"modelc %@ not found", url); } return model; } -void CoreMLRuntime::Init(const std::string& model_path, - TVMContext ctx, - const std::vector& output_names) { +void CoreMLRuntime::Init(const std::string& model_path, TVMContext ctx, + const std::vector& output_names) { model_ = load_coreml_model(model_path); ctx_ = ctx; input_dict_ = [NSMutableDictionary dictionary]; @@ -56,13 +55,14 @@ void CoreMLRuntime::Init(const std::string& model_path, } void CoreMLRuntime::Invoke() { - id input = [[MLDictionaryFeatureProvider alloc] initWithDictionary:input_dict_ error:nil]; + id input = [[MLDictionaryFeatureProvider alloc] initWithDictionary:input_dict_ + error:nil]; output_ = [model_ predictionFromFeatures:input error:nil]; } void CoreMLRuntime::SetInput(const std::string& key, DLTensor* data_in) { int64_t size = 1; - NSMutableArray *shape = [[NSMutableArray alloc] init]; + NSMutableArray* shape = [[NSMutableArray alloc] init]; for (int64_t i = 0; i < data_in->ndim; ++i) { size *= data_in->shape[i]; [shape addObject:[NSNumber numberWithInteger:data_in->shape[i]]]; @@ -81,21 +81,20 @@ void CoreMLRuntime::SetInput(const std::string& key, DLTensor* data_in) { return; } - MLMultiArray *dest = [[MLMultiArray alloc] initWithShape:shape - dataType:dataType error:nil]; + MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:nil]; CHECK(data_in->strides == NULL); memcpy(dest.dataPointer, data_in->data, size); - NSString *nsKey = [NSString stringWithUTF8String:key.c_str()]; + NSString* nsKey = [NSString stringWithUTF8String:key.c_str()]; [input_dict_ setObject:dest forKey:nsKey]; } NDArray CoreMLRuntime::GetOutput(int index) const { - NSString *name = output_names_[index]; - MLModelDescription *model_desc = model_.modelDescription; - MLFeatureDescription *output_desc = model_desc.outputDescriptionsByName[name]; - MLMultiArrayConstraint *data_desc = output_desc.multiArrayConstraint; + NSString* name = output_names_[index]; + MLModelDescription* model_desc = model_.modelDescription; + MLFeatureDescription* output_desc = model_desc.outputDescriptionsByName[name]; + MLMultiArrayConstraint* data_desc = output_desc.multiArrayConstraint; std::vector shape; int64_t size = 1; for (int64_t i = 0; i < data_desc.shape.count; ++i) { @@ -114,59 +113,50 @@ NDArray CoreMLRuntime::GetOutput(int index) const { } else { LOG(FATAL) << "unexpected data type " << data_desc.dataType; } - MLMultiArray *src = [output_ featureValueForName:name].multiArrayValue; + MLMultiArray* src = [output_ featureValueForName:name].multiArrayValue; NDArray ret = NDArray::Empty(shape, dtype, ctx_); ret.CopyFromBytes(src.dataPointer, size); return ret; } -int CoreMLRuntime::GetNumOutputs() const { - return output_names_.size(); -} +int CoreMLRuntime::GetNumOutputs() const { return output_names_.size(); } -PackedFunc CoreMLRuntime::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc CoreMLRuntime::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "invoke") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Invoke(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); }); } else if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - const auto& input_name = args[0].operator std::string(); - this->SetInput(input_name, args[1]); - }); + const auto& input_name = args[0].operator std::string(); + this->SetInput(input_name, args[1]); + }); } else if (name == "get_output") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetOutput(args[0]); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); }); } else if (name == "get_num_outputs") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetNumOutputs(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetNumOutputs(); }); } else { return PackedFunc(); } } -Module CoreMLRuntimeCreate(const std::string& model_path, - TVMContext ctx, - const std::vector& output_names) { +Module CoreMLRuntimeCreate(const std::string& model_path, TVMContext ctx, + const std::vector& output_names) { auto exec = make_object(); exec->Init(model_path, ctx, output_names); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - std::vector output_names; - for (size_t i = 2; i < args.size(); i++) { - const std::string& name = args[i]; - output_names.push_back([NSString stringWithUTF8String:name.c_str()]); - } - *rv = CoreMLRuntimeCreate(args[0], args[1], output_names); - }); +TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector output_names; + for (size_t i = 2; i < args.size(); i++) { + const std::string& name = args[i]; + output_names.push_back([NSString stringWithUTF8String:name.c_str()]); + } + *rv = CoreMLRuntimeCreate(args[0], args[1], output_names); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 5424f4c..ff20445 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -20,152 +20,98 @@ /*! * \file Use external cblas library call. */ -#include -#include #include +#include +#include + #include "../cblas/gemm_common.h" #include "cublas_utils.h" - namespace tvm { namespace contrib { using namespace runtime; -inline cublasOperation_t BooleanToTranspose(bool item) { - return item ? CUBLAS_OP_T : CUBLAS_OP_N; -} +inline cublasOperation_t BooleanToTranspose(bool item) { return item ? CUBLAS_OP_T : CUBLAS_OP_N; } inline void TryEnableTensorCore(cublasHandle_t hdl) { // TensorCores are only supported in cublas 9.0 or higher int version; CHECK_CUBLAS_ERROR(cublasGetVersion(hdl, &version)); - if (version >= 9000) - CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH)); + if (version >= 9000) CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH)); } struct CublasHgemmOp { typedef half TDatatype; cublasHandle_t handle; - explicit CublasHgemmOp(cublasHandle_t hdl) - : handle(hdl) {} - - void operator()(bool ta, bool tb, - int M, int N, int K, - half alpha, half* A, int lda, - half* B, int ldb, - half beta, half* C, int ldc) { - CHECK_CUBLAS_ERROR(cublasHgemm(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, A, lda, - B, ldb, - &beta, C, ldc)); + explicit CublasHgemmOp(cublasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, half alpha, half* A, int lda, half* B, + int ldb, half beta, half* C, int ldc) { + CHECK_CUBLAS_ERROR(cublasHgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, + &alpha, A, lda, B, ldb, &beta, C, ldc)); } }; struct CublasSgemmOp { typedef float TDatatype; cublasHandle_t handle; - explicit CublasSgemmOp(cublasHandle_t hdl) - : handle(hdl) {} - - void operator()(bool ta, bool tb, - int M, int N, int K, - float alpha, float* A, int lda, - float* B, int ldb, - float beta, float* C, int ldc) { - CHECK_CUBLAS_ERROR(cublasSgemm(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, A, lda, - B, ldb, - &beta, C, ldc)); + explicit CublasSgemmOp(cublasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + CHECK_CUBLAS_ERROR(cublasSgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, + &alpha, A, lda, B, ldb, &beta, C, ldc)); } }; struct CublasDgemmOp { typedef double TDatatype; cublasHandle_t handle; - explicit CublasDgemmOp(cublasHandle_t hdl) - : handle(hdl) {} - void operator()(bool ta, bool tb, - int M, int N, int K, - double alpha, double* A, int lda, - double* B, int ldb, - double beta, double* C, int ldc) { - CHECK_CUBLAS_ERROR(cublasDgemm(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, A, lda, - B, ldb, - &beta, C, ldc)); + explicit CublasDgemmOp(cublasHandle_t hdl) : handle(hdl) {} + void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda, + double* B, int ldb, double beta, double* C, int ldc) { + CHECK_CUBLAS_ERROR(cublasDgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, + &alpha, A, lda, B, ldb, &beta, C, ldc)); } }; struct CublasHgemmBatchOp { typedef half TDatatype; cublasHandle_t handle; - explicit CublasHgemmBatchOp(cublasHandle_t hdl) - : handle(hdl) {} + explicit CublasHgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {} void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, half alpha, half* A, int a_stride, int lda, half* B, int b_stride, int ldb, half beta, half* C, int c_stride, int ldc) { - CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, - A, lda, a_stride, - B, ldb, b_stride, - &beta, - C, ldc, c_stride, - batch_size)); + CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched( + handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride, + B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); } }; struct CublasSgemmBatchOp { typedef float TDatatype; cublasHandle_t handle; - explicit CublasSgemmBatchOp(cublasHandle_t hdl) - : handle(hdl) {} + explicit CublasSgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {} void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, int c_stride, int ldc) { - CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, - A, lda, a_stride, - B, ldb, b_stride, - &beta, - C, ldc, c_stride, - batch_size)); + CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched( + handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride, + B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); } }; struct CublasDgemmBatchOp { typedef double TDatatype; cublasHandle_t handle; - explicit CublasDgemmBatchOp(cublasHandle_t hdl) - : handle(hdl) {} + explicit CublasDgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {} void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, int c_stride, int ldc) { - CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, - A, lda, a_stride, - B, ldb, b_stride, - &beta, - C, ldc, c_stride, - batch_size)); + CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched( + handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride, + B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); } }; @@ -174,22 +120,19 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { return TypeMatch(in_dtype, kDLInt, 8); } else if (TypeMatch(out_dtype, kDLFloat, 32)) { - return TypeMatch(in_dtype, kDLInt, 8) || - TypeMatch(in_dtype, kDLFloat, 16); + return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16); } else { return false; } } -int roundoff(int v, int d) { - return (v + d - 1) / d * d; -} +int roundoff(int v, int d) { return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 -inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; // Reversed strides indicates an in-place transpose operation. @@ -230,53 +173,37 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; cublasLtMatmulDesc_t operationDesc = nullptr; CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I)); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTranspose, sizeof(opTranspose))); cublasOperation_t opTransA = BooleanToTranspose(transa); cublasOperation_t opTransB = BooleanToTranspose(transb); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA))); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTransA, sizeof(opTransA))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTransB, sizeof(opTransB))); // Create descriptors for the original matrices - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( - &Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k , - opTransA == CUBLAS_OP_N ? k : m, lda)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( - &Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n , - opTransB == CUBLAS_OP_N ? n : k, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k, + opTransA == CUBLAS_OP_N ? k : m, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n, + opTransB == CUBLAS_OP_N ? n : k, ldb)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32))); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); - - CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, - operationDesc, - &alpha, - B_data, - Adesc, - A_data, - Bdesc, - &beta, - C_data, - Cdesc, - C_data, - Cdesc, - NULL, - NULL, - 0, - 0)); + Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32))); + + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, &alpha, B_data, Adesc, A_data, Bdesc, &beta, + C_data, Cdesc, C_data, Cdesc, NULL, NULL, 0, 0)); } #endif -inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; CHECK_EQ(A->ndim, 2); @@ -297,10 +224,10 @@ inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { transb = IsInPlaceTransposed(B) ? !transb : transb; CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; - CHECK(!TypeMatch(A->dtype, kDLInt, 8) || - ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - CHECK(!TypeMatch(B->dtype, kDLInt, 8) || - ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; @@ -320,28 +247,21 @@ inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { beta_ptr = &beta_float; } - auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); - auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); - auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); - - CHECK_CUBLAS_ERROR(cublasGemmEx(hdl, - BooleanToTranspose(transb), - BooleanToTranspose(transa), - ColumnCount(B, transb), - RowCount(A, transa), - ColumnCount(A, transa), - alpha_ptr, - B_data, cuda_in_type, ColumnStride(B), - A_data, cuda_in_type, ColumnStride(A), - beta_ptr, - C_data, cuda_out_type, ColumnStride(C), - cuda_out_type, algo)); + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + CHECK_CUBLAS_ERROR(cublasGemmEx(hdl, BooleanToTranspose(transb), BooleanToTranspose(transa), + ColumnCount(B, transb), RowCount(A, transa), + ColumnCount(A, transa), alpha_ptr, B_data, cuda_in_type, + ColumnStride(B), A_data, cuda_in_type, ColumnStride(A), beta_ptr, + C_data, cuda_out_type, ColumnStride(C), cuda_out_type, algo)); } -inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; CHECK_EQ(A->ndim, 3); @@ -364,10 +284,10 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) transb = IsInPlaceTransposed(B) ? !transb : transb; CHECK(CheckMixPrecisionType(A->dtype, C->dtype, false)) << "Unsupported data type"; - CHECK(!TypeMatch(A->dtype, kDLInt, 8) || - ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - CHECK(!TypeMatch(B->dtype, kDLInt, 8) || - ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; @@ -391,88 +311,76 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) beta_ptr = &beta_float; } - auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); - auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); - auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); - CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx(hdl, - BooleanToTranspose(transb), - BooleanToTranspose(transa), - ColumnCount3D(B, transb), - RowCount3D(A, transa), - ColumnCount3D(A, transa), - alpha_ptr, - B_data, cuda_in_type, ColumnStride3D(B), B_size, - A_data, cuda_in_type, ColumnStride3D(A), A_size, - beta_ptr, - C_data, cuda_out_type, ColumnStride3D(C), C_size, - batch_size, cuda_out_type, algo)); + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx( + hdl, BooleanToTranspose(transb), BooleanToTranspose(transa), ColumnCount3D(B, transb), + RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data, cuda_in_type, + ColumnStride3D(B), B_size, A_data, cuda_in_type, ColumnStride3D(A), A_size, beta_ptr, C_data, + cuda_out_type, ColumnStride3D(C), C_size, batch_size, cuda_out_type, algo)); } // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* A = args[0]; - DLTensor* C = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - TryEnableTensorCore(entry_ptr->handle); + TryEnableTensorCore(entry_ptr->handle); - if (TypeEqual(A->dtype, C->dtype)) { - CHECK(TypeMatch(A->dtype, kDLFloat, 16) || - TypeMatch(A->dtype, kDLFloat, 32) || + if (TypeEqual(A->dtype, C->dtype)) { + CHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 16)) - CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); - else - CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); - } else { - CallGemmEx(args, ret, entry_ptr->handle); - } + if (TypeMatch(A->dtype, kDLFloat, 16)) + CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); + else if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); + else + CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); + } else { + CallGemmEx(args, ret, entry_ptr->handle); + } }); #if CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; +TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - TryEnableTensorCore(entry_ptr->handle); + TryEnableTensorCore(entry_ptr->handle); - CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; - cublasLtHandle_t ltHandle; - CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); - CallLtIgemm(args, ret, ltHandle); - CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); + CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; + cublasLtHandle_t ltHandle; + CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); + CallLtIgemm(args, ret, ltHandle); + CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); }); #endif // CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; - DLTensor* C = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - TryEnableTensorCore(entry_ptr->handle); - if (TypeEqual(A->dtype, C->dtype)) { - CHECK(TypeMatch(A->dtype, kDLFloat, 16) || - TypeMatch(A->dtype, kDLFloat, 32) || + TryEnableTensorCore(entry_ptr->handle); + if (TypeEqual(A->dtype, C->dtype)) { + CHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 16)) - CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) - CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); - else - CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); - } else { - CallBatchGemmEx(args, ret, entry_ptr->handle); - } + if (TypeMatch(A->dtype, kDLFloat, 16)) + CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); + else if (TypeMatch(A->dtype, kDLFloat, 32)) + CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); + else + CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); + } else { + CallBatchGemmEx(args, ret, entry_ptr->handle); + } }); } // namespace contrib diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 9953cda..d4ec087 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -21,18 +21,16 @@ * \file Use external cudnn utils function */ #include "cublas_utils.h" + #include #include + #include "../../cuda/cuda_common.h" namespace tvm { namespace contrib { - -CuBlasThreadEntry::CuBlasThreadEntry() { - CHECK_CUBLAS_ERROR(cublasCreate(&handle)); -} - +CuBlasThreadEntry::CuBlasThreadEntry() { CHECK_CUBLAS_ERROR(cublasCreate(&handle)); } CuBlasThreadEntry::~CuBlasThreadEntry() { if (handle) { @@ -41,10 +39,8 @@ CuBlasThreadEntry::~CuBlasThreadEntry() { } } - typedef dmlc::ThreadLocalStore CuBlasThreadStore; - CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() { auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; CuBlasThreadEntry* retval = CuBlasThreadStore::Get(); @@ -52,6 +48,5 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() { return retval; } - } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 2e553e2..5189c4f 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -24,11 +24,12 @@ #ifndef TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_ #define TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_ -#include -#include #include #include #include +#include +#include + #include #if CUDART_VERSION >= 10010 #include @@ -39,27 +40,35 @@ namespace contrib { inline const char* GetCublasErrorString(int error) { switch (error) { - case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; } return "Unrecognized error"; } #ifndef CHECK_CUBLAS_ERROR -#define CHECK_CUBLAS_ERROR(fn) \ - do { \ - int error = static_cast(fn); \ +#define CHECK_CUBLAS_ERROR(fn) \ + do { \ + int error = static_cast(fn); \ CHECK_EQ(error, CUBLAS_STATUS_SUCCESS) << "CUBLAS: " << GetCublasErrorString(error); \ } while (0) // ; intentionally left off. -#endif // CHECK_CUBLAS_ERROR - +#endif // CHECK_CUBLAS_ERROR struct CuBlasThreadEntry { CuBlasThreadEntry(); @@ -71,19 +80,26 @@ struct CuBlasThreadEntry { inline cudaDataType_t GetCudaDataType(DLDataType type) { if (type.code == kDLInt) { switch (type.bits) { - case 8: return CUDA_R_8I; - case 32: return CUDA_R_32I; + case 8: + return CUDA_R_8I; + case 32: + return CUDA_R_32I; } } else if (type.code == kDLUInt) { switch (type.bits) { - case 8: return CUDA_R_8U; - case 32: return CUDA_R_32U; + case 8: + return CUDA_R_8U; + case 32: + return CUDA_R_32U; } } else if (type.code == kDLFloat) { switch (type.bits) { - case 16: return CUDA_R_16F; - case 32: return CUDA_R_32F; - case 64: return CUDA_R_64F; + case 16: + return CUDA_R_16F; + case 32: + return CUDA_R_32F; + case 64: + return CUDA_R_64F; } } LOG(FATAL) << "Unsupported cuda type"; diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index c4c05d8..223a5b4 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -20,9 +20,10 @@ /*! * \file Use external cudnn utils function */ -#include #include #include +#include + #include "cudnn_utils.h" namespace tvm { @@ -30,19 +31,9 @@ namespace contrib { using namespace runtime; -void ConvolutionForward( - int mode, - int format, - int algo, - int dims, - int groups, - const int pad[], - const int stride[], - const int dilation[], - DLTensor* x, - DLTensor* w, - DLTensor* y, - const std::string& conv_dtype) { +void ConvolutionForward(int mode, int format, int algo, int dims, int groups, const int pad[], + const int stride[], const int dilation[], DLTensor* x, DLTensor* w, + DLTensor* y, const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); @@ -67,17 +58,11 @@ void ConvolutionForward( CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); if (dims == 2) { // Set Desc - CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, - pad[0], - pad[1], - stride[0], - stride[1], - dilation[0], - dilation[1], - entry_ptr->conv_entry.mode, - entry_ptr->conv_entry.data_type)); + CUDNN_CALL(cudnnSetConvolution2dDescriptor( + entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], + dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); int ni, ci, hi, wi; - if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { ni = 0; ci = 3; hi = 1; @@ -90,67 +75,46 @@ void ConvolutionForward( } // Set Filter - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - static_cast(w->shape[ni]), - static_cast(w->shape[ci]), - static_cast(w->shape[hi]), - static_cast(w->shape[wi]))); + CUDNN_CALL(cudnnSetFilter4dDescriptor( + entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, + static_cast(w->shape[ni]), static_cast(w->shape[ci]), + static_cast(w->shape[hi]), static_cast(w->shape[wi]))); // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - static_cast(x->shape[ni]), - static_cast(x->shape[ci]), - static_cast(x->shape[hi]), - static_cast(x->shape[wi]))); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, + static_cast(x->shape[ni]), static_cast(x->shape[ci]), + static_cast(x->shape[hi]), static_cast(x->shape[wi]))); // Set Output - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - static_cast(y->shape[ni]), - static_cast(y->shape[ci]), - static_cast(y->shape[hi]), - static_cast(y->shape[wi]))); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type, + static_cast(y->shape[ni]), static_cast(y->shape[ci]), + static_cast(y->shape[hi]), static_cast(y->shape[wi]))); } else { - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, - dims, - pad, - stride, - dilation, - entry_ptr->conv_entry.mode, + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); // Set Filter for (int i = 0; i < full_dims; i++) { dim[i] = static_cast(w->shape[i]); } - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - full_dims, + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, dim.data())); // Set Input for (int i = 0; i < full_dims; i++) { dim[i] = static_cast(x->shape[i]); } GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, - data_type, - full_dims, - dim.data(), - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + dim.data(), tensor_stride.data())); // Set Output for (int i = 0; i < full_dims; i++) { dim[i] = static_cast(y->shape[i]); } GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, - data_type, - full_dims, - dim.data(), - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, + dim.data(), tensor_stride.data())); } if (cudnnGetVersion() > 7000) { @@ -159,42 +123,23 @@ void ConvolutionForward( // Set workspace size_t workspace_size = 0; - CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.fwd_algo, - &workspace_size)); + CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.fwd_algo, &workspace_size)); entry_ptr->conv_entry.UpdateWorkspace(workspace_size); - CUDNN_CALL(cudnnConvolutionForward(entry_ptr->handle, - CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), - entry_ptr->conv_entry.input_desc, - x->data, - entry_ptr->conv_entry.filter_desc, - w->data, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.fwd_algo, - entry_ptr->conv_entry.workspace, - workspace_size, - CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), - entry_ptr->conv_entry.output_desc, - y->data)); + CUDNN_CALL(cudnnConvolutionForward( + entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo, + entry_ptr->conv_entry.workspace, workspace_size, + CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.output_desc, y->data)); } - -void OutputShape( - int format, - int dims, - int groups, - const int pad[], - const int stride[], - const int dilation[], - const int x_dim[], - const int w_dim[], - void *out_shape, - const std::string& data_dtype, - const std::string& conv_dtype) { +void OutputShape(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int x_dim[], const int w_dim[], void* out_shape, + const std::string& data_dtype, const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Data Type @@ -207,79 +152,46 @@ void OutputShape( // conv desc CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, - dims, - pad, - stride, - dilation, - CUDNN_CROSS_CORRELATION, + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); - if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { // Set Input CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - x_dim[0], - x_dim[3], - x_dim[1], - x_dim[2])); + entry_ptr->conv_entry.tensor_format, data_type, x_dim[0], + x_dim[3], x_dim[1], x_dim[2])); // filter desc - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - w_dim[0], - w_dim[3], - w_dim[1], - w_dim[2])); - - CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - static_cast(out_shape), - static_cast(out_shape) + 3, - static_cast(out_shape) + 1, - static_cast(out_shape) + 2)); + CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, w_dim[0], w_dim[3], + w_dim[1], w_dim[2])); + + CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, static_cast(out_shape), + static_cast(out_shape) + 3, static_cast(out_shape) + 1, + static_cast(out_shape) + 2)); } else { // Set Input std::vector tensor_stride(full_dims); GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, - data_type, - full_dims, - x_dim, - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + x_dim, tensor_stride.data())); // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - full_dims, - w_dim)); - - CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - full_dims, - static_cast(out_shape))); + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); + + CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, full_dims, static_cast(out_shape))); } } - -void FindAlgo( - int format, - int dims, - int groups, - const int pad[], - const int stride[], - const int dilation[], - const int x_dim[], - const int w_dim[], - const int y_dim[], - const std::string& data_dtype, - const std::string& conv_dtype, - TVMRetValue *ret) { +void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], + const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Data Type @@ -292,65 +204,46 @@ void FindAlgo( // conv desc CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, - dims, - pad, - stride, - dilation, - CUDNN_CROSS_CORRELATION, + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); std::vector tensor_stride(full_dims); // input desc GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, - data_type, - full_dims, - x_dim, - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + x_dim, tensor_stride.data())); // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - full_dims, - w_dim)); + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); // output desc GetCudnnStride(full_dims, y_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, - data_type, - full_dims, - y_dim, - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, + y_dim, tensor_stride.data())); if (cudnnGetVersion() > 7000) { CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) } int returned_algo_count = 0; cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - CUDNN_CONVOLUTION_FWD_ALGO_COUNT, - &returned_algo_count, - perf_results)); - - const std::vector fwd_algo_names{ - "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM", - "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM", - "CUDNN_CONVOLUTION_FWD_ALGO_GEMM", - "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT", - "CUDNN_CONVOLUTION_FWD_ALGO_FFT", - "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING", - "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD", - "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED" - }; + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, + CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results)); + + const std::vector fwd_algo_names{"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT", + "CUDNN_CONVOLUTION_FWD_ALGO_FFT", + "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING", + "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD", + "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"}; auto best_algo = perf_results[0].algo; - LOG(INFO) << "\tCUDNN Found " << returned_algo_count - << " fwd algorithms, choosing " << fwd_algo_names[best_algo]; + LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing " + << fwd_algo_names[best_algo]; for (int i = 0; i < returned_algo_count; ++i) { LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo] << " - time: " << perf_results[i].time << " ms" @@ -360,87 +253,83 @@ void FindAlgo( ret[0] = best_algo; } - TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int mode = args[0]; - int format = args[1]; - int algo = args[2]; - int pad_v[2], stride_v[2], dilation_v[2]; - for (int i = 0; i < 2; i++) { - pad_v[i] = args[3 + i]; - stride_v[i] = args[5 + i]; - dilation_v[i] = args[7 + i]; - } - DLTensor* x = args[9]; - DLTensor* w = args[10]; - DLTensor* y = args[11]; - std::string conv_dtype = args[12]; - int groups = args[13]; - - ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, - dilation_v, x, w, y, conv_dtype); -}); - + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; + } + DLTensor* x = args[9]; + DLTensor* w = args[10]; + DLTensor* y = args[11]; + std::string conv_dtype = args[12]; + int groups = args[13]; + + ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y, + conv_dtype); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int mode = args[0]; - int format = args[1]; - int algo = args[2]; - int pad_v[3], stride_v[3], dilation_v[3]; - for (int i = 0; i < 3; i++) { - pad_v[i] = args[3 + i]; - stride_v[i] = args[6 + i]; - dilation_v[i] = args[9 + i]; - } - DLTensor *x = args[12]; - DLTensor *w = args[13]; - DLTensor *y = args[14]; - std::string conv_dtype = args[15]; - int groups = args[16]; - - ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, - dilation_v, x, w, y, conv_dtype); -}); - + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[3], stride_v[3], dilation_v[3]; + for (int i = 0; i < 3; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[6 + i]; + dilation_v[i] = args[9 + i]; + } + DLTensor* x = args[12]; + DLTensor* w = args[13]; + DLTensor* y = args[14]; + std::string conv_dtype = args[15]; + int groups = args[16]; + + ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, dilation_v, x, w, y, + conv_dtype); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int format = args[0]; - int dims = args[1]; - int* pad = static_cast(static_cast(args[2])); - int* stride = static_cast(static_cast(args[3])); - int* dilation = static_cast(static_cast(args[4])); - int* x_dim = static_cast(static_cast(args[5])); - int* w_dim = static_cast(static_cast(args[6])); - void* out_shape = args[7]; - std::string data_dtype = args[8]; - std::string conv_dtype = args[9]; - int groups = args[10]; - - OutputShape(format, dims, groups, pad, stride, dilation, x_dim, - w_dim, out_shape, data_dtype, conv_dtype); -}); - + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* x_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + void* out_shape = args[7]; + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype, + conv_dtype); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int format = args[0]; - int dims = args[1]; - int* pad = static_cast(static_cast(args[2])); - int* stride = static_cast(static_cast(args[3])); - int* dilation = static_cast(static_cast(args[4])); - int* x_dim = static_cast(static_cast(args[5])); - int* w_dim = static_cast(static_cast(args[6])); - int* y_dim = static_cast(static_cast(args[7])); - std::string data_dtype = args[8]; - std::string conv_dtype = args[9]; - int groups = args[10]; - - FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, - w_dim, y_dim, data_dtype, conv_dtype, ret); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* x_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + int* y_dim = static_cast(static_cast(args[7])); + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, data_dtype, + conv_dtype, ret); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 9c895c5..cd934bc 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -21,38 +21,44 @@ * \file Use external cudnn utils function */ #include "cudnn_utils.h" + #include #include - namespace tvm { namespace contrib { // CuDNN Data Type -cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType &dtype) { +cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType& dtype) { switch (dtype.code) { - case kDLInt: - if (dtype.bits == 8 && dtype.lanes == 1) return CUDNN_DATA_INT8; - else if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_INT32; - else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4; - else - LOG(FATAL) << "Unsupported type"; - break; - case kDLUInt: + case kDLInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return CUDNN_DATA_INT8; + else if (dtype.bits == 32 && dtype.lanes == 1) + return CUDNN_DATA_INT32; + else if (dtype.bits == 8 && dtype.lanes == 4) + return CUDNN_DATA_INT8x4; + else LOG(FATAL) << "Unsupported type"; - break; - case kDLFloat: - if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_FLOAT; - else if (dtype.bits == 64 && dtype.lanes == 1) return CUDNN_DATA_DOUBLE; - else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF; - else - LOG(FATAL) << "Unsupported type"; - break; - } - return CUDNN_DATA_FLOAT; + break; + case kDLUInt: + LOG(FATAL) << "Unsupported type"; + break; + case kDLFloat: + if (dtype.bits == 32 && dtype.lanes == 1) + return CUDNN_DATA_FLOAT; + else if (dtype.bits == 64 && dtype.lanes == 1) + return CUDNN_DATA_DOUBLE; + else if (dtype.bits == 16 && dtype.lanes == 1) + return CUDNN_DATA_HALF; + else + LOG(FATAL) << "Unsupported type"; + break; + } + return CUDNN_DATA_FLOAT; } -template<> +template <> const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) { static const int int_v = 0; static const float float_v = 0; @@ -69,7 +75,7 @@ const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) { return nullptr; } -template<> +template <> const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) { static const int int_v = 1; static const float float_v = 1.f; @@ -91,22 +97,18 @@ const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) { CuDNNThreadEntry::CuDNNThreadEntry() { auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; auto func = runtime::Registry::Get("device_api.gpu"); - void *ret = (*func)(); + void* ret = (*func)(); cuda_api = static_cast(ret); CUDNN_CALL(cudnnCreate(&handle)); CUDNN_CALL(cudnnSetStream(handle, stream)); conv_entry.cuda_api = cuda_api; } -CuDNNThreadEntry::~CuDNNThreadEntry() { - CUDNN_CALL(cudnnDestroy(handle)); -} +CuDNNThreadEntry::~CuDNNThreadEntry() { CUDNN_CALL(cudnnDestroy(handle)); } typedef dmlc::ThreadLocalStore CuDNNThreadStore; -CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() { - return CuDNNThreadStore::Get(); -} +CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() { return CuDNNThreadStore::Get(); } // ConvEntry @@ -142,13 +144,9 @@ void ConvEntry::CleanWorkspace() { // SoftmaxEntry -SoftmaxEntry::SoftmaxEntry() { - CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); -} +SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); } -SoftmaxEntry::~SoftmaxEntry() { - CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); -} +SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); } } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index c2000d0..1b4eb40 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -24,11 +24,11 @@ #ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ #define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ -#include #include +#include #include -#include "../../cuda/cuda_common.h" +#include "../../cuda/cuda_common.h" namespace tvm { namespace contrib { @@ -41,24 +41,22 @@ namespace contrib { /*! breif Convert DLTensor type to CuDNN type */ struct CuDNNDataType { - static cudnnDataType_t DLTypeToCuDNNType(const DLDataType &dtype); - template + static cudnnDataType_t DLTypeToCuDNNType(const DLDataType& dtype); + template static const void* GetConst(cudnnDataType_t type); }; // struct CuDNNDataType -inline void GetStride(int nbdim, const int *dims, int *strides) { +inline void GetStride(int nbdim, const int* dims, int* strides) { int mul = 1; - for (int i = nbdim - 1; i >=0; --i) { + for (int i = nbdim - 1; i >= 0; --i) { mul *= dims[i]; strides[i] = mul; } } -inline void GetCudnnStride(int nbdim, - const int* dims, - int* strides) { +inline void GetCudnnStride(int nbdim, const int* dims, int* strides) { int mul = 1; - for (int i = nbdim - 1; i >=0; --i) { + for (int i = nbdim - 1; i >= 0; --i) { strides[i] = mul; mul *= dims[i]; } @@ -75,8 +73,8 @@ struct ConvEntry { cudnnConvolutionFwdAlgo_t fwd_algo; // cudnnMathType_t math_type; TVMContext ctx; - runtime::DeviceAPI *cuda_api; - void *workspace{nullptr}; + runtime::DeviceAPI* cuda_api; + void* workspace{nullptr}; size_t workspace_size{0}; ConvEntry(); ~ConvEntry(); @@ -98,7 +96,7 @@ struct CuDNNThreadEntry { cudnnHandle_t handle{nullptr}; ConvEntry conv_entry; SoftmaxEntry softmax_entry; - runtime::DeviceAPI *cuda_api{nullptr}; + runtime::DeviceAPI* cuda_api{nullptr}; static CuDNNThreadEntry* ThreadLocal(); }; // CuDNNThreadEntry diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index fb6d8a6..ff6d6a1 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -21,8 +21,9 @@ * \file src/runtime/contrib/cudnn/softmax.cc * \brief Use external cudnn softmax function */ -#include #include +#include + #include "cudnn_utils.h" namespace tvm { @@ -31,64 +32,53 @@ namespace contrib { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* x = args[0]; - DLTensor* y = args[1]; - int axis = args[2]; - int ndim = x->ndim; - int64_t* shape = x->shape; - if (axis < 0) axis += ndim; - CHECK(axis >= 0 && axis < ndim); + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* x = args[0]; + DLTensor* y = args[1]; + int axis = args[2]; + int ndim = x->ndim; + int64_t* shape = x->shape; + if (axis < 0) axis += ndim; + CHECK(axis >= 0 && axis < ndim); - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); - // Set mode and shape descriptor - if (axis == ndim - 1) { - int64_t N = 1; - for (int i = 0; i < ndim - 1; ++i) { - N *= shape[i]; - } - entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, - CUDNN_TENSOR_NCHW, - entry_ptr->softmax_entry.data_type, - static_cast(N), - static_cast(shape[ndim - 1]), - 1, - 1)); - } else { - int64_t pre_axis_dim = 1; - int64_t post_axis_dim = 1; - for (int i = 0; i < ndim; ++i) { - if (i < axis) { - pre_axis_dim *= shape[i]; - } else if (i > axis) { - post_axis_dim *= shape[i]; + // Set mode and shape descriptor + if (axis == ndim - 1) { + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, + CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type, + static_cast(N), + static_cast(shape[ndim - 1]), 1, 1)); + } else { + int64_t pre_axis_dim = 1; + int64_t post_axis_dim = 1; + for (int i = 0; i < ndim; ++i) { + if (i < axis) { + pre_axis_dim *= shape[i]; + } else if (i > axis) { + post_axis_dim *= shape[i]; + } + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, + entry_ptr->softmax_entry.data_type, static_cast(pre_axis_dim), + static_cast(shape[axis]), static_cast(post_axis_dim), 1)); } - } - entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, - CUDNN_TENSOR_NCHW, - entry_ptr->softmax_entry.data_type, - static_cast(pre_axis_dim), - static_cast(shape[axis]), - static_cast(post_axis_dim), - 1)); - } - auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); - auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); - CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, - CUDNN_SOFTMAX_ACCURATE, - entry_ptr->softmax_entry.mode, - alpha, - entry_ptr->softmax_entry.shape_desc, - x->data, - beta, - entry_ptr->softmax_entry.shape_desc, - y->data)); -}); + auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); + auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); + CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, CUDNN_SOFTMAX_ACCURATE, + entry_ptr->softmax_entry.mode, alpha, + entry_ptr->softmax_entry.shape_desc, x->data, beta, + entry_ptr->softmax_entry.shape_desc, y->data)); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 0922ac1..5b9f5e1 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -22,8 +22,6 @@ * \brief TVM compatible wrappers for dnnl kernels. */ -#include "dnnl_kernel.h" - #include #include #include @@ -34,6 +32,8 @@ #include #include +#include "dnnl_kernel.h" + namespace tvm { namespace runtime { namespace contrib { @@ -133,8 +133,7 @@ extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op()); } -extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, - int p_I_, int p_O_) { +extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) { using tag = memory::format_tag; using dt = memory::data_type; @@ -157,8 +156,8 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, auto bias_memory = memory(bias_md, eng, bias.data()); auto dst_memory = memory(dst_md, eng); - auto dense_desc = inner_product_forward::desc( - prop_kind::forward_inference, data_md, weight_md, bias_md, dst_md); + auto dense_desc = inner_product_forward::desc(prop_kind::forward_inference, data_md, weight_md, + bias_md, dst_md); auto dense_prim_desc = inner_product_forward::primitive_desc(dense_desc, eng); assert(dst_md == dense_prim_desc.dst_desc()); @@ -171,8 +170,7 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, read_from_dnnl_memory(out, dst_memory); } -extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, - int p_W_) { +extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_) { using tag = memory::format_tag; using dt = memory::data_type; @@ -186,8 +184,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, auto data_memory = memory(data_md, eng, data); auto dst_memory = memory(data_md, eng); - auto relu_desc = eltwise_forward::desc(prop_kind::forward_inference, - algorithm::eltwise_relu, data_md, 0); + auto relu_desc = + eltwise_forward::desc(prop_kind::forward_inference, algorithm::eltwise_relu, data_md, 0); auto relu_prim_desc = eltwise_forward::primitive_desc(relu_desc, eng); assert(data_md == relu_prim_desc.dst_desc()); @@ -215,8 +213,7 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo auto bn_desc = batch_normalization_forward::desc( prop_kind::forward_inference, data_md, p_E_, - normalization_flags::use_global_stats | - normalization_flags::use_scale_shift); + normalization_flags::use_global_stats | normalization_flags::use_scale_shift); auto bn_prim_desc = batch_normalization_forward::primitive_desc(bn_desc, eng); assert(data_md == bn_prim_desc.dst_desc()); @@ -239,8 +236,8 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo free(weight); } -extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, - int p_C_, int p_H_, int p_W_) { +extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, int p_C_, int p_H_, + int p_W_) { using tag = memory::format_tag; using dt = memory::data_type; @@ -257,15 +254,14 @@ extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, auto weight_memory = memory(weight_md, eng, weight); auto dst_memory = memory(dst_md, eng); - auto add_desc = - binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); + auto add_desc = binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); auto add_prim_desc = binary::primitive_desc(add_desc, eng); assert(dst_md == add_prim_desc.dst_desc()); auto add = binary(add_prim_desc); - add.execute(s, {{DNNL_ARG_SRC_0, data_memory}, - {DNNL_ARG_SRC_1, weight_memory}, - {DNNL_ARG_DST, dst_memory}}); + add.execute( + s, + {{DNNL_ARG_SRC_0, data_memory}, {DNNL_ARG_SRC_1, weight_memory}, {DNNL_ARG_DST, dst_memory}}); s.wait(); read_from_dnnl_memory(out, dst_memory); } diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index f92d767..dbc064a 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -26,6 +26,7 @@ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #include + #include "dnnl.hpp" namespace tvm { diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index 4823ef7..13b3c34 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -20,25 +20,23 @@ /*! * \file edgetpu_runtime.cc */ -#include +#include "edgetpu_runtime.h" + +#include #include #include #include -#include - - -#include "edgetpu_runtime.h" +#include namespace tvm { namespace runtime { -void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, - TVMContext ctx) { +void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { const char* buffer = tflite_model_bytes.c_str(); size_t buffer_size = tflite_model_bytes.size(); // Load compiled model as a FlatBufferModel std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); // Build resolver tflite::ops::builtin::BuiltinOpResolver resolver; // Init EdgeTPUContext object @@ -58,16 +56,14 @@ void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, ctx_ = ctx; } -Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, - TVMContext ctx) { +Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx) { auto exec = make_object(); exec->Init(tflite_model_bytes, ctx); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = EdgeTPURuntimeCreate(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = EdgeTPURuntimeCreate(args[0], args[1]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.h b/src/runtime/contrib/edgetpu/edgetpu_runtime.h index 78730d5..af3517b 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.h +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.h @@ -25,8 +25,8 @@ #ifndef TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_ -#include #include +#include #include "../tflite/tflite_runtime.h" @@ -44,17 +44,14 @@ class EdgeTPURuntime : public TFLiteRuntime { /*! * \return The type key of the executor. */ - const char* type_key() const final { - return "EdgeTPURuntime"; - } + const char* type_key() const final { return "EdgeTPURuntime"; } /*! * \brief Initialize the edge TPU tflite runtime with tflite model and context. * \param tflite_model_bytes The tflite model. * \param ctx The context where the tflite model will be executed on. */ - void Init(const std::string& tflite_model_bytes, - TVMContext ctx); + void Init(const std::string& tflite_model_bytes, TVMContext ctx); private: std::shared_ptr edgetpu_context_; diff --git a/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc b/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc index 98078b6..1a63ede 100644 --- a/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc +++ b/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc @@ -42,8 +42,8 @@ #include #include -#include #include +#include #include #include #include @@ -76,9 +76,8 @@ int Add(TVMValue* value, int* type_code, int nargs) { DLTensor* arg0 = static_cast(value[0].v_handle); DLTensor* arg1 = static_cast(value[1].v_handle); DLTensor* out = static_cast(value[2].v_handle); - Add_(static_cast(arg0->data), arg0->shape[0], - static_cast(arg1->data), arg1->shape[0], - static_cast(out->data)); + Add_(static_cast(arg0->data), arg0->shape[0], static_cast(arg1->data), + arg1->shape[0], static_cast(out->data)); return 0; } @@ -93,9 +92,8 @@ int Sub(TVMValue* value, int* type_code, int nargs) { DLTensor* arg0 = static_cast(value[0].v_handle); DLTensor* arg1 = static_cast(value[1].v_handle); DLTensor* out = static_cast(value[2].v_handle); - Sub_(static_cast(arg0->data), arg0->shape[0], - static_cast(arg1->data), arg1->shape[0], - static_cast(out->data)); + Sub_(static_cast(arg0->data), arg0->shape[0], static_cast(arg1->data), + arg1->shape[0], static_cast(out->data)); return 0; } @@ -110,9 +108,8 @@ int Mul(TVMValue* value, int* type_code, int nargs) { DLTensor* arg0 = static_cast(value[0].v_handle); DLTensor* arg1 = static_cast(value[1].v_handle); DLTensor* out = static_cast(value[2].v_handle); - Mul_(static_cast(arg0->data), arg0->shape[0], - static_cast(arg1->data), arg1->shape[0], - static_cast(out->data)); + Mul_(static_cast(arg0->data), arg0->shape[0], static_cast(arg1->data), + arg1->shape[0], static_cast(out->data)); return 0; } @@ -136,8 +133,7 @@ class ExampleJsonModule : public ModuleNode { * * \return The function pointer when it is found, otherwise, PackedFunc(nullptr). */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (this->graph_.find(name) != this->graph_.end()) { this->curr_subgraph_ = name; return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -215,9 +211,7 @@ class ExampleJsonModule : public ModuleNode { * * \param stream. The stream to save the binary. */ - void SaveToBinary(dmlc::Stream* stream) final { - stream->Write(this->graph_json_); - } + void SaveToBinary(dmlc::Stream* stream) final { stream->Write(this->graph_json_); } /*! * \brief Parse the example json string. @@ -333,12 +327,10 @@ class ExampleJsonModule : public ModuleNode { }; TVM_REGISTER_GLOBAL("runtime.module.loadfile_examplejson") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = ExampleJsonModule::Create(args[0]); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = ExampleJsonModule::Create(args[0]); }); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_examplejson") -.set_body_typed(ExampleJsonModule::LoadFromBinary); + .set_body_typed(ExampleJsonModule::LoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index d457548..1353e2f 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -20,9 +20,10 @@ /*! * \file Use external miopen utils function */ -#include #include #include +#include + #include "miopen_utils.h" namespace tvm { @@ -31,8 +32,7 @@ namespace miopen { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup").set_body([](TVMArgs args, TVMRetValue* ret) { const int mode = args[0]; const int dtype = args[1]; const int pad_h = args[2]; @@ -50,72 +50,52 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") const int w_dim2 = args[14]; const int w_dim3 = args[15]; const int n_group = args[16]; - void *out_shape = args[17]; + void* out_shape = args[17]; MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); assert(n_group > 0 && "Group Size > 0 is expected"); - if (n_group > 1) - assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); + if (n_group > 1) assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); // Set Ctx entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0}; // Set Data Type - entry_ptr->conv_entry.data_type = static_cast( - dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at - // this moment. + entry_ptr->conv_entry.data_type = + static_cast(dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), + // int32, int8 at this moment. // Set Desc MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.mode, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w)); + entry_ptr->conv_entry.mode, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w)); if (n_group > 1) MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group)); // Set Filter MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.data_type, - w_dim0, - w_dim1/n_group, - w_dim2, - w_dim3)); + entry_ptr->conv_entry.data_type, w_dim0, w_dim1 / n_group, + w_dim2, w_dim3)); // Set Input MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.data_type, - x_dim0, - x_dim1, - x_dim2, + entry_ptr->conv_entry.data_type, x_dim0, x_dim1, x_dim2, x_dim3)); // Set Output shape - MIOPEN_CALL(miopenGetConvolutionForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - static_cast(out_shape), - static_cast(out_shape) + 1, - static_cast(out_shape) + 2, - static_cast(out_shape) + 3)); - - const int *oshape = static_cast(out_shape); + MIOPEN_CALL(miopenGetConvolutionForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, static_cast(out_shape), + static_cast(out_shape) + 1, static_cast(out_shape) + 2, + static_cast(out_shape) + 3)); + + const int* oshape = static_cast(out_shape); // Set Output MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.data_type, - oshape[0], - oshape[1], - oshape[2], - oshape[3])); + entry_ptr->conv_entry.data_type, oshape[0], oshape[1], + oshape[2], oshape[3])); // Set workspace size_t workspace_size = 0; - MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize(entry_ptr->handle, - entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - &workspace_size)); + MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, &workspace_size)); entry_ptr->conv_entry.UpdateWorkspace(workspace_size); const size_t input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3; @@ -123,12 +103,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") const size_t output_size = oshape[0] * oshape[1] * oshape[2] * oshape[3]; runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry.rocm_api; - float* input_buf = static_cast(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, - input_size * sizeof(float))); - float* filter_buf = static_cast(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, - filter_size * sizeof(float))); - float* output_buf = static_cast(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, - output_size * sizeof(float))); + float* input_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, input_size * sizeof(float))); + float* filter_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, filter_size * sizeof(float))); + float* output_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, output_size * sizeof(float))); const int request_algo_count = 4; const bool exhaustive_search = false; @@ -137,20 +117,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") int returned_algo_count = 0; miopenConvAlgoPerf_t perfs[4]; - MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm(entry_ptr->handle, - entry_ptr->conv_entry.input_desc, - input_buf, - entry_ptr->conv_entry.filter_desc, - filter_buf, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - output_buf, - request_algo_count, - &returned_algo_count, - perfs, - workspace, - workspace_size, - exhaustive_search)); + MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, input_buf, + entry_ptr->conv_entry.filter_desc, filter_buf, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.output_desc, output_buf, request_algo_count, &returned_algo_count, + perfs, workspace, workspace_size, exhaustive_search)); rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, input_buf); rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, filter_buf); @@ -163,8 +134,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") "miopenConvolutionFwdAlgoWinograd", }; const auto best_algo = perfs[0].fwd_algo; - LOG(INFO) << "\tMIOpen Found " << returned_algo_count - << " fwd algorithms, choosing " << fwd_algo_names[best_algo]; + LOG(INFO) << "\tMIOpen Found " << returned_algo_count << " fwd algorithms, choosing " + << fwd_algo_names[best_algo]; for (int i = 0; i < returned_algo_count; ++i) { LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perfs[i].fwd_algo] << " - time: " << perfs[i].time << " ms" @@ -174,79 +145,56 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") ret[0] = static_cast(best_algo); }); - TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - const int mode = args[0]; - const int dtype = args[1]; - const int pad_h = args[2]; - const int pad_w = args[3]; - const int stride_h = args[4]; - const int stride_w = args[5]; - const int dilation_h = args[6]; - const int dilation_w = args[7]; - const int algo = args[8]; - const DLTensor *x = args[9]; - const DLTensor *w = args[10]; - const DLTensor *y = args[11]; - - MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); - entry_ptr->conv_entry.fwd_algo = static_cast(algo); - // Set Mode - entry_ptr->conv_entry.mode = static_cast(mode); - // Set Ctx - entry_ptr->conv_entry.ctx = x->ctx; - // Set Data Type - entry_ptr->conv_entry.data_type = static_cast( - dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at - // this moment. - // Set Desc - MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.mode, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w)); - // Set Filter - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.data_type, - w->shape[0], - w->shape[1], - w->shape[2], - w->shape[3])); - // Set Input - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.data_type, - x->shape[0], - x->shape[1], - x->shape[2], - x->shape[3])); - // Set Output - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.data_type, - y->shape[0], - y->shape[1], - y->shape[2], - y->shape[3])); - - const float alpha = 1.f; - const float beta = 0.f; - MIOPEN_CALL(miopenConvolutionForward(entry_ptr->handle, - &alpha, - entry_ptr->conv_entry.input_desc, - x->data, - entry_ptr->conv_entry.filter_desc, - w->data, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.fwd_algo, - &beta, - entry_ptr->conv_entry.output_desc, - y->data, - entry_ptr->conv_entry.workspace, - entry_ptr->conv_entry.workspace_size)); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + const int mode = args[0]; + const int dtype = args[1]; + const int pad_h = args[2]; + const int pad_w = args[3]; + const int stride_h = args[4]; + const int stride_w = args[5]; + const int dilation_h = args[6]; + const int dilation_w = args[7]; + const int algo = args[8]; + const DLTensor* x = args[9]; + const DLTensor* w = args[10]; + const DLTensor* y = args[11]; + + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + entry_ptr->conv_entry.fwd_algo = static_cast(algo); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + // Set Ctx + entry_ptr->conv_entry.ctx = x->ctx; + // Set Data Type + entry_ptr->conv_entry.data_type = + static_cast(dtype); // MIOpen supports fp32(miopenFloat), + // fp16(miopenHalf) at this moment. + // Set Desc + MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.mode, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w)); + // Set Filter + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.data_type, w->shape[0], + w->shape[1], w->shape[2], w->shape[3])); + // Set Input + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.data_type, x->shape[0], + x->shape[1], x->shape[2], x->shape[3])); + // Set Output + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.data_type, y->shape[0], + y->shape[1], y->shape[2], y->shape[3])); + + const float alpha = 1.f; + const float beta = 0.f; + MIOPEN_CALL(miopenConvolutionForward( + entry_ptr->handle, &alpha, entry_ptr->conv_entry.input_desc, x->data, + entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.fwd_algo, &beta, entry_ptr->conv_entry.output_desc, y->data, + entry_ptr->conv_entry.workspace, entry_ptr->conv_entry.workspace_size)); + }); } // namespace miopen } // namespace contrib diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index 330ccdd..a579180 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -21,20 +21,22 @@ * \file Use external miopen utils function */ #include "miopen_utils.h" + #include #include -#include + #include +#include namespace tvm { namespace contrib { namespace miopen { std::string miopenGetErrorString(int error_code) { - const std::vector mio_err{ - "StatusSuccess ", "StatusNotInitialized ", "StatusInvalidValue ", - "StatusBadParm ", "StatusAllocFailed ", "StatusInternalError ", - "StatusNotImplemented ", "StatusUnknownError "}; + const std::vector mio_err{"StatusSuccess ", "StatusNotInitialized ", + "StatusInvalidValue ", "StatusBadParm ", + "StatusAllocFailed ", "StatusInternalError ", + "StatusNotImplemented ", "StatusUnknownError "}; return mio_err[error_code]; } @@ -42,22 +44,18 @@ std::string miopenGetErrorString(int error_code) { MIOpenThreadEntry::MIOpenThreadEntry() { auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream; auto func = runtime::Registry::Get("device_api.rocm"); - void *ret = (*func)(); + void* ret = (*func)(); rocm_api = static_cast(ret); MIOPEN_CALL(miopenCreate(&handle)); MIOPEN_CALL(miopenSetStream(handle, stream)); conv_entry.rocm_api = rocm_api; } -MIOpenThreadEntry::~MIOpenThreadEntry() { - MIOPEN_CALL(miopenDestroy(handle)); -} +MIOpenThreadEntry::~MIOpenThreadEntry() { MIOPEN_CALL(miopenDestroy(handle)); } typedef dmlc::ThreadLocalStore MIOpenThreadStore; -MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { - return MIOpenThreadStore::Get(); -} +MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { return MIOpenThreadStore::Get(); } // ConvEntry diff --git a/src/runtime/contrib/miopen/miopen_utils.h b/src/runtime/contrib/miopen/miopen_utils.h index 8831e4f..4dec2ad 100644 --- a/src/runtime/contrib/miopen/miopen_utils.h +++ b/src/runtime/contrib/miopen/miopen_utils.h @@ -27,7 +27,9 @@ #include #include #include + #include + #include "../../rocm/rocm_common.h" namespace tvm { @@ -36,11 +38,10 @@ namespace miopen { std::string miopenGetErrorString(int error_code); -#define MIOPEN_CALL(func) \ - { \ - miopenStatus_t e = (func); \ - CHECK_EQ(e, miopenStatusSuccess) \ - << "miopen error: " << miopenGetErrorString(e); \ +#define MIOPEN_CALL(func) \ + { \ + miopenStatus_t e = (func); \ + CHECK_EQ(e, miopenStatusSuccess) << "miopen error: " << miopenGetErrorString(e); \ } struct ConvEntry { @@ -52,8 +53,8 @@ struct ConvEntry { miopenTensorDescriptor_t output_desc; miopenConvFwdAlgorithm_t fwd_algo; TVMContext ctx; - runtime::DeviceAPI *rocm_api; - void *workspace{nullptr}; + runtime::DeviceAPI* rocm_api; + void* workspace{nullptr}; size_t workspace_size{0}; ConvEntry(); ~ConvEntry(); @@ -66,8 +67,8 @@ struct MIOpenThreadEntry { ~MIOpenThreadEntry(); miopenHandle_t handle{nullptr}; ConvEntry conv_entry; - runtime::DeviceAPI *rocm_api{nullptr}; - static MIOpenThreadEntry *ThreadLocal(); + runtime::DeviceAPI* rocm_api{nullptr}; + static MIOpenThreadEntry* ThreadLocal(); }; // MIOpenThreadEntry } // namespace miopen diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index 064e6d5..b598014 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -24,69 +24,59 @@ namespace contrib { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *buf = args[0]; - DLTensor *img = args[1]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* buf = args[0]; + DLTensor* img = args[1]; // copy to temp id mtlbuf = (__bridge id)(buf->data); - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry *rt = - runtime::metal::MetalThreadEntry::ThreadLocal(); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); id dev = entry_ptr->metal_api->GetDevice(buf->ctx); id temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); - entry_ptr->metal_api->CopyDataFromTo( - (__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length], - buf->ctx, buf->ctx, nullptr - ); + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge void*)temp, 0, + [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); - MPSImageDescriptor *desc = [MPSImageDescriptor - imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 - width:buf->shape[2] - height:buf->shape[1] - featureChannels:buf->shape[3]]; + MPSImageDescriptor* desc = + [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 + width:buf->shape[2] + height:buf->shape[1] + featureChannels:buf->shape[3]]; - MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc); + MPSImage* mpsimg = entry_ptr->AllocMPSImage(dev, desc); [mpsimg writeBytes:[temp contents] dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels imageIndex:0]; - img->data = (__bridge void *)mpsimg; + img->data = (__bridge void*)mpsimg; [mpsimg readBytes:[temp contents] dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels imageIndex:0]; +}); - }); - -TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *img = args[0]; - DLTensor *buf = args[1]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* img = args[0]; + DLTensor* buf = args[1]; id mtlbuf = (__bridge id)(buf->data); - MPSImage *mpsimg = (__bridge MPSImage *)(img->data); - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry *rt = - runtime::metal::MetalThreadEntry::ThreadLocal(); + MPSImage* mpsimg = (__bridge MPSImage*)(img->data); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); id temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); [mpsimg readBytes:[temp contents] dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels imageIndex:0]; - entry_ptr->metal_api->CopyDataFromTo( - (__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length], - buf->ctx, buf->ctx, nullptr); - - }); + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge void*)mtlbuf, 0, + [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); +}); -TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) { // MPS-NHWC - DLTensor *data = args[0]; - DLTensor *weight = args[1]; - DLTensor *output = args[2]; + DLTensor* data = args[0]; + DLTensor* weight = args[1]; + DLTensor* output = args[2]; int pad = args[3]; int stride = args[4]; @@ -108,54 +98,48 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img"); auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer"); // Get Metal device API - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry *rt = - runtime::metal::MetalThreadEntry::ThreadLocal(); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); id dev = entry_ptr->metal_api->GetDevice(data->ctx); - id queue = - entry_ptr->metal_api->GetCommandQueue(data->ctx); + id queue = entry_ptr->metal_api->GetCommandQueue(data->ctx); id cb = [queue commandBuffer]; // data to MPSImage DLTensor tmp_in; (*f_buf2img)(data, &tmp_in); - MPSImage *tempA = (__bridge MPSImage *)tmp_in.data; + MPSImage* tempA = (__bridge MPSImage*)tmp_in.data; // weight to temp memory id bufB = (__bridge id)(weight->data); id tempB = rt->GetTempBuffer(weight->ctx, [bufB length]); - entry_ptr->metal_api->CopyDataFromTo( - (__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length], - weight->ctx, weight->ctx, nullptr); - float *ptr_w = (float *)[tempB contents]; + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0, + [bufB length], weight -> ctx, weight -> ctx, nullptr); + float* ptr_w = (float*)[tempB contents]; // output to MPSImage DLTensor tmp_out; (*f_buf2img)(output, &tmp_out); - MPSImage *tempC = (__bridge MPSImage *)tmp_out.data; + MPSImage* tempC = (__bridge MPSImage*)tmp_out.data; // conv desc - MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor - cnnConvolutionDescriptorWithKernelWidth:kW - kernelHeight:kH - inputFeatureChannels:iCh - outputFeatureChannels:oCh]; + MPSCNNConvolutionDescriptor* conv_desc = + [MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW + kernelHeight:kH + inputFeatureChannels:iCh + outputFeatureChannels:oCh]; [conv_desc setStrideInPixelsX:stride]; [conv_desc setStrideInPixelsY:stride]; - MPSCNNConvolution *conv = - [[MPSCNNConvolution alloc] initWithDevice:dev - convolutionDescriptor:conv_desc - kernelWeights:ptr_w - biasTerms:nil - flags:MPSCNNConvolutionFlagsNone]; + MPSCNNConvolution* conv = [[MPSCNNConvolution alloc] initWithDevice:dev + convolutionDescriptor:conv_desc + kernelWeights:ptr_w + biasTerms:nil + flags:MPSCNNConvolutionFlagsNone]; if (pad == 0) { - conv.padding = [MPSNNDefaultPadding - paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | - MPSNNPaddingMethodSizeSame]; + conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | + MPSNNPaddingMethodSizeSame]; } else if (pad == 1) { - conv.padding = [MPSNNDefaultPadding - paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | - MPSNNPaddingMethodSizeValidOnly]; + conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | + MPSNNPaddingMethodSizeValidOnly]; } [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC]; @@ -166,8 +150,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") [cb waitUntilCompleted]; (*f_img2buf)(&tmp_out, output); +}); - }); - -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index bc12167..109c952 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -24,11 +24,10 @@ namespace contrib { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; // call gemm for simple compact code. @@ -42,7 +41,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") CHECK(TypeMatch(B->dtype, kDLFloat, 32)); CHECK(TypeMatch(C->dtype, kDLFloat, 32)); // Get Metal device API - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); // CHECK_EQ(A->ctx, B->ctx); // CHECK_EQ(A->ctx, C->ctx); id dev = entry_ptr->metal_api->GetDevice(A->ctx); @@ -55,36 +54,31 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K); // mps a MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); - MPSMatrixDescriptor *descA = [MPSMatrixDescriptor - matrixDescriptorWithDimensions:M - columns:K - rowBytes:K * sizeof(MPSDataTypeFloat32) - dataType:MPSDataTypeFloat32]; + MPSMatrixDescriptor* descA = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:K + rowBytes:K * sizeof(MPSDataTypeFloat32) + dataType:MPSDataTypeFloat32]; id bufA = (__bridge id)(A->data); - MPSMatrix *matrixA = - [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; + MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; // mps b - MPSMatrixDescriptor *descB = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:K - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; + MPSMatrixDescriptor* descB = [MPSMatrixDescriptor matrixDescriptorWithDimensions:K + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; id bufB = (__bridge id)(B->data); - MPSMatrix *matrixB = - [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; + MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; // mps c - MPSMatrixDescriptor *descC = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:M - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; + MPSMatrixDescriptor* descC = [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; id bufC = (__bridge id)(C->data); - MPSMatrix *matrixC = - [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; + MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; // kernel - MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init]; - MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev + MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init]; + MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev transposeLeft:transa transposeRight:transb resultRows:M @@ -93,13 +87,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") alpha:1.0f beta:0.0f]; CHECK(sgemm != nil); - [sgemm encodeToCommandBuffer:cb - leftMatrix:matrixA - rightMatrix:matrixB - resultMatrix:matrixC]; + [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; [cb commit]; +}); - }); - -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/mps/mps_utils.h b/src/runtime/contrib/mps/mps_utils.h index f1fff95..170451e 100644 --- a/src/runtime/contrib/mps/mps_utils.h +++ b/src/runtime/contrib/mps/mps_utils.h @@ -27,10 +27,12 @@ #import #include #include +#include #include #include -#include + #include + #include "../../metal/metal_common.h" namespace tvm { @@ -38,18 +40,17 @@ namespace contrib { /*! breif Convert DLTensor type to MPS type */ struct MPSType { - static MPSDataType DLTypeToMPSType(const DLDataType &dtype); + static MPSDataType DLTypeToMPSType(const DLDataType& dtype); }; // struct MPSType struct MetalThreadEntry { MetalThreadEntry(); ~MetalThreadEntry(); - MPSImage *AllocMPSImage(id dev, MPSImageDescriptor *desc); - MPSTemporaryImage *AllocTempImage(id cb, - MPSImageDescriptor *desc); - runtime::metal::MetalWorkspace *metal_api{nullptr}; - static MetalThreadEntry *ThreadLocal(); - std::vector img_table; + MPSImage* AllocMPSImage(id dev, MPSImageDescriptor* desc); + MPSTemporaryImage* AllocTempImage(id cb, MPSImageDescriptor* desc); + runtime::metal::MetalWorkspace* metal_api{nullptr}; + static MetalThreadEntry* ThreadLocal(); + std::vector img_table; }; // MetalThreadEntry } // namespace contrib diff --git a/src/runtime/contrib/mps/mps_utils.mm b/src/runtime/contrib/mps/mps_utils.mm index b3d4070..f9f8043 100644 --- a/src/runtime/contrib/mps/mps_utils.mm +++ b/src/runtime/contrib/mps/mps_utils.mm @@ -23,60 +23,58 @@ namespace tvm { namespace contrib { // MPS Data Type -MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) { +MPSDataType MPSType::DLTypeToMPSType(const DLDataType& dtype) { switch (dtype.code) { - case kDLInt: - if (dtype.bits == 8 && dtype.lanes == 1) - return MPSDataTypeInt8; - else if (dtype.bits == 16 && dtype.lanes == 1) - return MPSDataTypeInt16; - else + case kDLInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return MPSDataTypeInt8; + else if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeInt16; + else + LOG(FATAL) << "Unsupported type"; + break; + case kDLUInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return MPSDataTypeUInt8; + else if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeUInt16; + else if (dtype.bits == 32 && dtype.lanes == 1) + return MPSDataTypeUInt32; LOG(FATAL) << "Unsupported type"; - break; - case kDLUInt: - if (dtype.bits == 8 && dtype.lanes == 1) - return MPSDataTypeUInt8; - else if (dtype.bits == 16 && dtype.lanes == 1) - return MPSDataTypeUInt16; - else if (dtype.bits == 32 && dtype.lanes == 1) - return MPSDataTypeUInt32; - LOG(FATAL) << "Unsupported type"; - break; - case kDLFloat: - if (dtype.bits == 16 && dtype.lanes == 1) - return MPSDataTypeFloat16; - else if (dtype.bits == 32 && dtype.lanes == 1) - return MPSDataTypeFloat32; - else + break; + case kDLFloat: + if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeFloat16; + else if (dtype.bits == 32 && dtype.lanes == 1) + return MPSDataTypeFloat32; + else + LOG(FATAL) << "Unsupported type"; + break; + default: LOG(FATAL) << "Unsupported type"; - break; - default: - LOG(FATAL) << "Unsupported type"; } return MPSDataTypeFloat32; } // MetalThreadEntry -MPSImage *MetalThreadEntry::AllocMPSImage(id dev, - MPSImageDescriptor *desc) { - MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc]; +MPSImage* MetalThreadEntry::AllocMPSImage(id dev, MPSImageDescriptor* desc) { + MPSImage* mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc]; img_table.push_back(mpsimg); return mpsimg; } -MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id cb, - MPSImageDescriptor *desc) { - MPSTemporaryImage *mpsimg = - [MPSTemporaryImage temporaryImageWithCommandBuffer:cb - imageDescriptor:desc]; +MPSTemporaryImage* MetalThreadEntry::AllocTempImage(id cb, + MPSImageDescriptor* desc) { + MPSTemporaryImage* mpsimg = [MPSTemporaryImage temporaryImageWithCommandBuffer:cb + imageDescriptor:desc]; return mpsimg; } MetalThreadEntry::MetalThreadEntry() { auto func = runtime::Registry::Get("device_api.metal"); - void *ret = (*func)(); - metal_api = static_cast(ret); + void* ret = (*func)(); + metal_api = static_cast(ret); } MetalThreadEntry::~MetalThreadEntry() { @@ -87,9 +85,7 @@ MetalThreadEntry::~MetalThreadEntry() { typedef dmlc::ThreadLocalStore MetalThreadStore; -MetalThreadEntry *MetalThreadEntry::ThreadLocal() { - return MetalThreadStore::Get(); -} +MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/nnpack/convolution.cc b/src/runtime/contrib/nnpack/convolution.cc index 79ea191..54c9ea4 100644 --- a/src/runtime/contrib/nnpack/convolution.cc +++ b/src/runtime/contrib/nnpack/convolution.cc @@ -20,11 +20,12 @@ /*! * \file Use external nnpack library call. */ -#include -#include -#include #include #include +#include +#include +#include + #include "nnpack_utils.h" namespace tvm { @@ -32,28 +33,25 @@ namespace contrib { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") - .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); static std::once_flag flag; - std::call_once(flag, - []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); - DLTensor *input = args[0]; - DLTensor *kernel = args[1]; - DLTensor *bias = nullptr; + std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor* input = args[0]; + DLTensor* kernel = args[1]; + DLTensor* bias = nullptr; if (args[2].type_code() == kTVMDLTensorHandle) { bias = args[2]; } - DLTensor *output = args[3]; - uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], - pad_left = args[7]; + DLTensor* output = args[3]; + uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; uint64_t stride_width = args[8], stride_height = args[9]; nnp_size stride_size{stride_width, stride_height}; NNPackConfig(args[10]); uint64_t algo_ = args[11]; - nnp_convolution_algorithm algo = - static_cast(algo_); + nnp_convolution_algorithm algo = static_cast(algo_); CHECK_EQ(input->ndim, 4); CHECK_EQ(kernel->ndim, 4); if (bias) { @@ -93,10 +91,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") size_t workspace_size = 0; nnp_status status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_compute, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_compute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr, + nullptr, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); // Division with rounding up, in case size is not multiple of sizeof(float) @@ -107,24 +104,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") DeviceAPI* cpu_api = DeviceAPI::Get(ctx); void* workspace_buffer = - cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); + cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); CHECK(workspace_buffer != nullptr); for (auto n = 0; n < input->shape[0]; ++n) { nnp_status status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_compute, input_channels, - output_channels, input_size, input_padding, kernel_size, - stride_size, - static_cast(input->data) + n * input->shape[1] * - input->shape[2] * - input->shape[3], - static_cast(kernel->data), - bias ? static_cast(bias->data) : zero_bias->data(), - static_cast(output->data) + n * output->shape[1] * - output->shape[2] * - output->shape[3], - workspace_buffer, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_compute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, + static_cast(input->data) + + n * input->shape[1] * input->shape[2] * input->shape[3], + static_cast(kernel->data), + bias ? static_cast(bias->data) : zero_bias->data(), + static_cast(output->data) + + n * output->shape[1] * output->shape[2] * output->shape[3], + workspace_buffer, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, + nullptr); CHECK_EQ(status, nnp_status_success); } @@ -132,28 +126,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") }); TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform") - .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); static std::once_flag flag; - std::call_once(flag, - []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); - DLTensor *input = args[0]; - DLTensor *transformed_kernel = args[1]; - DLTensor *bias = nullptr; + std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor* input = args[0]; + DLTensor* transformed_kernel = args[1]; + DLTensor* bias = nullptr; if (args[2].type_code() == kTVMDLTensorHandle) { bias = args[2]; } - DLTensor *output = args[3]; - uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], - pad_left = args[7]; + DLTensor* output = args[3]; + uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; uint64_t stride_width = args[8], stride_height = args[9]; nnp_size stride_size{stride_width, stride_height}; NNPackConfig(args[10]); uint64_t algo_ = args[11]; - nnp_convolution_algorithm algo = - static_cast(algo_); + nnp_convolution_algorithm algo = static_cast(algo_); CHECK_EQ(input->ndim, 4); if (bias) { CHECK_EQ(bias->ndim, 1); @@ -189,10 +180,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra size_t workspace_size = 0; nnp_status status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_reuse, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr, + nullptr, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); // Division with rounding up, in case size is not multiple of sizeof(float) @@ -203,38 +193,34 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra DeviceAPI* cpu_api = DeviceAPI::Get(ctx); void* workspace_buffer = - cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); + cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); CHECK(workspace_buffer != nullptr); for (auto n = 0; n < input->shape[0]; ++n) { nnp_status status = nnp_convolution_inference( algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, input_size, input_padding, kernel_size, stride_size, - static_cast(input->data) + n * input->shape[1] * - input->shape[2] * - input->shape[3], - static_cast(transformed_kernel->data), - bias ? static_cast(bias->data) : zero_bias->data(), - static_cast(output->data) + n * output->shape[1] * - output->shape[2] * - output->shape[3], - workspace_buffer, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + static_cast(input->data) + + n * input->shape[1] * input->shape[2] * input->shape[3], + static_cast(transformed_kernel->data), + bias ? static_cast(bias->data) : zero_bias->data(), + static_cast(output->data) + + n * output->shape[1] * output->shape[2] * output->shape[3], + workspace_buffer, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, + nullptr); CHECK_EQ(status, nnp_status_success); } cpu_api->FreeWorkspace(ctx, workspace_buffer); }); -TVM_REGISTER_GLOBAL( - "tvm.contrib.nnpack.convolution_inference_weight_transform") - .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); +TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_weight_transform") + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); static std::once_flag flag; - std::call_once(flag, - []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); - DLTensor *kernel = args[0]; - DLTensor *transformed_kernel = args[1]; + std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor* kernel = args[0]; + DLTensor* transformed_kernel = args[1]; // Dummy sizes nnp_padding input_padding{1, 1, 1, 1}; nnp_size stride_size{1, 1}; @@ -244,8 +230,7 @@ TVM_REGISTER_GLOBAL( NNPackConfig(args[2]); uint64_t algo_ = args[3]; - nnp_convolution_algorithm algo = - static_cast(algo_); + nnp_convolution_algorithm algo = static_cast(algo_); CHECK_EQ(kernel->ndim, 4); size_t input_channels = kernel->shape[1]; size_t output_channels = kernel->shape[0]; @@ -259,21 +244,20 @@ TVM_REGISTER_GLOBAL( size_t transformed_kernel_size = 0; nnp_status status; status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_precompute, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, nullptr, nullptr, nullptr, nullptr, &transformed_kernel_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_precompute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr, + nullptr, &transformed_kernel_size, nnp_activation_identity, nullptr, entry->threadpool, + nullptr); CHECK_EQ(status, nnp_status_success); CHECK_LE(transformed_kernel_size, GetDataSize(*transformed_kernel)); status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_precompute, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, static_cast(kernel->data), nullptr, nullptr, - static_cast(transformed_kernel->data), - &transformed_kernel_size, nnp_activation_identity, nullptr, - entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_precompute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, + static_cast(kernel->data), nullptr, nullptr, + static_cast(transformed_kernel->data), &transformed_kernel_size, + nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); }); } // namespace contrib diff --git a/src/runtime/contrib/nnpack/fully_connected.cc b/src/runtime/contrib/nnpack/fully_connected.cc index 5f111ef..543d239 100644 --- a/src/runtime/contrib/nnpack/fully_connected.cc +++ b/src/runtime/contrib/nnpack/fully_connected.cc @@ -20,10 +20,11 @@ /*! * \file Use external nnpack library call. */ -#include -#include #include #include +#include +#include + #include "nnpack_utils.h" namespace tvm { @@ -33,33 +34,30 @@ using namespace runtime; // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference") -.set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); - nnp_initialize(); - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; - NNPackConfig(args[3]); + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); + nnp_initialize(); + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + NNPackConfig(args[3]); - CHECK_EQ(A->ndim, 1); - CHECK_EQ(B->ndim, 2); - CHECK_EQ(C->ndim, 1); - CHECK_EQ(B->shape[0], C->shape[0]); - CHECK_EQ(B->shape[1], A->shape[0]); - CHECK(C->strides == nullptr); - CHECK(B->strides == nullptr); - CHECK(A->strides == nullptr); - CHECK(TypeMatch(A->dtype, kDLFloat, 32)); - CHECK(TypeMatch(B->dtype, kDLFloat, 32)); - CHECK(TypeMatch(C->dtype, kDLFloat, 32)); + CHECK_EQ(A->ndim, 1); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 1); + CHECK_EQ(B->shape[0], C->shape[0]); + CHECK_EQ(B->shape[1], A->shape[0]); + CHECK(C->strides == nullptr); + CHECK(B->strides == nullptr); + CHECK(A->strides == nullptr); + CHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CHECK(TypeMatch(B->dtype, kDLFloat, 32)); + CHECK(TypeMatch(C->dtype, kDLFloat, 32)); - nnp_fully_connected_inference(B->shape[1], - B->shape[0], - static_cast(A->data), - static_cast(B->data), - static_cast(C->data), - entry->threadpool); - }); + nnp_fully_connected_inference(B->shape[1], B->shape[0], static_cast(A->data), + static_cast(B->data), static_cast(C->data), + entry->threadpool); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/nnpack/nnpack_utils.cc b/src/runtime/contrib/nnpack/nnpack_utils.cc index f01ad85..91cf865 100644 --- a/src/runtime/contrib/nnpack/nnpack_utils.cc +++ b/src/runtime/contrib/nnpack/nnpack_utils.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,13 +28,12 @@ using namespace runtime; typedef dmlc::ThreadLocalStore NNPackThreadLocalStore; - NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() { return NNPackThreadLocalStore::Get(); } bool NNPackConfig(uint64_t nthreads) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); if (entry->threadpool && pthreadpool_get_threads_count(entry->threadpool) == nthreads) { CHECK_NE(nthreads, 1); return true; @@ -55,11 +54,9 @@ bool NNPackConfig(uint64_t nthreads) { return true; } - -TVM_REGISTER_GLOBAL("contrib.nnpack._initialize") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = nnp_initialize(); - }); +TVM_REGISTER_GLOBAL("contrib.nnpack._initialize").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = nnp_initialize(); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/nnpack/nnpack_utils.h b/src/runtime/contrib/nnpack/nnpack_utils.h index 4ba586f..bbb0d16 100644 --- a/src/runtime/contrib/nnpack/nnpack_utils.h +++ b/src/runtime/contrib/nnpack/nnpack_utils.h @@ -22,11 +22,11 @@ */ #ifndef TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ #define TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ -#include -#include -#include #include +#include #include +#include +#include namespace tvm { namespace contrib { diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 37166e2..c628e32 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,7 @@ * \brief mt19937 random engine */ #include + #include #include #include @@ -34,45 +35,37 @@ namespace contrib { */ class RandomEngine { public: - /*! - * \brief Creates a RandomEngine using a default seed. - */ - RandomEngine() { - this->Seed(time(0)); - } - - /*! - * \brief Creates a RandomEngine, suggesting the use of a provided seed. - */ - explicit RandomEngine(unsigned seed) { - this->Seed(seed); - } - - /*! - * \brief Seeds the underlying RNG, if possible. - */ + /*! + * \brief Creates a RandomEngine using a default seed. + */ + RandomEngine() { this->Seed(time(0)); } + + /*! + * \brief Creates a RandomEngine, suggesting the use of a provided seed. + */ + explicit RandomEngine(unsigned seed) { this->Seed(seed); } + + /*! + * \brief Seeds the underlying RNG, if possible. + */ inline void Seed(unsigned seed) { rnd_engine_.seed(seed); this->rseed_ = static_cast(seed); } - /*! - * \return the seed associated with the underlying RNG. - */ - inline unsigned GetSeed() const { - return rseed_; - } + /*! + * \return the seed associated with the underlying RNG. + */ + inline unsigned GetSeed() const { return rseed_; } - /*! - * \return a random integer sampled from the RNG. - */ - inline unsigned GetRandInt() { - return rnd_engine_(); - } + /*! + * \return a random integer sampled from the RNG. + */ + inline unsigned GetRandInt() { return rnd_engine_(); } - /*! - * \brief Fills a tensor with values drawn from Unif(low, high) - */ + /*! + * \brief Fills a tensor with values drawn from Unif(low, high) + */ void SampleUniform(DLTensor* data, float low, float high) { CHECK_GT(high, low) << "high must be bigger than low"; CHECK(data->strides == nullptr); @@ -87,17 +80,16 @@ class RandomEngine { if (data->ctx.device_type == kDLCPU) { std::uniform_real_distribution uniform_dist(low, high); - std::generate_n(static_cast(data->data), size, [&] () { - return uniform_dist(rnd_engine_); - }); + std::generate_n(static_cast(data->data), size, + [&]() { return uniform_dist(rnd_engine_); }); } else { LOG(FATAL) << "Do not support random.uniform on this device yet"; } } - /*! - * \brief Fills a tensor with values drawn from Normal(loc, scale**2) - */ + /*! + * \brief Fills a tensor with values drawn from Normal(loc, scale**2) + */ void SampleNormal(DLTensor* data, float loc, float scale) { CHECK_GT(scale, 0) << "standard deviation must be positive"; CHECK(data->strides == nullptr); @@ -112,9 +104,8 @@ class RandomEngine { if (data->ctx.device_type == kDLCPU) { std::normal_distribution normal_dist(loc, scale); - std::generate_n(static_cast(data->data), size, [&] () { - return normal_dist(rnd_engine_); - }); + std::generate_n(static_cast(data->data), size, + [&]() { return normal_dist(rnd_engine_); }); } else { LOG(FATAL) << "Do not support random.normal on this device yet"; } diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 8ae1f86..acba193 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -20,32 +20,34 @@ /*! * \file External random functions for tensor. */ -#include -#include #include #include +#include +#include + #include + #include "mt_random_engine.cc" #define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \ if (type.code == kDLInt && type.bits == 32) { \ typedef int32_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLInt && type.bits == 16) { \ typedef int16_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLInt && type.bits == 8) { \ typedef int8_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLUInt && type.bits == 32) { \ typedef uint32_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLUInt && type.bits == 16) { \ typedef uint16_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLUInt && type.bits == 8) { \ typedef uint8_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else { \ LOG(FATAL) << "unknown data type"; \ } @@ -66,61 +68,54 @@ RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() { return RandomThreadLocalStore::Get(); } +TVM_REGISTER_GLOBAL("tvm.contrib.random.randint").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + int64_t low = args[0]; + int64_t high = args[1]; + DLTensor* out = args[2]; + CHECK_GT(high, low) << "high must be bigger than low"; + CHECK(out->strides == nullptr); + + DLDataType dtype = out->dtype; + int64_t size = 1; + for (int i = 0; i < out->ndim; ++i) { + size *= out->shape[i]; + } -TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") -.set_body([](TVMArgs args, TVMRetValue *ret) { - RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); - int64_t low = args[0]; - int64_t high = args[1]; - DLTensor* out = args[2]; - CHECK_GT(high, low) << "high must be bigger than low"; - CHECK(out->strides == nullptr); - - DLDataType dtype = out->dtype; - int64_t size = 1; - for (int i = 0; i < out->ndim; ++i) { - size *= out->shape[i]; + DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, { + int64_t numeric_low = std::numeric_limits::min(); + int64_t numeric_high = std::numeric_limits::max(); + numeric_high += 1; // exclusive upper bound + low = std::max(low, numeric_low); + high = std::min(high, numeric_high); + + if (out->ctx.device_type == kDLCPU) { + // file the data with random byte + std::generate_n(static_cast(out->data), size, [&]() { + unsigned rint = entry->random_engine.GetRandInt(); + return low + rint % (high - low); + }); + } else { + LOG(FATAL) << "Do not support random.randint on this device yet"; } - - DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, { - int64_t numeric_low = std::numeric_limits::min(); - int64_t numeric_high = std::numeric_limits::max(); - numeric_high += 1; // exclusive upper bound - low = std::max(low, numeric_low); - high = std::min(high, numeric_high); - - if (out->ctx.device_type == kDLCPU) { - // file the data with random byte - std::generate_n(static_cast(out->data), size, [&] () { - unsigned rint = entry->random_engine.GetRandInt(); - return low + rint % (high - low); - }); - } else { - LOG(FATAL) << "Do not support random.randint on this device yet"; - } - }) - }); - - -TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") -.set_body([](TVMArgs args, TVMRetValue *ret) { - RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); - double low = args[0]; - double high = args[1]; - DLTensor* out = args[2]; - entry->random_engine.SampleUniform(out, low, high); - }); - - -TVM_REGISTER_GLOBAL("tvm.contrib.random.normal") -.set_body([](TVMArgs args, TVMRetValue *ret) { - RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); - double loc = args[0]; - double scale = args[1]; - DLTensor* out = args[2]; - entry->random_engine.SampleNormal(out, loc, scale); - }); - + }) +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + double low = args[0]; + double high = args[1]; + DLTensor* out = args[2]; + entry->random_engine.SampleUniform(out, low, high); +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + double loc = args[0]; + double scale = args[1]; + DLTensor* out = args[2]; + entry->random_engine.SampleNormal(out, loc, scale); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index dda4ee3..0e6f4bd 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -20,75 +20,68 @@ /*! * \file Use external rocblas library call. */ -#include -#include -#include #include "rocblas.h" +#include +#include +#include + namespace tvm { namespace contrib { using namespace runtime; #ifndef CHECK_ROCBLAS_ERROR -#define CHECK_ROCBLAS_ERROR(error) \ -if (error != rocblas_status_success) { \ - fprintf(stderr, "rocBLAS error: "); \ - if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \ - if (error == rocblas_status_not_implemented) fprintf(stderr, " rocblas_status_not_implemented"); \ - if (error == rocblas_status_invalid_pointer) fprintf(stderr, "rocblas_status_invalid_pointer"); \ - if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \ - if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \ - if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \ - fprintf(stderr, "\n"); \ - exit(EXIT_FAILURE); \ -} +#define CHECK_ROCBLAS_ERROR(error) \ + if (error != rocblas_status_success) { \ + fprintf(stderr, "rocBLAS error: "); \ + if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \ + if (error == rocblas_status_not_implemented) \ + fprintf(stderr, " rocblas_status_not_implemented"); \ + if (error == rocblas_status_invalid_pointer) \ + fprintf(stderr, "rocblas_status_invalid_pointer"); \ + if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \ + if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \ + if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \ + fprintf(stderr, "\n"); \ + exit(EXIT_FAILURE); \ + } #endif - // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; - bool transa = args[3]; - bool transb = args[4]; - // call gemm for simple compact code. - CHECK_EQ(A->ndim, 2); - CHECK_EQ(B->ndim, 2); - CHECK_EQ(C->ndim, 2); - CHECK(C->strides == nullptr); - CHECK(B->strides == nullptr); - CHECK(A->strides == nullptr); - CHECK(TypeMatch(A->dtype, kDLFloat, 32)); - CHECK(TypeMatch(B->dtype, kDLFloat, 32)); - CHECK(TypeMatch(C->dtype, kDLFloat, 32)); +TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + // call gemm for simple compact code. + CHECK_EQ(A->ndim, 2); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 2); + CHECK(C->strides == nullptr); + CHECK(B->strides == nullptr); + CHECK(A->strides == nullptr); + CHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CHECK(TypeMatch(B->dtype, kDLFloat, 32)); + CHECK(TypeMatch(C->dtype, kDLFloat, 32)); - rocblas_handle handle; - CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); - float alpha = 1.0; - float beta = 0.0; - float *A_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); - float *B_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); - float *C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); + rocblas_handle handle; + CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); + float alpha = 1.0; + float beta = 0.0; + float* A_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); + float* B_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); + float* C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); - CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle, - transb ? rocblas_operation_transpose : rocblas_operation_none, - transa ? rocblas_operation_transpose : rocblas_operation_none, - transb ? B->shape[0] : B->shape[1], - transa ? A->shape[1] : A->shape[0], - transb ? B->shape[1] : B->shape[0], - &alpha, - A_ptr, - B->shape[1], - B_ptr, - A->shape[1], - &beta, - C_ptr, - C->shape[1])); + CHECK_ROCBLAS_ERROR( + rocblas_sgemm(handle, transb ? rocblas_operation_transpose : rocblas_operation_none, + transa ? rocblas_operation_transpose : rocblas_operation_none, + transb ? B->shape[0] : B->shape[1], transa ? A->shape[1] : A->shape[0], + transb ? B->shape[1] : B->shape[0], &alpha, A_ptr, B->shape[1], B_ptr, + A->shape[1], &beta, C_ptr, C->shape[1])); - CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); + CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 0c9c575..9543e4b 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -21,8 +21,9 @@ * \file Use standard C library call. */ -#include #include +#include + #include #include @@ -31,19 +32,16 @@ namespace contrib { using namespace runtime; -template -bool CompareAscend(const std::pair& lhs, - const std::pair& rhs) { +template +bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { return lhs.second < rhs.second; } -template -bool CompareDescend(const std::pair& lhs, - const std::pair& rhs) { +template +bool CompareDescend(const std::pair& lhs, const std::pair& rhs) { return lhs.second > rhs.second; } - // Argsort implemented C library sort for nms. // Return indices of sorted tensor. // By default, the last axis will be used to sort. @@ -51,17 +49,16 @@ bool CompareDescend(const std::pair& lhs, // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *input = args[0]; - DLTensor *sort_num = args[1]; - DLTensor *output = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* sort_num = args[1]; + DLTensor* output = args[2]; int32_t axis = args[3]; bool is_ascend = args[4]; auto dtype = input->dtype; - auto data_ptr = static_cast(input->data); - auto sort_num_ptr = static_cast(sort_num->data); + auto data_ptr = static_cast(input->data); + auto sort_num_ptr = static_cast(sort_num->data); std::vector> sorter; int64_t axis_mul_before = 1; int64_t axis_mul_after = 1; @@ -72,13 +69,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") // Currently only supports input dtype to be float32. CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " - "to be float."; + "to be float."; #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " - "to be float32."; + "to be float32."; #endif CHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " << input->ndim; + "input ndim " + << input->ndim; for (int i = 0; i < input->ndim; ++i) { if (i < axis) { @@ -88,8 +86,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") } } - for (int64_t i = 0 ; i < axis_mul_before; ++i) { - for (int64_t j = 0 ; j < axis_mul_after; ++j) { + for (int64_t i = 0; i < axis_mul_before; ++i) { + for (int64_t j = 0; j < axis_mul_after; ++j) { sorter.clear(); int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j); int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; @@ -103,7 +101,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); } else { #endif - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif @@ -113,24 +111,24 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); } else { #endif - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif } for (int32_t k = 0; k < input->shape[axis]; ++k) { - *(static_cast(output->data) + base_idx + k * axis_mul_after) - = k < static_cast(sorter.size()) ? sorter[k].first : k; + *(static_cast(output->data) + base_idx + k * axis_mul_after) = + k < static_cast(sorter.size()) ? sorter[k].first : k; } } } }); -template +template void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { - auto data_ptr = static_cast(input->data); - auto out_ptr = static_cast(output->data); - std::vector > sorter; + auto data_ptr = static_cast(input->data); + auto out_ptr = static_cast(output->data); + std::vector> sorter; int axis_mul_before = 1; int axis_mul_after = 1; @@ -142,8 +140,8 @@ void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { } } - for (int i = 0 ; i < axis_mul_before; ++i) { - for (int j = 0 ; j < axis_mul_after; ++j) { + for (int i = 0; i < axis_mul_before; ++i) { + for (int j = 0; j < axis_mul_after; ++j) { sorter.clear(); int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; for (int64_t k = 0; k < input->shape[axis]; ++k) { @@ -169,17 +167,17 @@ void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *input = args[0]; - DLTensor *output = args[1]; +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* output = args[1]; int32_t axis = args[2]; bool is_ascend = args[3]; if (axis < 0) { axis = input->ndim + axis; } CHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " << input->ndim; + "input ndim " + << input->ndim; auto data_dtype = DLDataType2String(input->dtype); auto out_dtype = DLDataType2String(output->dtype); @@ -228,7 +226,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - } else if (data_dtype == "int64") { + } else if (data_dtype == "int64") { if (out_dtype == "int32") { argsort(input, output, axis, is_ascend); } else if (out_dtype == "int64") { @@ -245,19 +243,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } }); -template -void topk(DLTensor* input, - DLTensor* out_values, - DLTensor* out_indices, - int k, - int axis, +template +void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis, bool is_ascend) { - DataType* data_ptr = static_cast(input->data); - DataType* values_ptr = (out_values == nullptr) ? nullptr : - static_cast(out_values->data); - IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : - static_cast(out_indices->data); - std::vector > sorter; + DataType* data_ptr = static_cast(input->data); + DataType* values_ptr = + (out_values == nullptr) ? nullptr : static_cast(out_values->data); + IndicesType* indices_ptr = + (out_indices == nullptr) ? nullptr : static_cast(out_indices->data); + std::vector> sorter; int axis_mul_before = 1; int axis_mul_after = 1; @@ -272,8 +266,8 @@ void topk(DLTensor* input, k = input->shape[axis]; } - for (int i = 0 ; i < axis_mul_before; ++i) { - for (int j = 0 ; j < axis_mul_after; ++j) { + for (int i = 0; i < axis_mul_before; ++i) { + for (int j = 0; j < axis_mul_after; ++j) { sorter.clear(); int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; int64_t dst_base_idx = i * k * axis_mul_after + j; @@ -290,11 +284,10 @@ void topk(DLTensor* input, for (int64_t kk = 0; kk < cnt; ++kk) { if (indices_ptr != nullptr) { indices_ptr[dst_base_idx + kk * axis_mul_after] = - static_cast(sorter[kk].first); + static_cast(sorter[kk].first); } if (values_ptr != nullptr) { - values_ptr[dst_base_idx + kk * axis_mul_after] = - static_cast(sorter[kk].second); + values_ptr[dst_base_idx + kk * axis_mul_after] = static_cast(sorter[kk].second); } } } @@ -308,8 +301,7 @@ void topk(DLTensor* input, // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* input = args[0]; DLTensor* values_out = nullptr; DLTensor* indices_out = nullptr; @@ -371,7 +363,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - } else if (data_dtype == "int64") { + } else if (data_dtype == "int64") { if (out_dtype == "int32") { topk(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "int64") { diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 56d3ce9..53d7754 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -20,53 +20,52 @@ /*! * \file tflite_runtime.cc */ -#include +#include "tflite_runtime.h" + #include #include #include - - -#include "tflite_runtime.h" +#include namespace tvm { namespace runtime { -#define TVM_DTYPE_DISPATCH(type, DType, ...) \ - if (type == DataType::Float(64)) { \ - typedef double DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(32)) { \ - typedef float DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(64)) { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(32)) { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(16)) { \ - typedef int16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(8)) { \ - typedef int8_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(64)) { \ - typedef uint64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(32)) { \ - typedef uint32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(8)) { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } else { \ - LOG(FATAL) << "unknown data type " << type; \ +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == DataType::Float(64)) { \ + typedef double DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(32)) { \ + typedef float DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(64)) { \ + typedef int64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(32)) { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(16)) { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(8)) { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(64)) { \ + typedef uint64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(32)) { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(8)) { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } else { \ + LOG(FATAL) << "unknown data type " << type; \ } DataType TfLiteDType2TVMDType(TfLiteType dtype) { @@ -91,12 +90,11 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { } } -void TFLiteRuntime::Init(const std::string& tflite_model_bytes, - TVMContext ctx) { +void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { const char* buffer = tflite_model_bytes.c_str(); size_t buffer_size = tflite_model_bytes.size(); std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); tflite::ops::builtin::BuiltinOpResolver resolver; // Build interpreter TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_); @@ -108,24 +106,22 @@ void TFLiteRuntime::Init(const std::string& tflite_model_bytes, ctx_ = ctx; } -void TFLiteRuntime::Invoke() { - interpreter_->Invoke(); -} +void TFLiteRuntime::Invoke() { interpreter_->Invoke(); } void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { DataType dtype(data_in->dtype); TVM_DTYPE_DISPATCH(dtype, DType, { - DType* dest = interpreter_->typed_input_tensor(index); - DType* src = static_cast(data_in->data); - CHECK(data_in->strides == NULL); - int64_t size = 1; - for (int64_t i = 0; i < data_in->ndim; ++i) { - size *= data_in->shape[i]; - } - for (int64_t i = 0; i < size; ++i) { - dest[i] = src[i]; - } - }); + DType* dest = interpreter_->typed_input_tensor(index); + DType* src = static_cast(data_in->data); + CHECK(data_in->strides == NULL); + int64_t size = 1; + for (int64_t i = 0; i < data_in->ndim; ++i) { + size *= data_in->shape[i]; + } + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); } NDArray TFLiteRuntime::GetOutput(int index) const { @@ -140,48 +136,42 @@ NDArray TFLiteRuntime::GetOutput(int index) const { } NDArray ret = NDArray::Empty(shape, dtype, ctx_); TVM_DTYPE_DISPATCH(dtype, DType, { - DType* dest = static_cast(ret->data); - DType* src = interpreter_->typed_output_tensor(index); - for (int64_t i = 0; i < size; ++i) { - dest[i] = src[i]; - } - }); + DType* dest = static_cast(ret->data); + DType* src = interpreter_->typed_output_tensor(index); + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); return ret; } -PackedFunc TFLiteRuntime::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc TFLiteRuntime::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - int in_idx = args[0]; - CHECK_GE(in_idx, 0); - this->SetInput(in_idx, args[1]); - }); + int in_idx = args[0]; + CHECK_GE(in_idx, 0); + this->SetInput(in_idx, args[1]); + }); } else if (name == "get_output") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetOutput(args[0]); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); }); } else if (name == "invoke") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Invoke(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); }); } else { return PackedFunc(); } } -Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, - TVMContext ctx) { +Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx) { auto exec = make_object(); exec->Init(tflite_model_bytes, ctx); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = TFLiteRuntimeCreate(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = TFLiteRuntimeCreate(args[0], args[1]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index d823690..f61f6ee 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -29,9 +29,9 @@ #include #include -#include -#include #include +#include +#include namespace tvm { namespace runtime { @@ -52,18 +52,15 @@ class TFLiteRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. */ - const char* type_key() const { - return "TFLiteRuntime"; - } + const char* type_key() const { return "TFLiteRuntime"; } /*! - * \brief Invoke the internal tflite interpreter and run the whole model in + * \brief Invoke the internal tflite interpreter and run the whole model in * dependency order. */ void Invoke(); @@ -73,8 +70,7 @@ class TFLiteRuntime : public ModuleNode { * \param tflite_model_bytes The tflite model. * \param ctx The context where the tflite model will be executed on. */ - void Init(const std::string& tflite_model_bytes, - TVMContext ctx); + void Init(const std::string& tflite_model_bytes, TVMContext ctx); /*! * \brief set index-th input to the model. diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 920bdae..c70a4f2 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -22,10 +22,12 @@ */ #include #include -#include #include +#include + #include #include + #include "workspace_pool.h" #ifdef __ANDROID__ @@ -42,9 +44,7 @@ class CPUDeviceAPI final : public DeviceAPI { *rv = 1; } } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { void* ptr; #if _MSC_VER @@ -69,53 +69,38 @@ class CPUDeviceAPI final : public DeviceAPI { #endif } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { - memcpy(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); + memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); } - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {} void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; void FreeWorkspace(TVMContext ctx, void* data) final; static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } }; struct CPUWorkspacePool : public WorkspacePool { - CPUWorkspacePool() : - WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {} + CPUWorkspacePool() : WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {} }; -void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { - return dmlc::ThreadLocalStore::Get() - ->AllocWorkspace(ctx, size); +void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { + return dmlc::ThreadLocalStore::Get()->AllocWorkspace(ctx, size); } void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { dmlc::ThreadLocalStore::Get()->FreeWorkspace(ctx, data); } -TVM_REGISTER_GLOBAL("device_api.cpu") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CPUDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.cpu").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CPUDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/crt/graph_runtime.h b/src/runtime/crt/graph_runtime.h index 3cb8ba9..fd3b146 100644 --- a/src/runtime/crt/graph_runtime.h +++ b/src/runtime/crt/graph_runtime.h @@ -27,9 +27,9 @@ #include #include "load_json.h" +#include "module.h" #include "ndarray.h" #include "packed_func.h" -#include "module.h" /*! \brief operator attributes about tvm op */ typedef struct TVMOpParam { @@ -51,7 +51,7 @@ typedef struct TVMGraphRuntimeNodeEntry { uint32_t index; uint32_t version; // JSON Loader - void (*Load)(JSONReader *reader); + void (*Load)(JSONReader* reader); } TVMGraphRuntimeNodeEntry; // Node @@ -63,26 +63,26 @@ typedef struct TVMGraphRuntimeNode { // parameters TVMOpParam param; // inputs - TVMGraphRuntimeNodeEntry * inputs; + TVMGraphRuntimeNodeEntry* inputs; // number of inputs size_t inputs_count; // control deps uint32_t control_deps[20]; // JSON Loader - void (*LoadAttrs)(struct TVMGraphRuntimeNode * node, JSONReader *reader, TVMOpParam* param); + void (*LoadAttrs)(struct TVMGraphRuntimeNode* node, JSONReader* reader, TVMOpParam* param); // JSON Loader - int (*Load)(struct TVMGraphRuntimeNode * node, JSONReader *reader); + int (*Load)(struct TVMGraphRuntimeNode* node, JSONReader* reader); } TVMGraphRuntimeNode; // Graph attribute typedef struct TVMGraphRuntimeGraphAttr { uint32_t storage_num_not_alloctaed; - uint32_t * storage_id; - uint32_t * device_index; - char * dltype; // "int8", "int16", "float32" + uint32_t* storage_id; + uint32_t* device_index; + char* dltype; // "int8", "int16", "float32" uint32_t dltype_count; - int64_t * shape; - uint32_t * ndim; + int64_t* shape; + uint32_t* ndim; uint32_t shape_count; } TVMGraphRuntimeGraphAttr; @@ -96,7 +96,7 @@ typedef DLTensor* DLTensorPtr; */ /* class GraphRuntime : public ModuleNode { */ typedef struct TVMGraphRuntime { - void (*Run)(struct TVMGraphRuntime * runtime); + void (*Run)(struct TVMGraphRuntime* runtime); /*! * \brief Initialize the graph executor with graph and context. @@ -107,10 +107,8 @@ typedef struct TVMGraphRuntime { * \param ctxs The context of the host and devices where graph nodes will be * executed on. */ - void (*Init)(struct TVMGraphRuntime * runtime, - const char * graph_json, - const TVMModule * module, - const TVMContext * ctxs); + void (*Init)(struct TVMGraphRuntime* runtime, const char* graph_json, const TVMModule* module, + const TVMContext* ctxs); /*! * \brief Get the input index given the name of input. @@ -118,7 +116,7 @@ typedef struct TVMGraphRuntime { * \param name The name of the input. * \return The index of input. */ - int (*GetInputIndex)(struct TVMGraphRuntime * runtime, const char * name); + int (*GetInputIndex)(struct TVMGraphRuntime* runtime, const char* name); /*! * \brief set input to the graph based on name. @@ -126,7 +124,7 @@ typedef struct TVMGraphRuntime { * \param name The name of the input. * \param data_in The input data. */ - void (*SetInput)(struct TVMGraphRuntime * runtime, const char * name, DLTensor* data_in); + void (*SetInput)(struct TVMGraphRuntime* runtime, const char* name, DLTensor* data_in); /*! * \brief Return NDArray for given output index. @@ -135,7 +133,7 @@ typedef struct TVMGraphRuntime { * \param out The DLTensor corresponding to given output node index. * \return The result of this function execution. */ - int (*GetOutput)(struct TVMGraphRuntime * runtime, const int32_t index, DLTensor * out); + int (*GetOutput)(struct TVMGraphRuntime* runtime, const int32_t index, DLTensor* out); /*! * \brief Load parameters from parameter blob. * \param runtime The graph runtime. @@ -143,15 +141,15 @@ typedef struct TVMGraphRuntime { * \param param_size The parameter size. * \return The result of this function execution. */ - int (*LoadParams)(struct TVMGraphRuntime * runtime, const char * param_blob, + int (*LoadParams)(struct TVMGraphRuntime* runtime, const char* param_blob, const uint32_t param_size); // The graph attribute fields. - int (*Load)(struct TVMGraphRuntime * runtime, JSONReader *reader); + int (*Load)(struct TVMGraphRuntime* runtime, JSONReader* reader); /*! \brief Setup the temporal storage */ - void (*SetupStorage)(struct TVMGraphRuntime * runtime); + void (*SetupStorage)(struct TVMGraphRuntime* runtime); /*! \brief Setup the executors. */ - int (*SetupOpExecs)(struct TVMGraphRuntime * runtime); + int (*SetupOpExecs)(struct TVMGraphRuntime* runtime); /*! * \brief Create an execution function given input. @@ -163,25 +161,25 @@ typedef struct TVMGraphRuntime { * \param pf The created executor. * \return The result of this function execution. */ - int32_t (*CreateTVMOp)(struct TVMGraphRuntime * runtime, const TVMOpParam * attrs, - DLTensorPtr * args, const uint32_t args_count, - uint32_t num_inputs, TVMPackedFunc * pf); + int32_t (*CreateTVMOp)(struct TVMGraphRuntime* runtime, const TVMOpParam* attrs, + DLTensorPtr* args, const uint32_t args_count, uint32_t num_inputs, + TVMPackedFunc* pf); // Get node entry index. - uint32_t (*GetEntryId)(struct TVMGraphRuntime * runtime, uint32_t nid, uint32_t index); + uint32_t (*GetEntryId)(struct TVMGraphRuntime* runtime, uint32_t nid, uint32_t index); /*! \brief The graph nodes. */ - TVMGraphRuntimeNode * nodes; + TVMGraphRuntimeNode* nodes; /*! \brief The graph nodes counter. */ uint32_t nodes_count; /*! \brief The argument nodes. */ - uint32_t * input_nodes; + uint32_t* input_nodes; uint32_t input_nodes_count; /*! \brief Used for quick entry indexing. */ - uint32_t * node_row_ptr; + uint32_t* node_row_ptr; uint32_t node_row_ptr_count; /*! \brief Output entries. */ - TVMGraphRuntimeNodeEntry * outputs; + TVMGraphRuntimeNodeEntry* outputs; /*! \brief Output entries counter. */ uint32_t outputs_count; /*! \brief Additional graph attributes. */ @@ -190,28 +188,28 @@ typedef struct TVMGraphRuntime { TVMModule module; /*! \brief Execution context of all devices including the host. */ TVMContext ctxs[1]; - uint32_t ctxs_count; + uint32_t ctxs_count; /*! \brief Common storage pool for all devices. */ - TVMNDArray * storage_pool; + TVMNDArray* storage_pool; uint32_t storage_pool_count; /*! \brief Data entry of each node. */ - TVMNDArray * data_entry; + TVMNDArray* data_entry; uint32_t data_entry_count; /*! \brief Operator on each node. */ - TVMPackedFunc * op_execs; + TVMPackedFunc* op_execs; uint32_t op_execs_count; } TVMGraphRuntime; // public functions -TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, const TVMModule * m, - const TVMContext * ctxs); -void TVMGraphRuntimeRelease(TVMGraphRuntime ** runtime); +TVMGraphRuntime* TVMGraphRuntimeCreate(const char* sym_json, const TVMModule* m, + const TVMContext* ctxs); +void TVMGraphRuntimeRelease(TVMGraphRuntime** runtime); // private functions -void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in); -int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob, +void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in); +int TVMGraphRuntime_LoadParams(TVMGraphRuntime* runtime, const char* param_blob, const uint32_t param_size); -void TVMGraphRuntime_Run(TVMGraphRuntime * runtime); -int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTensor * out); +void TVMGraphRuntime_Run(TVMGraphRuntime* runtime); +int TVMGraphRuntime_GetOutput(TVMGraphRuntime* runtime, const int32_t idx, DLTensor* out); #endif // TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_ diff --git a/src/runtime/crt/load_json.h b/src/runtime/crt/load_json.h index a5df7a0..0c93247 100644 --- a/src/runtime/crt/load_json.h +++ b/src/runtime/crt/load_json.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_CRT_LOAD_JSON_H_ #define TVM_RUNTIME_CRT_LOAD_JSON_H_ -#include #include +#include enum { JSON_READ_TYPE_U8 = 1, @@ -42,12 +42,12 @@ enum { }; typedef struct Seq { - uint32_t * data; + uint32_t* data; uint64_t allocated; uint32_t size; - void (*push_back)(struct Seq * seq, uint32_t src); - uint32_t * (*back)(struct Seq * seq); - void (*pop_back)(struct Seq * seq); + void (*push_back)(struct Seq* seq, uint32_t src); + uint32_t* (*back)(struct Seq* seq); + void (*pop_back)(struct Seq* seq); } Seq; /*! @@ -56,8 +56,8 @@ typedef struct Seq { */ typedef struct JSONReader { /*! \brief internal reader string */ - char * is_; - char * isptr; + char* is_; + char* isptr; /*! \brief "\\r" counter */ size_t line_count_r_; /*! \brief "\\n" counter */ @@ -66,27 +66,27 @@ typedef struct JSONReader { * \brief record how many element processed in * current array/object scope. */ - Seq * scope_counter_; + Seq* scope_counter_; - char (*NextChar)(struct JSONReader * reader); - char (*NextNonSpace)(struct JSONReader * reader); - char (*PeekNextChar)(struct JSONReader * reader); - char (*PeekNextNonSpace)(struct JSONReader * reader); - int (*ReadUnsignedInteger)(struct JSONReader * reader, unsigned int * out_value); - int (*ReadInteger)(struct JSONReader * reader, int64_t * out_value); - int (*ReadString)(struct JSONReader * reader, char * out_value); - void (*BeginArray)(struct JSONReader * reader); - void (*BeginObject)(struct JSONReader * reader); - uint8_t (*NextObjectItem)(struct JSONReader * reader, char * out_key); - uint8_t (*NextArrayItem)(struct JSONReader * reader); + char (*NextChar)(struct JSONReader* reader); + char (*NextNonSpace)(struct JSONReader* reader); + char (*PeekNextChar)(struct JSONReader* reader); + char (*PeekNextNonSpace)(struct JSONReader* reader); + int (*ReadUnsignedInteger)(struct JSONReader* reader, unsigned int* out_value); + int (*ReadInteger)(struct JSONReader* reader, int64_t* out_value); + int (*ReadString)(struct JSONReader* reader, char* out_value); + void (*BeginArray)(struct JSONReader* reader); + void (*BeginObject)(struct JSONReader* reader); + uint8_t (*NextObjectItem)(struct JSONReader* reader, char* out_key); + uint8_t (*NextArrayItem)(struct JSONReader* reader); } JSONReader; /*! * \brief Constructor of JSONReader class * \param is the input source. */ -JSONReader JSONReader_Create(const char * is); +JSONReader JSONReader_Create(const char* is); -void JSONReader_Release(JSONReader * reader); +void JSONReader_Release(JSONReader* reader); #endif // TVM_RUNTIME_CRT_LOAD_JSON_H_ diff --git a/src/runtime/crt/logging.h b/src/runtime/crt/logging.h index 2c58834..c711b3a 100644 --- a/src/runtime/crt/logging.h +++ b/src/runtime/crt/logging.h @@ -27,31 +27,31 @@ #define TVM_RUNTIME_CRT_LOGGING_H_ #ifndef CHECK -#define CHECK(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "Check failed: %s\n", #x); \ - exit(-1); \ - } \ - }while(0) +#define CHECK(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "Check failed: %s\n", #x); \ + exit(-1); \ + } \ + } while (0) #endif #ifndef CHECK_BINARY_OP -#define CHECK_BINARY_OP(op, x, y, fmt, ...) \ - do { \ - if (!(x op y)) { \ +#define CHECK_BINARY_OP(op, x, y, fmt, ...) \ + do { \ + if (!(x op y)) { \ fprintf(stderr, "Check failed: %s %s %s: " fmt "\n", #x, #op, #y, ##__VA_ARGS__); \ - exit(-1); \ - } \ - }while(0) + exit(-1); \ + } \ + } while (0) #endif #ifndef CHECK_LT -#define CHECK_LT(x, y, fmt, ...) CHECK_BINARY_OP(<, x, y, fmt, ##__VA_ARGS__) +#define CHECK_LT(x, y, fmt, ...) CHECK_BINARY_OP(<, x, y, fmt, ##__VA_ARGS__) #endif #ifndef CHECK_GT -#define CHECK_GT(x, y, fmt, ...) CHECK_BINARY_OP(>, x, y, fmt, ##__VA_ARGS__) +#define CHECK_GT(x, y, fmt, ...) CHECK_BINARY_OP(>, x, y, fmt, ##__VA_ARGS__) #endif #ifndef CHECK_LE diff --git a/src/runtime/crt/module.h b/src/runtime/crt/module.h index 9ef287d..57f8dd7 100644 --- a/src/runtime/crt/module.h +++ b/src/runtime/crt/module.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_CRT_MODULE_H_ #define TVM_RUNTIME_CRT_MODULE_H_ -#include #include +#include struct TVMPackedFunc; @@ -41,7 +41,7 @@ typedef struct TVMModule { * * This function will return PackedFunc(nullptr) if function do not exist. */ - void (*GetFunction)(struct TVMModule * mod, const char * name, struct TVMPackedFunc * pf); + void (*GetFunction)(struct TVMModule* mod, const char* name, struct TVMPackedFunc* pf); } TVMModule; #endif // TVM_RUNTIME_CRT_MODULE_H_ diff --git a/src/runtime/crt/ndarray.h b/src/runtime/crt/ndarray.h index dde23ca..ae76726 100644 --- a/src/runtime/crt/ndarray.h +++ b/src/runtime/crt/ndarray.h @@ -24,13 +24,12 @@ #ifndef TVM_RUNTIME_CRT_NDARRAY_H_ #define TVM_RUNTIME_CRT_NDARRAY_H_ -#include -#include #include - -#include #include #include +#include +#include +#include /*! \brief Magic number for NDArray file */ static const uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; @@ -42,17 +41,17 @@ typedef struct TVMNDArray { DLTensor dl_tensor; } TVMNDArray; -TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t * shape, - DLDataType dtype, DLContext ctx); +TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype, + DLContext ctx); -TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape, - DLDataType dtype, DLContext ctx); +TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype, + DLContext ctx); -int TVMNDArray_Load(TVMNDArray * ret, const char ** strm); +int TVMNDArray_Load(TVMNDArray* ret, const char** strm); -TVMNDArray TVMNDArray_CreateView(TVMNDArray * arr, const tvm_index_t * shape, - uint32_t ndim, DLDataType dtype); +TVMNDArray TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, uint32_t ndim, + DLDataType dtype); -int TVMNDArray_Release(TVMNDArray * arr); +int TVMNDArray_Release(TVMNDArray* arr); #endif // TVM_RUNTIME_CRT_NDARRAY_H_ diff --git a/src/runtime/crt/packed_func.h b/src/runtime/crt/packed_func.h index 93898a4..d4597e6 100644 --- a/src/runtime/crt/packed_func.h +++ b/src/runtime/crt/packed_func.h @@ -24,29 +24,34 @@ #ifndef TVM_RUNTIME_CRT_PACKED_FUNC_H_ #define TVM_RUNTIME_CRT_PACKED_FUNC_H_ -#include - +#include #include #include -#include +#include #include "module.h" -static inline DLDataType String2DLDataType(const char * s) { +static inline DLDataType String2DLDataType(const char* s) { DLDataType t; // handle None type if (strlen(s) == 0) { - t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; + t.bits = 0; + t.lanes = 0; + t.code = kTVMOpaqueHandle; return t; } - t.bits = 32; t.lanes = 1; + t.bits = 32; + t.lanes = 1; const char* scan; if (!strncmp(s, "int", 3)) { - t.code = kDLInt; scan = s + 3; + t.code = kDLInt; + scan = s + 3; } else if (!strncmp(s, "uint", 4)) { - t.code = kDLUInt; scan = s + 4; + t.code = kDLUInt; + scan = s + 4; } else if (!strncmp(s, "float", 5)) { - t.code = kDLFloat; scan = s + 5; + t.code = kDLFloat; + scan = s + 5; } else if (!strncmp(s, "handle", 6)) { t.code = kTVMOpaqueHandle; t.bits = 64; // handle uses 64 bit by default. @@ -75,11 +80,11 @@ static inline DLDataType String2DLDataType(const char * s) { typedef struct TVMArgs { TVMValue values[TVM_CRT_MAX_ARGS]; - int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */ + int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */ uint32_t values_count; } TVMArgs; -static inline TVMArgs TVMArgs_Create(TVMValue * values, uint32_t * tcodes, uint32_t values_count) { +static inline TVMArgs TVMArgs_Create(TVMValue* values, uint32_t* tcodes, uint32_t values_count) { uint32_t idx; TVMArgs args; memset(&args, 0, sizeof(args)); @@ -91,8 +96,8 @@ static inline TVMArgs TVMArgs_Create(TVMValue * values, uint32_t * tcodes, uint3 return args; } -static inline int TVMNoOperation(TVMValue * args, int * type_codes, int num_args, - TVMRetValueHandle ret, void * res) { +static inline int TVMNoOperation(TVMValue* args, int* type_codes, int num_args, + TVMRetValueHandle ret, void* res) { return 0; } @@ -100,24 +105,24 @@ typedef struct TVMPackedFunc { char name[200]; TVMPackedCFunc fexec; TVMArgs args; - void (*Call)(struct TVMPackedFunc * pf); - void (*SetArgs)(struct TVMPackedFunc * pf, const struct TVMArgs * args); + void (*Call)(struct TVMPackedFunc* pf); + void (*SetArgs)(struct TVMPackedFunc* pf, const struct TVMArgs* args); } TVMPackedFunc; -static inline void TVMPackedFunc_Call(TVMPackedFunc * pf) { +static inline void TVMPackedFunc_Call(TVMPackedFunc* pf) { pf->fexec(pf->args.values, pf->args.tcodes, pf->args.values_count, 0, 0); } -static inline void TVMPackedFunc_SetArgs(TVMPackedFunc * pf, const TVMArgs * args) { +static inline void TVMPackedFunc_SetArgs(TVMPackedFunc* pf, const TVMArgs* args) { memcpy(&(pf->args), args, sizeof(TVMArgs)); } -TVMPackedFunc * g_fexecs = 0; +TVMPackedFunc* g_fexecs = 0; uint32_t g_fexecs_count = 0; // Implement TVMModule::GetFunction // Put implementation in this file so we have seen the TVMPackedFunc -static inline void TVMModule_GetFunction(TVMModule * mod, const char * name, TVMPackedFunc * pf) { +static inline void TVMModule_GetFunction(TVMModule* mod, const char* name, TVMPackedFunc* pf) { int idx; memset(pf, 0, sizeof(TVMPackedFunc)); assert(strlen(name) <= sizeof(pf->name)); diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h index b7d9ecb..25ff28a 100644 --- a/src/runtime/cuda/cuda_common.h +++ b/src/runtime/cuda/cuda_common.h @@ -26,7 +26,9 @@ #include #include + #include + #include "../workspace_pool.h" namespace tvm { @@ -36,18 +38,16 @@ namespace runtime { { \ CUresult result = x; \ if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \ - const char *msg; \ + const char* msg; \ cuGetErrorName(result, &msg); \ - LOG(FATAL) \ - << "CUDAError: " #x " failed with error: " << msg; \ + LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \ } \ } -#define CUDA_CALL(func) \ - { \ - cudaError_t e = (func); \ - CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ - << "CUDA: " << cudaGetErrorString(e); \ +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) << "CUDA: " << cudaGetErrorString(e); \ } /*! \brief Thread local workspace */ diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index d9f03e7..a6d4a54 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -21,13 +21,14 @@ * \file cuda_device_api.cc * \brief GPU specific API */ -#include - -#include -#include #include #include +#include +#include +#include + #include + #include "cuda_common.h" namespace tvm { @@ -35,40 +36,32 @@ namespace runtime { class CUDADeviceAPI final : public DeviceAPI { public: - void SetDevice(TVMContext ctx) final { - CUDA_CALL(cudaSetDevice(ctx.device_id)); - } + void SetDevice(TVMContext ctx) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { int value = 0; switch (kind) { case kExist: - value = ( - cudaDeviceGetAttribute( - &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) - == cudaSuccess); + value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) == + cudaSuccess); break; case kMaxThreadsPerBlock: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrWarpSize, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, ctx.device_id)); break; } case kMaxSharedMemoryPerBlock: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); + CUDA_CALL( + cudaDeviceGetAttribute(&value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); break; } case kComputeVersion: { std::ostringstream os; - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMajor, ctx.device_id)); os << value << "."; - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMinor, ctx.device_id)); os << value; *rv = os.str(); return; @@ -81,40 +74,33 @@ class CUDADeviceAPI final : public DeviceAPI { return; } case kMaxClockRate: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrClockRate, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrClockRate, ctx.device_id)); break; } case kMultiProcessorCount: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrMultiProcessorCount, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMultiProcessorCount, ctx.device_id)); break; } case kMaxThreadDimensions: { int dims[3]; - CUDA_CALL(cudaDeviceGetAttribute( - &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id)); - CUDA_CALL(cudaDeviceGetAttribute( - &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id)); - CUDA_CALL(cudaDeviceGetAttribute( - &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); std::stringstream ss; // use json string to return multiple int values; - ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; *rv = ss.str(); return; } - case kGcnArch: return; + case kGcnArch: + return; } *rv = value; } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { - CHECK_EQ(256 % alignment, 0U) - << "CUDA space is aligned at 256 bytes"; - void *ret; + CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; + void* ret; if (ctx.device_type == kDLCPUPinned) { CUDA_CALL(cudaMallocHost(&ret, nbytes)); } else { @@ -133,14 +119,8 @@ class CUDADeviceAPI final : public DeviceAPI { } } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { cudaStream_t cu_stream = static_cast(stream); from = static_cast(from) + from_offset; @@ -156,8 +136,8 @@ class CUDADeviceAPI final : public DeviceAPI { // In case there is a copy from host mem to host mem */ if (ctx_to.device_type == kDLCPU && ctx_from.device_type == kDLCPU) { - memcpy(to, from, size); - return; + memcpy(to, from, size); + return; } if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) { @@ -165,9 +145,7 @@ class CUDADeviceAPI final : public DeviceAPI { if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); } else { - cudaMemcpyPeerAsync(to, ctx_to.device_id, - from, ctx_from.device_id, - size, cu_stream); + cudaMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream); } } else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) { CUDA_CALL(cudaSetDevice(ctx_from.device_id)); @@ -210,8 +188,7 @@ class CUDADeviceAPI final : public DeviceAPI { } void SetStream(TVMContext ctx, TVMStreamHandle stream) final { - CUDAThreadEntry::ThreadLocal() - ->stream = static_cast(stream); + CUDAThreadEntry::ThreadLocal()->stream = static_cast(stream); } void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { @@ -223,16 +200,12 @@ class CUDADeviceAPI final : public DeviceAPI { } static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } private: - static void GPUCopy(const void* from, - void* to, - size_t size, - cudaMemcpyKind kind, + static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind kind, cudaStream_t stream) { if (stream != 0) { CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream)); @@ -244,25 +217,19 @@ class CUDADeviceAPI final : public DeviceAPI { typedef dmlc::ThreadLocalStore CUDAThreadStore; -CUDAThreadEntry::CUDAThreadEntry() - : pool(kDLGPU, CUDADeviceAPI::Global()) { -} +CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {} -CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { - return CUDAThreadStore::Get(); -} +CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.gpu") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); -TVM_REGISTER_GLOBAL("device_api.cpu_pinned") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 0550712..498a9b7 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -22,19 +22,21 @@ */ #include "cuda_module.h" -#include #include #include -#include +#include + #include -#include #include +#include #include -#include "cuda_common.h" +#include + +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "cuda_common.h" namespace tvm { namespace runtime { @@ -45,8 +47,7 @@ namespace runtime { // The modules will be lazily loaded class CUDAModuleNode : public runtime::ModuleNode { public: - explicit CUDAModuleNode(std::string data, - std::string fmt, + explicit CUDAModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string cuda_source) : data_(data), fmt_(fmt), fmap_(fmap), cuda_source_(cuda_source) { @@ -62,16 +63,11 @@ class CUDAModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { - return "cuda"; - } + const char* type_key() const final { return "cuda"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cu") { @@ -79,8 +75,7 @@ class CUDAModuleNode : public runtime::ModuleNode { SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, cuda_source_); } else { - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); } @@ -112,18 +107,14 @@ class CUDAModuleNode : public runtime::ModuleNode { CUfunction func; CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str()); if (result != CUDA_SUCCESS) { - const char *msg; + const char* msg; cuGetErrorName(result, &msg); - LOG(FATAL) - << "CUDAError: cuModuleGetFunction " << func_name - << " failed with error: " << msg; + LOG(FATAL) << "CUDAError: cuModuleGetFunction " << func_name << " failed with error: " << msg; } return func; } // get a global var from primary context in device_id - CUdeviceptr GetGlobal(int device_id, - const std::string& global_name, - size_t expect_nbytes) { + CUdeviceptr GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) { std::lock_guard lock(mutex_); // must recheck under the lock scope if (module_[device_id] == nullptr) { @@ -132,15 +123,12 @@ class CUDAModuleNode : public runtime::ModuleNode { CUdeviceptr global; size_t nbytes; - CUresult result = cuModuleGetGlobal(&global, &nbytes, - module_[device_id], global_name.c_str()); + CUresult result = cuModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str()); CHECK_EQ(nbytes, expect_nbytes); if (result != CUDA_SUCCESS) { - const char *msg; + const char* msg; cuGetErrorName(result, &msg); - LOG(FATAL) - << "CUDAError: cuModuleGetGlobal " << global_name - << " failed with error: " << msg; + LOG(FATAL) << "CUDAError: cuModuleGetGlobal " << global_name << " failed with error: " << msg; } return global; } @@ -164,11 +152,8 @@ class CUDAModuleNode : public runtime::ModuleNode { class CUDAWrappedFunc { public: // initialize the CUDA function. - void Init(CUDAModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, - size_t num_void_args, - const std::vector& thread_axis_tags) { + void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_void_args, const std::vector& thread_axis_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; @@ -176,9 +161,7 @@ class CUDAWrappedFunc { thread_axis_cfg_.Init(num_void_args, thread_axis_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - void** void_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); if (fcache_[device_id] == nullptr) { @@ -186,24 +169,17 @@ class CUDAWrappedFunc { } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - CUresult result = cuLaunchKernel( - fcache_[device_id], - wl.grid_dim(0), - wl.grid_dim(1), - wl.grid_dim(2), - wl.block_dim(0), - wl.block_dim(1), - wl.block_dim(2), - 0, strm, void_args, 0); + CUresult result = + cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), + wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), 0, strm, void_args, 0); if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { - const char *msg; + const char* msg; cuGetErrorName(result, &msg); std::ostringstream os; os << "CUDALaunch Error: " << msg << "\n" - << " grid=(" << wl.grid_dim(0) << "," - << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " - << " block=(" << wl.block_dim(0) << "," - << wl.block_dim(1) << "," << wl.block_dim(2) << ")\n"; + << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " + << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2) + << ")\n"; std::string cuda = m_->GetSource(""); if (cuda.length() != 0) { os << "// func_name=" << func_name_ << "\n" @@ -231,9 +207,7 @@ class CUDAWrappedFunc { class CUDAPrepGlobalBarrier { public: - CUDAPrepGlobalBarrier(CUDAModuleNode* m, - ObjectPtr sptr) - : m_(m), sptr_(sptr) { + CUDAPrepGlobalBarrier(CUDAModuleNode* m, ObjectPtr sptr) : m_(m), sptr_(sptr) { std::fill(pcache_.begin(), pcache_.end(), 0); } @@ -241,8 +215,8 @@ class CUDAPrepGlobalBarrier { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); if (pcache_[device_id] == 0) { - pcache_[device_id] = m_->GetGlobal( - device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned)); + pcache_[device_id] = + m_->GetGlobal(device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned)); } CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1)); } @@ -256,12 +230,10 @@ class CUDAPrepGlobalBarrier { mutable std::array pcache_; }; -PackedFunc CUDAModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc CUDAModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; if (name == symbol::tvm_prepare_global_barrier) { return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self)); } @@ -273,18 +245,15 @@ PackedFunc CUDAModuleNode::GetFunction( return PackFuncVoidAddr(f, info.arg_types); } -Module CUDAModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { auto n = make_object(data, fmt, fmap, cuda_source); return Module(n); } // Load module from module. -Module CUDAModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module CUDAModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -305,13 +274,10 @@ Module CUDAModuleLoadBinary(void* strm) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin") -.set_body_typed(CUDAModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx") -.set_body_typed(CUDAModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda") -.set_body_typed(CUDAModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.h b/src/runtime/cuda/cuda_module.h index bce0d63..e65c5fe 100644 --- a/src/runtime/cuda/cuda_module.h +++ b/src/runtime/cuda/cuda_module.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_CUDA_CUDA_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -45,11 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param cuda_source Optional, cuda source file */ -Module CUDAModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string cuda_source); +Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_ diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 378f976..6d3eec4 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -21,10 +21,11 @@ * \file dso_libary.cc * \brief Create library module to load from dynamic shared library. */ -#include #include -#include +#include #include +#include + #include "library_module.h" #if defined(_WIN32) @@ -43,13 +44,9 @@ class DSOLibrary final : public Library { ~DSOLibrary() { if (lib_handle_) Unload(); } - void Init(const std::string& name) { - Load(name); - } + void Init(const std::string& name) { Load(name); } - void* GetSymbol(const char* name) final { - return GetSymbol_(name); - } + void* GetSymbol(const char* name) final { return GetSymbol_(name); } private: // Platform dependent handling. @@ -58,8 +55,7 @@ class DSOLibrary final : public Library { HMODULE lib_handle_{nullptr}; void* GetSymbol_(const char* name) { - return reinterpret_cast( - GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) + return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) } // Load the library @@ -67,8 +63,7 @@ class DSOLibrary final : public Library { // use wstring version that is needed by LLVM. std::wstring wname(name.begin(), name.end()); lib_handle_ = LoadLibraryW(wname.c_str()); - CHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name; + CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; } void Unload() { @@ -81,14 +76,11 @@ class DSOLibrary final : public Library { // load the library void Load(const std::string& name) { lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - CHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name - << " " << dlerror(); + CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " + << dlerror(); } - void* GetSymbol_(const char* name) { - return dlsym(lib_handle_, name); - } + void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } void Unload() { dlclose(lib_handle_); @@ -97,11 +89,10 @@ class DSOLibrary final : public Library { #endif }; -TVM_REGISTER_GLOBAL("runtime.module.loadfile_so") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->Init(args[0]); - *rv = CreateModuleFromLibrary(n); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + n->Init(args[0]); + *rv = CreateModuleFromLibrary(n); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_util.cc b/src/runtime/file_util.cc index f94b2d3..68d174e 100644 --- a/src/runtime/file_util.cc +++ b/src/runtime/file_util.cc @@ -20,13 +20,15 @@ /*! * \file file_util.cc */ +#include "file_util.h" + #include #include #include + #include -#include #include -#include "file_util.h" +#include namespace tvm { namespace runtime { @@ -69,8 +71,7 @@ bool FunctionInfo::Load(dmlc::Stream* reader) { return true; } -std::string GetFileFormat(const std::string& file_name, - const std::string& format) { +std::string GetFileFormat(const std::string& file_name, const std::string& format) { std::string fmt = format; if (fmt.length() == 0) { size_t pos = file_name.find_last_of("."); @@ -103,7 +104,7 @@ std::string GetFileBasename(const std::string& file_name) { } std::string GetMetaFilePath(const std::string& file_name) { - size_t pos = file_name.find_last_of("."); + size_t pos = file_name.find_last_of("."); if (pos != std::string::npos) { return file_name.substr(0, pos) + ".tvm_meta.json"; } else { @@ -111,8 +112,7 @@ std::string GetMetaFilePath(const std::string& file_name) { } } -void LoadBinaryFromFile(const std::string& file_name, - std::string* data) { +void LoadBinaryFromFile(const std::string& file_name, std::string* data) { std::ifstream fs(file_name, std::ios::in | std::ios::binary); CHECK(!fs.fail()) << "Cannot open " << file_name; // get its size: @@ -123,17 +123,14 @@ void LoadBinaryFromFile(const std::string& file_name, fs.read(&(*data)[0], size); } -void SaveBinaryToFile( - const std::string& file_name, - const std::string& data) { +void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); CHECK(!fs.fail()) << "Cannot open " << file_name; fs.write(&data[0], data.length()); } -void SaveMetaDataToFile( - const std::string& file_name, - const std::unordered_map& fmap) { +void SaveMetaDataToFile(const std::string& file_name, + const std::unordered_map& fmap) { std::string version = "0.1.0"; std::ofstream fs(file_name.c_str()); CHECK(!fs.fail()) << "Cannot open file " << file_name; @@ -145,9 +142,8 @@ void SaveMetaDataToFile( fs.close(); } -void LoadMetaDataFromFile( - const std::string& file_name, - std::unordered_map* fmap) { +void LoadMetaDataFromFile(const std::string& file_name, + std::unordered_map* fmap) { std::ifstream fs(file_name.c_str()); CHECK(!fs.fail()) << "Cannot open file " << file_name; std::string version; @@ -159,9 +155,7 @@ void LoadMetaDataFromFile( fs.close(); } -void RemoveFile(const std::string& file_name) { - std::remove(file_name.c_str()); -} +void RemoveFile(const std::string& file_name) { std::remove(file_name.c_str()); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_util.h b/src/runtime/file_util.h index dfbaa16..1c35035 100644 --- a/src/runtime/file_util.h +++ b/src/runtime/file_util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,6 +26,7 @@ #include #include + #include "meta_data.h" namespace tvm { @@ -35,8 +36,7 @@ namespace runtime { * \param file_name The name of the file. * \param format The format of the file. */ -std::string GetFileFormat(const std::string& file_name, - const std::string& format); +std::string GetFileFormat(const std::string& file_name, const std::string& format); /*! * \return the directory in which TVM stores cached files. @@ -62,34 +62,30 @@ std::string GetFileBasename(const std::string& file_name); * \param file_name The name of the file. * \param data The data to be loaded. */ -void LoadBinaryFromFile(const std::string& file_name, - std::string* data); +void LoadBinaryFromFile(const std::string& file_name, std::string* data); /*! * \brief Load binary file into a in-memory buffer. * \param file_name The name of the file. * \param data The binary data to be saved. */ -void SaveBinaryToFile(const std::string& file_name, - const std::string& data); +void SaveBinaryToFile(const std::string& file_name, const std::string& data); /*! * \brief Save meta data to file. * \param file_name The name of the file. * \param fmap The function info map. */ -void SaveMetaDataToFile( - const std::string& file_name, - const std::unordered_map& fmap); +void SaveMetaDataToFile(const std::string& file_name, + const std::unordered_map& fmap); /*! * \brief Load meta data to file. * \param file_name The name of the file. * \param fmap The function info map. */ -void LoadMetaDataFromFile( - const std::string& file_name, - std::unordered_map* fmap); +void LoadMetaDataFromFile(const std::string& file_name, + std::unordered_map* fmap); /*! * \brief Remove (unlink) a file. diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 1c85de8..9f206fd 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -20,12 +20,13 @@ /*! * \file graph_runtime_debug.cc */ +#include #include #include -#include #include #include + #include "../graph_runtime.h" namespace tvm { @@ -59,15 +60,14 @@ class GraphRuntimeDebug : public GraphRuntime { std::ostringstream os; std::vector time_per_op(op_execs_.size(), 0); for (int i = 0; i < repeat; ++i) { - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; + std::chrono::time_point tbegin, + tend; double duration_ms = 0.0; do { std::fill(time_per_op.begin(), time_per_op.end(), 0); if (duration_ms > 0.0) { - number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random + number = static_cast(std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random } tbegin = std::chrono::high_resolution_clock::now(); for (int k = 0; k < number; k++) { @@ -78,15 +78,17 @@ class GraphRuntimeDebug : public GraphRuntime { op_execs_[index](); TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_tend = std::chrono::high_resolution_clock::now(); - double op_duration = std::chrono::duration_cast< - std::chrono::duration >(op_tend - op_tbegin).count(); + double op_duration = + std::chrono::duration_cast >(op_tend - op_tbegin) + .count(); time_per_op[index] += op_duration * 1e6; // us } } } tend = std::chrono::high_resolution_clock::now(); - duration_ms = std::chrono::duration_cast > - (tend - tbegin).count() * 1000; + duration_ms = + std::chrono::duration_cast >(tend - tbegin).count() * + 1000; } while (duration_ms < min_repeat_ms); LOG(INFO) << "Iteration: " << i; @@ -94,8 +96,8 @@ class GraphRuntimeDebug : public GraphRuntime { for (size_t index = 0; index < time_per_op.size(); index++) { if (op_execs_[index]) { time_per_op[index] /= number; - LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " - << time_per_op[index] << " us/iter"; + LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " << time_per_op[index] + << " us/iter"; } } } @@ -110,17 +112,14 @@ class GraphRuntimeDebug : public GraphRuntime { * \param index The index of op which needs to be returned. * \param eid The Entry id of the op. */ - NDArray GetOutputByLayer(int index, int eid) { - return data_entry_[entry_id(index, eid)]; - } + NDArray GetOutputByLayer(int index, int eid) { return data_entry_[entry_id(index, eid)]; } /*! * \brief GetFunction Get the function based on input. * \param name The function which needs to be invoked. * \param sptr_to_self Packed function pointer. */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \brief Get the node index given the name of node. @@ -135,53 +134,51 @@ class GraphRuntimeDebug : public GraphRuntime { } LOG(FATAL) << "cannot find " << name << " among nodex"; return -1; -} + } -/*! - * \brief Copy index-th node to data_out. - * - * This method will do a partial run of the the graph - * from begining upto the index-th node and return output of index-th node. - * This is costly operation and suggest to use only for debug porpose. - * - * \param index: The index of the node. - * \param data_out the node data. - */ -void DebugGetNodeOutput(int index, DLTensor* data_out) { - CHECK_LT(static_cast(index), op_execs_.size()); - uint32_t eid = index; + /*! + * \brief Copy index-th node to data_out. + * + * This method will do a partial run of the the graph + * from begining upto the index-th node and return output of index-th node. + * This is costly operation and suggest to use only for debug porpose. + * + * \param index: The index of the node. + * \param data_out the node data. + */ + void DebugGetNodeOutput(int index, DLTensor* data_out) { + CHECK_LT(static_cast(index), op_execs_.size()); + uint32_t eid = index; - for (size_t i = 0; i < op_execs_.size(); ++i) { - if (op_execs_[i]) op_execs_[i](); - if (static_cast(i) == index) break; - } + for (size_t i = 0; i < op_execs_.size(); ++i) { + if (op_execs_[i]) op_execs_[i](); + if (static_cast(i) == index) break; + } - data_entry_[eid].CopyTo(data_out); -} + data_entry_[eid].CopyTo(data_out); + } }; - /*! * \brief GetFunction Get the function based on input. * \param name The function which needs to be invoked. * \param sptr_to_self Packed function pointer. */ -PackedFunc GraphRuntimeDebug::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc GraphRuntimeDebug::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // return member functions during query. if (name == "get_output_by_layer") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetOutputByLayer(args[0], args[1]); - }); + *rv = this->GetOutputByLayer(args[0], args[1]); + }); } else if (name == "debug_get_output") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { - this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]); - } else { - this->DebugGetNodeOutput(args[0], args[1]); - } - }); + if (args[0].type_code() == kTVMStr) { + this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]); + } else { + this->DebugGetNodeOutput(args[0], args[1]); + } + }); } else if (name == "run_individual") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { int number = args[0]; @@ -203,21 +200,18 @@ PackedFunc GraphRuntimeDebug::GetFunction( * \param m Compiled module which will be loaded. * \param ctxs All devices contexts. */ -Module GraphRuntimeDebugCreate(const std::string& sym_json, - const tvm::runtime::Module& m, +Module GraphRuntimeDebugCreate(const std::string& sym_json, const tvm::runtime::Module& m, const std::vector& ctxs) { auto exec = make_object(); exec->Init(sym_json, m, ctxs); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create") -.set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) - << "The expected number of arguments for graph_runtime.create is " - "at least 4, but it has " - << args.num_args; - *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); - }); +TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index f3bcedf..239e43d 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -20,6 +20,8 @@ /*! * \file graph_runtime.cc */ +#include "graph_runtime.h" + #include #include #include @@ -35,8 +37,6 @@ #include #include -#include "graph_runtime.h" - namespace tvm { namespace runtime { namespace details { @@ -64,8 +64,7 @@ void GraphRuntime::Run() { * \param ctxs The context of the host and devices where graph nodes will be * executed on. */ -void GraphRuntime::Init(const std::string& graph_json, - tvm::runtime::Module module, +void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module, const std::vector& ctxs) { std::istringstream is(graph_json); dmlc::JSONReader reader(&is); @@ -133,9 +132,7 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { * * \return The number of outputs from graph. */ -int GraphRuntime::NumOutputs() const { - return outputs_.size(); -} +int GraphRuntime::NumOutputs() const { return outputs_.size(); } /*! * \brief Return NDArray for given input index. * \param index The input index. @@ -188,21 +185,16 @@ void GraphRuntime::LoadParams(const std::string& param_blob) { void GraphRuntime::LoadParams(dmlc::Stream* strm) { uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid parameters file format"; - CHECK(header == kTVMNDArrayListMagic) - << "Invalid parameters file format"; - CHECK(strm->Read(&reserved)) - << "Invalid parameters file format"; + CHECK(strm->Read(&header)) << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; std::vector names; - CHECK(strm->Read(&names)) - << "Invalid parameters file format"; + CHECK(strm->Read(&names)) << "Invalid parameters file format"; uint64_t sz; strm->Read(&sz); size_t size = static_cast(sz); - CHECK(size == names.size()) - << "Invalid parameters file format"; + CHECK(size == names.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { int in_idx = GetInputIndex(names[i]); CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i]; @@ -217,13 +209,10 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { } void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { - uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid parameters file format"; - CHECK(header == kTVMNDArrayListMagic) - << "Invalid parameters file format"; - CHECK(strm->Read(&reserved)) - << "Invalid parameters file format"; + uint64_t header, reserved; + CHECK(strm->Read(&header)) << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; std::vector names; CHECK(strm->Read(&names)) << "Invalid parameters file format"; uint64_t sz; @@ -268,15 +257,14 @@ void GraphRuntime::SetupStorage() { CHECK_GE(storage_id, 0) << "Do not support runtime shape op"; DLDataType t = vtype[i]; size_t bits = t.bits * t.lanes; - CHECK(bits % 8U == 0U || bits ==1U); + CHECK(bits % 8U == 0U || bits == 1U); size_t bytes = ((bits + 7U) / 8U) * size; uint32_t sid = static_cast(storage_id); if (sid >= pool_entry.size()) { pool_entry.resize(sid + 1, {0, -1}); } else { - CHECK(pool_entry[sid].device_type == -1 || - pool_entry[sid].device_type == device_type) + CHECK(pool_entry[sid].device_type == -1 || pool_entry[sid].device_type == device_type) << "The same pool entry cannot be assigned to multiple devices"; } pool_entry[sid].size = std::max(pool_entry[sid].size, bytes); @@ -288,14 +276,12 @@ void GraphRuntime::SetupStorage() { std::vector shape; // This for loop is very fast since there are usually only a couple of // devices available on the same hardware. - const auto& cit = - std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) { - return pit.device_type == static_cast(c.device_type); - }); + const auto& cit = std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) { + return pit.device_type == static_cast(c.device_type); + }); TVMContext ctx = cit == ctxs_.end() ? ctxs_[0] : *cit; shape.push_back(static_cast(pit.size + 3) / 4); - storage_pool_.push_back( - NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); + storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); } // Assign the pooled entries. A unified memory pool is used to simplifiy @@ -306,8 +292,7 @@ void GraphRuntime::SetupStorage() { for (size_t i = 0; i < data_entry_.size(); ++i) { int storage_id = attrs_.storage_id[i]; CHECK_LT(static_cast(storage_id), storage_pool_.size()); - data_entry_[i] = - storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); + data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); const DLTensor* tmp = data_entry_[i].operator->(); data_alignment_[i] = details::GetDataAlignment(*tmp); } @@ -338,24 +323,20 @@ void GraphRuntime::SetupOpExecs() { CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; std::shared_ptr op_args = nullptr; - std::tie(op_execs_[nid], op_args) = - CreateTVMOp(inode.param, args, inode.inputs.size()); + std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args, inode.inputs.size()); for (size_t i = 0; i < inode.inputs.size(); i++) { uint32_t eid = this->entry_id(inode.inputs[i]); // check if op input is model input if (input_node_eids.count(eid) > 0) { - input_dltensors_[eid].push_back( - static_cast(op_args->arg_values[i].v_handle)); + input_dltensors_[eid].push_back(static_cast(op_args->arg_values[i].v_handle)); } } } } std::pair, std::shared_ptr > GraphRuntime::CreateTVMOp( - const TVMOpParam& param, - const std::vector& args, - size_t num_inputs) { + const TVMOpParam& param, const std::vector& args, size_t num_inputs) { std::shared_ptr arg_ptr = std::make_shared(); // setup address. arg_ptr->args = args; @@ -369,15 +350,15 @@ std::pair, std::shared_ptr > GraphRu arg_ptr->arg_values.push_back(v); arg_ptr->arg_tcodes.push_back(kTVMDLTensorHandle); if (param.flatten_data) { - arg_ptr->shape_data[i] = std::accumulate( - t->shape, t->shape + t->ndim, 1, std::multiplies()); + arg_ptr->shape_data[i] = + std::accumulate(t->shape, t->shape + t->ndim, 1, std::multiplies()); t->ndim = 1; t->shape = &(arg_ptr->shape_data[i]); } } if (param.func_name == "__nop") { - return {[](){}, arg_ptr}; + return {[]() {}, arg_ptr}; } else if (param.func_name == "__copy") { // Perform cross device data copy. // Directly copy data from the input to the output. @@ -396,27 +377,25 @@ std::pair, std::shared_ptr > GraphRu auto fexec = [arg_ptr, pf]() { TVMRetValue rv; - TVMArgs targs(arg_ptr->arg_values.data(), - arg_ptr->arg_tcodes.data(), + TVMArgs targs(arg_ptr->arg_values.data(), arg_ptr->arg_tcodes.data(), static_cast(arg_ptr->arg_values.size())); pf.CallPacked(targs, &rv); }; return {fexec, arg_ptr}; } -PackedFunc GraphRuntime::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc GraphRuntime::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { - int in_idx = this->GetInputIndex(args[0]); - if (in_idx >= 0) this->SetInput(in_idx, args[1]); - } else { - this->SetInput(args[0], args[1]); - } - }); + if (args[0].type_code() == kTVMStr) { + int in_idx = this->GetInputIndex(args[0]); + if (in_idx >= 0) this->SetInput(in_idx, args[1]); + } else { + this->SetInput(args[0], args[1]); + } + }); } else if (name == "set_input_zero_copy") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (args[0].type_code() == kTVMStr) { @@ -436,42 +415,38 @@ PackedFunc GraphRuntime::GetFunction( }); } else if (name == "get_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - int in_idx = 0; - if (args[0].type_code() == kTVMStr) { - in_idx = this->GetInputIndex(args[0]); - } else { - in_idx = args[0]; - } - CHECK_GE(in_idx, 0); - *rv = this->GetInput(in_idx); - }); + int in_idx = 0; + if (args[0].type_code() == kTVMStr) { + in_idx = this->GetInputIndex(args[0]); + } else { + in_idx = args[0]; + } + CHECK_GE(in_idx, 0); + *rv = this->GetInput(in_idx); + }); } else if (name == "get_num_outputs") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->NumOutputs(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); } else if (name == "run") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Run(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); } else if (name == "load_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->LoadParams(args[0].operator std::string()); - }); + this->LoadParams(args[0].operator std::string()); + }); } else if (name == "share_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - const auto& module = args[0].operator Module(); - CHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); - const auto& param_blob = args[1].operator std::string(); - dmlc::MemoryStringStream strm(const_cast(¶m_blob)); - this->ShareParams(dynamic_cast(*module.operator->()), &strm); - }); + const auto& module = args[0].operator Module(); + CHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); + const auto& param_blob = args[1].operator std::string(); + dmlc::MemoryStringStream strm(const_cast(¶m_blob)); + this->ShareParams(dynamic_cast(*module.operator->()), &strm); + }); } else { return PackedFunc(); } } -Module GraphRuntimeCreate(const std::string& sym_json, - const tvm::runtime::Module& m, +Module GraphRuntimeCreate(const std::string& sym_json, const tvm::runtime::Module& m, const std::vector& ctxs) { auto exec = make_object(); exec->Init(sym_json, m, ctxs); @@ -497,14 +472,12 @@ std::vector GetAllContext(const TVMArgs& args) { // execution support yet. For heterogenenous execution, at least 5 arguments will // be passed in. The third one is the number of devices. // Eventually, we will only probably pass TVMContext for all the languages. -TVM_REGISTER_GLOBAL("tvm.graph_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) - << "The expected number of arguments for graph_runtime.create is " - "at least 4, but it has " - << args.num_args; - const auto& contexts = GetAllContext(args); - *rv = GraphRuntimeCreate(args[0], args[1], contexts); - }); +TVM_REGISTER_GLOBAL("tvm.graph_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + const auto& contexts = GetAllContext(args); + *rv = GraphRuntimeCreate(args[0], args[1], contexts); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index b787c0a..d0c9822 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -26,26 +26,25 @@ #define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ #include -#include #include +#include #include #include #include +#include #include #include #include -#include namespace tvm { namespace runtime { /*! \brief macro to do C API call */ -#define TVM_CCALL(func) \ - { \ - int ret = (func); \ - CHECK_EQ(ret, 0) \ - << TVMGetLastError(); \ +#define TVM_CCALL(func) \ + { \ + int ret = (func); \ + CHECK_EQ(ret, 0) << TVMGetLastError(); \ } /*! \brief Magic number for NDArray list file */ @@ -80,15 +79,12 @@ class TVM_DLL GraphRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. */ - const char* type_key() const final { - return "GraphRuntime"; - } + const char* type_key() const final { return "GraphRuntime"; } void Run(); /*! @@ -100,8 +96,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { * executed on. */ - void Init(const std::string& graph_json, - tvm::runtime::Module module, + void Init(const std::string& graph_json, tvm::runtime::Module module, const std::vector& ctxs); /*! @@ -172,14 +167,9 @@ class TVM_DLL GraphRuntime : public ModuleNode { * \brief Get total number of nodes. * \return Total number of nodes. */ - uint32_t GetNumOfNodes() const { - return static_cast(nodes_.size()); - } - - std::string GetNodeName(uint32_t nid) const { - return nodes_[nid].name; - } + uint32_t GetNumOfNodes() const { return static_cast(nodes_.size()); } + std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; } protected: // Memory pool entry. @@ -194,7 +184,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { uint32_t index; uint32_t version; // JSON Loader - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginArray(); CHECK(reader->NextArrayItem()) << "invalid json format"; reader->Read(&node_id); @@ -221,7 +211,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { // control deps std::vector control_deps; // JSON Loader - void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) { + void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) { int bitmask = 0; std::string key, value; reader->BeginObject(); @@ -241,10 +231,10 @@ class TVM_DLL GraphRuntime : public ModuleNode { bitmask |= 8; } } - CHECK_EQ(bitmask, 1|2|4|8) << "invalid format"; + CHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "invalid format"; } // JSON Loader - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginObject(); int bitmask = 0; std::string key; @@ -266,7 +256,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { LOG(FATAL) << "do not support key " << key; } } - CHECK_EQ(bitmask, 1|2|4) << "invalid format"; + CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format"; } }; struct GraphAttr { @@ -274,9 +264,9 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::vector storage_id; std::vector device_index; std::vector dltype; - std::vector > shape; + std::vector> shape; // The graph attribute fields. - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginObject(); int bitmask = 0; std::string key, type; @@ -334,37 +324,37 @@ class TVM_DLL GraphRuntime : public ModuleNode { CHECK(!reader->NextArrayItem()); } } - CHECK_EQ(bitmask, 1|2|4) << "invalid format"; + CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format"; } }; // The graph attribute fields. - void Load(dmlc::JSONReader *reader) { - reader->BeginObject(); - int bitmask = 0; - std::string key; - while (reader->NextObjectItem(&key)) { - if (key == "nodes") { - reader->Read(&nodes_); - bitmask |= 1; - } else if (key == "arg_nodes") { - reader->Read(&input_nodes_); - bitmask |= 2; - } else if (key == "node_row_ptr") { - reader->Read(&node_row_ptr_); - bitmask |= 4; - } else if (key == "heads") { - reader->Read(&outputs_); - bitmask |= 8; - } else if (key == "attrs") { - reader->Read(&attrs_); - bitmask |= 16; - } else if (key == "metadata") { - break; - } else { - LOG(FATAL) << "key " << key << " is not supported"; - } + void Load(dmlc::JSONReader* reader) { + reader->BeginObject(); + int bitmask = 0; + std::string key; + while (reader->NextObjectItem(&key)) { + if (key == "nodes") { + reader->Read(&nodes_); + bitmask |= 1; + } else if (key == "arg_nodes") { + reader->Read(&input_nodes_); + bitmask |= 2; + } else if (key == "node_row_ptr") { + reader->Read(&node_row_ptr_); + bitmask |= 4; + } else if (key == "heads") { + reader->Read(&outputs_); + bitmask |= 8; + } else if (key == "attrs") { + reader->Read(&attrs_); + bitmask |= 16; + } else if (key == "metadata") { + break; + } else { + LOG(FATAL) << "key " << key << " is not supported"; } - CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format"; + } + CHECK_EQ(bitmask, 1 | 2 | 4 | 8 | 16) << "invalid format"; } /*! \brief Setup the temporal storage */ void SetupStorage(); @@ -377,21 +367,14 @@ class TVM_DLL GraphRuntime : public ModuleNode { * \param num_inputs Number of inputs. * \return The created executor. */ - std::pair, std::shared_ptr > CreateTVMOp( - const TVMOpParam& attrs, const std::vector& args, - size_t num_inputs); + std::pair, std::shared_ptr> CreateTVMOp( + const TVMOpParam& attrs, const std::vector& args, size_t num_inputs); // Get node entry index. - uint32_t entry_id(uint32_t nid, uint32_t index) const { - return node_row_ptr_[nid] + index; - } + uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; } // Get node entry index. - uint32_t entry_id(const NodeEntry& e) const { - return entry_id(e.node_id, e.index); - } + uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); } // Number of node entries. - uint32_t num_node_entries() const { - return node_row_ptr_.back(); - } + uint32_t num_node_entries() const { return node_row_ptr_.back(); } /*! \brief The graph nodes. */ std::vector nodes_; /*! \brief The argument nodes. */ @@ -417,7 +400,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { /*! \brief Data alignment of each node. */ std::vector data_alignment_; /*! \brief Operator on each node. */ - std::vector > op_execs_; + std::vector> op_execs_; }; std::vector GetAllContext(const TVMArgs& args); diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index d88e6d7..fd6f323 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -33,21 +33,17 @@ class HexagonDeviceAPI : public DeviceAPI { public: void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t num_bytes, TVMContext ctx_from, - TVMContext ctx_to, DLDataType type_hint, - TVMStreamHandle stream) final; + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; - void* AllocWorkspace(TVMContext ctx, size_t nbytes, - DLDataType type_hint = {}) final; + void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}) final; void FreeWorkspace(TVMContext ctx, void* ptr) final; static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } }; @@ -56,13 +52,11 @@ class HexagonDeviceAPI : public DeviceAPI { inline void HexagonDeviceAPI::SetDevice(TVMContext ctx) {} -inline void HexagonDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, - TVMRetValue* rv) { +inline void HexagonDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { if (kind == kExist) *rv = 1; } -inline void* HexagonDeviceAPI::AllocDataSpace(TVMContext ctx, size_t nbytes, - size_t alignment, +inline void* HexagonDeviceAPI::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) { CHECK(hexagon::Device::ValidateDeviceId(ctx.device_id)); return hexagon::Device::Global()->Alloc(nbytes, alignment); @@ -73,10 +67,10 @@ inline void HexagonDeviceAPI::FreeDataSpace(TVMContext ctx, void* ptr) { hexagon::Device::Global()->Free(ptr); } -inline void HexagonDeviceAPI::CopyDataFromTo( - const void* from, size_t from_offset, void* to, size_t to_offset, - size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, - DLDataType type_hint, TVMStreamHandle stream) { +inline void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t num_bytes, + TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) { const char* src = static_cast(from) + from_offset; char* dst = static_cast(to) + to_offset; @@ -110,11 +104,9 @@ inline void HexagonDeviceAPI::CopyDataFromTo( } } -inline void HexagonDeviceAPI::StreamSync(TVMContext ctx, - TVMStreamHandle stream) {} +inline void HexagonDeviceAPI::StreamSync(TVMContext ctx, TVMStreamHandle stream) {} -inline void* HexagonDeviceAPI::AllocWorkspace(TVMContext ctx, size_t nbytes, - DLDataType type_hint) { +inline void* HexagonDeviceAPI::AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint) { CHECK(hexagon::Device::ValidateDeviceId(ctx.device_id)); if (type_hint.code == 100) { size_t align = std::min(nbytes, 2048lu); @@ -128,11 +120,10 @@ inline void HexagonDeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { DeviceAPI::FreeWorkspace(ctx, ptr); } -TVM_REGISTER_GLOBAL("device_api.hexagon") - .set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = HexagonDeviceAPI::Global().get(); - *rv = ptr; - }); +TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = HexagonDeviceAPI::Global().get(); + *rv = ptr; +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index e148436..f76ac16 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -176,8 +176,7 @@ void ArgLayout::Push(uint32_t* v, unsigned t_size, unsigned t_align) { if (!InReg) { // Allocate on stack. - CHECK_EQ((t_align & (t_align - 1)), 0) - << "Alignment should be a power of 2"; + CHECK_EQ((t_align & (t_align - 1)), 0) << "Alignment should be a power of 2"; CHECK_GE(t_align, 4) << "Alignment should be at least 4"; // Round t_size up to a multiple of 4. unsigned s_size = Stack.size(); @@ -193,9 +192,8 @@ void ArgLayout::Push(uint32_t* v, unsigned t_size, unsigned t_align) { class HexagonModuleNode final : public runtime::ModuleNode { public: HexagonModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi) : hexagon_device_(hexagon::Device::Global()), data_(data), @@ -214,13 +212,11 @@ class HexagonModuleNode final : public runtime::ModuleNode { } } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; const char* type_key() const final { return "hexagon"; } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = runtime::GetFileFormat(file_name, format); if (fmt == "so" || fmt == "dll" || fmt == "hexagon") { std::string meta_file = GetMetaFilePath(file_name); @@ -240,8 +236,7 @@ class HexagonModuleNode final : public runtime::ModuleNode { CHECK(!bc_.empty()) << "LLVM IR bitcode not available"; SaveBinaryToFile(file_name, bc_); } else { - LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt - << "'"; + LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt << "'"; } } void SaveToBinary(dmlc::Stream* stream) final { @@ -251,10 +246,8 @@ class HexagonModuleNode final : public runtime::ModuleNode { } private: - void CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, - TVMRetValue* rv) const; - void CallRemoteDirect(void* func_ptr, const TVMArgs& args, - TVMRetValue* rv) const; + void CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const; + void CallRemoteDirect(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const; void RemapArgs(const TVMArgs& args, std::vector& values, // NOLINT(*) std::vector& type_codes, // NOLINT(*) @@ -274,8 +267,7 @@ class HexagonModuleNode final : public runtime::ModuleNode { std::set packed_c_abi_funcs_; }; -void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, - const TVMArgs& args, +void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const { // Remap all arguments, creating remote DLTensors. std::vector values; @@ -297,8 +289,8 @@ void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, int num_args = args.size(); int values_size = num_args * sizeof(TVMValue); int codes_size = num_args * sizeof(int); - void* remote = hexagon_device_->Alloc( - values_size + sizeof(TVMValue) + codes_size + sizeof(int), 8); + void* remote = + hexagon_device_->Alloc(values_size + sizeof(TVMValue) + codes_size + sizeof(int), 8); // Copy all argument TVMValues to the remote space. void* remote_values = remote; @@ -316,12 +308,12 @@ void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, temp_values[2].v_int64 = num_args; temp_values[3].v_handle = remote_ret_value; temp_values[4].v_handle = remote_ret_code; - int temp_codes[5] = {kTVMOpaqueHandle, kTVMOpaqueHandle, kDLInt, - kTVMOpaqueHandle, kTVMOpaqueHandle}; + int temp_codes[5] = {kTVMOpaqueHandle, kTVMOpaqueHandle, kDLInt, kTVMOpaqueHandle, + kTVMOpaqueHandle}; TVMArgs temp_args(temp_values, temp_codes, 5); hexagon::ArgLayout as = BuildArgLayout(temp_args); - hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), - as.Stack.data(), as.Stack.size()); + hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(), + as.Stack.size()); // TODO(kparzysz-quic): copy return value back std::for_each(remote_tensors.begin(), remote_tensors.end(), @@ -332,12 +324,12 @@ void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, void HexagonModuleNode::CallRemoteDirect(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const { hexagon::ArgLayout as = BuildArgLayout(args); - hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), - as.Stack.data(), as.Stack.size()); + hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(), + as.Stack.size()); } -PackedFunc HexagonModuleNode::GetFunction( - const std::string& name, const ObjectPtr& sptr_to_self) { +PackedFunc HexagonModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { auto f = fmap_.find(name); if (f == fmap_.end()) return PackedFunc(nullptr); @@ -363,8 +355,7 @@ PackedFunc HexagonModuleNode::GetFunction( } } -void HexagonModuleNode::RemapArgs(const TVMArgs& args, - std::vector& values, +void HexagonModuleNode::RemapArgs(const TVMArgs& args, std::vector& values, std::vector& type_codes, std::vector& remote_tensors) const { for (unsigned i = 0, e = args.size(); i != e; ++i) { @@ -437,18 +428,17 @@ void* HexagonModuleNode::CreateRemoteTensor(const DLTensor* t) const { uint32_t remote_as_int = reinterpret_cast(remote); void* remote_ss = reinterpret_cast(remote_as_int + size_ht); - HexagonDLTensor local = { - .data = static_cast(reinterpret_cast(t->data)), - .ctx_device_type = uint8_t(t->ctx.device_type), - .pad0 = {0, 0, 0}, - .ctx_device_id = t->ctx.device_id, - .ndim = t->ndim, - .dtype_code = t->dtype.code, - .dtype_bits = t->dtype.bits, - .dtype_lanes = t->dtype.lanes, - .shape = remote_as_int + size_ht, - .strides = t->strides ? remote_as_int + size_ht + size_s : 0u, - .byte_offset = t->byte_offset}; + HexagonDLTensor local = {.data = static_cast(reinterpret_cast(t->data)), + .ctx_device_type = uint8_t(t->ctx.device_type), + .pad0 = {0, 0, 0}, + .ctx_device_id = t->ctx.device_id, + .ndim = t->ndim, + .dtype_code = t->dtype.code, + .dtype_bits = t->dtype.bits, + .dtype_lanes = t->dtype.lanes, + .shape = remote_as_int + size_ht, + .strides = t->strides ? remote_as_int + size_ht + size_s : 0u, + .byte_offset = t->byte_offset}; std::vector local_ss(size_ss / 8); for (int i = 0; i != ndim; ++i) local_ss[i] = t->shape[i]; @@ -505,18 +495,16 @@ hexagon::ArgLayout HexagonModuleNode::BuildArgLayout(const TVMArgs& As) const { } Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi) { - auto n = make_object(data, fmt, fmap, asm_str, obj_str, - ir_str, bc_str, packed_c_abi); + auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str, + packed_c_abi); return Module(n); } // Load module from file. -Module HexagonModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module HexagonModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data = file_name; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -552,10 +540,9 @@ std::shared_ptr Device::Global() { } // namespace hexagon -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = HexagonModuleLoadFile(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = HexagonModuleLoadFile(args[0], args[1]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index c9e23a7..b922b16 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -47,9 +47,8 @@ namespace runtime { * convention. */ Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi); namespace hexagon { @@ -91,24 +90,21 @@ class Device { * \param src Pointer (local to device) of the source buffer. * \param len Number of bytes to copy. */ - virtual void CopyDeviceToDevice(void* dst, const void* src, - unsigned len) = 0; + virtual void CopyDeviceToDevice(void* dst, const void* src, unsigned len) = 0; /*! * \brief Copy a block of data from device to host. * \param host_dst Pointer (local to host) to the destination buffer. * \param src Pointer (local to device) to the source buffer. * \param len Number of bytes to copy. */ - virtual void CopyDeviceToHost(void* host_dst, const void* src, - unsigned len) = 0; + virtual void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) = 0; /*! * \brief Copy a block of data from host to device. * \param dst Pointer (local to device) to the destination buffer. * \param host_src Pointer (local to host) to the source buffer. * \param len Number of bytes to copy. */ - virtual void CopyHostToDevice(void* dst, const void* host_src, - unsigned len) = 0; + virtual void CopyHostToDevice(void* dst, const void* host_src, unsigned len) = 0; /*! * \brief Load a module (typically a shared library) into device. * \param data Name of the shared library. @@ -141,8 +137,8 @@ class Device { * for padding. * \param st_num Number of values in the "stack" array. */ - virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, - uint32_t* stack, unsigned st_num) = 0; + virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, + unsigned st_num) = 0; virtual ~Device() = 0; diff --git a/src/runtime/hexagon/hexagon_posix.cc b/src/runtime/hexagon/hexagon_posix.cc index 627963f..e98fefd 100644 --- a/src/runtime/hexagon/hexagon_posix.cc +++ b/src/runtime/hexagon/hexagon_posix.cc @@ -23,12 +23,10 @@ #include extern "C" { -int posix_memalign(void** memptr, size_t alignment, size_t size) - __attribute__((nothrow)); +int posix_memalign(void** memptr, size_t alignment, size_t size) __attribute__((nothrow)); } -__attribute__((nothrow)) int posix_memalign(void** memptr, size_t alignment, - size_t size) { +__attribute__((nothrow)) int posix_memalign(void** memptr, size_t alignment, size_t size) { if (void* p = memalign(alignment, size)) { *memptr = p; return 0; diff --git a/src/runtime/hexagon/sim/hexagon_device_sim.cc b/src/runtime/hexagon/sim/hexagon_device_sim.cc index b58377b..477da09 100644 --- a/src/runtime/hexagon/sim/hexagon_device_sim.cc +++ b/src/runtime/hexagon/sim/hexagon_device_sim.cc @@ -41,8 +41,7 @@ namespace tvm { namespace runtime { namespace hexagon { -static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t), - "Hexagon VA must be uint32"); +static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t), "Hexagon VA must be uint32"); template struct unalign { @@ -89,8 +88,7 @@ std::unique_ptr make_unique(size_t size) { // user from memory reallocation and copying. struct non_const_str { non_const_str() {} - explicit non_const_str(const std::string& str) - : non_const_str(std::vector{str}) {} + explicit non_const_str(const std::string& str) : non_const_str(std::vector{str}) {} explicit non_const_str(const std::vector& vec) { for (const std::string& s : vec) { auto c = detail::make_unique(s.size() + 1); @@ -220,8 +218,7 @@ class HexagonSimulator final : public tvm::runtime::hexagon::Device { void* Load(const std::string& data, const std::string& fmt) final; void Unload(void* mod) final; void* Resolve(const std::string& sym) final; - void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, - unsigned st_num) final; + void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, unsigned st_num) final; static std::string to_string(HEXAPI_Status status); @@ -312,10 +309,8 @@ class HexagonSimulator final : public tvm::runtime::hexagon::Device { bool should_parse_next(const string_list& rest); llvm::Optional to_interval(const detail::MaybeString& str); - llvm::Optional to_timingmode( - const detail::MaybeString& str); - llvm::Optional to_verbosemode( - const detail::MaybeString& str); + llvm::Optional to_timingmode(const detail::MaybeString& str); + llvm::Optional to_verbosemode(const detail::MaybeString& str); llvm::Optional to_nullptr(const detail::MaybeString& str); MaybeUIntRange ahb_, axi2_; @@ -399,12 +394,11 @@ decltype(HexagonSimulator::opt_map_) HexagonSimulator::opt_map_ = { {"--verbose", &HexagonSimulator::HandleVerbose}, }; -#define CHECKED_CALL(func, ...) \ - do { \ - HEXAPI_Status s = sim_->func(__VA_ARGS__); \ - CHECK_EQ(s, HEX_STAT_SUCCESS) \ - << "HexagonSimulator: " #func " failed with code " \ - << HexagonSimulator::to_string(s); \ +#define CHECKED_CALL(func, ...) \ + do { \ + HEXAPI_Status s = sim_->func(__VA_ARGS__); \ + CHECK_EQ(s, HEX_STAT_SUCCESS) << "HexagonSimulator: " #func " failed with code " \ + << HexagonSimulator::to_string(s); \ } while (false) inline HEX_VA_t HexagonSimulator::p2va(const void* p) { @@ -444,8 +438,7 @@ void HexagonSimulator::CopyNFromV(void* host_dst, HEX_VA_t src) { pd->value = v; } -void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src, - unsigned len) { +void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src, unsigned len) { const uint8_t* src = static_cast(host_src); while (len >= 8) { @@ -556,18 +549,15 @@ HexagonSimulator::HexagonSimulator(bool enable_queuing) { using iterator = std::istream_iterator; auto sim_args = string_list(iterator(sim_args_iss), iterator()); - std::string target_str = - !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66"); + std::string target_str = !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66"); arch_ = target_str; - sim_ = - detail::make_unique(detail::non_const_str(target_str)); + sim_ = detail::make_unique(detail::non_const_str(target_str)); LOG(INFO) << "HexagonSimulator: Core version: " << arch_; // Locate the sim_dev binary in PATH, or in the current working directory. llvm::StringRef sim_dev = "sim_dev"; - detail::MaybeString path_sim_dev = - llvm::sys::Process::FindInEnvPath("PATH", sim_dev); + detail::MaybeString path_sim_dev = llvm::sys::Process::FindInEnvPath("PATH", sim_dev); if (!path_sim_dev) { if (!llvm::sys::fs::exists(sim_dev)) { LOG(FATAL) << "Cannot find sim_dev in PATH."; @@ -615,8 +605,7 @@ HexagonSimulator::HexagonSimulator(bool enable_queuing) { } void* HexagonSimulator::Alloc(unsigned size, unsigned align) { - LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align - << ')'; + LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align << ')'; Message m = {kAlloc, sizeof(MsgAlloc), 0u}; MsgAlloc ma = {size, align}; SendMsg(m, &ma, true); @@ -631,8 +620,7 @@ void* HexagonSimulator::Alloc(unsigned size, unsigned align) { } void HexagonSimulator::Free(void* ptr) { - LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec - << ')'; + LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec << ')'; if (task_queuing_) { Message mf = {kFlush, 0, 0}; SendMsg(mf, 0, true); @@ -643,8 +631,7 @@ void HexagonSimulator::Free(void* ptr) { } void* HexagonSimulator::AllocVtcm(unsigned size, unsigned align) { - LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size - << ", align=" << align << ')'; + LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size << ", align=" << align << ')'; Message m = {kAllocVtcm, sizeof(MsgAlloc), 0u}; MsgAlloc ma = {size, align}; SendMsg(m, &ma, true); @@ -653,28 +640,25 @@ void* HexagonSimulator::AllocVtcm(unsigned size, unsigned align) { MsgPointer mp; CopyFromV(&mp, m.va, m.len); - LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va - << std::dec; + LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va << std::dec; CHECK_NE(mp.va, 0); return va2p(mp.va); } void HexagonSimulator::FreeVtcm(void* ptr) {} -void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src, - unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst - << ", src=" << src << ", len=" << std::dec << len << ')'; +void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst << ", src=" << src + << ", len=" << std::dec << len << ')'; CHECK(dst != nullptr && src != nullptr); Message m = {kCopy, sizeof(MsgCopy), 0u}; MsgCopy mc = {p2va(dst), p2va(src), len}; SendMsg(m, &mc, true); } -void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, - unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst - << ", src=" << src << ", len=" << len << ')'; +void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst << ", src=" << src + << ", len=" << len << ')'; if (task_queuing_) { Message mf = {kFlush, 0, 0}; SendMsg(mf, 0, true); @@ -682,10 +666,9 @@ void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, CopyFromV(host_dst, p2va(src), len); } -void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src, - unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst - << ", host_src=" << host_src << ", len=" << len << ')'; +void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst << ", host_src=" << host_src + << ", len=" << len << ')'; CopyToV(p2va(dst), host_src, len); } @@ -717,19 +700,17 @@ void* HexagonSimulator::Resolve(const std::string& sym) { MsgPointer mp; CopyFromV(&mp, m.va, sizeof(mp)); - LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va - << std::dec; + LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va << std::dec; return va2p(mp.va); } -void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, - uint32_t* stack, unsigned st_num) { - LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func - << ", scalar=" << scalar << ", sc_num=" << std::dec +void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, + unsigned st_num) { + LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func << ", scalar=" << scalar + << ", sc_num=" << std::dec << sc_num // NOLINTNEXTLINE(build/include_what_you_use) - << ", stack=" << std::hex << stack << ", st_num=" << std::dec - << st_num; + << ", stack=" << std::hex << stack << ", st_num=" << std::dec << st_num; std::vector data; @@ -753,8 +734,7 @@ void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, log_data << std::dec << " }" << std::flush; LOG(INFO) << log_data.str(); - Message m = {kCall, static_cast(data.size() * sizeof(uint32_t)), - 0u}; + Message m = {kCall, static_cast(data.size() * sizeof(uint32_t)), 0u}; SendMsg(m, data.data(), true); if (!task_queuing_) { @@ -768,8 +748,7 @@ void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, std::ostringstream log_rv; log_rv << "HexagonSimulator::Call -> {" << std::hex; for (unsigned i = 0, e = std::min(rv.size(), 4u); i != e; ++i) { - log_rv << ' ' << std::setw(2) << std::setfill('0') - << static_cast(rv[i]); + log_rv << ' ' << std::setw(2) << std::setfill('0') << static_cast(rv[i]); } if (rv.size() > 4) log_rv << "..."; log_rv << std::dec << " }"; @@ -1059,8 +1038,7 @@ bool HexagonSimulator::HandlePacketAnalyze(string_list& rest) { } bool HexagonSimulator::HandlePCFilter(string_list& rest) { - auto range = - detail::to_range(detail::pop_front(rest)); + auto range = detail::to_range(detail::pop_front(rest)); if (range) { CHECKED_CALL(ConfigurePCRangeFilter, range->first, range->second); } @@ -1222,11 +1200,9 @@ bool HexagonSimulator::HandleTCMLowAddr(string_list& rest) { } bool HexagonSimulator::HandleTimeFilterNS(string_list& rest) { - auto range = - detail::to_range(detail::pop_front(rest)); + auto range = detail::to_range(detail::pop_front(rest)); if (range) { - CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC, - range->second, HEX_NANOSEC); + CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC, range->second, HEX_NANOSEC); } return static_cast(range); } @@ -1284,8 +1260,7 @@ bool HexagonSimulator::should_parse_next(const string_list& rest) { return false; } -llvm::Optional HexagonSimulator::to_interval( - const detail::MaybeString& str) { +llvm::Optional HexagonSimulator::to_interval(const detail::MaybeString& str) { auto none = llvm::Optional(); if (!str) return none; @@ -1309,8 +1284,7 @@ llvm::Optional HexagonSimulator::to_interval( .Default(none); } -llvm::Optional HexagonSimulator::to_timingmode( - const detail::MaybeString& str) { +llvm::Optional HexagonSimulator::to_timingmode(const detail::MaybeString& str) { auto none = llvm::Optional(); if (!str) return none; @@ -1357,8 +1331,7 @@ llvm::Optional HexagonSimulator::to_verbosemode( .Default(none); } -llvm::Optional HexagonSimulator::to_nullptr( - const detail::MaybeString& str) { +llvm::Optional HexagonSimulator::to_nullptr(const detail::MaybeString& str) { auto none = llvm::Optional(); if (!str) return none; diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc index 7f3c503..c9e3332 100644 --- a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc @@ -35,8 +35,8 @@ // Stub functions for targets that don't support VTCM. static void* HAP_request_VTCM(int a, int b) { return 0; } static int HAP_release_VTCM(void* a) { return 0; } -static int HAP_query_avail_VTCM(unsigned* avail_block_size, - unsigned* max_page_size, unsigned* num_pages) { +static int HAP_query_avail_VTCM(unsigned* avail_block_size, unsigned* max_page_size, + unsigned* num_pages) { FARF(ALWAYS, "%s: running on architecture V62 or less", __func__); return AEE_ENOMEMORY; } @@ -62,8 +62,7 @@ int tvm_remote_open(const char* uri, remote_handle64* handle_ptr) { return rc; } - *handle_ptr = - static_cast(reinterpret_cast(malloc(1))); + *handle_ptr = static_cast(reinterpret_cast(malloc(1))); if (!*handle_ptr) { FARF(ERROR, "%s: cannot allocate memory", __func__); return AEE_ENOMEMORY; @@ -98,9 +97,7 @@ int tvm_remote_close(remote_handle64 handle) { * This function is present as a workaround. See comment at the call site * in hexagon_device_target.cc. */ -int tvm_remote_call_mmap64(remote_handle64 handle) { - return AEE_SUCCESS; -} +int tvm_remote_call_mmap64(remote_handle64 handle) { return AEE_SUCCESS; } /*! * \brief Load a shared library. @@ -112,8 +109,8 @@ int tvm_remote_call_mmap64(remote_handle64 handle) { * * \return 0 on success, negative value on error. */ -int tvm_remote_load_library(remote_handle64 handle, const char* soname, - int soname_len, tvm_remote_handle_t* lib_ptr) { +int tvm_remote_load_library(remote_handle64 handle, const char* soname, int soname_len, + tvm_remote_handle_t* lib_ptr) { return tvm_remote_nd_load_library(soname, soname_len, lib_ptr); } @@ -128,9 +125,8 @@ int tvm_remote_load_library(remote_handle64 handle, const char* soname, * * \return 0 on success, negative value on error. */ -int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib, - const char* name, int name_len, - tvm_remote_handle_t* sym_ptr) { +int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib, const char* name, + int name_len, tvm_remote_handle_t* sym_ptr) { return tvm_remote_nd_get_symbol(lib, name, name_len, sym_ptr); } @@ -163,24 +159,20 @@ int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib, * The 8 "octet" arguments in this function are used for cache operations * only. They are not used for procesing. */ -int tvm_remote_kernel( - remote_handle64 handle, tvm_remote_handle_t lib, - tvm_remote_handle_t symbol, const int* scalar, int scalar_len, - const int* stack, int stack_len, const tvm_remote_buffer* scalar_in_octet, - int scalar_in_octet_len, tvm_remote_buffer* scalar_out_octet, - int scalar_out_octet_len, const tvm_remote_buffer* stack_in_octet, - int stack_in_octet_len, tvm_remote_buffer* stack_out_octet, - int stack_out_octet_len, uint64* pcycles, uint64* time_usec) { +int tvm_remote_kernel(remote_handle64 handle, tvm_remote_handle_t lib, tvm_remote_handle_t symbol, + const int* scalar, int scalar_len, const int* stack, int stack_len, + const tvm_remote_buffer* scalar_in_octet, int scalar_in_octet_len, + tvm_remote_buffer* scalar_out_octet, int scalar_out_octet_len, + const tvm_remote_buffer* stack_in_octet, int stack_in_octet_len, + tvm_remote_buffer* stack_out_octet, int stack_out_octet_len, uint64* pcycles, + uint64* time_usec) { return tvm_remote_nd_kernel( lib, symbol, scalar, scalar_len, stack, stack_len, - reinterpret_cast(scalar_in_octet), - scalar_in_octet_len, - reinterpret_cast(scalar_out_octet), - scalar_out_octet_len, - reinterpret_cast(stack_in_octet), - stack_in_octet_len, - reinterpret_cast(stack_out_octet), - stack_out_octet_len, pcycles, time_usec); + reinterpret_cast(scalar_in_octet), scalar_in_octet_len, + reinterpret_cast(scalar_out_octet), scalar_out_octet_len, + reinterpret_cast(stack_in_octet), stack_in_octet_len, + reinterpret_cast(stack_out_octet), stack_out_octet_len, pcycles, + time_usec); } /*! @@ -191,8 +183,7 @@ int tvm_remote_kernel( * * \return 0 on success, negative value on error. */ -int tvm_remote_release_library(remote_handle64 handle, - tvm_remote_handle_t lib) { +int tvm_remote_release_library(remote_handle64 handle, tvm_remote_handle_t lib) { // FARF(ALWAYS, "tvm_remote_release_library begin "); return tvm_remote_nd_release_library(lib); } @@ -208,8 +199,7 @@ int tvm_remote_release_library(remote_handle64 handle, * * \return 0 on success, negative value on error. */ -int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size, - unsigned align, unsigned* dsp_va) { +int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size, unsigned align, unsigned* dsp_va) { FARF(ALWAYS, "%s: size=%u, align=%u", __func__, size, align); unsigned avail_block_size, max_page_size, num_pages; int rc = HAP_query_avail_VTCM(&avail_block_size, &max_page_size, &num_pages); @@ -217,12 +207,11 @@ int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size, FARF(ERROR, "%s: HAP_query_avail_VTCM failed, rc=%08x", __func__, rc); return rc; } - FARF(ALWAYS, "%s: avail_block_size=%u, max_page_size=%u, num_pages=%u", - __func__, avail_block_size, max_page_size, num_pages); + FARF(ALWAYS, "%s: avail_block_size=%u, max_page_size=%u, num_pages=%u", __func__, + avail_block_size, max_page_size, num_pages); if (max_page_size < MIN_VTCM_SZ) { - FARF(ERROR, "%s: available VTCM size less than %d KB, aborting", __func__, - MIN_VTCM_SZ / 1024); + FARF(ERROR, "%s: available VTCM size less than %d KB, aborting", __func__, MIN_VTCM_SZ / 1024); return AEE_ENOMEMORY; } diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc index dce2e03..c0f6f22 100644 --- a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc @@ -41,8 +41,7 @@ struct msg_call { uint32_t data[]; } __attribute__((packed)); -__attribute__((naked)) uint32_t launcher(volatile msg_call* mc, - uint64_t* pcc) { +__attribute__((naked)) uint32_t launcher(volatile msg_call* mc, uint64_t* pcc) { __asm__( "// This function is intentionally written to be readable, \n" "// rather than fast. \n" @@ -114,8 +113,7 @@ __attribute__((naked)) uint32_t launcher(volatile msg_call* mc, extern "C" { #pragma weak __wrap_pthread_create -int __wrap_pthread_create(pthread_t* restrict thread, - const pthread_attr_t* restrict attr, +int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, void* (*start)(void*), void* restrict arg) { FARF(ERROR, "Wrong %s called", __func__); abort(); @@ -133,15 +131,13 @@ static void* lib_thread = nullptr; int tvm_remote_nd_open() { lib_thread = dlopen("libtvm_wrap_pthread.so", RTLD_NOW | RTLD_GLOBAL); if (lib_thread == nullptr) { - FARF(ERROR, "%s: dlopen failed for libtvm_wrap_pthread.so: %s", __func__, - dlerror()); + FARF(ERROR, "%s: dlopen failed for libtvm_wrap_pthread.so: %s", __func__, dlerror()); return AEE_EUNABLETOLOAD; } lib_rt = dlopen("libtvm_runtime.so", RTLD_NOW | RTLD_GLOBAL); if (lib_rt == nullptr) { - FARF(ERROR, "%s: dlopen failed for libtvm_runtime.so: %s", __func__, - dlerror()); + FARF(ERROR, "%s: dlopen failed for libtvm_runtime.so: %s", __func__, dlerror()); return AEE_EUNABLETOLOAD; } return AEE_SUCCESS; @@ -174,9 +170,7 @@ int tvm_remote_nd_close() { * This function is present as a workaround. See comment at the call site * in hexagon_device_target.cc. */ -int tvm_remote_nd_call_mmap64() { - return AEE_SUCCESS; -} +int tvm_remote_nd_call_mmap64() { return AEE_SUCCESS; } /*! * \brief Load a shared library. @@ -210,8 +204,8 @@ int tvm_remote_nd_load_library(const char* soname, int soname_len, * * \return 0 on success, negative value on error. */ -int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name, - int name_len, tvm_remote_nd_handle_t* sym_ptr) { +int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name, int name_len, + tvm_remote_nd_handle_t* sym_ptr) { FARF(ALWAYS, "%s: name=%s", __func__, name); if (void* p = dlsym(reinterpret_cast(lib), name)) { *sym_ptr = reinterpret_cast(p); @@ -223,8 +217,8 @@ int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name, } static void print_msg_call(const msg_call& mc) { - FARF(ALWAYS, "device: launching %x scalar_num:%d stack_num:%d", mc.func_va, - mc.scalar_num, mc.stack_num); + FARF(ALWAYS, "device: launching %x scalar_num:%d stack_num:%d", mc.func_va, mc.scalar_num, + mc.stack_num); for (unsigned i = 0; i != mc.scalar_num; ++i) { FARF(ALWAYS, "scalar_data[%d] %x", i, mc.data[i]); } @@ -261,14 +255,13 @@ static void print_msg_call(const msg_call& mc) { * The 8 "octet" arguments in this function are used for cache operations * only. They are not used for procesing. */ -int tvm_remote_nd_kernel( - tvm_remote_nd_handle_t lib, tvm_remote_nd_handle_t symbol, - const int* scalar, int scalar_len, const int* stack, int stack_len, - const tvm_remote_nd_buffer* scalar_in_octet, int scalar_in_octet_len, - tvm_remote_nd_buffer* scalar_out_octet, int scalar_out_octet_len, - const tvm_remote_nd_buffer* stack_in_octet, int stack_in_octet_len, - tvm_remote_nd_buffer* stack_out_octet, int stack_out_octet_len, - uint64* pcycles, uint64* time_usec) { +int tvm_remote_nd_kernel(tvm_remote_nd_handle_t lib, tvm_remote_nd_handle_t symbol, + const int* scalar, int scalar_len, const int* stack, int stack_len, + const tvm_remote_nd_buffer* scalar_in_octet, int scalar_in_octet_len, + tvm_remote_nd_buffer* scalar_out_octet, int scalar_out_octet_len, + const tvm_remote_nd_buffer* stack_in_octet, int stack_in_octet_len, + tvm_remote_nd_buffer* stack_out_octet, int stack_out_octet_len, + uint64* pcycles, uint64* time_usec) { hvx::config_t hvx_info = {0}; hvx::prepare_mt_job(&hvx_info); @@ -277,18 +270,16 @@ int tvm_remote_nd_kernel( if (hvx_info.num_reserved > 0) { lock_result = hvx::lock(hvx::MODE_128B); if (lock_result < 0) { - FARF(ERROR, "%s: HVX locking failed lock_result=%d num_reserved=%d", - __func__, lock_result, hvx_info.num_reserved); + FARF(ERROR, "%s: HVX locking failed lock_result=%d num_reserved=%d", __func__, lock_result, + hvx_info.num_reserved); } else { - FARF(ALWAYS, "%s: HVX lock successful lock_result=%d", __func__, - lock_result); + FARF(ALWAYS, "%s: HVX lock successful lock_result=%d", __func__, lock_result); } } else { FARF(ERROR, "%s: there are no HVX units available", __func__); } - struct msg_call* mc = (struct msg_call*)malloc(sizeof(uint32_t) * - (3 + scalar_len + stack_len)); + struct msg_call* mc = (struct msg_call*)malloc(sizeof(uint32_t) * (3 + scalar_len + stack_len)); if (mc == nullptr) { FARF(ERROR, "%s: failed to allocate memory for mc", __func__); return AEE_ENOMEMORY; @@ -312,8 +303,7 @@ int tvm_remote_nd_kernel( uint64_t start_time = HAP_perf_get_time_us(); int result = launcher(mc, pcycles); *time_usec = HAP_perf_get_time_us() - start_time; - FARF(ALWAYS, "kernel execution: %llu pcycles %llu usec", *pcycles, - *time_usec); + FARF(ALWAYS, "kernel execution: %llu pcycles %llu usec", *pcycles, *time_usec); if (lock_result > 0) hvx::unlock(); hvx::cleanup_mt_job(&hvx_info); if (mc) free(mc); diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc index 1192f7a..d26073a 100644 --- a/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc @@ -44,13 +44,11 @@ static constexpr size_t kThreadStackSize = 128 * 1024; // 128kB // Make sure the function has C linkage. extern "C" { -int __wrap_pthread_create(pthread_t* restrict thread, - const pthread_attr_t* restrict attr, +int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, void* (*start)(void*), void* restrict arg); } -int __wrap_pthread_create(pthread_t* restrict thread, - const pthread_attr_t* restrict attr, +int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, void* (*start)(void*), void* restrict arg) { pthread_attr_t def_attr; if (attr == nullptr) { @@ -72,8 +70,7 @@ int __wrap_pthread_create(pthread_t* restrict thread, FARF(ALWAYS, "launching thread with stack_size=%zu", stack_size); int t = pthread_create(thread, attr, start, arg); if (int rc = pthread_attr_destroy(&def_attr)) { - FARF(ERROR, "pthread_attr_destroy failed (after pthread_create): rc=%08x", - rc); + FARF(ERROR, "pthread_attr_destroy failed (after pthread_create): rc=%08x", rc); } return t; } diff --git a/src/runtime/hexagon/target/hexagon_device_target.cc b/src/runtime/hexagon/target/hexagon_device_target.cc index a62aa47..ee326ca 100644 --- a/src/runtime/hexagon/target/hexagon_device_target.cc +++ b/src/runtime/hexagon/target/hexagon_device_target.cc @@ -45,10 +45,8 @@ // The downside is that the format string must be given as a string literal, // but it seems to be a minor issue. #define VA_EXPANDER(...) , ##__VA_ARGS__ -#define TVM_LOGD_HT(fmt, ...) \ - TVM_LOGD("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) -#define TVM_LOGE_HT(fmt, ...) \ - TVM_LOGE("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) +#define TVM_LOGD_HT(fmt, ...) TVM_LOGD("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) +#define TVM_LOGE_HT(fmt, ...) TVM_LOGE("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) namespace tvm { namespace runtime { @@ -74,8 +72,7 @@ class HexagonTarget : public tvm::runtime::hexagon::Device { unsigned stack_num) final; private: - std::pair AddAddrMapping(const void* dsp_addr, - void* apps_addr, size_t size); + std::pair AddAddrMapping(const void* dsp_addr, void* apps_addr, size_t size); std::pair GetAppsAddr(const void* dsp_addr, bool exact) const; void RemoveAddrMapping(const void* dsp_addr); int OpenDomainChannel(bool set_unsigned_pd); @@ -102,24 +99,19 @@ class HexagonTarget : public tvm::runtime::hexagon::Device { void* const HexagonTarget::vtcm_mark_ = reinterpret_cast(~0); -std::shared_ptr CreateHexagonTarget() { - return std::make_shared(); -} +std::shared_ptr CreateHexagonTarget() { return std::make_shared(); } -std::pair HexagonTarget::AddAddrMapping(const void* dsp_addr, - void* apps_addr, +std::pair HexagonTarget::AddAddrMapping(const void* dsp_addr, void* apps_addr, size_t size) { crit_section_.lock(); auto p = dsp_to_apps_.insert({dsp_addr, {apps_addr, size}}); crit_section_.unlock(); if (!p.second) { - TVM_LOGE_HT( - "failed to insert address mapping: dsp:%p -> apps:%p, size:%zu", - dsp_addr, apps_addr, size); + TVM_LOGE_HT("failed to insert address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, + apps_addr, size); return std::make_pair(nullptr, 0); } - TVM_LOGD_HT("added address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, - apps_addr, size); + TVM_LOGD_HT("added address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, apps_addr, size); return p.first->second; } @@ -135,8 +127,7 @@ void HexagonTarget::RemoveAddrMapping(const void* dsp_addr) { crit_section_.unlock(); } -std::pair HexagonTarget::GetAppsAddr(const void* dsp_addr, - bool exact) const { +std::pair HexagonTarget::GetAppsAddr(const void* dsp_addr, bool exact) const { struct AutoUnlock { explicit AutoUnlock(std::mutex& m) : m(m) {} ~AutoUnlock() { m.unlock(); } @@ -192,16 +183,14 @@ int HexagonTarget::OpenDomainChannel(bool use_unsigned_pd) { data.domain = CDSP_DOMAIN_ID; int rc = rsc_ptr(DSPRPC_CONTROL_UNSIGNED_MODULE, &data, sizeof(data)); if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("remote_session_control failed rc=%08x for unsigned PD", - rc); + TVM_LOGE_HT("remote_session_control failed rc=%08x for unsigned PD", rc); } } } else { TVM_LOGD_HT("remote_session_control not available"); } - int rc = stub_api->tvm_remote_open(tvm_remote_URI "&_dom=cdsp", - &domain_channel_handle_); + int rc = stub_api->tvm_remote_open(tvm_remote_URI "&_dom=cdsp", &domain_channel_handle_); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to open channel rc=0x%x", rc); } else { @@ -231,8 +220,7 @@ void HexagonTarget::ReleaseLibrary() { crit_section_.lock(); if (module_pointer_ != AEE_EUNKNOWN) { const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_remote_release_library(domain_channel_handle_, - module_pointer_); + int rc = stub_api->tvm_remote_release_library(domain_channel_handle_, module_pointer_); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to unload device library rc=0x%x", rc); } else { @@ -267,24 +255,20 @@ void* HexagonTarget::Alloc(unsigned size, unsigned align) { // thread then remote_mmap64 fails. FastRPC expects one call to be made to // DSP before calling remote_map64. Hence this call is needed for now untill // FastRPC comes up with a fix. - int rc_call_mmap_64 = - stub_api->tvm_remote_call_mmap64(domain_channel_handle_); + int rc_call_mmap_64 = stub_api->tvm_remote_call_mmap64(domain_channel_handle_); if (rc_call_mmap_64 != AEE_SUCCESS) { - TVM_LOGE_HT("mmap64 failed for domain channel %lu", - domain_channel_handle_); + TVM_LOGE_HT("mmap64 failed for domain channel %lu", domain_channel_handle_); return nullptr; } - void* mem = - stub_api->rpcmem_alloc_ptr()(RPCMEM_HEAP, RPCMEM_DEFAULT_FLAGS, size); + void* mem = stub_api->rpcmem_alloc_ptr()(RPCMEM_HEAP, RPCMEM_DEFAULT_FLAGS, size); if (mem == nullptr) { TVM_LOGE_HT("mem alloc failed for size=0x%x alignment=0x%x", size, align); return nullptr; } int mem_fd = stub_api->rpcmem_to_fd_ptr()(mem); uintptr_t dsp_va = 0; - int rc = dsp_api->remote_mmap64_ptr()( - mem_fd, 0, reinterpret_cast(mem), size, &dsp_va); + int rc = dsp_api->remote_mmap64_ptr()(mem_fd, 0, reinterpret_cast(mem), size, &dsp_va); if (rc != AEE_SUCCESS) { TVM_LOGE_HT( "buffer mapping failed for remote_map64 fd=0x%x rc=0x%x " @@ -313,8 +297,7 @@ void HexagonTarget::Free(void* ptr) { auto aa = GetAppsAddr(ptr, true); if (aa.first == nullptr) return; - int rc = dsp_api->remote_munmap64_ptr()(reinterpret_cast(ptr), - aa.second); + int rc = dsp_api->remote_munmap64_ptr()(reinterpret_cast(ptr), aa.second); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("buffer unmapping failed rc=0x%x", rc); } @@ -326,8 +309,7 @@ void* HexagonTarget::AllocVtcm(unsigned size, unsigned align) { const StubAPI* stub_api = StubAPI::Global(); unsigned int dsp_va = 0; - int rc = stub_api->tvm_remote_alloc_vtcm(domain_channel_handle_, size, align, - &dsp_va); + int rc = stub_api->tvm_remote_alloc_vtcm(domain_channel_handle_, size, align, &dsp_va); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("VTCM allocation failed size=%u, align=%u", size, align); return nullptr; @@ -350,8 +332,7 @@ void HexagonTarget::FreeVtcm(void* ptr) { TVM_LOGD_HT("Done VTCM free from HexagonTarget::FreeVtcm"); } -void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, - unsigned len) { +void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, unsigned len) { auto aa_src = GetAppsAddr(src, false); auto aa_dst = GetAppsAddr(dst, false); if (aa_src.first == vtcm_mark_ || aa_dst.first == vtcm_mark_) { @@ -375,13 +356,12 @@ void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, len, aa_dst.second); } len = std::min({size_t(len), aa_src.second, aa_dst.second}); - TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> dsp:%p(apps:%p), len:%u", src, - aa_src.first, dst, aa_dst.first, len); + TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> dsp:%p(apps:%p), len:%u", src, aa_src.first, dst, + aa_dst.first, len); std::memcpy(aa_dst.first, aa_src.first, len); } -void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, - unsigned len) { +void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) { auto aa = GetAppsAddr(src, false); if (aa.first == vtcm_mark_) { TVM_LOGE_HT("VTCM address. Copy operation not supported"); @@ -392,18 +372,14 @@ void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, return; } if (aa.second < len) { - TVM_LOGD_HT( - "specified length:%u larger than buffer size:%zu, copy truncated", len, - aa.second); + TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second); len = aa.second; } - TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> apps:%p, len:%u", src, aa.first, - host_dst, len); + TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> apps:%p, len:%u", src, aa.first, host_dst, len); std::memcpy(host_dst, aa.first, len); } -void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, - unsigned len) { +void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, unsigned len) { auto aa = GetAppsAddr(dst, false); if (aa.first == vtcm_mark_) { TVM_LOGE_HT("VTCM address. Copy operation not supported"); @@ -414,13 +390,10 @@ void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, return; } if (aa.second < len) { - TVM_LOGD_HT( - "specified length:%u larger than buffer size:%zu, copy truncated", len, - aa.second); + TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second); len = aa.second; } - TVM_LOGD_HT("copy, dsp:%p(apps:%p) <- apps:%p, len:%u", dst, aa.first, - host_src, len); + TVM_LOGD_HT("copy, dsp:%p(apps:%p) <- apps:%p, len:%u", dst, aa.first, host_src, len); std::memcpy(aa.first, host_src, len); } @@ -429,8 +402,7 @@ void* HexagonTarget::Load(const std::string& data, const std::string& fmt) { int rc_oc = OpenDomainChannel(/*use_unsigned_pd*/ unsigned_pd); crit_section_.unlock(); if (rc_oc != AEE_SUCCESS) { - TVM_LOGE_HT("loading of %s failed: unable to open domain channel", - data.c_str()); + TVM_LOGE_HT("loading of %s failed: unable to open domain channel", data.c_str()); return nullptr; } @@ -440,8 +412,8 @@ void* HexagonTarget::Load(const std::string& data, const std::string& fmt) { crit_section_.lock(); TVM_LOGD_HT("loading library %s ", data.c_str()); const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_remote_load_library( - domain_channel_handle_, data.c_str(), data.size() + 1, &module_pointer_); + int rc = stub_api->tvm_remote_load_library(domain_channel_handle_, data.c_str(), data.size() + 1, + &module_pointer_); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to load device library rc=0x%x", rc); } @@ -473,9 +445,8 @@ void* HexagonTarget::Resolve(const std::string& sym) { tvm_remote_handle_t pf; TVM_LOGD_HT("resolving symbol %s", sym.c_str()); - int rc = - stub_api->tvm_remote_get_symbol(domain_channel_handle_, module_pointer_, - sym.c_str(), sym.size() + 1, &pf); + int rc = stub_api->tvm_remote_get_symbol(domain_channel_handle_, module_pointer_, sym.c_str(), + sym.size() + 1, &pf); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to get symbol from CDSP rc=0x%x", rc); return nullptr; @@ -485,13 +456,11 @@ void* HexagonTarget::Resolve(const std::string& sym) { return addr; } -void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, - uint32_t* stack, unsigned stack_num) { +void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, uint32_t* stack, + unsigned stack_num) { uint64 pcycles = 0, execution_time_usec = 0; - auto scalar_octet = - std::unique_ptr(new tvm_remote_buffer[scalar_num]); - auto stack_octet = - std::unique_ptr(new tvm_remote_buffer[stack_num]); + auto scalar_octet = std::unique_ptr(new tvm_remote_buffer[scalar_num]); + auto stack_octet = std::unique_ptr(new tvm_remote_buffer[stack_num]); TVM_LOGD_HT("scalars=%p, stack=%p", scalar, stack); if (scalar_octet == nullptr || stack_octet == nullptr) { @@ -501,8 +470,7 @@ void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, std::memset(scalar_octet.get(), 0, scalar_num * sizeof(tvm_remote_buffer)); std::memset(stack_octet.get(), 0, stack_num * sizeof(tvm_remote_buffer)); - auto ProcessInputs = [this](uint32_t* inputs, tvm_remote_buffer* buffers, - unsigned num) { + auto ProcessInputs = [this](uint32_t* inputs, tvm_remote_buffer* buffers, unsigned num) { for (unsigned i = 0; i != num; ++i) { void* ptr = reinterpret_cast(static_cast(inputs[i])); auto aa = GetAppsAddr(ptr, false); @@ -534,16 +502,15 @@ void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, int rc = stub_api->tvm_remote_kernel( domain_channel_handle_, module_pointer_, static_cast(reinterpret_cast(func)), - reinterpret_cast(scalar), scalar_num, - reinterpret_cast(stack), stack_num, scalar_octet.get(), scalar_num, - scalar_octet.get(), scalar_num, stack_octet.get(), stack_num, + reinterpret_cast(scalar), scalar_num, reinterpret_cast(stack), stack_num, + scalar_octet.get(), scalar_num, scalar_octet.get(), scalar_num, stack_octet.get(), stack_num, stack_octet.get(), stack_num, &pcycles, &execution_time_usec); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to run kernel on CDSP rc=0x%x", rc); } else { - TVM_LOGD_HT("kernel execution: %llu pcycles, %llu usec, scalar_num=%d", - pcycles, execution_time_usec, scalar_num); + TVM_LOGD_HT("kernel execution: %llu pcycles, %llu usec, scalar_num=%d", pcycles, + execution_time_usec, scalar_num); } } diff --git a/src/runtime/hexagon/target/hexagon_stubapi.cc b/src/runtime/hexagon/target/hexagon_stubapi.cc index 939c382..2ed3347 100644 --- a/src/runtime/hexagon/target/hexagon_stubapi.cc +++ b/src/runtime/hexagon/target/hexagon_stubapi.cc @@ -44,8 +44,7 @@ StubAPI::StubAPI() { constexpr auto domain_lib_name = "libtvm_remote_stub.so"; constexpr auto nondomain_lib_name = "libtvm_remote_nd_stub.so"; - const char* lib_name = - enable_domains_ ? domain_lib_name : nondomain_lib_name; + const char* lib_name = enable_domains_ ? domain_lib_name : nondomain_lib_name; CHECK(lib_handle_ = dlopen(lib_name, RTLD_LAZY | RTLD_LOCAL)); #define RESOLVE(fn) p##fn##_ = GetSymbol(#fn) diff --git a/src/runtime/hexagon/target/hexagon_stubapi.h b/src/runtime/hexagon/target/hexagon_stubapi.h index 6f3828f..5213b6d 100644 --- a/src/runtime/hexagon/target/hexagon_stubapi.h +++ b/src/runtime/hexagon/target/hexagon_stubapi.h @@ -162,8 +162,7 @@ class StubAPI { // two types identical in the function types created below. // For example, int foo(tvm_remote_buffer*) and // int bar(tvm_remote_nd_buffer*) should both have the same type. -#define MAPTYPE(fn, ty) \ - using fn##_t = typename map_func_type::type; +#define MAPTYPE(fn, ty) using fn##_t = typename map_func_type::type; MAPTYPE(tvm_remote_load_library, tvm_remote_buffer) MAPTYPE(tvm_remote_release_library, tvm_remote_buffer) MAPTYPE(tvm_remote_get_symbol, tvm_remote_buffer) @@ -196,8 +195,7 @@ class StubAPI { public: template - int invoke(Fd func_d, Fnd func_nd, remote_handle64 handle, - Ts... args) const { + int invoke(Fd func_d, Fnd func_nd, remote_handle64 handle, Ts... args) const { if (enable_domains_) { return func_d(handle, args...); } @@ -219,11 +217,10 @@ class StubAPI { #define FUNC_ND(name) CONCAT_STR(tvm_remote_nd_, name) #define PTRNAME(fn) CONCAT_STR(p, CONCAT_STR(fn, _)) -#define DECLFUNC(name) \ - template \ - int FUNC(name)(remote_handle64 handle, Ts... args) const { \ - return invoke(PTRNAME(FUNC_D(name)), PTRNAME(FUNC_ND(name)), handle, \ - args...); \ +#define DECLFUNC(name) \ + template \ + int FUNC(name)(remote_handle64 handle, Ts... args) const { \ + return invoke(PTRNAME(FUNC_D(name)), PTRNAME(FUNC_ND(name)), handle, args...); \ } #define DECLFUNC_D(name) \ diff --git a/src/runtime/hexagon/target/hexagon_target_log.h b/src/runtime/hexagon/target/hexagon_target_log.h index ae09503..c7684fc 100644 --- a/src/runtime/hexagon/target/hexagon_target_log.h +++ b/src/runtime/hexagon/target/hexagon_target_log.h @@ -23,18 +23,12 @@ #include -#define TVM_LOGV(...) \ - __android_log_print(ANDROID_LOG_VERBOSE, "TVM", ##__VA_ARGS__) -#define TVM_LOGD(...) \ - __android_log_print(ANDROID_LOG_DEBUG, "TVM", ##__VA_ARGS__) -#define TVM_LOGI(...) \ - __android_log_print(ANDROID_LOG_INFO, "TVM", ##__VA_ARGS__) -#define TVM_LOGW(...) \ - __android_log_print(ANDROID_LOG_WARN, "TVM", ##__VA_ARGS__) -#define TVM_LOGE(...) \ - __android_log_print(ANDROID_LOG_ERROR, "TVM", ##__VA_ARGS__) -#define TVM_LOGF(...) \ - __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__) +#define TVM_LOGV(...) __android_log_print(ANDROID_LOG_VERBOSE, "TVM", ##__VA_ARGS__) +#define TVM_LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, "TVM", ##__VA_ARGS__) +#define TVM_LOGI(...) __android_log_print(ANDROID_LOG_INFO, "TVM", ##__VA_ARGS__) +#define TVM_LOGW(...) __android_log_print(ANDROID_LOG_WARN, "TVM", ##__VA_ARGS__) +#define TVM_LOGE(...) __android_log_print(ANDROID_LOG_ERROR, "TVM", ##__VA_ARGS__) +#define TVM_LOGF(...) __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__) #endif // __ANDROID__ #endif // TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_TARGET_LOG_H_ diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 306a7e9..7c3323c 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -21,13 +21,15 @@ * \file module_util.cc * \brief Utilities for module. */ +#include "library_module.h" + #include #include #include + #include -#include #include -#include "library_module.h" +#include namespace tvm { namespace runtime { @@ -35,22 +37,16 @@ namespace runtime { // Library module that exposes symbols from a library. class LibraryModuleNode final : public ModuleNode { public: - explicit LibraryModuleNode(ObjectPtr lib) - : lib_(lib) { - } + explicit LibraryModuleNode(ObjectPtr lib) : lib_(lib) {} - const char* type_key() const final { - return "library"; - } + const char* type_key() const final { return "library"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = reinterpret_cast( - lib_->GetSymbol(runtime::symbol::tvm_module_main)); - CHECK(entry_name!= nullptr) + const char* entry_name = + reinterpret_cast(lib_->GetSymbol(runtime::symbol::tvm_module_main)); + CHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; faddr = reinterpret_cast(lib_->GetSymbol(entry_name)); } else { @@ -70,35 +66,27 @@ class LibraryModuleNode final : public ModuleNode { class ModuleInternal { public: // Get mutable reference of imports. - static std::vector* GetImportsAddr(ModuleNode* node) { - return &(node->imports_); - } + static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } }; -PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, - const ObjectPtr& sptr_to_self) { +PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - TVMValue ret_value; - int ret_type_code = kTVMNullptr; - int ret = (*faddr)( - const_cast(args.values), - const_cast(args.type_codes), - args.num_args, - &ret_value, - &ret_type_code); - CHECK_EQ(ret, 0) << TVMGetLastError(); - if (ret_type_code != kTVMNullptr) { - *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); - } - }); + TVMValue ret_value; + int ret_type_code = kTVMNullptr; + int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), + args.num_args, &ret_value, &ret_type_code); + CHECK_EQ(ret, 0) << TVMGetLastError(); + if (ret_type_code != kTVMNullptr) { + *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); + } + }); } void InitContextFunctions(std::function fgetsymbol) { - #define TVM_INIT_CONTEXT_FUNC(FuncName) \ - if (auto *fp = reinterpret_cast \ - (fgetsymbol("__" #FuncName))) { \ - *fp = FuncName; \ - } +#define TVM_INIT_CONTEXT_FUNC(FuncName) \ + if (auto* fp = reinterpret_cast(fgetsymbol("__" #FuncName))) { \ + *fp = FuncName; \ + } // Initialize the functions TVM_INIT_CONTEXT_FUNC(TVMFuncCall); TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError); @@ -108,7 +96,7 @@ void InitContextFunctions(std::function fgetsymbol) { TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); - #undef TVM_INIT_CONTEXT_FUNC +#undef TVM_INIT_CONTEXT_FUNC } /*! @@ -123,10 +111,10 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { uint64_t nbytes = 0; for (size_t i = 0; i < sizeof(nbytes); ++i) { uint64_t c = mblob[i]; - nbytes |= (c & 0xffUL) << (i * 8); + nbytes |= (c & 0xffUL) << (i * 8); } - dmlc::MemoryFixedSizeStream fs( - const_cast(mblob + sizeof(nbytes)), static_cast(nbytes)); + dmlc::MemoryFixedSizeStream fs(const_cast(mblob + sizeof(nbytes)), + static_cast(nbytes)); dmlc::Stream* stream = &fs; uint64_t size; CHECK(stream->Read(&size)); @@ -147,9 +135,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { } else { std::string fkey = "runtime.module.loadbinary_" + tkey; const PackedFunc* f = Registry::Get(fkey); - CHECK(f != nullptr) - << "Loader of " << tkey << "(" - << fkey << ") is not presented."; + CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented."; Module m = (*f)(static_cast(stream)); modules.emplace_back(m); } @@ -180,14 +166,11 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { } Module CreateModuleFromLibrary(ObjectPtr lib) { - InitContextFunctions([lib](const char* fname) { - return lib->GetSymbol(fname); - }); + InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); auto n = make_object(lib); // Load the imported modules const char* dev_mblob = - reinterpret_cast( - lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); + reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); Module root_mod; if (dev_mblob != nullptr) { root_mod = ProcessModuleBlob(dev_mblob, lib); @@ -197,8 +180,7 @@ Module CreateModuleFromLibrary(ObjectPtr lib) { } // allow lookup of symbol from root (so all symbols are visible). - if (auto *ctx_addr = - reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { + if (auto* ctx_addr = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { *ctx_addr = root_mod.operator->(); } diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 61e6266..91918c1 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -24,9 +24,10 @@ #ifndef TVM_RUNTIME_LIBRARY_MODULE_H_ #define TVM_RUNTIME_LIBRARY_MODULE_H_ -#include -#include #include +#include +#include + #include namespace tvm { @@ -47,7 +48,7 @@ class Library : public Object { * \param name The name of the symbol. * \return The symbol. */ - virtual void *GetSymbol(const char* name) = 0; + virtual void* GetSymbol(const char* name) = 0; // NOTE: we do not explicitly create an type index and type_key here for libary. // This is because we do not need dynamic type downcasting. }; @@ -77,4 +78,4 @@ void InitContextFunctions(std::function fgetsymbol); Module CreateModuleFromLibrary(ObjectPtr lib); } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_LIBRARY_MODULE_H_ +#endif // TVM_RUNTIME_LIBRARY_MODULE_H_ diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 22f2e9a..451c0e8 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -24,11 +24,13 @@ #ifndef TVM_RUNTIME_META_DATA_H_ #define TVM_RUNTIME_META_DATA_H_ -#include #include +#include #include + #include #include + #include "runtime_base.h" namespace tvm { @@ -40,10 +42,10 @@ struct FunctionInfo { std::vector arg_types; std::vector thread_axis_tags; - void Save(dmlc::JSONWriter *writer) const; - void Load(dmlc::JSONReader *reader); - void Save(dmlc::Stream *writer) const; - bool Load(dmlc::Stream *reader); + void Save(dmlc::JSONWriter* writer) const; + void Load(dmlc::JSONReader* reader); + void Save(dmlc::Stream* writer) const; + bool Load(dmlc::Stream* reader); }; } // namespace runtime } // namespace tvm diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 8a7c9fe..ca369d4 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -24,21 +24,22 @@ #ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_ #define TVM_RUNTIME_METAL_METAL_COMMON_H_ +#import #import -#import #import -#import +#import #import #import - +#include #include -#include #include -#include +#include + +#include #include #include #include -#include + #include "../workspace_pool.h" namespace tvm { @@ -64,14 +65,14 @@ class MetalWorkspace final : public DeviceAPI { // Get command queue for given context. id GetCommandQueue(TVMContext ctx) { CHECK_EQ(ctx.device_type, kDLMetal); - CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) << "Invalid Metal device_id=" << ctx.device_id; return queues[ctx.device_id]; } // Get device for given context id GetDevice(TVMContext ctx) { CHECK_EQ(ctx.device_type, kDLMetal); - CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < devices.size()) + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < devices.size()) << "Invalid Metal device_id=" << ctx.device_id; return devices[ctx.device_id]; } @@ -81,19 +82,10 @@ class MetalWorkspace final : public DeviceAPI { // override device API void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, - size_t from_size, - void* to, - size_t to_size, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; @@ -112,8 +104,7 @@ class MetalThreadEntry { /*! \brief workspace pool */ WorkspacePool pool; // constructor - MetalThreadEntry() - : pool(static_cast(kDLMetal), MetalWorkspace::Global()) { + MetalThreadEntry() : pool(static_cast(kDLMetal), MetalWorkspace::Global()) { context.device_id = 0; context.device_type = static_cast(kDLMetal); } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index a49f8a5..3bad2c3 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -20,8 +20,8 @@ /*! * \file metal_device_api.mm */ -#include #include +#include #include "metal_common.h" namespace tvm { @@ -29,25 +29,21 @@ namespace runtime { namespace metal { const std::shared_ptr& MetalWorkspace::Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } -void MetalWorkspace::GetAttr( - TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { +void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { this->Init(); size_t index = static_cast(ctx.device_id); if (kind == kExist) { - *rv = int(index< devices.size()); + *rv = int(index < devices.size()); return; } - CHECK_LT(index, devices.size()) - << "Invalid device id " << index; + CHECK_LT(index, devices.size()) << "Invalid device id " << index; switch (kind) { case kMaxThreadsPerBlock: { - *rv = static_cast( - [devices[ctx.device_id] maxThreadsPerThreadgroup].width); + *rv = static_cast([devices[ctx.device_id] maxThreadsPerThreadgroup].width); break; } case kWarpSize: { @@ -55,14 +51,22 @@ void MetalWorkspace::GetAttr( *rv = 1; break; } - case kMaxSharedMemoryPerBlock: return; - case kComputeVersion: return; - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; - case kExist: break; - case kGcnArch: return; + case kMaxSharedMemoryPerBlock: + return; + case kComputeVersion: + return; + case kDeviceName: + return; + case kMaxClockRate: + return; + case kMultiProcessorCount: + return; + case kMaxThreadDimensions: + return; + case kExist: + break; + case kGcnArch: + return; } } @@ -87,22 +91,13 @@ kernel void CopyKernel( // But we keep this code. int GetWarpSize(id dev) { NSError* error_msg = nil; - id lib = - [dev - newLibraryWithSource: - [NSString stringWithUTF8String:kDummyKernel] - options:nil - error:&error_msg]; + id lib = [dev newLibraryWithSource:[NSString stringWithUTF8String:kDummyKernel] + options:nil + error:&error_msg]; CHECK(lib != nil) << [[error_msg localizedDescription] UTF8String]; - id f = - [lib - newFunctionWithName: - [NSString stringWithUTF8String:"CopyKernel"]]; - CHECK(f!= nil); - id state = - [dev - newComputePipelineStateWithFunction:f - error:&error_msg]; + id f = [lib newFunctionWithName:[NSString stringWithUTF8String:"CopyKernel"]]; + CHECK(f != nil); + id state = [dev newComputePipelineStateWithFunction:f error:&error_msg]; CHECK(state != nil) << [[error_msg localizedDescription] UTF8String]; return static_cast(state.threadExecutionWidth); } @@ -123,20 +118,19 @@ void MetalWorkspace::Init() { initialized_ = true; if (devices.size() != 0) return; #if TARGET_OS_IPHONE - // on iPhone - id d = MTLCreateSystemDefaultDevice(); + // on iPhone + id d = MTLCreateSystemDefaultDevice(); + devices.push_back([d retain]); + queues.push_back([[d newCommandQueue] retain]); +#else + NSArray >* devs = MTLCopyAllDevices(); + for (size_t i = 0; i < devs.count; ++i) { + id d = [devs objectAtIndex:i]; devices.push_back([d retain]); queues.push_back([[d newCommandQueue] retain]); -#else - NSArray>* devs = MTLCopyAllDevices(); - for (size_t i = 0; i < devs.count; ++i) { - id d = [devs objectAtIndex:i]; - devices.push_back([d retain]); - queues.push_back([[d newCommandQueue] retain]); - LOG(INFO) << "Intializing Metal device " << i - << ", name=" << [d.name UTF8String]; - warp_size.push_back(GetWarpSize(d)); - } + LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; + warp_size.push_back(GetWarpSize(d)); + } #endif } @@ -144,8 +138,8 @@ void MetalWorkspace::SetDevice(TVMContext ctx) { MetalThreadEntry::ThreadLocal()->context.device_id = ctx.device_id; } -void* MetalWorkspace::AllocDataSpace( - TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) { +void* MetalWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint) { this->Init(); id dev = GetDevice(ctx); // GPU memory only @@ -157,9 +151,7 @@ void* MetalWorkspace::AllocDataSpace( storage_mode = MTLResourceStorageModeManaged; #endif */ - id buf = [ - dev newBufferWithLength:nbytes - options:storage_mode]; + id buf = [dev newBufferWithLength:nbytes options:storage_mode]; CHECK(buf != nil); return (__bridge void*)([buf retain]); } @@ -169,14 +161,9 @@ void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { CFRelease(ptr); } -void MetalWorkspace::CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, +void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) { this->Init(); CHECK(stream == nullptr); @@ -188,65 +175,54 @@ void MetalWorkspace::CopyDataFromTo(const void* from, int to_dev_type = static_cast(ctx_to.device_type); if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) { - CHECK_EQ(ctx_from.device_id, ctx_to.device_id) - << "Metal disallow cross device copy."; + CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross device copy."; id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:(__bridge id)(from) - sourceOffset:from_offset - toBuffer:(__bridge id)(to) - destinationOffset:to_offset - size:size]; + sourceOffset:from_offset + toBuffer:(__bridge id)(to)destinationOffset:to_offset + size:size]; [encoder endEncoding]; [cb commit]; } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) { // copy to a local buffer before get into global buffer. id from_buf = (__bridge id)(from); if (from_buf.storageMode != MTLStorageModeShared) { - id temp = MetalThreadEntry::ThreadLocal() - ->GetTempBuffer(ctx_from, size); + id temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size); id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:from_buf - sourceOffset:from_offset - toBuffer:temp - destinationOffset:0 - size:size]; + sourceOffset:from_offset + toBuffer:temp + destinationOffset:0 + size:size]; [encoder endEncoding]; [cb commit]; [cb waitUntilCompleted]; - memcpy(static_cast(to) + to_offset, - static_cast([temp contents]), - size); + memcpy(static_cast(to) + to_offset, static_cast([temp contents]), size); } else { memcpy(static_cast(to) + to_offset, - static_cast([from_buf contents]) + from_offset, - size); + static_cast([from_buf contents]) + from_offset, size); } } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) { id to_buf = (__bridge id)(to); if (to_buf.storageMode != MTLStorageModeShared) { - id temp = MetalThreadEntry::ThreadLocal() - ->GetTempBuffer(ctx_to, size); - memcpy([temp contents], - static_cast(from) + from_offset, - size); + id temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size); + memcpy([temp contents], static_cast(from) + from_offset, size); id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:temp - sourceOffset:0 - toBuffer:to_buf - destinationOffset:to_offset - size:size]; + sourceOffset:0 + toBuffer:to_buf + destinationOffset:to_offset + size:size]; [encoder endEncoding]; [cb commit]; [cb waitUntilCompleted]; } else { memcpy(static_cast([to_buf contents]) + to_offset, - static_cast(from) + from_offset, - size); + static_cast(from) + from_offset, size); } } else { LOG(FATAL) << "Expect copy from/to Metal or between Metal" - << ", from=" << from_dev_type - << ", to=" << to_dev_type; + << ", from=" << from_dev_type << ", to=" << to_dev_type; } } @@ -259,9 +235,7 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { [cb waitUntilCompleted]; } -void* MetalWorkspace::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { +void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); } @@ -279,30 +253,25 @@ id MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) { if (temp_buffer_.size() <= static_cast(ctx.device_id)) { temp_buffer_.resize(ctx.device_id + 1, nil); } - if (temp_buffer_[ctx.device_id] == nil || - temp_buffer_[ctx.device_id].length < size) { + if (temp_buffer_[ctx.device_id] == nil || temp_buffer_[ctx.device_id].length < size) { id dev = MetalWorkspace::Global()->GetDevice(ctx); if (temp_buffer_[ctx.device_id] != nil) { [temp_buffer_[ctx.device_id] release]; } - temp_buffer_[ctx.device_id] = [ - [dev newBufferWithLength:size - options:MTLStorageModeShared] retain]; + temp_buffer_[ctx.device_id] = [[dev newBufferWithLength:size + options:MTLStorageModeShared] retain]; } return temp_buffer_[ctx.device_id]; } typedef dmlc::ThreadLocalStore MetalThreadStore; -MetalThreadEntry* MetalThreadEntry::ThreadLocal() { - return MetalThreadStore::Get(); -} +MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.metal") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = MetalWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = MetalWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace metal } // namespace runtime diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index 0d2d429..77cdf64 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,10 +25,12 @@ #define TVM_RUNTIME_METAL_METAL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -44,11 +46,8 @@ static constexpr const int kMetalMaxNumDevice = 32; * \param fmap The map function information map of each function. * \param source Optional, source file */ -Module MetalModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module MetalModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_METAL_METAL_MODULE_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 41269b9..9bdebf3 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -20,18 +20,18 @@ /*! * \file metal_module.cc */ +#include "metal_module.h" #include -#include #include +#include #include -#include #include -#include "metal_module.h" -#include "metal_common.h" +#include +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "metal_common.h" namespace tvm { namespace runtime { @@ -39,27 +39,18 @@ namespace runtime { // Module to support thread-safe multi-GPU execution. // The runtime will contain a per-device module table // The modules will be lazily loaded -class MetalModuleNode final :public runtime::ModuleNode { +class MetalModuleNode final : public runtime::ModuleNode { public: - explicit MetalModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) - : data_(data), fmt_(fmt), fmap_(fmap), source_(source) { - } - const char* type_key() const final { - return "metal"; - } + explicit MetalModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) + : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} + const char* type_key() const final { return "metal"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -81,8 +72,7 @@ class MetalModuleNode final :public runtime::ModuleNode { } } // get a from primary context in device_id - id GetPipelineState( - size_t device_id, const std::string& func_name) { + id GetPipelineState(size_t device_id, const std::string& func_name) { metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get(); CHECK_LT(device_id, w->devices.size()); // start lock scope. @@ -97,53 +87,43 @@ class MetalModuleNode final :public runtime::ModuleNode { NSError* err_msg = nil; if (e.lib == nil) { if (fmt_ == "metal") { - MTLCompileOptions *opts = [MTLCompileOptions alloc]; + MTLCompileOptions* opts = [MTLCompileOptions alloc]; // Use the Metal 1.2 for now. opts.languageVersion = MTLLanguageVersion1_2; opts.fastMathEnabled = YES; // opts = nil; - e.lib = [ - w->devices[device_id] - newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] - options:opts - error:&err_msg]; + e.lib = [w->devices[device_id] + newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] + options:opts + error:&err_msg]; [opts dealloc]; if (e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" - << [[err_msg localizedDescription] UTF8String]; + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; } if (err_msg != nil) { - LOG(INFO) << "Warning: " - << [[err_msg localizedDescription] UTF8String]; + LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String]; } } else { // Build from library. auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); - auto data = dispatch_data_create( - data_.c_str(), data_.length(), q, ^{}); - e.lib = [ - w->devices[device_id] - newLibraryWithData:data - error:&err_msg]; + auto data = dispatch_data_create(data_.c_str(), data_.length(), q, + ^{ + }); + e.lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg]; if (err_msg != nil || e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" - << [[err_msg localizedDescription] UTF8String]; + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; } } [e.lib retain]; } - id f = [ - e.lib - newFunctionWithName: - [NSString stringWithUTF8String:func_name.c_str()]]; + id f = + [e.lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; CHECK(f != nil) << "cannot find function " << func_name; id state = - [w->devices[device_id] - newComputePipelineStateWithFunction:f - error:&err_msg]; - CHECK(state != nil) - << "cannot get state:" << " for function " << func_name - << [[err_msg localizedDescription] UTF8String]; + [w->devices[device_id] newComputePipelineStateWithFunction:f error:&err_msg]; + CHECK(state != nil) << "cannot get state:" + << " for function " << func_name + << [[err_msg localizedDescription] UTF8String]; // The state.threadExecutionWidth can change dynamically according // to the resource constraint in kernel, so it is not strictly hold // Turn of warp aware optimziation for now. @@ -162,7 +142,7 @@ class MetalModuleNode final :public runtime::ModuleNode { ~DeviceEntry() { if (lib != nil) [lib release]; - for (auto &&kv : smap) { + for (auto&& kv : smap) { [kv.second release]; } } @@ -185,11 +165,8 @@ class MetalModuleNode final :public runtime::ModuleNode { class MetalWrappedFunc { public: // initialize the METAL function. - void Init(MetalModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, - size_t num_buffer_args, - size_t num_pack_args, + void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_buffer_args, size_t num_pack_args, const std::vector& thread_axis_tags) { w_ = metal::MetalWorkspace::Global().get(); m_ = m; @@ -204,9 +181,7 @@ class MetalWrappedFunc { scache_[dev_id] = m->GetPipelineState(dev_id, func_name); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - const ArgUnion* pack_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->context.device_id; if (scache_[device_id] == nil) { @@ -223,16 +198,13 @@ class MetalWrappedFunc { } if (num_pack_args_ != 0) { [encoder setBytes:pack_args - length:num_pack_args_ * sizeof(ArgUnion) - atIndex:num_buffer_args_]; + length:num_pack_args_ * sizeof(ArgUnion) + atIndex:num_buffer_args_]; } // launch - MTLSize dimGrid = MTLSizeMake( - wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); - MTLSize dimBlock = MTLSizeMake( - wl.block_dim(0), wl.block_dim(1), wl.block_dim(2)); - [encoder dispatchThreadgroups: dimGrid - threadsPerThreadgroup: dimBlock]; + MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); + MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), wl.block_dim(2)); + [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock]; [encoder endEncoding]; [cb commit]; } @@ -257,36 +229,29 @@ class MetalWrappedFunc { ThreadAxisConfig thread_axis_cfg_; }; -PackedFunc MetalModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc MetalModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; MetalWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); - f.Init(this, sptr_to_self, name, - num_buffer_args, info.arg_types.size() - num_buffer_args, + f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, info.thread_axis_tags); return PackFuncNonBufferArg(f, info.arg_types); } -Module MetalModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module MetalModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { metal::MetalWorkspace::Global()->Init(); auto n = make_object(data, fmt, fmap, source); return Module(n); } // Load module from module. -Module MetalModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module MetalModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -307,10 +272,8 @@ Module MetalModuleLoadBinary(void* strm) { return MetalModuleCreate(data, fmt, fmap, ""); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal") -.set_body_typed(MetalModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal").set_body_typed(MetalModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal") -.set_body_typed(MetalModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/host_driven/utvm_runtime.h b/src/runtime/micro/host_driven/utvm_runtime.h index fc11b70..1a4486c 100644 --- a/src/runtime/micro/host_driven/utvm_runtime.h +++ b/src/runtime/micro/host_driven/utvm_runtime.h @@ -29,8 +29,8 @@ extern "C" { #endif #include -#include #include +#include /*! * \brief TODO @@ -98,9 +98,9 @@ void UTVMDone(); // GCC -O3 begins to inject memset and memmove calls, so we provide impls in // the runtime for this case and for general usage. -void *memset(void *s, int c, size_t n); +void* memset(void* s, int c, size_t n); -void *memmove(void *to, const void *from, size_t n); +void* memmove(void* to, const void* from, size_t n); #ifdef __cplusplus } // TVM_EXTERN_C diff --git a/src/runtime/micro/host_low_level_device.cc b/src/runtime/micro/host_low_level_device.cc index da4ade4..7c3e7a2 100644 --- a/src/runtime/micro/host_low_level_device.cc +++ b/src/runtime/micro/host_low_level_device.cc @@ -23,10 +23,12 @@ */ #include + #include #include -#include "micro_common.h" + #include "low_level_device.h" +#include "micro_common.h" namespace tvm { namespace runtime { @@ -50,16 +52,14 @@ class HostLowLevelDevice final : public LowLevelDevice { int mmap_prot = PROT_READ | PROT_WRITE | PROT_EXEC; int mmap_flags = MAP_ANONYMOUS | MAP_PRIVATE; base_addr_ = mmap(nullptr, size_in_pages * kPageSize, mmap_prot, mmap_flags, -1, 0); - *base_addr = TargetPtr(TargetWordSize(sizeof(size_t) * 8), - reinterpret_cast(base_addr_)); + *base_addr = + TargetPtr(TargetWordSize(sizeof(size_t) * 8), reinterpret_cast(base_addr_)); } /*! * \brief destructor to deallocate on-host device region */ - virtual ~HostLowLevelDevice() { - munmap(base_addr_, size_); - } + virtual ~HostLowLevelDevice() { munmap(base_addr_, size_); } void Read(TargetPtr addr, void* buf, size_t num_bytes) { std::memcpy(buf, addr.cast_to(), num_bytes); @@ -73,9 +73,7 @@ class HostLowLevelDevice final : public LowLevelDevice { reinterpret_cast(func_addr.value().uint64())(); } - const char* device_type() const final { - return "host"; - } + const char* device_type() const final { return "host"; } private: /*! \brief base address of the micro device memory region */ diff --git a/src/runtime/micro/low_level_device.h b/src/runtime/micro/low_level_device.h index c5b5f3d..6cc0e1d 100644 --- a/src/runtime/micro/low_level_device.h +++ b/src/runtime/micro/low_level_device.h @@ -45,9 +45,7 @@ class LowLevelDevice { * \param buffer on-host buffer to be read into * \param num_bytes number of bytes to read */ - virtual void Read(TargetPtr addr, - void* buffer, - size_t num_bytes) = 0; + virtual void Read(TargetPtr addr, void* buffer, size_t num_bytes) = 0; /*! * \brief writes num_bytes from buffer to device memory at addr @@ -55,9 +53,7 @@ class LowLevelDevice { * \param buffer host buffer to write from * \param num_bytes number of bytes to write */ - virtual void Write(TargetPtr addr, - const void* buffer, - size_t num_bytes) = 0; + virtual void Write(TargetPtr addr, const void* buffer, size_t num_bytes) = 0; /*! * \brief starts execution of device at func_addr diff --git a/src/runtime/micro/micro_common.cc b/src/runtime/micro/micro_common.cc index c544fcd..020df62 100644 --- a/src/runtime/micro/micro_common.cc +++ b/src/runtime/micro/micro_common.cc @@ -22,65 +22,65 @@ * \brief common utilties for uTVM */ +#include "micro_common.h" + #include #include + +#include #include -#include #include -#include -#include "micro_session.h" -#include "micro_common.h" +#include + #include "low_level_device.h" +#include "micro_session.h" namespace tvm { namespace runtime { const char* SectionToString(SectionKind section) { switch (section) { - case SectionKind::kText: return "text"; - case SectionKind::kRodata: return "rodata"; - case SectionKind::kData: return "data"; - case SectionKind::kBss: return "bss"; - case SectionKind::kArgs: return "args"; - case SectionKind::kHeap: return "heap"; - case SectionKind::kWorkspace: return "workspace"; - case SectionKind::kStack: return "stack"; - default: return ""; + case SectionKind::kText: + return "text"; + case SectionKind::kRodata: + return "rodata"; + case SectionKind::kData: + return "data"; + case SectionKind::kBss: + return "bss"; + case SectionKind::kArgs: + return "args"; + case SectionKind::kHeap: + return "heap"; + case SectionKind::kWorkspace: + return "workspace"; + case SectionKind::kStack: + return "stack"; + default: + return ""; } } -std::string RelocateBinarySections( - const std::string& binary_path, - TargetWordSize word_size, - TargetPtr text_start, - TargetPtr rodata_start, - TargetPtr data_start, - TargetPtr bss_start, - TargetPtr stack_end, - const std::string& toolchain_prefix) { +std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size, + TargetPtr text_start, TargetPtr rodata_start, + TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end, + const std::string& toolchain_prefix) { const auto* f = Registry::Get("tvm_callback_relocate_binary"); - CHECK(f != nullptr) - << "Require tvm_callback_relocate_binary to exist in registry"; - std::string relocated_bin = (*f)(binary_path, - word_size.bytes(), - text_start.cast_to(), - rodata_start.cast_to(), - data_start.cast_to(), - bss_start.cast_to(), - stack_end.cast_to(), - toolchain_prefix); + CHECK(f != nullptr) << "Require tvm_callback_relocate_binary to exist in registry"; + std::string relocated_bin = + (*f)(binary_path, word_size.bytes(), text_start.cast_to(), + rodata_start.cast_to(), data_start.cast_to(), + bss_start.cast_to(), stack_end.cast_to(), toolchain_prefix); return relocated_bin; } -std::string ReadSection(const std::string& binary, - SectionKind section, +std::string ReadSection(const std::string& binary, SectionKind section, const std::string& toolchain_prefix) { CHECK(section == SectionKind::kText || section == SectionKind::kRodata || section == SectionKind::kData || section == SectionKind::kBss) << "ReadSection requires section to be one of text, rodata, data, or bss."; const auto* f = Registry::Get("tvm_callback_read_binary_section"); - CHECK(f != nullptr) - << "Require tvm_callback_read_binary_section to exist in registry"; + CHECK(f != nullptr) << "Require tvm_callback_read_binary_section to exist in registry"; TVMByteArray arr; arr.data = &binary[0]; arr.size = binary.length(); @@ -88,16 +88,13 @@ std::string ReadSection(const std::string& binary, return section_contents; } -size_t GetSectionSize(const std::string& binary_path, - SectionKind section, - const std::string& toolchain_prefix, - TargetWordSize word_size) { +size_t GetSectionSize(const std::string& binary_path, SectionKind section, + const std::string& toolchain_prefix, TargetWordSize word_size) { CHECK(section == SectionKind::kText || section == SectionKind::kRodata || section == SectionKind::kData || section == SectionKind::kBss) << "GetSectionSize requires section to be one of text, rodata, data, or bss."; const auto* f = Registry::Get("tvm_callback_get_section_size"); - CHECK(f != nullptr) - << "Require tvm_callback_get_section_size to exist in registry"; + CHECK(f != nullptr) << "Require tvm_callback_get_section_size to exist in registry"; int size = (*f)(binary_path, SectionToString(section), toolchain_prefix); return UpperAlignValue(size, word_size.bytes()); } diff --git a/src/runtime/micro/micro_common.h b/src/runtime/micro/micro_common.h index 2d74bc3..4375791 100644 --- a/src/runtime/micro/micro_common.h +++ b/src/runtime/micro/micro_common.h @@ -24,7 +24,6 @@ #define TVM_RUNTIME_MICRO_MICRO_COMMON_H_ #include - #include #include @@ -58,22 +57,17 @@ class TargetWordSize { public: explicit TargetWordSize(size_t word_size_bits) : word_size_bits_{word_size_bits} { CHECK(word_size_bits == 32 || word_size_bits == 64) - << "only 32-bit and 64-bit are supported now"; + << "only 32-bit and 64-bit are supported now"; } - size_t bytes() const { - return word_size_bits_ / 8; - } + size_t bytes() const { return word_size_bits_ / 8; } - size_t bits() const { - return word_size_bits_; - } + size_t bits() const { return word_size_bits_; } private: size_t word_size_bits_; }; - /*! \brief class for storing values on varying target word sizes */ class TargetVal { private: @@ -82,7 +76,7 @@ class TargetVal { public: /*! \brief construct a TargetVal matching the size of the given integral argument */ - template::value, T>::type> + template ::value, T>::type> explicit constexpr TargetVal(T value) : TargetVal(sizeof(T) * 8, value) {} /*! \brief construct an uninitialized value */ @@ -90,10 +84,8 @@ class TargetVal { /*! \brief construct a TargetVal with explicit size and value */ TargetVal(size_t width_bits, uint64_t value) : width_bits_{width_bits} { - CHECK(width_bits >= 8 && - width_bits <= 64 && - (width_bits & (width_bits - 1)) == 0) - << "width_bits must be a power of 2 in [8, 64], got " << width_bits; + CHECK(width_bits >= 8 && width_bits <= 64 && (width_bits & (width_bits - 1)) == 0) + << "width_bits must be a power of 2 in [8, 64], got " << width_bits; value_ = value & Bitmask(); } @@ -134,8 +126,8 @@ class TargetVal { } CHECK(width_bits_ >= other.width_bits_) - << "Cannot assign TargetVal with width " << other.width_bits_ - << "bits to TargetVal with width " << width_bits_ << "bits"; + << "Cannot assign TargetVal with width " << other.width_bits_ + << "bits to TargetVal with width " << width_bits_ << "bits"; value_ = other.value_ & Bitmask(); return *this; @@ -147,12 +139,12 @@ class TargetVal { class TargetPtr { public: /*! \brief construct a device address with variable-length value `value` */ - TargetPtr(TargetWordSize word_size, std::uint64_t value) : - value_(TargetVal(word_size.bits(), value)) {} + TargetPtr(TargetWordSize word_size, std::uint64_t value) + : value_(TargetVal(word_size.bits(), value)) {} /*! \brief construct a null address */ - TargetPtr(TargetWordSize word_size, std::nullptr_t value) : - value_{TargetVal(word_size.bits(), 0)} {} + TargetPtr(TargetWordSize word_size, std::nullptr_t value) + : value_{TargetVal(word_size.bits(), 0)} {} /*! \brief construct an uninitialized pointer whose word_size can be changed once */ TargetPtr() = default; @@ -174,7 +166,9 @@ class TargetPtr { * \return casted result */ template - T cast_to() const { return reinterpret_cast(value_.uint64()); } + T cast_to() const { + return reinterpret_cast(value_.uint64()); + } /*! \brief check if location is null */ bool operator==(std::nullptr_t) const { return value_.uint64() == 0; } @@ -224,8 +218,7 @@ class SymbolMap { * \param binary contents of binary object file * \param toolchain_prefix prefix of compiler toolchain to use */ - SymbolMap(const std::string& binary, - const std::string& toolchain_prefix, + SymbolMap(const std::string& binary, const std::string& toolchain_prefix, TargetWordSize word_size) { const auto* f = Registry::Get("tvm_callback_get_symbol_map"); CHECK(f != nullptr) << "require tvm_callback_get_symbol_map to exist in registry"; @@ -258,9 +251,7 @@ class SymbolMap { return result->second; } - bool HasSymbol(const std::string& name) const { - return map_.find(name) != map_.end(); - } + bool HasSymbol(const std::string& name) const { return map_.find(name) != map_.end(); } void Dump(std::ostream& stream) const { for (auto e : map_) { @@ -332,15 +323,10 @@ const char* SectionToString(SectionKind section); * \param toolchain_prefix prefix of compiler toolchain to use * \return relocated binary file contents */ -std::string RelocateBinarySections( - const std::string& binary_path, - TargetWordSize word_size, - TargetPtr text_start, - TargetPtr rodata_start, - TargetPtr data_start, - TargetPtr bss_start, - TargetPtr stack_end, - const std::string& toolchain_prefix); +std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size, + TargetPtr text_start, TargetPtr rodata_start, + TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end, + const std::string& toolchain_prefix); /*! * \brief reads section from binary @@ -349,8 +335,7 @@ std::string RelocateBinarySections( * \param toolchain_prefix prefix of compiler toolchain to use * \return contents of the section */ -std::string ReadSection(const std::string& binary, - SectionKind section, +std::string ReadSection(const std::string& binary, SectionKind section, const std::string& toolchain_prefix); /*! @@ -361,10 +346,8 @@ std::string ReadSection(const std::string& binary, * \param word_size word size of the target, for alignment * \return size of the section if it exists, 0 otherwise */ -size_t GetSectionSize(const std::string& binary_name, - SectionKind section, - const std::string& toolchain_prefix, - TargetWordSize word_size); +size_t GetSectionSize(const std::string& binary_name, SectionKind section, + const std::string& toolchain_prefix, TargetWordSize word_size); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_device_api.cc b/src/runtime/micro/micro_device_api.cc index 77ad865..6848078 100644 --- a/src/runtime/micro/micro_device_api.cc +++ b/src/runtime/micro/micro_device_api.cc @@ -21,9 +21,10 @@ * \file micro_device_api.cc */ -#include -#include #include +#include +#include + #include "../workspace_pool.h" #include "micro_session.h" @@ -35,7 +36,7 @@ namespace runtime { class MicroDeviceAPI final : public DeviceAPI { public: /*! \brief constructor */ - MicroDeviceAPI() { } + MicroDeviceAPI() {} void SetDevice(TVMContext ctx) final {} @@ -45,9 +46,7 @@ class MicroDeviceAPI final : public DeviceAPI { } } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { ObjectPtr& session = MicroSession::Current(); TargetPtr data = session->AllocateInSection(SectionKind::kHeap, nbytes); @@ -61,14 +60,8 @@ class MicroDeviceAPI final : public DeviceAPI { delete dev_space; } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { std::tuple type_from_to(ctx_from.device_type, ctx_to.device_type); if (type_from_to == std::make_tuple(kDLMicroDev, kDLMicroDev)) { @@ -76,11 +69,10 @@ class MicroDeviceAPI final : public DeviceAPI { MicroDevSpace* from_space = static_cast(const_cast(from)); MicroDevSpace* to_space = static_cast(const_cast(to)); CHECK(from_space->session == to_space->session) - << "attempt to copy data between different micro sessions (" - << from_space->session.get() + << "attempt to copy data between different micro sessions (" << from_space->session.get() << " != " << to_space->session.get() << ")"; CHECK(ctx_from.device_id == ctx_to.device_id) - << "can only copy between the same micro device"; + << "can only copy between the same micro device"; ObjectPtr& session = from_space->session; // flush all pending tasks to ensure data is consistent session->FlushTaskQueue(); @@ -132,7 +124,7 @@ class MicroDeviceAPI final : public DeviceAPI { TargetPtr data = session->AllocateInSection(SectionKind::kWorkspace, size); CHECK(data.value().uint64() != 0) - << "unable to allocate " << size << " bytes on device workspace"; + << "unable to allocate " << size << " bytes on device workspace"; return static_cast(new MicroDevSpace{data, session}); } @@ -154,9 +146,7 @@ class MicroDeviceAPI final : public DeviceAPI { } private: - TargetPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) { - return dev_space->data + offset; - } + TargetPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) { return dev_space->data + offset; } void* GetHostLoc(const void* ptr, size_t offset) { return reinterpret_cast(reinterpret_cast(ptr) + offset); @@ -164,10 +154,9 @@ class MicroDeviceAPI final : public DeviceAPI { }; // register device that can be obtained from Python frontend -TVM_REGISTER_GLOBAL("device_api.micro_dev") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = MicroDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = MicroDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_module.cc b/src/runtime/micro/micro_module.cc index 01056de..b4770ec 100644 --- a/src/runtime/micro/micro_module.cc +++ b/src/runtime/micro/micro_module.cc @@ -21,15 +21,17 @@ * \file micro_module.cc */ -#include #include #include -#include +#include + #include -#include "micro_session.h" +#include + +#include "../pack_args.h" #include "low_level_device.h" #include "micro_common.h" -#include "../pack_args.h" +#include "micro_session.h" namespace tvm { namespace runtime { @@ -42,12 +44,9 @@ class MicroModuleNode final : public ModuleNode { ~MicroModuleNode() {} - const char* type_key() const final { - return "micro"; - } + const char* type_key() const final { return "micro"; } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; /*! * \brief initializes module by establishing device connection and loads binary @@ -68,8 +67,7 @@ class MicroModuleNode final : public ModuleNode { class MicroWrappedFunc { public: - MicroWrappedFunc(ObjectPtr session, - TargetPtr func_ptr) { + MicroWrappedFunc(ObjectPtr session, TargetPtr func_ptr) { session_ = session; func_ptr_ = func_ptr; } @@ -85,9 +83,8 @@ class MicroWrappedFunc { TargetPtr func_ptr_; }; -PackedFunc MicroModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc MicroModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { TargetPtr func_ptr; if (name == tvm::runtime::symbol::tvm_module_main) { if (symbol_map_.HasSymbol(tvm::runtime::symbol::tvm_module_main)) { @@ -104,10 +101,10 @@ PackedFunc MicroModuleNode::GetFunction( // register loadfile function to load module from Python frontend TVM_REGISTER_GLOBAL("runtime.module.loadfile_micro_dev") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->InitMicroModule(args[0]); - *rv = runtime::Module(n); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + n->InitMicroModule(args[0]); + *rv = runtime::Module(n); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_section_allocator.h b/src/runtime/micro/micro_section_allocator.h index 2067794..5cafb41 100644 --- a/src/runtime/micro/micro_section_allocator.h +++ b/src/runtime/micro/micro_section_allocator.h @@ -25,6 +25,7 @@ #include #include + #include "micro_common.h" namespace tvm { @@ -39,19 +40,18 @@ class MicroSectionAllocator { * \brief constructor that specifies section boundaries * \param region location and size of the section on the device */ - explicit MicroSectionAllocator(std::string section_name, - DevMemRegion region, + explicit MicroSectionAllocator(std::string section_name, DevMemRegion region, TargetWordSize word_size) - : section_name_(section_name), - start_addr_(region.start), - size_(0), - capacity_(region.size), - word_size_(word_size) { - CHECK_EQ(start_addr_.value().uint64() % word_size.bytes(), 0) + : section_name_(section_name), + start_addr_(region.start), + size_(0), + capacity_(region.size), + word_size_(word_size) { + CHECK_EQ(start_addr_.value().uint64() % word_size.bytes(), 0) << "micro section start not aligned to " << word_size.bytes() << " bytes"; - CHECK_EQ(capacity_ % word_size.bytes(), 0) + CHECK_EQ(capacity_ % word_size.bytes(), 0) << "micro section end not aligned to " << word_size.bytes() << " bytes"; - } + } /*! * \brief destructor @@ -66,9 +66,9 @@ class MicroSectionAllocator { TargetPtr Allocate(size_t size) { size_ = UpperAlignValue(size_, word_size_.bytes()); CHECK(size_ + size < capacity_) - << "cannot alloc " << size << " bytes in section \"" - << section_name_ << "\" (start_addr=" << start_addr_.cast_to() - << ", used=" << size_ << ", capacity=" << capacity_ << ")"; + << "cannot alloc " << size << " bytes in section \"" << section_name_ + << "\" (start_addr=" << start_addr_.cast_to() << ", used=" << size_ + << ", capacity=" << capacity_ << ")"; TargetPtr alloc_addr = start_addr_ + size_; size_ += size; alloc_map_[alloc_addr.value().uint64()] = size; @@ -82,7 +82,7 @@ class MicroSectionAllocator { */ void Free(TargetPtr addr) { CHECK(alloc_map_.find(addr.value().uint64()) != alloc_map_.end()) - << "freed pointer was never allocated"; + << "freed pointer was never allocated"; alloc_map_.erase(addr.value().uint64()); if (alloc_map_.empty()) { size_ = 0; diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index 0e8e169..a9efa0f 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -21,16 +21,19 @@ * \file micro_session.cc */ +#include "micro_session.h" + #include -#include #include +#include + #include -#include #include +#include #include #include #include -#include "micro_session.h" + #include "low_level_device.h" #include "target_data_layout_encoder.h" @@ -44,47 +47,31 @@ struct TVMMicroSessionThreadLocalEntry { typedef dmlc::ThreadLocalStore TVMMicroSessionThreadLocalStore; ObjectPtr& MicroSession::Current() { - TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); + TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); CHECK_GT(entry->session_stack.size(), 0) << "No current session"; return entry->session_stack.top(); } void MicroSession::EnterWithScope(ObjectPtr session) { - TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); + TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); entry->session_stack.push(session); } void MicroSession::ExitWithScope() { - TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); + TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); CHECK(!entry->session_stack.empty()); entry->session_stack.pop(); } -MicroSession::MicroSession( - const std::string& comms_method, - const std::string& binary_path, - const std::string& toolchain_prefix, - uint64_t text_start, - size_t text_size, - uint64_t rodata_start, - size_t rodata_size, - uint64_t data_start, - size_t data_size, - uint64_t bss_start, - size_t bss_size, - uint64_t args_start, - size_t args_size, - uint64_t heap_start, - size_t heap_size, - uint64_t workspace_start, - size_t workspace_size, - uint64_t stack_start, - size_t stack_size, - TargetWordSize word_size, - bool thumb_mode, - bool use_device_timer, - const std::string& server_addr, - int port) +MicroSession::MicroSession(const std::string& comms_method, const std::string& binary_path, + const std::string& toolchain_prefix, uint64_t text_start, + size_t text_size, uint64_t rodata_start, size_t rodata_size, + uint64_t data_start, size_t data_size, uint64_t bss_start, + size_t bss_size, uint64_t args_start, size_t args_size, + uint64_t heap_start, size_t heap_size, uint64_t workspace_start, + size_t workspace_size, uint64_t stack_start, size_t stack_size, + TargetWordSize word_size, bool thumb_mode, bool use_device_timer, + const std::string& server_addr, int port) : toolchain_prefix_(toolchain_prefix), word_size_(word_size), thumb_mode_(thumb_mode), @@ -92,130 +79,131 @@ MicroSession::MicroSession( batch_args_encoder_(args_size, word_size) { if (comms_method == "host") { // TODO(weberlo): move checks to python - CHECK( - text_start == 0 && - rodata_start == 0 && - data_start == 0 && - bss_start == 0 && - args_start == 0 && - heap_start == 0 && - workspace_start == 0 && - stack_start == 0) << "unable to specify section addresses for host device"; - size_t memory_size = - text_size + rodata_size + data_size + bss_size + - args_size + heap_size + workspace_size + stack_size; + CHECK(text_start == 0 && rodata_start == 0 && data_start == 0 && bss_start == 0 && + args_start == 0 && heap_start == 0 && workspace_start == 0 && stack_start == 0) + << "unable to specify section addresses for host device"; + size_t memory_size = text_size + rodata_size + data_size + bss_size + args_size + heap_size + + workspace_size + stack_size; TargetPtr base_addr; low_level_device_ = HostLowLevelDeviceCreate(memory_size, &base_addr); CHECK_EQ(base_addr.value().uint64() % word_size.bytes(), 0) - << "base address not aligned to " << word_size.bytes() << " bytes"; + << "base address not aligned to " << word_size.bytes() << " bytes"; TargetPtr curr_addr = base_addr; - section_allocators_[0] = std::make_shared( - "text", - DevMemRegion { - .start = curr_addr, - .size = text_size, - }, word_size_); + section_allocators_[0] = std::make_shared("text", + DevMemRegion{ + .start = curr_addr, + .size = text_size, + }, + word_size_); curr_addr += text_size; - section_allocators_[1] = std::make_shared( - "rodata", - DevMemRegion { - .start = curr_addr, - .size = rodata_size, - }, word_size_); + section_allocators_[1] = std::make_shared("rodata", + DevMemRegion{ + .start = curr_addr, + .size = rodata_size, + }, + word_size_); curr_addr += rodata_size; - section_allocators_[2] = std::make_shared( - "data", - DevMemRegion { - .start = curr_addr, - .size = data_size, - }, word_size_); + section_allocators_[2] = std::make_shared("data", + DevMemRegion{ + .start = curr_addr, + .size = data_size, + }, + word_size_); curr_addr += data_size; - section_allocators_[3] = std::make_shared( - "bss", - DevMemRegion { - .start = curr_addr, - .size = bss_size, - }, word_size_); + section_allocators_[3] = std::make_shared("bss", + DevMemRegion{ + .start = curr_addr, + .size = bss_size, + }, + word_size_); curr_addr += bss_size; - section_allocators_[4] = std::make_shared( - "args", - DevMemRegion { - .start = curr_addr, - .size = args_size, - }, word_size_); + section_allocators_[4] = std::make_shared("args", + DevMemRegion{ + .start = curr_addr, + .size = args_size, + }, + word_size_); curr_addr += args_size; - section_allocators_[5] = std::make_shared( - "heap", - DevMemRegion { - .start = curr_addr, - .size = heap_size, - }, word_size_); + section_allocators_[5] = std::make_shared("heap", + DevMemRegion{ + .start = curr_addr, + .size = heap_size, + }, + word_size_); curr_addr += heap_size; - section_allocators_[6] = std::make_shared( - "workspace", - DevMemRegion { - .start = curr_addr, - .size = workspace_size, - }, word_size_); + section_allocators_[6] = std::make_shared("workspace", + DevMemRegion{ + .start = curr_addr, + .size = workspace_size, + }, + word_size_); curr_addr += workspace_size; - section_allocators_[7] = std::make_shared( - "stack", - DevMemRegion { - .start = curr_addr, - .size = stack_size, - }, word_size_); + section_allocators_[7] = std::make_shared("stack", + DevMemRegion{ + .start = curr_addr, + .size = stack_size, + }, + word_size_); curr_addr += stack_size; } else if (comms_method == "openocd") { low_level_device_ = OpenOCDLowLevelDeviceCreate(server_addr, port); - section_allocators_[0] = std::make_shared( - "text", - DevMemRegion { - .start = TargetPtr(word_size_, text_start), - .size = text_size, - }, word_size_); - section_allocators_[1] = std::make_shared( - "rodata", - DevMemRegion { - .start = TargetPtr(word_size_, rodata_start), - .size = rodata_size, - }, word_size_); - section_allocators_[2] = std::make_shared( - "data", - DevMemRegion { - .start = TargetPtr(word_size_, data_start), - .size = data_size, - }, word_size_); - section_allocators_[3] = std::make_shared( - "bss", - DevMemRegion { - .start = TargetPtr(word_size_, bss_start), - .size = bss_size, - }, word_size_); - section_allocators_[4] = std::make_shared( - "args", - DevMemRegion { - .start = TargetPtr(word_size_, args_start), - .size = args_size, - }, word_size_); - section_allocators_[5] = std::make_shared( - "heap", - DevMemRegion { - .start = TargetPtr(word_size_, heap_start), - .size = heap_size, - }, word_size_); - section_allocators_[6] = std::make_shared( - "workspace", - DevMemRegion { - .start = TargetPtr(word_size_, workspace_start), - .size = workspace_size, - }, word_size_); - section_allocators_[7] = std::make_shared( - "stack", - DevMemRegion { - .start = TargetPtr(word_size_, stack_start), - .size = stack_size, - }, word_size_); + section_allocators_[0] = + std::make_shared("text", + DevMemRegion{ + .start = TargetPtr(word_size_, text_start), + .size = text_size, + }, + word_size_); + section_allocators_[1] = + std::make_shared("rodata", + DevMemRegion{ + .start = TargetPtr(word_size_, rodata_start), + .size = rodata_size, + }, + word_size_); + section_allocators_[2] = + std::make_shared("data", + DevMemRegion{ + .start = TargetPtr(word_size_, data_start), + .size = data_size, + }, + word_size_); + section_allocators_[3] = + std::make_shared("bss", + DevMemRegion{ + .start = TargetPtr(word_size_, bss_start), + .size = bss_size, + }, + word_size_); + section_allocators_[4] = + std::make_shared("args", + DevMemRegion{ + .start = TargetPtr(word_size_, args_start), + .size = args_size, + }, + word_size_); + section_allocators_[5] = + std::make_shared("heap", + DevMemRegion{ + .start = TargetPtr(word_size_, heap_start), + .size = heap_size, + }, + word_size_); + section_allocators_[6] = + std::make_shared("workspace", + DevMemRegion{ + .start = TargetPtr(word_size_, workspace_start), + .size = workspace_size, + }, + word_size_); + section_allocators_[7] = + std::make_shared("stack", + DevMemRegion{ + .start = TargetPtr(word_size_, stack_start), + .size = stack_size, + }, + word_size_); } else { LOG(FATAL) << "unsupported micro low-level device"; } @@ -257,13 +245,10 @@ void MicroSession::PushToTaskQueue(TargetPtr func_ptr, const TVMArgs& args) { TargetVal arg_values_dev_addr{std::get<0>(arg_field_addrs).value()}; TargetVal arg_type_codes_dev_addr{std::get<1>(arg_field_addrs).value()}; - task_queue_.push_back( - DevTask { - .func = func_dev_addr, - .arg_values = arg_values_dev_addr, - .arg_type_codes = arg_type_codes_dev_addr, - .num_args = args.num_args - }); + task_queue_.push_back(DevTask{.func = func_dev_addr, + .arg_values = arg_values_dev_addr, + .arg_type_codes = arg_type_codes_dev_addr, + .num_args = args.num_args}); if (task_queue_.size() == MicroSession::kTaskQueueCapacity) { FlushTaskQueue(); @@ -290,17 +275,14 @@ void MicroSession::FlushTaskQueuePriv() { } // Flush `args` to device memory. - low_level_device()->Write( - batch_args_encoder_.start_addr(), - reinterpret_cast(batch_args_encoder_.data()), - batch_args_encoder_.buf_size()); + low_level_device()->Write(batch_args_encoder_.start_addr(), + reinterpret_cast(batch_args_encoder_.data()), + batch_args_encoder_.buf_size()); // Flush `tasks` to device memory. TargetPtr dev_tasks_addr = runtime_symbol_map_["utvm_tasks"]; - low_level_device()->Write( - dev_tasks_addr, - reinterpret_cast(prepped_tasks.data()), - prepped_tasks.size() * sizeof(T)); + low_level_device()->Write(dev_tasks_addr, reinterpret_cast(prepped_tasks.data()), + prepped_tasks.size() * sizeof(T)); DevSymbolWrite(runtime_symbol_map_, "utvm_num_tasks", prepped_tasks.size()); TargetPtr utvm_init_addr = runtime_symbol_map_["UTVMInit"]; @@ -310,8 +292,8 @@ void MicroSession::FlushTaskQueuePriv() { utvm_init_addr += 1; } - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; + std::chrono::time_point tbegin, + tend; tbegin = std::chrono::high_resolution_clock::now(); // std::string tmp; // while (tmp[0] != 'd' && tmp[0] != 'e') { @@ -335,8 +317,7 @@ void MicroSession::FlushTaskQueuePriv() { uint64_t sum = 0; std::vector times; times.resize(task_queue_.size()); - low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], - times.data(), + low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(), task_queue_.size() * sizeof(uint32_t)); int i = 0; for (uint32_t time : times) { @@ -345,14 +326,13 @@ void MicroSession::FlushTaskQueuePriv() { } last_batch_time_ += static_cast(sum) / 1e3; } else { - last_batch_time_ += std::chrono::duration_cast > - (tend - tbegin).count() * 1000; + last_batch_time_ += + std::chrono::duration_cast>(tend - tbegin).count() * 1000; // TODO(weberlo): Reading internal data structure is hacky. uint64_t sum = 0; std::vector times; times.resize(task_queue_.size()); - low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], - times.data(), + low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(), task_queue_.size() * sizeof(uint32_t)); for (uint32_t time : times) { sum += time; @@ -370,14 +350,13 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d DevMemRegion data_section; DevMemRegion bss_section; - text_section.size = GetSectionSize( - binary_path, SectionKind::kText, toolchain_prefix_, word_size_); - rodata_section.size = GetSectionSize( - binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_); - data_section.size = GetSectionSize( - binary_path, SectionKind::kData, toolchain_prefix_, word_size_); - bss_section.size = GetSectionSize( - binary_path, SectionKind::kBss, toolchain_prefix_, word_size_); + text_section.size = + GetSectionSize(binary_path, SectionKind::kText, toolchain_prefix_, word_size_); + rodata_section.size = + GetSectionSize(binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_); + data_section.size = + GetSectionSize(binary_path, SectionKind::kData, toolchain_prefix_, word_size_); + bss_section.size = GetSectionSize(binary_path, SectionKind::kBss, toolchain_prefix_, word_size_); text_section.start = AllocateInSection(SectionKind::kText, text_section.size); rodata_section.start = AllocateInSection(SectionKind::kRodata, rodata_section.size); @@ -385,14 +364,8 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d bss_section.start = AllocateInSection(SectionKind::kBss, bss_section.size); std::string relocated_bin = RelocateBinarySections( - binary_path, - word_size_, - text_section.start, - rodata_section.start, - data_section.start, - bss_section.start, - GetAllocator(SectionKind::kStack)->max_addr(), - toolchain_prefix_); + binary_path, word_size_, text_section.start, rodata_section.start, data_section.start, + bss_section.start, GetAllocator(SectionKind::kStack)->max_addr(), toolchain_prefix_); std::string text_contents = ReadSection(relocated_bin, SectionKind::kText, toolchain_prefix_); std::string rodata_contents = ReadSection(relocated_bin, SectionKind::kRodata, toolchain_prefix_); std::string data_contents = ReadSection(relocated_bin, SectionKind::kData, toolchain_prefix_); @@ -402,7 +375,7 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d low_level_device_->Write(rodata_section.start, &rodata_contents[0], rodata_section.size); low_level_device_->Write(data_section.start, &data_contents[0], data_section.size); low_level_device_->Write(bss_section.start, &bss_contents[0], bss_section.size); - SymbolMap symbol_map {relocated_bin, toolchain_prefix_, word_size_}; + SymbolMap symbol_map{relocated_bin, toolchain_prefix_, word_size_}; if (patch_dylib_pointers) { // Patch device lib pointers. @@ -411,7 +384,7 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d PatchImplHole(symbol_map, "TVMAPISetLastError"); } - return BinaryInfo { + return BinaryInfo{ .text_section = text_section, .rodata_section = rodata_section, .data_section = data_section, @@ -420,8 +393,8 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d }; } -std::tuple MicroSession::EncoderAppend( - TargetDataLayoutEncoder* encoder, const TVMArgs& args) { +std::tuple MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, + const TVMArgs& args) { const int* type_codes = args.type_codes; int num_args = args.num_args; @@ -485,16 +458,11 @@ TargetPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DL strides_dev_addr = stride_slot.start_addr(); } - T dev_arr( - TargetVal { word_size_.bits(), reinterpret_cast(arr.data) }, - arr.ctx, - arr.ndim, - arr.dtype, - shape_dev_addr.value(), - strides_dev_addr.value(), - TargetVal { word_size_.bits(), arr.byte_offset }); + T dev_arr(TargetVal{word_size_.bits(), reinterpret_cast(arr.data)}, arr.ctx, arr.ndim, + arr.dtype, shape_dev_addr.value(), strides_dev_addr.value(), + TargetVal{word_size_.bits(), arr.byte_offset}); CHECK(dev_arr.ctx.device_type == static_cast(kDLMicroDev)) - << "attempt to write DLTensor with non-micro device type"; + << "attempt to write DLTensor with non-micro device type"; // Update the device type to CPU, because from the microcontroller's // perspective, it is. dev_arr.ctx.device_type = DLDeviceType::kDLCPU; @@ -509,8 +477,7 @@ void MicroSession::CheckDeviceError() { if (last_error) { if (!use_device_timer_ && - (last_error == UTVM_ERR_TIMER_OVERFLOW || - last_error == UTVM_ERR_TIMER_NOT_IMPLEMENTED)) { + (last_error == UTVM_ERR_TIMER_OVERFLOW || last_error == UTVM_ERR_TIMER_NOT_IMPLEMENTED)) { // these errors don't matter if we're not using the on-device timer return; } @@ -599,8 +566,7 @@ T MicroSession::DevSymbolRead(const SymbolMap& symbol_map, const std::string& sy return result; } -void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, - const std::string& symbol, +void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const TargetPtr& ptr) { if (word_size_.bytes() == 4) { DevSymbolWrite(symbol_map, symbol, ptr.value().uint32()); @@ -612,54 +578,48 @@ void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, } template -void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, - const std::string& symbol, +void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const T& value) { TargetPtr sym_addr = symbol_map[symbol]; low_level_device()->Write(sym_addr, &value, sizeof(T)); } -PackedFunc MicroSession::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc MicroSession::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { if (name == "enter") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { MicroSession::EnterWithScope(GetObjectPtr(this)); }); } else if (name == "exit") { - return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { - MicroSession::ExitWithScope(); - }); + return PackedFunc( + [sptr_to_self](TVMArgs args, TVMRetValue* rv) { MicroSession::ExitWithScope(); }); // TODO(weberlo): add a `clear_batch_timer` func } else if (name == "get_last_batch_time") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLastBatchTime(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchTime(); }); // TODO(weberlo): remove this func } else if (name == "get_last_batch_cycles") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLastBatchCycles(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchCycles(); }); } else { return PackedFunc(); } } -TVM_REGISTER_GLOBAL("micro._GetMicroTimeEvaluator") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("micro._GetMicroTimeEvaluator").set_body([](TVMArgs args, TVMRetValue* rv) { PackedFunc pf = args[0]; TVMContext ctx = args[1]; uint64_t number = args[2]; uint64_t repeat = args[3]; - auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue *rv) mutable { + auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue* rv) mutable { TVMRetValue temp; std::ostringstream os; for (unsigned int i = 0; i < repeat; ++i) { // start timing CHECK(number < MicroSession::kTaskQueueCapacity) - << "`number` must be less than uTVM task queue capacity"; + << "`number` must be less than uTVM task queue capacity"; for (unsigned int j = 0; j < number; ++j) { pf.CallPacked(args, &temp); } @@ -678,61 +638,39 @@ TVM_REGISTER_GLOBAL("micro._GetMicroTimeEvaluator") *rv = PackedFunc(ftimer); }); - // create micro session and low-level device from Python frontend -TVM_REGISTER_GLOBAL("micro._CreateSession") -.set_body([](TVMArgs args, TVMRetValue* rv) { - const std::string& comms_method = args[0]; - const std::string& binary_path = args[1]; - const std::string& toolchain_prefix = args[2]; - uint64_t text_start = args[3]; - size_t text_size = uint64_t(args[4]); - uint64_t rodata_start = args[5]; - size_t rodata_size = uint64_t(args[6]); - uint64_t data_start = args[7]; - size_t data_size = uint64_t(args[8]); - uint64_t bss_start = args[9]; - size_t bss_size = uint64_t(args[10]); - uint64_t args_start = args[11]; - size_t args_size = uint64_t(args[12]); - uint64_t heap_start = args[13]; - size_t heap_size = uint64_t(args[14]); - uint64_t workspace_start = args[15]; - size_t workspace_size = uint64_t(args[16]); - uint64_t stack_start = args[17]; - size_t stack_size = uint64_t(args[18]); - TargetWordSize word_size{uint64_t(args[19])}; - bool thumb_mode = args[20]; - bool use_device_timer = args[21]; - const std::string& server_addr = args[22]; - int port = args[23]; - ObjectPtr session = make_object( - comms_method, - binary_path, - toolchain_prefix, - text_start, - text_size, - rodata_start, - rodata_size, - data_start, - data_size, - bss_start, - bss_size, - args_start, - args_size, - heap_start, - heap_size, - workspace_start, - workspace_size, - stack_start, - stack_size, - word_size, - thumb_mode, - use_device_timer, - server_addr, - port); - *rv = Module(session); - }); +TVM_REGISTER_GLOBAL("micro._CreateSession").set_body([](TVMArgs args, TVMRetValue* rv) { + const std::string& comms_method = args[0]; + const std::string& binary_path = args[1]; + const std::string& toolchain_prefix = args[2]; + uint64_t text_start = args[3]; + size_t text_size = uint64_t(args[4]); + uint64_t rodata_start = args[5]; + size_t rodata_size = uint64_t(args[6]); + uint64_t data_start = args[7]; + size_t data_size = uint64_t(args[8]); + uint64_t bss_start = args[9]; + size_t bss_size = uint64_t(args[10]); + uint64_t args_start = args[11]; + size_t args_size = uint64_t(args[12]); + uint64_t heap_start = args[13]; + size_t heap_size = uint64_t(args[14]); + uint64_t workspace_start = args[15]; + size_t workspace_size = uint64_t(args[16]); + uint64_t stack_start = args[17]; + size_t stack_size = uint64_t(args[18]); + TargetWordSize word_size{uint64_t(args[19])}; + bool thumb_mode = args[20]; + bool use_device_timer = args[21]; + const std::string& server_addr = args[22]; + int port = args[23]; + ObjectPtr session = make_object( + comms_method, binary_path, toolchain_prefix, text_start, text_size, rodata_start, rodata_size, + data_start, data_size, bss_start, bss_size, args_start, args_size, heap_start, heap_size, + workspace_start, workspace_size, stack_start, stack_size, word_size, thumb_mode, + use_device_timer, server_addr, port); + *rv = Module(session); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_session.h b/src/runtime/micro/micro_session.h index bf0996c..ab3afcc 100644 --- a/src/runtime/micro/micro_session.h +++ b/src/runtime/micro/micro_session.h @@ -34,19 +34,18 @@ #ifndef TVM_RUNTIME_MICRO_MICRO_SESSION_H_ #define TVM_RUNTIME_MICRO_MICRO_SESSION_H_ -#include "micro_common.h" -#include "micro_section_allocator.h" - -#include #include +#include #include #include +#include #include #include -#include #include "low_level_device.h" +#include "micro_common.h" +#include "micro_section_allocator.h" #include "target_data_layout_encoder.h" namespace tvm { @@ -65,8 +64,7 @@ class MicroSession : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); // todo having this decoupled from the value in utvm_runtime.c gives me stress dreams static const size_t kTaskQueueCapacity = 20; @@ -74,9 +72,7 @@ class MicroSession : public ModuleNode { /*! * \return The type key of the executor. */ - const char* type_key() const final { - return "MicroSession"; - } + const char* type_key() const final { return "MicroSession"; } /*! * \brief creates session by setting up a low-level device and initting allocators for it @@ -104,31 +100,14 @@ class MicroSession : public ModuleNode { * \param server_addr address of the OpenOCD server to connect to (if `comms_method == "openocd"`) * \param port port of the OpenOCD server to connect to (if `comms_method == "openocd"`) */ - MicroSession( - const std::string& comms_method, - const std::string& binary_path, - const std::string& toolchain_prefix, - uint64_t text_start, - size_t text_size, - uint64_t rodata_start, - size_t rodata_size, - uint64_t data_start, - size_t data_size, - uint64_t bss_start, - size_t bss_size, - uint64_t args_start, - size_t args_size, - uint64_t heap_start, - size_t heap_size, - uint64_t workspace_start, - size_t workspace_size, - uint64_t stack_start, - size_t stack_size, - TargetWordSize word_size, - bool thumb_mode, - bool use_device_timer, - const std::string& server_addr, - int port); + MicroSession(const std::string& comms_method, const std::string& binary_path, + const std::string& toolchain_prefix, uint64_t text_start, size_t text_size, + uint64_t rodata_start, size_t rodata_size, uint64_t data_start, size_t data_size, + uint64_t bss_start, size_t bss_size, uint64_t args_start, size_t args_size, + uint64_t heap_start, size_t heap_size, uint64_t workspace_start, + size_t workspace_size, uint64_t stack_start, size_t stack_size, + TargetWordSize word_size, bool thumb_mode, bool use_device_timer, + const std::string& server_addr, int port); /*! * \brief destructor @@ -188,29 +167,27 @@ class MicroSession : public ModuleNode { std::string ReadString(TargetPtr str_addr); /*! - * \brief read value of symbol from device memory - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being read from - * \return value at symbol in memory - */ + * \brief read value of symbol from device memory + * \param symbol_map symbol map to read location of symbol from + * \param symbol name of symbol being read from + * \return value at symbol in memory + */ template T DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol); /*! * \brief write pointer value into device memory corresponding to symbol - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being written to - * \param ptr pointer value to write into symbol + * \param symbol_map symbol map to read location of symbol from + * \param symbol name of symbol being written to + * \param ptr pointer value to write into symbol */ - void DevSymbolWrite(const SymbolMap& symbol_map, - const std::string& symbol, - const TargetPtr& ptr); + void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const TargetPtr& ptr); /*! - * \brief write value into device memory corresponding to symbol - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being written to - * \param value value being written into symbol + * \brief write value into device memory corresponding to symbol + * \param symbol_map symbol map to read location of symbol from + * \param symbol name of symbol being written to + * \param value value being written into symbol */ template void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const T& value); @@ -307,15 +284,15 @@ class MicroSession : public ModuleNode { } /*! - * \brief Push a new session context onto the thread-local stack. - * The session on top of the stack is used as the current global session. - */ + * \brief Push a new session context onto the thread-local stack. + * The session on top of the stack is used as the current global session. + */ static void EnterWithScope(ObjectPtr session); /*! - * \brief Pop a session off the thread-local context stack, - * restoring the previous session as the current context. - */ + * \brief Pop a session off the thread-local context stack, + * restoring the previous session as the current context. + */ static void ExitWithScope(); }; @@ -336,24 +313,18 @@ struct MicroDevSpace { /*! \brief TVM array for serialization to 32-bit devices */ struct TVMArray32 { - TVMArray32( - TargetVal data, - DLContext ctx, - int32_t ndim, - DLDataType dtype, - TargetVal shape, - TargetVal strides, - TargetVal byte_offset) - : data(data.uint32()), - ctx(ctx), - ndim(ndim), - pad0(0), - dtype(dtype), - shape(shape.uint32()), - strides(strides.uint32()), - pad1(0), - byte_offset(byte_offset.uint32()), - pad2(0) { } + TVMArray32(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape, + TargetVal strides, TargetVal byte_offset) + : data(data.uint32()), + ctx(ctx), + ndim(ndim), + pad0(0), + dtype(dtype), + shape(shape.uint32()), + strides(strides.uint32()), + pad1(0), + byte_offset(byte_offset.uint32()), + pad2(0) {} /*! * \brief The opaque data pointer points to the allocated data. @@ -386,22 +357,16 @@ struct TVMArray32 { /*! \brief TVM array for serialization to 64-bit devices */ struct TVMArray64 { - TVMArray64( - TargetVal data, - DLContext ctx, - int32_t ndim, - DLDataType dtype, - TargetVal shape, - TargetVal strides, - TargetVal byte_offset) - : data(data.uint64()), - ctx(ctx), - ndim(ndim), - pad0(0), - dtype(dtype), - shape(shape.uint64()), - strides(strides.uint64()), - byte_offset(byte_offset.uint64()) { } + TVMArray64(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape, + TargetVal strides, TargetVal byte_offset) + : data(data.uint64()), + ctx(ctx), + ndim(ndim), + pad0(0), + dtype(dtype), + shape(shape.uint64()), + strides(strides.uint64()), + byte_offset(byte_offset.uint64()) {} /*! * \brief The opaque data pointer points to the allocated data. * This will be CUDA device pointer or cl_mem handle in OpenCL. @@ -442,10 +407,10 @@ struct DevTask { /*! \brief MicroTVM task for serialization to 32-bit devices */ typedef struct StructUTVMTask32 { StructUTVMTask32(DevTask task) - : func(task.func.uint32()), - arg_values(task.arg_values.uint32()), - arg_type_codes(task.arg_type_codes.uint32()), - num_args(task.num_args) { } + : func(task.func.uint32()), + arg_values(task.arg_values.uint32()), + arg_type_codes(task.arg_type_codes.uint32()), + num_args(task.num_args) {} /*! \brief Pointer to function to call for this task */ uint32_t func; @@ -460,10 +425,10 @@ typedef struct StructUTVMTask32 { /*! \brief MicroTVM task for serialization to 64-bit devices */ typedef struct StructUTVMTask64 { StructUTVMTask64(DevTask task) - : func(task.func.uint64()), - arg_values(task.arg_values.uint64()), - arg_type_codes(task.arg_type_codes.uint64()), - num_args(task.num_args) { } + : func(task.func.uint64()), + arg_values(task.arg_values.uint64()), + arg_type_codes(task.arg_type_codes.uint64()), + num_args(task.num_args) {} /*! \brief Pointer to function to call for this task */ uint64_t func; diff --git a/src/runtime/micro/openocd_low_level_device.cc b/src/runtime/micro/openocd_low_level_device.cc index 91df8b9..610ca85 100644 --- a/src/runtime/micro/openocd_low_level_device.cc +++ b/src/runtime/micro/openocd_low_level_device.cc @@ -23,8 +23,8 @@ #include #include -#include "micro_common.h" #include "low_level_device.h" +#include "micro_common.h" #include "tcl_socket.h" namespace tvm { @@ -40,8 +40,7 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { * \param server_addr address of the OpenOCD server to connect to * \param port port of the OpenOCD server to connect to */ - explicit OpenOCDLowLevelDevice(const std::string& server_addr, - int port) : socket_() { + explicit OpenOCDLowLevelDevice(const std::string& server_addr, int port) : socket_() { server_addr_ = server_addr; port_ = port; @@ -80,13 +79,12 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { socket_.cmd_builder() << "array unset output"; socket_.SendCommand(); - socket_.cmd_builder() - << "mem2array output" - << " " << std::dec << kWordSize - << " " << addr.cast_to() - // Round up any request sizes under a byte, since OpenOCD doesn't support - // sub-byte-sized transfers. - << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes); + socket_.cmd_builder() << "mem2array output" + << " " << std::dec << kWordSize << " " + << addr.cast_to() + // Round up any request sizes under a byte, since OpenOCD doesn't + // support sub-byte-sized transfers. + << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes); socket_.SendCommand(); } @@ -104,9 +102,8 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { // The response from this command pairs indices with the contents of the // memory at that index. values >> index; - CHECK(index < num_bytes) - << "index " << index << - " out of bounds (length " << num_bytes << ")"; + CHECK(index < num_bytes) << "index " << index << " out of bounds (length " << num_bytes + << ")"; // Read the value into `curr_val`, instead of reading directly into // `buf_iter`, because otherwise it's interpreted as the ASCII value and // not the integral value. @@ -165,11 +162,9 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { socket_.SendCommand(); } { - socket_.cmd_builder() - << "array2mem input" - << " " << std::dec << kWordSize - << " " << addr.cast_to() - << " " << std::dec << num_bytes; + socket_.cmd_builder() << "array2mem input" + << " " << std::dec << kWordSize << " " << addr.cast_to() << " " + << std::dec << num_bytes; socket_.SendCommand(); } } @@ -196,9 +191,7 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { socket_.SendCommand(); } - const char* device_type() const final { - return "openocd"; - } + const char* device_type() const final { return "openocd"; } private: /*! \brief socket used to communicate with the device through Tcl */ @@ -220,8 +213,7 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { const std::shared_ptr OpenOCDLowLevelDeviceCreate(const std::string& server_addr, int port) { - std::shared_ptr lld = - std::make_shared(server_addr, port); + std::shared_ptr lld = std::make_shared(server_addr, port); return lld; } diff --git a/src/runtime/micro/standalone/minimal_vector.h b/src/runtime/micro/standalone/minimal_vector.h index 4d04e52..74bea06 100644 --- a/src/runtime/micro/standalone/minimal_vector.h +++ b/src/runtime/micro/standalone/minimal_vector.h @@ -27,7 +27,6 @@ namespace tvm { namespace micro { - // A minimal wrapper, derived from https://github.com/Robbepop/dynarray/, that // supports a minimal subset of the std::vector API with a minimized code size. template diff --git a/src/runtime/micro/standalone/utvm_graph_runtime.cc b/src/runtime/micro/standalone/utvm_graph_runtime.cc index 546ed7d..db55634 100644 --- a/src/runtime/micro/standalone/utvm_graph_runtime.cc +++ b/src/runtime/micro/standalone/utvm_graph_runtime.cc @@ -20,8 +20,10 @@ #include "utvm_graph_runtime.h" #include + #include #include + #include "picojson.h" namespace tvm { diff --git a/src/runtime/micro/standalone/utvm_runtime.cc b/src/runtime/micro/standalone/utvm_runtime.cc index 4184438..73d616b 100644 --- a/src/runtime/micro/standalone/utvm_runtime.cc +++ b/src/runtime/micro/standalone/utvm_runtime.cc @@ -16,15 +16,15 @@ * specific language governing permissions and limitations * under the License. */ +#include "tvm/runtime/micro/standalone/utvm_runtime.h" + #include -#include "tvm/runtime/micro/standalone/utvm_runtime.h" #include "utvm_graph_runtime.h" void* UTVMRuntimeCreate(const char* json, size_t json_len, void* module) { - return new tvm::micro::MicroGraphRuntime( - std::string(json, json + json_len), - reinterpret_cast(module)); + return new tvm::micro::MicroGraphRuntime(std::string(json, json + json_len), + reinterpret_cast(module)); } void UTVMRuntimeDestroy(void* handle) { diff --git a/src/runtime/micro/standalone/utvm_runtime_api.cc b/src/runtime/micro/standalone/utvm_runtime_api.cc index 896ff57..a6ac420 100644 --- a/src/runtime/micro/standalone/utvm_runtime_api.cc +++ b/src/runtime/micro/standalone/utvm_runtime_api.cc @@ -20,6 +20,7 @@ #include "utvm_runtime_api.h" #include + #include #include diff --git a/src/runtime/micro/standalone/utvm_runtime_api.h b/src/runtime/micro/standalone/utvm_runtime_api.h index 1b87052..b38aa0a 100644 --- a/src/runtime/micro/standalone/utvm_runtime_api.h +++ b/src/runtime/micro/standalone/utvm_runtime_api.h @@ -21,6 +21,7 @@ #include #include + #include // The subset of the TVM runtime API that is implemented by the minimal runtime API. diff --git a/src/runtime/micro/target_data_layout_encoder.h b/src/runtime/micro/target_data_layout_encoder.h index c99d796..9778177 100644 --- a/src/runtime/micro/target_data_layout_encoder.h +++ b/src/runtime/micro/target_data_layout_encoder.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_ #include + #include "host_driven/utvm_runtime.h" namespace tvm { @@ -97,10 +98,11 @@ class TargetDataLayoutEncoder { * \param start_addr start address of the encoder in device memory */ explicit TargetDataLayoutEncoder(size_t capacity, TargetWordSize word_size) - : buf_(std::vector()), curr_offset_(0), + : buf_(std::vector()), + curr_offset_(0), start_addr_(word_size, nullptr), - capacity_(capacity), word_size_(word_size) { - } + capacity_(capacity), + word_size_(word_size) {} /*! * \brief allocates a slot for `sizeof(T) * num_elems` bytes of data @@ -129,17 +131,13 @@ class TargetDataLayoutEncoder { * \brief returns the array backing the encoder's buffer * \return array backing the encoder's buffer */ - uint8_t* data() { - return buf_.data(); - } + uint8_t* data() { return buf_.data(); } /*! * \brief returns current size of the encoder's buffer * \return buffer size */ - size_t buf_size() const { - return buf_.size(); - } + size_t buf_size() const { return buf_.size(); } TargetPtr start_addr() const { CHECK_NE(start_addr_.value().uint64(), 0) << "start addr uninitialized"; @@ -148,8 +146,8 @@ class TargetDataLayoutEncoder { void set_start_addr(TargetPtr start_addr) { CHECK_EQ(buf_.size(), 0) << "cannot change encoder start addr unless empty"; - start_addr_ = TargetPtr(word_size_, - UpperAlignValue(start_addr.value().uint64(), word_size_.bytes())); + start_addr_ = + TargetPtr(word_size_, UpperAlignValue(start_addr.value().uint64(), word_size_.bytes())); } private: @@ -166,10 +164,8 @@ class TargetDataLayoutEncoder { }; template -TargetDataLayoutEncoder::Slot::Slot(TargetDataLayoutEncoder* parent, - size_t start_offset, - size_t size, - TargetPtr start_addr) +TargetDataLayoutEncoder::Slot::Slot(TargetDataLayoutEncoder* parent, size_t start_offset, + size_t size, TargetPtr start_addr) : parent_(parent), start_offset_(start_offset), curr_offset_(0), @@ -180,8 +176,8 @@ template TargetDataLayoutEncoder::Slot::~Slot() { // TODO(weberlo, areusch): this can mask the exception thrown by slot allocation... even though // that doesn't make sense. - CHECK(curr_offset_ == size_) << "unwritten space in slot; curr_offset=" - << curr_offset_ << ", size=" << size_; + CHECK(curr_offset_ == size_) << "unwritten space in slot; curr_offset=" << curr_offset_ + << ", size=" << size_; } template diff --git a/src/runtime/micro/tcl_socket.cc b/src/runtime/micro/tcl_socket.cc index 24abe42..8f482b8 100644 --- a/src/runtime/micro/tcl_socket.cc +++ b/src/runtime/micro/tcl_socket.cc @@ -20,10 +20,10 @@ /*! * \file tcl_socket.cc */ -#include - #include "tcl_socket.h" +#include + namespace tvm { namespace runtime { @@ -33,9 +33,7 @@ TclSocket::TclSocket() { reply_buf_.reserve(kReplyBufSize); } -TclSocket::~TclSocket() { - tcp_socket_.Close(); -} +TclSocket::~TclSocket() { tcp_socket_.Close(); } void TclSocket::Connect(tvm::support::SockAddr addr) { CHECK(tcp_socket_.Connect(addr)) << "failed to connect"; @@ -46,8 +44,7 @@ void TclSocket::SendCommand() { cmd_builder_ << terminate_token; std::string full_cmd = cmd_builder_.str(); - CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1) - << "failed to send command"; + CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1) << "failed to send command"; cmd_builder_.str(std::string()); reply_builder_.str(std::string()); @@ -67,8 +64,7 @@ void TclSocket::SendCommand() { CHECK(bytes_read != -1) << "failed to read command reply"; } while (last_read != terminate_token); last_reply_ = reply_builder_.str(); - CHECK_EQ(last_reply_[last_reply_.length()-1], terminate_token) - << "missing command terminator"; + CHECK_EQ(last_reply_[last_reply_.length() - 1], terminate_token) << "missing command terminator"; } } // namespace runtime diff --git a/src/runtime/micro/tcl_socket.h b/src/runtime/micro/tcl_socket.h index 0b23e7f..4aef2ae 100644 --- a/src/runtime/micro/tcl_socket.h +++ b/src/runtime/micro/tcl_socket.h @@ -66,12 +66,12 @@ class TclSocket { /* * \return string stream for current command being built - */ + */ std::ostringstream& cmd_builder() { return cmd_builder_; } /* * \return reply from most recently sent command - */ + */ const std::string& last_reply() { return last_reply_; } private: diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 813a79d..19f1f39 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -22,10 +22,12 @@ * \brief TVM module system */ #include -#include #include -#include +#include + #include +#include + #include "file_util.h" namespace tvm { @@ -55,8 +57,7 @@ void ModuleNode::Import(Module other) { stack.push_back(next); } } - CHECK(!visited.count(this)) - << "Cyclic dependency detected during import"; + CHECK(!visited.count(this)) << "Cyclic dependency detected during import"; this->imports_.emplace_back(std::move(other)); } @@ -73,25 +74,20 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) return pf; } -Module Module::LoadFromFile(const std::string& file_name, - const std::string& format) { +Module Module::LoadFromFile(const std::string& file_name, const std::string& format) { std::string fmt = GetFileFormat(file_name, format); - CHECK(fmt.length() != 0) - << "Cannot deduce format of file " << file_name; + CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name; if (fmt == "dll" || fmt == "dylib" || fmt == "dso") { fmt = "so"; } std::string load_f_name = "runtime.module.loadfile_" + fmt; const PackedFunc* f = Registry::Get(load_f_name); - CHECK(f != nullptr) - << "Loader of " << format << "(" - << load_f_name << ") is not presented."; + CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name << ") is not presented."; Module m = (*f)(file_name, format); return m; } -void ModuleNode::SaveToFile(const std::string& file_name, - const std::string& format) { +void ModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile"; } @@ -114,9 +110,8 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { } if (pf == nullptr) { const PackedFunc* f = Registry::Get(name); - CHECK(f != nullptr) - << "Cannot find function " << name - << " in the imported modules or global registry"; + CHECK(f != nullptr) << "Cannot find function " << name + << " in the imported modules or global registry"; return f; } else { import_cache_.insert(std::make_pair(name, std::make_shared(pf))); @@ -158,36 +153,30 @@ bool RuntimeEnabled(const std::string& target) { return runtime::Registry::Get(f_name) != nullptr; } -TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled") -.set_body_typed(RuntimeEnabled); +TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled); -TVM_REGISTER_GLOBAL("runtime.ModuleGetSource") -.set_body_typed([](Module mod, std::string fmt) { +TVM_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) { return mod->GetSource(fmt); }); -TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize") -.set_body_typed([](Module mod) { +TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) { return static_cast(mod->imports().size()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetImport") -.set_body_typed([](Module mod, int index) { +TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) { return mod->imports().at(index); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey") -.set_body_typed([](Module mod) { +TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { return std::string(mod->type_key()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") -.set_body_typed(Module::LoadFromFile); +TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") -.set_body_typed([](Module mod, std::string name, std::string fmt) { - mod->SaveToFile(name, fmt); -}); + .set_body_typed([](Module mod, std::string name, std::string fmt) { + mod->SaveToFile(name, fmt); + }); TVM_REGISTER_OBJECT_TYPE(ModuleNode); } // namespace runtime diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index ac12472..d97d01b 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -22,9 +22,10 @@ * \brief NDArray container infratructure. */ #include -#include #include #include +#include + #include "runtime_base.h" extern "C" { @@ -45,9 +46,12 @@ inline void VerifyDataType(DLDataType dtype) { // allow uint1 as a special flag for bool. if (dtype.bits == 1 && dtype.code == kDLUInt) return; // allow int1/uint4/int4 - else if (dtype.bits == 1 && dtype.code == kDLInt) return; - else if (dtype.bits == 4 && dtype.code == kDLUInt) return; - else if (dtype.bits == 4 && dtype.code == kDLInt) return; + else if (dtype.bits == 1 && dtype.code == kDLInt) + return; + else if (dtype.bits == 4 && dtype.code == kDLUInt) + return; + else if (dtype.bits == 4 && dtype.code == kDLInt) + return; else CHECK_EQ(dtype.bits % 8, 0); } @@ -65,12 +69,10 @@ void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; size_t arr_size = GetDataSize(*handle); - CHECK_EQ(arr_size, nbytes) - << "ArrayCopyFromBytes: size mismatch"; - DeviceAPI::Get(handle->ctx)->CopyDataFromTo( - data, 0, - handle->data, static_cast(handle->byte_offset), - nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr); + CHECK_EQ(arr_size, nbytes) << "ArrayCopyFromBytes: size mismatch"; + DeviceAPI::Get(handle->ctx) + ->CopyDataFromTo(data, 0, handle->data, static_cast(handle->byte_offset), nbytes, + cpu_ctx, handle->ctx, handle->dtype, nullptr); } void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { @@ -78,12 +80,10 @@ void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; size_t arr_size = GetDataSize(*handle); - CHECK_EQ(arr_size, nbytes) - << "ArrayCopyToBytes: size mismatch"; - DeviceAPI::Get(handle->ctx)->CopyDataFromTo( - handle->data, static_cast(handle->byte_offset), - data, 0, - nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr); + CHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; + DeviceAPI::Get(handle->ctx) + ->CopyDataFromTo(handle->data, static_cast(handle->byte_offset), data, 0, nbytes, + handle->ctx, cpu_ctx, handle->dtype, nullptr); } struct NDArray::Internal { @@ -93,8 +93,8 @@ struct NDArray::Internal { if (ptr->manager_ctx != nullptr) { static_cast(ptr->manager_ctx)->DecRef(); } else if (ptr->dl_tensor.data != nullptr) { - tvm::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace( - ptr->dl_tensor.ctx, ptr->dl_tensor.data); + tvm::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx) + ->FreeDataSpace(ptr->dl_tensor.ctx, ptr->dl_tensor.data); } delete ptr; } @@ -113,9 +113,7 @@ struct NDArray::Internal { } // Local create function which allocates tensor metadata // but does not allocate space for the data. - static NDArray Create(std::vector shape, - DLDataType dtype, - DLContext ctx) { + static NDArray Create(std::vector shape, DLDataType dtype, DLContext ctx) { VerifyDataType(dtype); // critical zone: construct header @@ -140,13 +138,11 @@ struct NDArray::Internal { ObjectRef::FFIClearAfterMove(&arr); return handle; } - static void FFIDecRef(TVMArrayHandle tensor) { - NDArray::FFIDecRef(tensor); - } + static void FFIDecRef(TVMArrayHandle tensor) { NDArray::FFIDecRef(tensor); } // Container to DLManagedTensor static DLManagedTensor* ToDLPack(TVMArrayHandle handle) { - auto* from = static_cast( - reinterpret_cast(handle)); + auto* from = + static_cast(reinterpret_cast(handle)); return ToDLPack(from); } @@ -168,11 +164,9 @@ struct NDArray::Internal { NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { CHECK(data_ != nullptr); - CHECK(get_mutable()->dl_tensor.strides == nullptr) - << "Can only create view for compact tensor"; + CHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor"; NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx); - ret.get_mutable()->dl_tensor.byte_offset = - this->get_mutable()->dl_tensor.byte_offset; + ret.get_mutable()->dl_tensor.byte_offset = this->get_mutable()->dl_tensor.byte_offset; size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor); size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor); CHECK_LE(view_size, curr_size) @@ -184,20 +178,15 @@ NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { return ret; } -DLManagedTensor* NDArray::ToDLPack() const { - return Internal::ToDLPack(get_mutable()); -} +DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); } -NDArray NDArray::Empty(std::vector shape, - DLDataType dtype, - DLContext ctx) { +NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.get_mutable()->dl_tensor); size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor); ret.get_mutable()->dl_tensor.data = - DeviceAPI::Get(ret->ctx)->AllocDataSpace( - ret->ctx, size, alignment, ret->dtype); + DeviceAPI::Get(ret->ctx)->AllocDataSpace(ret->ctx, size, alignment, ret->dtype); return ret; } @@ -227,34 +216,26 @@ void NDArray::CopyFromBytes(const void* data, size_t nbytes) { ArrayCopyFromBytes(&get_mutable()->dl_tensor, data, nbytes); } -void NDArray::CopyFromTo(const DLTensor* from, - DLTensor* to, - TVMStreamHandle stream) { +void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) { size_t from_size = GetDataSize(*from); size_t to_size = GetDataSize(*to); - CHECK_EQ(from_size, to_size) - << "TVMArrayCopyFromTo: The size must exactly match"; + CHECK_EQ(from_size, to_size) << "TVMArrayCopyFromTo: The size must exactly match"; - CHECK(from->ctx.device_type == to->ctx.device_type - || from->ctx.device_type == kDLCPU - || to->ctx.device_type == kDLCPU - || from->ctx.device_type == kDLCPUPinned - || to->ctx.device_type == kDLCPUPinned) - << "Can not copy across different ctx types directly"; + CHECK(from->ctx.device_type == to->ctx.device_type || from->ctx.device_type == kDLCPU || + to->ctx.device_type == kDLCPU || from->ctx.device_type == kDLCPUPinned || + to->ctx.device_type == kDLCPUPinned) + << "Can not copy across different ctx types directly"; // Use the context that is *not* a cpu context to get the correct device // api manager. TVMContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx; - DeviceAPI::Get(ctx)->CopyDataFromTo( - from->data, static_cast(from->byte_offset), - to->data, static_cast(to->byte_offset), - from_size, from->ctx, to->ctx, from->dtype, stream); + DeviceAPI::Get(ctx)->CopyDataFromTo(from->data, static_cast(from->byte_offset), to->data, + static_cast(to->byte_offset), from_size, from->ctx, + to->ctx, from->dtype, stream); } -std::vector NDArray::Shape() const { - return get_mutable()->shape_; -} +std::vector NDArray::Shape() const { return get_mutable()->shape_; } TVM_REGISTER_OBJECT_TYPE(NDArray::Container); @@ -273,14 +254,8 @@ int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { API_END(); } -int TVMArrayAlloc(const tvm_index_t* shape, - int ndim, - int dtype_code, - int dtype_bits, - int dtype_lanes, - int device_type, - int device_id, - TVMArrayHandle* out) { +int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, + int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { API_BEGIN(); DLDataType dtype; dtype.code = static_cast(dtype_code); @@ -300,43 +275,33 @@ int TVMArrayFree(TVMArrayHandle handle) { API_END(); } -int TVMArrayCopyFromTo(TVMArrayHandle from, - TVMArrayHandle to, - TVMStreamHandle stream) { +int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream) { API_BEGIN(); NDArray::CopyFromTo(from, to, stream); API_END(); } -int TVMArrayFromDLPack(DLManagedTensor* from, - TVMArrayHandle* out) { +int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out) { API_BEGIN(); *out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from)); API_END(); } -int TVMArrayToDLPack(TVMArrayHandle from, - DLManagedTensor** out) { +int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out) { API_BEGIN(); *out = NDArray::Internal::ToDLPack(from); API_END(); } -void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) { - (*(dltensor->deleter))(dltensor); -} +void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) { (*(dltensor->deleter))(dltensor); } -int TVMArrayCopyFromBytes(TVMArrayHandle handle, - void* data, - size_t nbytes) { +int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) { API_BEGIN(); ArrayCopyFromBytes(handle, data, nbytes); API_END(); } -int TVMArrayCopyToBytes(TVMArrayHandle handle, - void* data, - size_t nbytes) { +int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes) { API_BEGIN(); ArrayCopyToBytes(handle, data, nbytes); API_END(); diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 0301200..c8e6671 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -21,13 +21,15 @@ * \brief Object type management system. */ #include -#include #include +#include + #include #include -#include -#include #include +#include +#include + #include "object_internal.h" #include "runtime_base.h" @@ -75,10 +77,8 @@ class TypeContext { return child_tindex == parent_tindex; } - uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, - uint32_t static_tindex, - uint32_t parent_tindex, - uint32_t num_child_slots, + uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex, + uint32_t parent_tindex, uint32_t num_child_slots, bool child_slots_can_overflow) { std::lock_guard lock(mutex_); auto it = type_key2index_.find(skey); @@ -105,10 +105,8 @@ class TypeContext { allocated_tindex = static_tindex; CHECK_LT(static_tindex, type_table_.size()); CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U) - << "Conflicting static index " << static_tindex - << " between " << type_table_[allocated_tindex].name - << " and " - << skey; + << "Conflicting static index " << static_tindex << " between " + << type_table_[allocated_tindex].name << " and " << skey; } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) { // allocate the slot from parent's reserved pool allocated_tindex = parent_tindex + pinfo.allocated_slots; @@ -129,8 +127,7 @@ class TypeContext { type_table_[allocated_tindex].parent_index = parent_tindex; type_table_[allocated_tindex].num_slots = num_slots; type_table_[allocated_tindex].allocated_slots = 1; - type_table_[allocated_tindex].child_slots_can_overflow = - child_slots_can_overflow; + type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow; type_table_[allocated_tindex].name = skey; type_table_[allocated_tindex].name_hash = std::hash()(skey); // update the key2index mapping. @@ -140,16 +137,14 @@ class TypeContext { std::string TypeIndex2Key(uint32_t tindex) { std::lock_guard lock(mutex_); - CHECK(tindex < type_table_.size() && - type_table_[tindex].allocated_slots != 0) + CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) << "Unknown type index " << tindex; return type_table_[tindex].name; } size_t TypeIndex2KeyHash(uint32_t tindex) { std::lock_guard lock(mutex_); - CHECK(tindex < type_table_.size() && - type_table_[tindex].allocated_slots != 0) + CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) << "Unknown type index " << tindex; return type_table_[tindex].name_hash; } @@ -173,7 +168,7 @@ class TypeContext { for (const auto& info : type_table_) { if (info.index != 0 && num_children[info.index] >= min_children_count) { - std::cerr <<'[' << info.index << "] "<< info.name + std::cerr << '[' << info.index << "] " << info.name << "\tparent=" << type_table_[info.parent_index].name << "\tnum_child_slots=" << info.num_slots - 1 << "\tnum_children=" << num_children[info.index] << std::endl; @@ -198,18 +193,15 @@ class TypeContext { std::unordered_map type_key2index_; }; -uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, - uint32_t static_tindex, - uint32_t parent_tindex, - uint32_t num_child_slots, +uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, + uint32_t parent_tindex, uint32_t num_child_slots, bool child_slots_can_overflow) { return TypeContext::Global()->GetOrAllocRuntimeTypeIndex( key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow); } bool Object::DerivedFrom(uint32_t parent_tindex) const { - return TypeContext::Global()->DerivedFrom( - this->type_index_, parent_tindex); + return TypeContext::Global()->DerivedFrom(this->type_index_, parent_tindex); } std::string Object::TypeIndex2Key(uint32_t tindex) { @@ -224,14 +216,11 @@ uint32_t Object::TypeKey2Index(const std::string& key) { return TypeContext::Global()->TypeKey2Index(key); } - -TVM_REGISTER_GLOBAL("runtime.ObjectHash") -.set_body_typed([](ObjectRef obj) { +TVM_REGISTER_GLOBAL("runtime.ObjectHash").set_body_typed([](ObjectRef obj) { return static_cast(ObjectHash()(obj)); }); -TVM_REGISTER_GLOBAL("runtime.DumpTypeTable") -.set_body_typed([](int min_child_count) { +TVM_REGISTER_GLOBAL("runtime.DumpTypeTable").set_body_typed([](int min_child_count) { TypeContext::Global()->Dump(min_child_count); }); } // namespace runtime @@ -252,7 +241,6 @@ int TVMObjectFree(TVMObjectHandle obj) { int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { API_BEGIN(); - out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index( - type_key); + out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); API_END(); } diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 7955130..d56046c 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -24,8 +24,9 @@ #ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_ #define TVM_RUNTIME_OBJECT_INTERNAL_H_ -#include #include +#include + #include namespace tvm { @@ -68,4 +69,4 @@ class ObjectInternal { } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_ +#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_ diff --git a/src/runtime/opencl/aocl/aocl_common.h b/src/runtime/opencl/aocl/aocl_common.h index d9251f8..1b98d4b 100644 --- a/src/runtime/opencl/aocl/aocl_common.h +++ b/src/runtime/opencl/aocl/aocl_common.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_OPENCL_AOCL_AOCL_COMMON_H_ #include + #include "../opencl_common.h" namespace tvm { @@ -44,7 +45,6 @@ class AOCLWorkspace final : public OpenCLWorkspace { static const std::shared_ptr& Global(); }; - /*! \brief Thread local workspace for AOCL */ class AOCLThreadEntry : public OpenCLThreadEntry { public: diff --git a/src/runtime/opencl/aocl/aocl_device_api.cc b/src/runtime/opencl/aocl/aocl_device_api.cc index 84c29ee..07057ff 100644 --- a/src/runtime/opencl/aocl/aocl_device_api.cc +++ b/src/runtime/opencl/aocl/aocl_device_api.cc @@ -20,17 +20,16 @@ /*! * \file aocl_device_api.cc */ -#include #include +#include + #include "aocl_common.h" namespace tvm { namespace runtime { namespace cl { -OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { - return AOCLThreadEntry::ThreadLocal(); -} +OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { return AOCLThreadEntry::ThreadLocal(); } const std::shared_ptr& AOCLWorkspace::Global() { static std::shared_ptr inst = std::make_shared(); @@ -47,15 +46,12 @@ bool AOCLWorkspace::IsOpenCLDevice(TVMContext ctx) { typedef dmlc::ThreadLocalStore AOCLThreadStore; -AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { - return AOCLThreadStore::Get(); -} +AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { return AOCLThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.aocl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = AOCLWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.aocl").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = AOCLWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace cl } // namespace runtime diff --git a/src/runtime/opencl/aocl/aocl_module.cc b/src/runtime/opencl/aocl/aocl_module.cc index abda5b1..747188c 100644 --- a/src/runtime/opencl/aocl/aocl_module.cc +++ b/src/runtime/opencl/aocl/aocl_module.cc @@ -20,23 +20,24 @@ /*! * \file aocl_module.cc */ +#include "aocl_module.h" + #include #include -#include + #include #include +#include + #include "aocl_common.h" -#include "aocl_module.h" namespace tvm { namespace runtime { class AOCLModuleNode : public OpenCLModuleNode { public: - explicit AOCLModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) + explicit AOCLModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) : OpenCLModuleNode(data, fmt, fmap, source) {} const std::shared_ptr& GetGlobalWorkspace() final; }; @@ -45,18 +46,14 @@ const std::shared_ptr& AOCLModuleNode::GetGlobalWorkspace() return cl::AOCLWorkspace::Global(); } -Module AOCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module AOCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } -Module AOCLModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module AOCLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -66,8 +63,7 @@ Module AOCLModuleLoadFile(const std::string& file_name, return AOCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_aocx") -.set_body_typed(AOCLModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_aocx").set_body_typed(AOCLModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/aocl/aocl_module.h b/src/runtime/opencl/aocl/aocl_module.h index 70955cc..199a94d 100644 --- a/src/runtime/opencl/aocl/aocl_module.h +++ b/src/runtime/opencl/aocl/aocl_module.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_OPENCL_AOCL_AOCL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../../meta_data.h" namespace tvm { @@ -40,11 +42,8 @@ namespace runtime { * \param fmt The format of the data, can be "aocx" * \param fmap The map function information map of each function. */ -Module AOCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module AOCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_AOCL_AOCL_MODULE_H_ diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 8f9d5d6..a892bff 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -24,10 +24,10 @@ #ifndef TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ #define TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ +#include #include -#include #include -#include +#include /* There are many OpenCL platforms that do not yet support OpenCL 2.0, * hence we use 1.2 APIs, some of which are now deprecated. In order @@ -45,73 +45,120 @@ #include #endif +#include #include #include -#include -#include #include -#include "../workspace_pool.h" +#include + +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "../workspace_pool.h" namespace tvm { namespace runtime { namespace cl { -static_assert(sizeof(cl_mem) ==sizeof(void*), - "Required to store cl_mem inside void*"); +static_assert(sizeof(cl_mem) == sizeof(void*), "Required to store cl_mem inside void*"); inline const char* CLGetErrorString(cl_int error) { switch (error) { - case CL_SUCCESS: return "CL_SUCCESS"; - case CL_DEVICE_NOT_FOUND: return "CL_DEVICE_NOT_FOUND"; - case CL_DEVICE_NOT_AVAILABLE: return "CL_DEVICE_NOT_AVAILABLE"; - case CL_COMPILER_NOT_AVAILABLE: return "CL_COMPILER_NOT_AVAILABLE"; - case CL_MEM_OBJECT_ALLOCATION_FAILURE: return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; - case CL_OUT_OF_RESOURCES: return "CL_OUT_OF_RESOURCES"; - case CL_OUT_OF_HOST_MEMORY: return "CL_OUT_OF_HOST_MEMORY"; - case CL_PROFILING_INFO_NOT_AVAILABLE: return "CL_PROFILING_INFO_NOT_AVAILABLE"; - case CL_MEM_COPY_OVERLAP: return "CL_MEM_COPY_OVERLAP"; - case CL_IMAGE_FORMAT_MISMATCH: return "CL_IMAGE_FORMAT_MISMATCH"; - case CL_IMAGE_FORMAT_NOT_SUPPORTED: return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; - case CL_BUILD_PROGRAM_FAILURE: return "CL_BUILD_PROGRAM_FAILURE"; - case CL_MAP_FAILURE: return "CL_MAP_FAILURE"; - case CL_INVALID_VALUE: return "CL_INVALID_VALUE"; - case CL_INVALID_DEVICE_TYPE: return "CL_INVALID_DEVICE_TYPE"; - case CL_INVALID_PLATFORM: return "CL_INVALID_PLATFORM"; - case CL_INVALID_DEVICE: return "CL_INVALID_DEVICE"; - case CL_INVALID_CONTEXT: return "CL_INVALID_CONTEXT"; - case CL_INVALID_QUEUE_PROPERTIES: return "CL_INVALID_QUEUE_PROPERTIES"; - case CL_INVALID_COMMAND_QUEUE: return "CL_INVALID_COMMAND_QUEUE"; - case CL_INVALID_HOST_PTR: return "CL_INVALID_HOST_PTR"; - case CL_INVALID_MEM_OBJECT: return "CL_INVALID_MEM_OBJECT"; - case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; - case CL_INVALID_IMAGE_SIZE: return "CL_INVALID_IMAGE_SIZE"; - case CL_INVALID_SAMPLER: return "CL_INVALID_SAMPLER"; - case CL_INVALID_BINARY: return "CL_INVALID_BINARY"; - case CL_INVALID_BUILD_OPTIONS: return "CL_INVALID_BUILD_OPTIONS"; - case CL_INVALID_PROGRAM: return "CL_INVALID_PROGRAM"; - case CL_INVALID_PROGRAM_EXECUTABLE: return "CL_INVALID_PROGRAM_EXECUTABLE"; - case CL_INVALID_KERNEL_NAME: return "CL_INVALID_KERNEL_NAME"; - case CL_INVALID_KERNEL_DEFINITION: return "CL_INVALID_KERNEL_DEFINITION"; - case CL_INVALID_KERNEL: return "CL_INVALID_KERNEL"; - case CL_INVALID_ARG_INDEX: return "CL_INVALID_ARG_INDEX"; - case CL_INVALID_ARG_VALUE: return "CL_INVALID_ARG_VALUE"; - case CL_INVALID_ARG_SIZE: return "CL_INVALID_ARG_SIZE"; - case CL_INVALID_KERNEL_ARGS: return "CL_INVALID_KERNEL_ARGS"; - case CL_INVALID_WORK_DIMENSION: return "CL_INVALID_WORK_DIMENSION"; - case CL_INVALID_WORK_GROUP_SIZE: return "CL_INVALID_WORK_GROUP_SIZE"; - case CL_INVALID_WORK_ITEM_SIZE: return "CL_INVALID_WORK_ITEM_SIZE"; - case CL_INVALID_GLOBAL_OFFSET: return "CL_INVALID_GLOBAL_OFFSET"; - case CL_INVALID_EVENT_WAIT_LIST: return "CL_INVALID_EVENT_WAIT_LIST"; - case CL_INVALID_EVENT: return "CL_INVALID_EVENT"; - case CL_INVALID_OPERATION: return "CL_INVALID_OPERATION"; - case CL_INVALID_GL_OBJECT: return "CL_INVALID_GL_OBJECT"; - case CL_INVALID_BUFFER_SIZE: return "CL_INVALID_BUFFER_SIZE"; - case CL_INVALID_MIP_LEVEL: return "CL_INVALID_MIP_LEVEL"; - default: return "Unknown OpenCL error code"; + case CL_SUCCESS: + return "CL_SUCCESS"; + case CL_DEVICE_NOT_FOUND: + return "CL_DEVICE_NOT_FOUND"; + case CL_DEVICE_NOT_AVAILABLE: + return "CL_DEVICE_NOT_AVAILABLE"; + case CL_COMPILER_NOT_AVAILABLE: + return "CL_COMPILER_NOT_AVAILABLE"; + case CL_MEM_OBJECT_ALLOCATION_FAILURE: + return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; + case CL_OUT_OF_RESOURCES: + return "CL_OUT_OF_RESOURCES"; + case CL_OUT_OF_HOST_MEMORY: + return "CL_OUT_OF_HOST_MEMORY"; + case CL_PROFILING_INFO_NOT_AVAILABLE: + return "CL_PROFILING_INFO_NOT_AVAILABLE"; + case CL_MEM_COPY_OVERLAP: + return "CL_MEM_COPY_OVERLAP"; + case CL_IMAGE_FORMAT_MISMATCH: + return "CL_IMAGE_FORMAT_MISMATCH"; + case CL_IMAGE_FORMAT_NOT_SUPPORTED: + return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; + case CL_BUILD_PROGRAM_FAILURE: + return "CL_BUILD_PROGRAM_FAILURE"; + case CL_MAP_FAILURE: + return "CL_MAP_FAILURE"; + case CL_INVALID_VALUE: + return "CL_INVALID_VALUE"; + case CL_INVALID_DEVICE_TYPE: + return "CL_INVALID_DEVICE_TYPE"; + case CL_INVALID_PLATFORM: + return "CL_INVALID_PLATFORM"; + case CL_INVALID_DEVICE: + return "CL_INVALID_DEVICE"; + case CL_INVALID_CONTEXT: + return "CL_INVALID_CONTEXT"; + case CL_INVALID_QUEUE_PROPERTIES: + return "CL_INVALID_QUEUE_PROPERTIES"; + case CL_INVALID_COMMAND_QUEUE: + return "CL_INVALID_COMMAND_QUEUE"; + case CL_INVALID_HOST_PTR: + return "CL_INVALID_HOST_PTR"; + case CL_INVALID_MEM_OBJECT: + return "CL_INVALID_MEM_OBJECT"; + case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: + return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; + case CL_INVALID_IMAGE_SIZE: + return "CL_INVALID_IMAGE_SIZE"; + case CL_INVALID_SAMPLER: + return "CL_INVALID_SAMPLER"; + case CL_INVALID_BINARY: + return "CL_INVALID_BINARY"; + case CL_INVALID_BUILD_OPTIONS: + return "CL_INVALID_BUILD_OPTIONS"; + case CL_INVALID_PROGRAM: + return "CL_INVALID_PROGRAM"; + case CL_INVALID_PROGRAM_EXECUTABLE: + return "CL_INVALID_PROGRAM_EXECUTABLE"; + case CL_INVALID_KERNEL_NAME: + return "CL_INVALID_KERNEL_NAME"; + case CL_INVALID_KERNEL_DEFINITION: + return "CL_INVALID_KERNEL_DEFINITION"; + case CL_INVALID_KERNEL: + return "CL_INVALID_KERNEL"; + case CL_INVALID_ARG_INDEX: + return "CL_INVALID_ARG_INDEX"; + case CL_INVALID_ARG_VALUE: + return "CL_INVALID_ARG_VALUE"; + case CL_INVALID_ARG_SIZE: + return "CL_INVALID_ARG_SIZE"; + case CL_INVALID_KERNEL_ARGS: + return "CL_INVALID_KERNEL_ARGS"; + case CL_INVALID_WORK_DIMENSION: + return "CL_INVALID_WORK_DIMENSION"; + case CL_INVALID_WORK_GROUP_SIZE: + return "CL_INVALID_WORK_GROUP_SIZE"; + case CL_INVALID_WORK_ITEM_SIZE: + return "CL_INVALID_WORK_ITEM_SIZE"; + case CL_INVALID_GLOBAL_OFFSET: + return "CL_INVALID_GLOBAL_OFFSET"; + case CL_INVALID_EVENT_WAIT_LIST: + return "CL_INVALID_EVENT_WAIT_LIST"; + case CL_INVALID_EVENT: + return "CL_INVALID_EVENT"; + case CL_INVALID_OPERATION: + return "CL_INVALID_OPERATION"; + case CL_INVALID_GL_OBJECT: + return "CL_INVALID_GL_OBJECT"; + case CL_INVALID_BUFFER_SIZE: + return "CL_INVALID_BUFFER_SIZE"; + case CL_INVALID_MIP_LEVEL: + return "CL_INVALID_MIP_LEVEL"; + default: + return "Unknown OpenCL error code"; } } @@ -119,16 +166,13 @@ inline const char* CLGetErrorString(cl_int error) { * \brief Protected OpenCL call * \param func Expression to call. */ -#define OPENCL_CHECK_ERROR(e) \ - { \ - CHECK(e == CL_SUCCESS) \ - << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); \ - } +#define OPENCL_CHECK_ERROR(e) \ + { CHECK(e == CL_SUCCESS) << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); } -#define OPENCL_CALL(func) \ - { \ - cl_int e = (func); \ - OPENCL_CHECK_ERROR(e); \ +#define OPENCL_CALL(func) \ + { \ + cl_int e = (func); \ + OPENCL_CHECK_ERROR(e); \ } class OpenCLThreadEntry; @@ -172,37 +216,24 @@ class OpenCLWorkspace : public DeviceAPI { // Initialzie the device. void Init(const std::string& type_key, const std::string& device_type, const std::string& platform_name = ""); - virtual void Init() { - Init("opencl", "gpu"); - } + virtual void Init() { Init("opencl", "gpu"); } // Check whether the context is OpenCL or not. - virtual bool IsOpenCLDevice(TVMContext ctx) { - return ctx.device_type == kDLOpenCL; - } + virtual bool IsOpenCLDevice(TVMContext ctx) { return ctx.device_type == kDLOpenCL; } // get the queue of the context cl_command_queue GetQueue(TVMContext ctx) { CHECK(IsOpenCLDevice(ctx)); this->Init(); - CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) << "Invalid OpenCL device_id=" << ctx.device_id; return queues[ctx.device_id]; } // override device API void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t size, - size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; @@ -217,7 +248,6 @@ class OpenCLWorkspace : public DeviceAPI { static const std::shared_ptr& Global(); }; - /*! \brief Thread local workspace */ class OpenCLThreadEntry { public: @@ -240,8 +270,7 @@ class OpenCLThreadEntry { context.device_id = 0; context.device_type = device_type; } - OpenCLThreadEntry() - : OpenCLThreadEntry(kDLOpenCL, OpenCLWorkspace::Global()) {} + OpenCLThreadEntry() : OpenCLThreadEntry(kDLOpenCL, OpenCLWorkspace::Global()) {} // get the global workspace static OpenCLThreadEntry* ThreadLocal(); @@ -260,10 +289,8 @@ class OpenCLModuleNode : public ModuleNode { size_t kernel_id; size_t version; }; - explicit OpenCLModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) + explicit OpenCLModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} // destructor ~OpenCLModuleNode(); @@ -275,20 +302,15 @@ class OpenCLModuleNode : public ModuleNode { const char* type_key() const final { return workspace_->type_key.c_str(); } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; std::string GetSource(const std::string& format) final; // Initialize the programs void Init(); // install a new kernel to thread local entry - cl_kernel InstallKernel(cl::OpenCLWorkspace* w, - cl::OpenCLThreadEntry* t, - const std::string& func_name, - const KTRefEntry& e); + cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e); private: // The workspace, need to keep reference to use it in destructor. diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 99d2b0c..6d9835e 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -20,17 +20,16 @@ /*! * \file opencl_device_api.cc */ -#include #include +#include + #include "opencl_common.h" namespace tvm { namespace runtime { namespace cl { -OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { - return OpenCLThreadEntry::ThreadLocal(); -} +OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); } const std::shared_ptr& OpenCLWorkspace::Global() { static std::shared_ptr inst = std::make_shared(); @@ -41,23 +40,21 @@ void OpenCLWorkspace::SetDevice(TVMContext ctx) { GetThreadEntry()->context.device_id = ctx.device_id; } -void OpenCLWorkspace::GetAttr( - TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { +void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { this->Init(); size_t index = static_cast(ctx.device_id); if (kind == kExist) { - *rv = static_cast(index< devices.size()); + *rv = static_cast(index < devices.size()); return; } - CHECK_LT(index, devices.size()) - << "Invalid device id " << index; + CHECK_LT(index, devices.size()) << "Invalid device id " << index; switch (kind) { - case kExist: break; + case kExist: + break; case kMaxThreadsPerBlock: { size_t value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, - sizeof(size_t), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), + &value, nullptr)); *rv = static_cast(value); break; } @@ -72,58 +69,55 @@ void OpenCLWorkspace::GetAttr( } case kMaxSharedMemoryPerBlock: { cl_ulong value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_LOCAL_MEM_SIZE, - sizeof(cl_ulong), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_LOCAL_MEM_SIZE, sizeof(cl_ulong), + &value, nullptr)); *rv = static_cast(value); break; } - case kComputeVersion: return; + case kComputeVersion: + return; case kDeviceName: { char value[128] = {0}; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_NAME, - sizeof(value) - 1, value, nullptr)); + OPENCL_CALL( + clGetDeviceInfo(devices[index], CL_DEVICE_NAME, sizeof(value) - 1, value, nullptr)); *rv = std::string(value); break; } case kMaxClockRate: { cl_uint value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY, - sizeof(cl_uint), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY, sizeof(cl_uint), + &value, nullptr)); *rv = static_cast(value); break; } case kMultiProcessorCount: { cl_uint value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_COMPUTE_UNITS, - sizeof(cl_uint), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(cl_uint), + &value, nullptr)); *rv = static_cast(value); break; } case kMaxThreadDimensions: { size_t dims[3]; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, + nullptr)); std::stringstream ss; // use json string to return multiple int values; - ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; *rv = ss.str(); break; } - case kGcnArch: return; + case kGcnArch: + return; } } -void* OpenCLWorkspace::AllocDataSpace( - TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) { +void* OpenCLWorkspace::AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, + DLDataType type_hint) { this->Init(); CHECK(context != nullptr) << "No OpenCL device"; cl_int err_code; - cl_mem mptr = clCreateBuffer( - this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); + cl_mem mptr = clCreateBuffer(this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); OPENCL_CHECK_ERROR(err_code); return mptr; } @@ -137,38 +131,27 @@ void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { OPENCL_CALL(clReleaseMemObject(mptr)); } -void OpenCLWorkspace::CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, +void OpenCLWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) { this->Init(); CHECK(stream == nullptr); if (IsOpenCLDevice(ctx_from) && IsOpenCLDevice(ctx_to)) { - OPENCL_CALL(clEnqueueCopyBuffer( - this->GetQueue(ctx_to), - static_cast((void*)from), // NOLINT(*) - static_cast(to), - from_offset, to_offset, size, 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueCopyBuffer(this->GetQueue(ctx_to), + static_cast((void*)from), // NOLINT(*) + static_cast(to), from_offset, to_offset, size, 0, + nullptr, nullptr)); } else if (IsOpenCLDevice(ctx_from) && ctx_to.device_type == kDLCPU) { - OPENCL_CALL(clEnqueueReadBuffer( - this->GetQueue(ctx_from), - static_cast((void*)from), // NOLINT(*) - CL_FALSE, from_offset, size, - static_cast(to) + to_offset, - 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueReadBuffer(this->GetQueue(ctx_from), + static_cast((void*)from), // NOLINT(*) + CL_FALSE, from_offset, size, static_cast(to) + to_offset, + 0, nullptr, nullptr)); OPENCL_CALL(clFinish(this->GetQueue(ctx_from))); } else if (ctx_from.device_type == kDLCPU && IsOpenCLDevice(ctx_to)) { - OPENCL_CALL(clEnqueueWriteBuffer( - this->GetQueue(ctx_to), - static_cast(to), - CL_FALSE, to_offset, size, - static_cast(from) + from_offset, - 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueWriteBuffer(this->GetQueue(ctx_to), static_cast(to), CL_FALSE, + to_offset, size, static_cast(from) + from_offset, + 0, nullptr, nullptr)); OPENCL_CALL(clFinish(this->GetQueue(ctx_to))); } else { LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL"; @@ -180,9 +163,7 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { OPENCL_CALL(clFinish(this->GetQueue(ctx))); } -void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { +void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return GetThreadEntry()->pool.AllocWorkspace(ctx, size); } @@ -192,12 +173,9 @@ void OpenCLWorkspace::FreeWorkspace(TVMContext ctx, void* data) { typedef dmlc::ThreadLocalStore OpenCLThreadStore; -OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { - return OpenCLThreadStore::Get(); -} +OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { return OpenCLThreadStore::Get(); } -std::string GetPlatformInfo( - cl_platform_id pid, cl_platform_info param_name) { +std::string GetPlatformInfo(cl_platform_id pid, cl_platform_info param_name) { size_t ret_size; OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size)); std::string ret; @@ -206,8 +184,7 @@ std::string GetPlatformInfo( return ret; } -std::string GetDeviceInfo( - cl_device_id pid, cl_device_info param_name) { +std::string GetDeviceInfo(cl_device_id pid, cl_device_info param_name) { size_t ret_size; OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size)); std::string ret; @@ -226,8 +203,7 @@ std::vector GetPlatformIDs() { return ret; } -std::vector GetDeviceIDs( - cl_platform_id pid, std::string device_type) { +std::vector GetDeviceIDs(cl_platform_id pid, std::string device_type) { cl_device_type dtype = CL_DEVICE_TYPE_ALL; if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU; if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU; @@ -241,10 +217,7 @@ std::vector GetDeviceIDs( return ret; } -bool MatchPlatformInfo( - cl_platform_id pid, - cl_platform_info param_name, - std::string value) { +bool MatchPlatformInfo(cl_platform_id pid, cl_platform_info param_name, std::string value) { if (value.length() == 0) return true; std::string param_value = GetPlatformInfo(pid, param_name); return param_value.find(value) != std::string::npos; @@ -286,25 +259,22 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic return; } cl_int err_code; - this->context = clCreateContext( - nullptr, this->devices.size(), &(this->devices[0]), - nullptr, nullptr, &err_code); + this->context = clCreateContext(nullptr, this->devices.size(), &(this->devices[0]), nullptr, + nullptr, &err_code); OPENCL_CHECK_ERROR(err_code); CHECK_EQ(this->queues.size(), 0U); for (size_t i = 0; i < this->devices.size(); ++i) { cl_device_id did = this->devices[i]; - this->queues.push_back( - clCreateCommandQueue(this->context, did, 0, &err_code)); + this->queues.push_back(clCreateCommandQueue(this->context, did, 0, &err_code)); OPENCL_CHECK_ERROR(err_code); } initialized_ = true; } -TVM_REGISTER_GLOBAL("device_api.opencl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = OpenCLWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = OpenCLWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace cl } // namespace runtime diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index fefde72..95d0481 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -20,13 +20,16 @@ /*! * \file opencl_module.cc */ +#include "opencl_module.h" + #include #include -#include + #include #include +#include + #include "opencl_common.h" -#include "opencl_module.h" namespace tvm { namespace runtime { @@ -34,12 +37,9 @@ namespace runtime { class OpenCLWrappedFunc { public: // initialize the OpenCL function. - void Init(OpenCLModuleNode* m, - ObjectPtr sptr, - OpenCLModuleNode::KTRefEntry entry, - std::string func_name, - std::vector arg_size, - const std::vector& thread_axis_tags) { + void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, + std::string func_name, std::vector arg_size, + const std::vector& thread_axis_tags) { w_ = m->GetGlobalWorkspace().get(); m_ = m; sptr_ = sptr; @@ -49,9 +49,7 @@ class OpenCLWrappedFunc { thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - void** void_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { CHECK(w_->context != nullptr) << "No OpenCL device"; cl::OpenCLThreadEntry* t = w_->GetThreadEntry(); // get the kernel from thread local kernel table. @@ -74,11 +72,8 @@ class OpenCLWrappedFunc { wl.work_size[i] *= wl.work_size[i + 3]; } // launch kernel - OPENCL_CALL(clEnqueueNDRangeKernel( - queue, kernel, work_dim, nullptr, - wl.work_size, - wl.work_size + 3, - 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueNDRangeKernel(queue, kernel, work_dim, nullptr, wl.work_size, + wl.work_size + 3, 0, nullptr, nullptr)); } private: @@ -119,12 +114,10 @@ const std::shared_ptr& OpenCLModuleNode::GetGlobalWorkspace return cl::OpenCLWorkspace::Global(); } -PackedFunc OpenCLModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; @@ -143,16 +136,13 @@ PackedFunc OpenCLModuleNode::GetFunction( } } // initialize the wrapped func. - f.Init(this, sptr_to_self, kid_map_.at(name), - name, arg_size, info.thread_axis_tags); + f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags); return PackFuncVoidAddr(f, info.arg_types); } -void OpenCLModuleNode::SaveToFile(const std::string& file_name, - const std::string& format) { +void OpenCLModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -193,10 +183,8 @@ void OpenCLModuleNode::Init() { } } -cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, - cl::OpenCLThreadEntry* t, - const std::string& func_name, - const KTRefEntry& e) { +cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) { std::lock_guard lock(build_lock_); int device_id = t->context.device_id; if (!device_built_flag_[device_id]) { @@ -210,7 +198,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, OPENCL_CHECK_ERROR(err); } } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") { - const unsigned char* s = (const unsigned char *)data_.c_str(); + const unsigned char* s = (const unsigned char*)data_.c_str(); size_t len = data_.length(); cl_int err; cl_device_id dev = w->devices[device_id]; @@ -226,11 +214,9 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, if (err != CL_SUCCESS) { size_t len; std::string log; - clGetProgramBuildInfo( - program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len); + clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len); log.resize(len); - clGetProgramBuildInfo( - program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); + clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); LOG(FATAL) << "OpenCL build error for device=" << dev << log; } device_built_flag_[device_id] = true; @@ -245,19 +231,15 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, return kernel; } -Module OpenCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } // Load module from module. -Module OpenCLModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module OpenCLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -278,13 +260,10 @@ Module OpenCLModuleLoadBinary(void* strm) { return OpenCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl") -.set_body_typed(OpenCLModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin") -.set_body_typed(OpenCLModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl") -.set_body_typed(OpenCLModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl").set_body_typed(OpenCLModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 3b7ebb9..77f4b80 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -40,11 +42,8 @@ namespace runtime { * \param fmt The format of the data, can be "clbin", "cl" * \param fmap The map function information map of each function. */ -Module OpenCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ diff --git a/src/runtime/opencl/sdaccel/sdaccel_common.h b/src/runtime/opencl/sdaccel/sdaccel_common.h index 2100b50..803cbe6 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_common.h +++ b/src/runtime/opencl/sdaccel/sdaccel_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,6 +25,7 @@ #define TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_COMMON_H_ #include + #include "../opencl_common.h" namespace tvm { @@ -44,7 +45,6 @@ class SDAccelWorkspace final : public OpenCLWorkspace { static const std::shared_ptr& Global(); }; - /*! \brief Thread local workspace for SDAccel*/ class SDAccelThreadEntry : public OpenCLThreadEntry { public: diff --git a/src/runtime/opencl/sdaccel/sdaccel_device_api.cc b/src/runtime/opencl/sdaccel/sdaccel_device_api.cc index 59e8a25..6bac0c9 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_device_api.cc +++ b/src/runtime/opencl/sdaccel/sdaccel_device_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,26 +20,23 @@ /*! * \file sdaccel_device_api.cc */ -#include #include +#include + #include "sdaccel_common.h" namespace tvm { namespace runtime { namespace cl { -OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { - return SDAccelThreadEntry::ThreadLocal(); -} +OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { return SDAccelThreadEntry::ThreadLocal(); } const std::shared_ptr& SDAccelWorkspace::Global() { static std::shared_ptr inst = std::make_shared(); return inst; } -void SDAccelWorkspace::Init() { - OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx"); -} +void SDAccelWorkspace::Init() { OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx"); } bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) { return ctx.device_type == static_cast(kDLSDAccel); @@ -47,15 +44,12 @@ bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) { typedef dmlc::ThreadLocalStore SDAccelThreadStore; -SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { - return SDAccelThreadStore::Get(); -} +SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { return SDAccelThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.sdaccel") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = SDAccelWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.sdaccel").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = SDAccelWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace cl } // namespace runtime diff --git a/src/runtime/opencl/sdaccel/sdaccel_module.cc b/src/runtime/opencl/sdaccel/sdaccel_module.cc index 4569ec3..b4edca3 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_module.cc +++ b/src/runtime/opencl/sdaccel/sdaccel_module.cc @@ -20,23 +20,24 @@ /*! * \file sdaccel_module.cc */ +#include "sdaccel_module.h" + #include #include -#include + #include #include +#include + #include "sdaccel_common.h" -#include "sdaccel_module.h" namespace tvm { namespace runtime { class SDAccelModuleNode : public OpenCLModuleNode { public: - explicit SDAccelModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) + explicit SDAccelModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) : OpenCLModuleNode(data, fmt, fmap, source) {} const std::shared_ptr& GetGlobalWorkspace() final; }; @@ -45,18 +46,14 @@ const std::shared_ptr& SDAccelModuleNode::GetGlobalWorkspac return cl::SDAccelWorkspace::Global(); } -Module SDAccelModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module SDAccelModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } -Module SDAccelModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module SDAccelModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -77,10 +74,8 @@ Module SDAccelModuleLoadBinary(void* strm) { return SDAccelModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_xclbin") -.set_body_typed(SDAccelModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_xclbin").set_body_typed(SDAccelModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_awsxclbin") -.set_body_typed(SDAccelModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_awsxclbin").set_body_typed(SDAccelModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/sdaccel/sdaccel_module.h b/src/runtime/opencl/sdaccel/sdaccel_module.h index e126291..322decc 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_module.h +++ b/src/runtime/opencl/sdaccel/sdaccel_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,10 +25,12 @@ #define TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../../meta_data.h" namespace tvm { @@ -40,11 +42,8 @@ namespace runtime { * \param fmt The format of the data, can be "xclbin", "awsxclbin" * \param fmap The map function information map of each function. */ -Module SDAccelModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module SDAccelModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_MODULE_H_ diff --git a/src/runtime/opengl/opengl_common.h b/src/runtime/opengl/opengl_common.h index 009ea6c..eca45d7 100644 --- a/src/runtime/opengl/opengl_common.h +++ b/src/runtime/opengl/opengl_common.h @@ -24,19 +24,20 @@ #ifndef TVM_RUNTIME_OPENGL_OPENGL_COMMON_H_ #define TVM_RUNTIME_OPENGL_OPENGL_COMMON_H_ +#include #include -#include #include -#include +#include #if defined(__APPLE__) #define GLFW_INCLUDE_GLCOREARB #endif #include + +#include #include #include #include #include -#include namespace tvm { namespace runtime { @@ -54,8 +55,7 @@ inline GLFWglproc GetProcAddress(const char* procname) { return proc; } -#define SetGLFunctionPointer(NAME) \ - NAME(decltype(NAME)(GetProcAddress("gl" #NAME))) +#define SetGLFunctionPointer(NAME) NAME(decltype(NAME)(GetProcAddress("gl" #NAME))) /*! * \brief The function pointers of all OpenGL APIs that are used. @@ -117,8 +117,7 @@ class GLFunctionPointers { void (*BindFramebuffer)(GLenum target, GLuint framebuffer); void (*BindTexture)(GLenum target, GLuint texture); void (*BindVertexArray)(GLuint array); - void (*BufferData)(GLenum target, GLsizeiptr size, const GLvoid* data, - GLenum usage); + void (*BufferData)(GLenum target, GLsizeiptr size, const GLvoid* data, GLenum usage); GLenum (*CheckFramebufferStatus)(GLenum target); void (*Clear)(GLbitfield mask); void (*CompileShader)(GLuint shader); @@ -133,8 +132,8 @@ class GLFunctionPointers { void (*DrawBuffers)(GLsizei n, const GLenum* bufs); void (*EnableVertexAttribArray)(GLuint index); void (*Finish)(); - void (*FramebufferTexture2D)(GLenum target, GLenum attachment, - GLenum textarget, GLuint texture, GLint level); + void (*FramebufferTexture2D)(GLenum target, GLenum attachment, GLenum textarget, GLuint texture, + GLint level); void (*GenBuffers)(GLsizei n, GLuint* buffers); void (*GenFramebuffers)(GLsizei n, GLuint* ids); void (*GenTextures)(GLsizei n, GLuint* textures); @@ -142,32 +141,26 @@ class GLFunctionPointers { GLint (*GetAttribLocation)(GLuint program, const GLchar* name); GLenum (*GetError)(); void (*GetIntegerv)(GLenum pname, GLint* data); - void (*GetProgramInfoLog)(GLuint program, GLsizei maxLength, GLsizei* length, - GLchar* info_log); + void (*GetProgramInfoLog)(GLuint program, GLsizei maxLength, GLsizei* length, GLchar* info_log); void (*GetProgramiv)(GLuint program, GLenum pname, GLint* params); - void (*GetShaderInfoLog)(GLuint shader, GLsizei max_length, GLsizei* length, - GLchar* info_log); + void (*GetShaderInfoLog)(GLuint shader, GLsizei max_length, GLsizei* length, GLchar* info_log); void (*GetShaderiv)(GLuint shader, GLenum pname, GLint* params); - const GLubyte *(*GetString)(GLenum name); + const GLubyte* (*GetString)(GLenum name); GLint (*GetUniformLocation)(GLuint program, const GLchar* name); void (*LinkProgram)(GLuint program); - void (*ReadPixels)(GLint x, GLint y, GLsizei width, GLsizei height, - GLenum format, GLenum type, GLvoid* data); - void (*ShaderSource)(GLuint shader, GLsizei count, const GLchar** string, - const GLint* length); - void (*TexImage2D)(GLenum target, GLint level, GLint internal_format, - GLsizei width, GLsizei height, GLint border, GLenum format, - GLenum type, const GLvoid* data); + void (*ReadPixels)(GLint x, GLint y, GLsizei width, GLsizei height, GLenum format, GLenum type, + GLvoid* data); + void (*ShaderSource)(GLuint shader, GLsizei count, const GLchar** string, const GLint* length); + void (*TexImage2D)(GLenum target, GLint level, GLint internal_format, GLsizei width, + GLsizei height, GLint border, GLenum format, GLenum type, const GLvoid* data); void (*TexParameteri)(GLenum target, GLenum pname, GLint param); - void (*TexSubImage2D)(GLenum target, GLint level, GLint xoffset, - GLint yoffset, GLsizei width, GLsizei height, - GLenum format, GLenum type, const GLvoid* data); + void (*TexSubImage2D)(GLenum target, GLint level, GLint xoffset, GLint yoffset, GLsizei width, + GLsizei height, GLenum format, GLenum type, const GLvoid* data); void (*Uniform1f)(GLint location, GLfloat v0); void (*Uniform1i)(GLint location, GLint v0); void (*UseProgram)(GLuint program); - void (*VertexAttribPointer)(GLuint index, GLint size, GLenum type, - GLboolean normalized, GLsizei stride, - const GLvoid* pointer); + void (*VertexAttribPointer)(GLuint index, GLint size, GLenum type, GLboolean normalized, + GLsizei stride, const GLvoid* pointer); void (*Viewport)(GLint x, GLint y, GLsizei width, GLsizei height); }; @@ -181,19 +174,10 @@ class OpenGLWorkspace final : public DeviceAPI { // override device API void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; @@ -225,10 +209,7 @@ class OpenGLWorkspace final : public DeviceAPI { * \param nelems The number of elements to be written to. * \param data The user data. */ - void PutTextureData(Texture* texture, - GLint begin, - GLsizei nelems, - const GLvoid* data); + void PutTextureData(Texture* texture, GLint begin, GLsizei nelems, const GLvoid* data); /*! * \brief Download a sub-region of an OpenGL texture. * \param texture The texture to download from. @@ -236,10 +217,7 @@ class OpenGLWorkspace final : public DeviceAPI { * \param nelems The number of elements to download from. * \param data The user buffer. */ - void GetTextureData(const Texture* texture, - GLint begin, - GLsizei nelems, - GLvoid* data); + void GetTextureData(const Texture* texture, GLint begin, GLsizei nelems, GLvoid* data); /*! * \brief Set currently used OpenGL program. @@ -254,10 +232,7 @@ class OpenGLWorkspace final : public DeviceAPI { * \param type The type of the uniform. * \param value The value to pass in. */ - void SetUniform(const Program& program, - const std::string& name, - DLDataType type, - void* value); + void SetUniform(const Program& program, const std::string& name, DLDataType type, void* value); /*! * \brief Set input texture for an OpenGL program. @@ -268,9 +243,7 @@ class OpenGLWorkspace final : public DeviceAPI { * different unit. * \param texture The OpenGL texture to pass in. */ - void SetInputTexture(const Program& program, - const std::string& name, - GLuint unit, + void SetInputTexture(const Program& program, const std::string& name, GLuint unit, Texture* texture); /*! @@ -354,8 +327,7 @@ class OpenGLWorkspace final : public DeviceAPI { class Program { public: // Move constructor. - Program(Program&& other) noexcept - : workspace_(other.workspace_), program_(other.program_) { + Program(Program&& other) noexcept : workspace_(other.workspace_), program_(other.program_) { other.program_ = kInvalidProgram; } @@ -406,11 +378,14 @@ struct TextureFormat { GLsizei elemsz() const { switch (type) { - case GL_BYTE: case GL_UNSIGNED_BYTE: + case GL_BYTE: + case GL_UNSIGNED_BYTE: return 1; - case GL_SHORT: case GL_UNSIGNED_SHORT: + case GL_SHORT: + case GL_UNSIGNED_SHORT: return 2; - case GL_INT: case GL_UNSIGNED_INT: + case GL_INT: + case GL_UNSIGNED_INT: return 4; case GL_FLOAT: return 4; @@ -422,7 +397,7 @@ struct TextureFormat { bool operator==(const TextureFormat& other) const { return std::make_tuple(internal_format, format, type) == - std::make_tuple(other.internal_format, other.format, other.type); + std::make_tuple(other.internal_format, other.format, other.type); } GLint internal_format; // OpenGL says this is GLint, not GLenum. @@ -439,8 +414,11 @@ class Texture { public: // Move constructor. Texture(Texture&& other) noexcept - : workspace_(other.workspace_), texture_(other.texture_), - format_(other.format_), width_(other.width_), height_(other.height_) { + : workspace_(other.workspace_), + texture_(other.texture_), + format_(other.format_), + width_(other.width_), + height_(other.height_) { other.texture_ = kInvalidTexture; } @@ -489,11 +467,9 @@ class Texture { // We enforce this to make sure OpenGL is initialized. // Always only use the first dimension of a 2D texture. // The reason is that texelFetch only supports 2D textures. - explicit Texture(OpenGLWorkspace* workspace, GLuint texture, - TextureFormat format, - GLsizei width, GLsizei height) - : workspace_(workspace), texture_(texture), format_(format), - width_(width), height_(height) {} + explicit Texture(OpenGLWorkspace* workspace, GLuint texture, TextureFormat format, GLsizei width, + GLsizei height) + : workspace_(workspace), texture_(texture), format_(format), width_(width), height_(height) {} // The internal texture ID. GLuint texture() const { return texture_; } diff --git a/src/runtime/opengl/opengl_device_api.cc b/src/runtime/opengl/opengl_device_api.cc index 0be921c..b3e4f59 100644 --- a/src/runtime/opengl/opengl_device_api.cc +++ b/src/runtime/opengl/opengl_device_api.cc @@ -21,7 +21,9 @@ * \file opengl_device_api.cc */ #include + #include + #include "opengl_common.h" #include "opengl_module.h" @@ -60,26 +62,23 @@ static const char* GLGetErrorString(GLenum error) { */ void OpenGLWorkspace::CheckOpenGLError() { GLenum err = gl->GetError(); - CHECK_EQ(err, GL_NO_ERROR) << "OpenGL error, code=" << err << ": " - << gl::GLGetErrorString(err); + CHECK_EQ(err, GL_NO_ERROR) << "OpenGL error, code=" << err << ": " << gl::GLGetErrorString(err); } /*! * \brief Protected OpenGL call. * \param func Expression to call. */ -#define OPENGL_CALL(func) \ - { \ - (func); \ - CheckOpenGLError(); \ +#define OPENGL_CALL(func) \ + { \ + (func); \ + CheckOpenGLError(); \ } /*! * \brief The error handling callback passed to GLFW. */ -void GlfwErrorCallback(int err, const char* str) { - LOG(FATAL) << "Error: [" << err << "] " << str; -} +void GlfwErrorCallback(int err, const char* str) { LOG(FATAL) << "Error: [" << err << "] " << str; } const std::shared_ptr& OpenGLWorkspace::Global() { static std::shared_ptr inst(new OpenGLWorkspace); @@ -87,13 +86,11 @@ const std::shared_ptr& OpenGLWorkspace::Global() { } void OpenGLWorkspace::SetDevice(TVMContext ctx) { - CHECK_EQ(ctx.device_type, static_cast(kOpenGL)) - << "Device type must be OpenGL."; + CHECK_EQ(ctx.device_type, static_cast(kOpenGL)) << "Device type must be OpenGL."; CHECK_EQ(ctx.device_id, 0) << "Only support 1 OpenGL \"device\"."; } -void OpenGLWorkspace::GetAttr( - TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { +void OpenGLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { switch (kind) { case kExist: { *rv = static_cast(ctx.device_id == 0); @@ -108,20 +105,26 @@ void OpenGLWorkspace::GetAttr( *rv = 1; break; } - case kMaxSharedMemoryPerBlock: return; + case kMaxSharedMemoryPerBlock: + return; case kComputeVersion: { break; } - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; - case kGcnArch: return; + case kDeviceName: + return; + case kMaxClockRate: + return; + case kMultiProcessorCount: + return; + case kMaxThreadDimensions: + return; + case kGcnArch: + return; } } -void* OpenGLWorkspace::AllocDataSpace( - TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) { +void* OpenGLWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint) { return reinterpret_cast(new Texture(CreateTexture(type_hint, nbytes))); } @@ -129,14 +132,9 @@ void OpenGLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { delete reinterpret_cast(ptr); } -void OpenGLWorkspace::CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, +void OpenGLWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) { CHECK(stream == nullptr); @@ -159,7 +157,7 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from, } else if (type_from_to == std::make_tuple(gl_devtype, kDLCPU)) { auto texture = static_cast(from); - void *data = static_cast(to) + to_offset; + void* data = static_cast(to) + to_offset; auto elemsz = texture->elemsz(); auto begin = static_cast(from_offset / elemsz); auto nelems = static_cast(size / elemsz); @@ -213,8 +211,7 @@ OpenGLWorkspace::OpenGLWorkspace() { GLuint vertex_buffer; OPENGL_CALL(gl->GenBuffers(1, &vertex_buffer)); OPENGL_CALL(gl->BindBuffer(GL_ARRAY_BUFFER, vertex_buffer)); - OPENGL_CALL(gl->BufferData(GL_ARRAY_BUFFER, sizeof(vertices), vertices, - GL_STATIC_DRAW)); + OPENGL_CALL(gl->BufferData(GL_ARRAY_BUFFER, sizeof(vertices), vertices, GL_STATIC_DRAW)); GLuint vertex_array; OPENGL_CALL(gl->GenVertexArrays(1, &vertex_array)); @@ -244,9 +241,7 @@ void OpenGLWorkspace::OnDeleteTexture(GLuint texture) { OPENGL_CALL(gl->DeleteTextures(1, &texture)); } -void OpenGLWorkspace::OnDeleteProgram(GLuint program) { - OPENGL_CALL(gl->DeleteProgram(program)); -} +void OpenGLWorkspace::OnDeleteProgram(GLuint program) { OPENGL_CALL(gl->DeleteProgram(program)); } GLuint OpenGLWorkspace::NumTextureUnits() { GLint num_units; @@ -255,28 +250,22 @@ GLuint OpenGLWorkspace::NumTextureUnits() { } const OpenGLWorkspace::Vertex OpenGLWorkspace::vertices[OpenGLWorkspace::kNumVertices] = { - {-1.f, -1.f}, - {1.0f, -1.f}, - {1.0f, 1.0f}, - {-1.f, -1.f}, - {-1.f, 1.0f}, - {1.0f, 1.0f}, + {-1.f, -1.f}, {1.0f, -1.f}, {1.0f, 1.0f}, {-1.f, -1.f}, {-1.f, 1.0f}, {1.0f, 1.0f}, }; // Don't need to change this. // The vertex shader only needs to take in the triangle points. // No need for point transformations. -const char* OpenGLWorkspace::vertex_shader_text_ = "#version 300 es\n" +const char* OpenGLWorkspace::vertex_shader_text_ = + "#version 300 es\n" "in vec2 point; // input to vertex shader\n" "void main() {\n" " gl_Position = vec4(point, 0.0, 1.0);\n" "}\n"; -Program OpenGLWorkspace::CreateProgram( - const char* fragment_shader_src) { +Program OpenGLWorkspace::CreateProgram(const char* fragment_shader_src) { // Create and compile the shaders. - GLuint fragment_shader = CreateShader(GL_FRAGMENT_SHADER, - fragment_shader_src); + GLuint fragment_shader = CreateShader(GL_FRAGMENT_SHADER, fragment_shader_src); // Link the shaders and create the program. Program program = CreateProgram(fragment_shader); @@ -286,8 +275,7 @@ Program OpenGLWorkspace::CreateProgram( return program; } -GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind, - const char* shader_src) { +GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind, const char* shader_src) { // Create the shader. GLuint shader = gl->CreateShader(shader_kind); gl->ShaderSource(shader, 1, &shader_src, nullptr); @@ -367,20 +355,14 @@ Texture OpenGLWorkspace::CreateTexture(DLDataType type, size_t nbytes) { auto nelems = static_cast(nbytes / (type.bits / 8)); auto height = (nelems + kTextureRowSize - 1) / kTextureRowSize; auto width = (height == 1) ? nelems : kTextureRowSize; - OPENGL_CALL(gl->TexImage2D(GL_TEXTURE_2D, /*level=*/0, - texture_format.internal_format, - width, height, /*border=*/0, - texture_format.format, texture_format.type, + OPENGL_CALL(gl->TexImage2D(GL_TEXTURE_2D, /*level=*/0, texture_format.internal_format, width, + height, /*border=*/0, texture_format.format, texture_format.type, /*data=*/nullptr)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)); + OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)); + OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)); + OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)); + OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)); return Texture(this, texture, texture_format, width, height); } @@ -414,8 +396,8 @@ Program OpenGLWorkspace::CreateProgram(GLuint fragment_shader) { auto point_attrib = GLuint(gl->GetAttribLocation(program, "point")); OPENGL_CALL(gl->EnableVertexAttribArray(point_attrib)); - OPENGL_CALL(gl->VertexAttribPointer(point_attrib, 2, GL_FLOAT, GL_FALSE, - sizeof(Vertex), nullptr)); + OPENGL_CALL( + gl->VertexAttribPointer(point_attrib, 2, GL_FLOAT, GL_FALSE, sizeof(Vertex), nullptr)); return Program(this, program); } @@ -465,29 +447,22 @@ static void Visit1DRange(GLint beg, GLint end, F&& on_2d_block) { on_2d_block(0, ylast, xlast + 1, 1); } -void OpenGLWorkspace::PutTextureData(Texture *texture, - GLint begin, - GLsizei nelems, +void OpenGLWorkspace::PutTextureData(Texture* texture, GLint begin, GLsizei nelems, const GLvoid* data) { // Bind to temporary unit. BindTextureUnit(NumTextureUnits() - 1, texture->texture()); - Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, - GLsizei width, GLsizei height) { + Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, GLsizei width, GLsizei height) { auto offset = (ybeg * kTextureRowSize + xbeg - begin) * texture->elemsz(); const GLvoid* ptr = static_cast(data) + offset; // Similar to cudaMemcpy. - OPENGL_CALL(gl->TexSubImage2D(GL_TEXTURE_2D, /*level=*/0, - xbeg, ybeg, width, height, - texture->format_.format, - texture->format_.type, ptr)); + OPENGL_CALL(gl->TexSubImage2D(GL_TEXTURE_2D, /*level=*/0, xbeg, ybeg, width, height, + texture->format_.format, texture->format_.type, ptr)); }); } -void OpenGLWorkspace::GetTextureData(const Texture *texture, - GLint begin, - GLsizei nelems, +void OpenGLWorkspace::GetTextureData(const Texture* texture, GLint begin, GLsizei nelems, GLvoid* data) { BindTextureUnit(NumTextureUnits() - 1, texture->texture()); @@ -497,8 +472,8 @@ void OpenGLWorkspace::GetTextureData(const Texture *texture, OPENGL_CALL(gl->BindFramebuffer(GL_FRAMEBUFFER, frame_buffer)); // Bind texture to framebuffer's attachment 0. - OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_TEXTURE_2D, texture->texture(), 0)); + OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, + texture->texture(), 0)); // Always check that our framebuffer is okay. if (gl->CheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE) { @@ -521,28 +496,24 @@ void OpenGLWorkspace::GetTextureData(const Texture *texture, auto nchannels = 4; auto padded_data_size = nchannels * nelems * elemsz; auto padded_data = std::unique_ptr(new char[padded_data_size]); - Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, - GLsizei width, GLsizei height) { + Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, GLsizei width, GLsizei height) { auto data_offset = (ybeg * kTextureRowSize + xbeg - begin) * elemsz; auto padded_data_offset = data_offset * nchannels; - OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, - GL_RGBA, GL_FLOAT, + OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, GL_RGBA, GL_FLOAT, padded_data.get() + padded_data_offset)); }); for (GLsizei i = 0; i != nelems; ++i) { - auto dst = reinterpret_cast(data) + i * elemsz; + auto dst = reinterpret_cast(data) + i * elemsz; auto src = padded_data.get() + nchannels * i * elemsz; std::memcpy(dst, src, elemsz); } #else - Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, - GLsizei width, GLsizei height) { + Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, GLsizei width, GLsizei height) { auto offset = (ybeg * kTextureRowSize + xbeg - begin) * texture->elemsz(); GLvoid* ptr = static_cast(data) + offset; - OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, - texture->format_.format, texture->format_.type, - ptr)); + OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, texture->format_.format, + texture->format_.type, ptr)); }); #endif @@ -553,9 +524,7 @@ void OpenGLWorkspace::SetCurrentProgram(const Program& program) { OPENGL_CALL(gl->UseProgram(program.program())); } -void OpenGLWorkspace::SetUniform(const Program& program, - const std::string& name, - DLDataType type, +void OpenGLWorkspace::SetUniform(const Program& program, const std::string& name, DLDataType type, void* value) { GLint location = gl->GetUniformLocation(program.program(), name.c_str()); switch (type.code) { @@ -582,9 +551,7 @@ void OpenGLWorkspace::SetUniform(const Program& program, } } -void OpenGLWorkspace::SetInputTexture(const Program& program, - const std::string& name, - GLuint unit, +void OpenGLWorkspace::SetInputTexture(const Program& program, const std::string& name, GLuint unit, Texture* texture) { // We always use the last texture unit as temporary. // Therefore, we can have "NumTextureUnits() - 1" input textures. @@ -602,8 +569,8 @@ void OpenGLWorkspace::Render(Texture* output) { OPENGL_CALL(gl->BindFramebuffer(GL_FRAMEBUFFER, frame_buffer)); // Set "renderedTexture" as our colour attachement 0. - OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_TEXTURE_2D, output->texture(), 0)); + OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, + output->texture(), 0)); // Specify that we will render to color attachment 0. GLenum DrawBuffers[1] = {GL_COLOR_ATTACHMENT0}; @@ -622,8 +589,7 @@ void OpenGLWorkspace::Render(Texture* output) { OPENGL_CALL(gl->DeleteFramebuffers(1, &frame_buffer)); } -TVM_REGISTER_GLOBAL("device_api.opengl") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.opengl").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = OpenGLWorkspace::Global().get(); *rv = static_cast(ptr); }); diff --git a/src/runtime/opengl/opengl_module.cc b/src/runtime/opengl/opengl_module.cc index 6435aca..ee490f2 100644 --- a/src/runtime/opengl/opengl_module.cc +++ b/src/runtime/opengl/opengl_module.cc @@ -20,35 +20,35 @@ /*! * \file opengl_module.cc */ +#include "opengl_module.h" + #include -#include + #include -#include "opengl_common.h" -#include "opengl_module.h" +#include + +#include "../file_util.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../file_util.h" +#include "opengl_common.h" namespace tvm { namespace runtime { class OpenGLModuleNode final : public ModuleNode { public: - OpenGLModuleNode(std::unordered_map shaders, - std::string fmt, + OpenGLModuleNode(std::unordered_map shaders, std::string fmt, std::unordered_map fmap); ~OpenGLModuleNode() override = default; const char* type_key() const final { return "opengl"; } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; std::string GetSource(const std::string& format) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final; + void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; @@ -72,11 +72,8 @@ class OpenGLModuleNode final : public ModuleNode { class OpenGLWrappedFunc { public: - OpenGLWrappedFunc(OpenGLModuleNode* m, - ObjectPtr sptr, - std::string func_name, - std::vector arg_size, - const std::vector& thread_axis_tags); + OpenGLWrappedFunc(OpenGLModuleNode* m, ObjectPtr sptr, std::string func_name, + std::vector arg_size, const std::vector& thread_axis_tags); void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const; @@ -93,30 +90,32 @@ class OpenGLWrappedFunc { ThreadAxisConfig thread_axis_cfg_; }; -OpenGLModuleNode::OpenGLModuleNode( - std::unordered_map shaders, - std::string fmt, - std::unordered_map fmap) - : workspace_(gl::OpenGLWorkspace::Global()), shaders_(std::move(shaders)), - fmt_(std::move(fmt)), fmap_(std::move(fmap)), programs_() { +OpenGLModuleNode::OpenGLModuleNode(std::unordered_map shaders, + std::string fmt, + std::unordered_map fmap) + : workspace_(gl::OpenGLWorkspace::Global()), + shaders_(std::move(shaders)), + fmt_(std::move(fmt)), + fmap_(std::move(fmap)), + programs_() { CHECK_EQ(fmt_, "gl") << "Unknown OpenGL format " << fmt_; - for (auto &pair : shaders_) { - auto &func_name = pair.first; - auto &shader = pair.second; - programs_.emplace(func_name, - workspace_->CreateProgram(shader.source.c_str())); + for (auto& pair : shaders_) { + auto& func_name = pair.first; + auto& shader = pair.second; + programs_.emplace(func_name, workspace_->CreateProgram(shader.source.c_str())); } } -PackedFunc OpenGLModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc OpenGLModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto func_info_it = fmap_.find(name); - if (func_info_it == fmap_.end()) { return PackedFunc(); } - auto &func_info = func_info_it->second; + if (func_info_it == fmap_.end()) { + return PackedFunc(); + } + auto& func_info = func_info_it->second; std::vector arg_size(func_info.arg_types.size()); for (size_t i = 0; i < func_info.arg_types.size(); ++i) { @@ -128,26 +127,27 @@ PackedFunc OpenGLModuleNode::GetFunction( } // Initialize the wrapped func. - OpenGLWrappedFunc f(this, sptr_to_self, name, arg_size, - func_info.thread_axis_tags); + OpenGLWrappedFunc f(this, sptr_to_self, name, arg_size, func_info.thread_axis_tags); return PackFuncVoidAddr(f, func_info.arg_types); } std::string OpenGLModuleNode::GetSource(const std::string& format) { - if (format != fmt_ && fmt_ != "gl") { return ""; } + if (format != fmt_ && fmt_ != "gl") { + return ""; + } std::ostringstream os; - for (auto &pair : shaders_) { - auto &name = pair.first; - auto &shader = pair.second; - os << "[" << name << "]" << "\n"; - os << shader.source <<"\n"; + for (auto& pair : shaders_) { + auto& name = pair.first; + auto& shader = pair.second; + os << "[" << name << "]" + << "\n"; + os << shader.source << "\n"; } return os.str(); } -void OpenGLModuleNode::SaveToFile(const std::string& file_name, - const std::string& format) { +void OpenGLModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { std::string fmt = GetFileFormat(file_name, format); CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -161,8 +161,7 @@ void OpenGLModuleNode::SaveToBinary(dmlc::Stream* stream) { stream->Write(ToJSON(shaders_)); } -const gl::Program& OpenGLModuleNode::GetProgram( - const std::string& func_name) const { +const gl::Program& OpenGLModuleNode::GetProgram(const std::string& func_name) const { auto it = programs_.find(func_name); if (it == programs_.end()) { LOG(FATAL) << "Cannot find program"; @@ -170,8 +169,7 @@ const gl::Program& OpenGLModuleNode::GetProgram( return it->second; } -const OpenGLShader& OpenGLModuleNode::GetShader( - const std::string& func_name) const { +const OpenGLShader& OpenGLModuleNode::GetShader(const std::string& func_name) const { auto it = shaders_.find(func_name); if (it == shaders_.end()) { LOG(FATAL) << "Cannot find shader"; @@ -179,8 +177,7 @@ const OpenGLShader& OpenGLModuleNode::GetShader( return it->second; } -const FunctionInfo& OpenGLModuleNode::GetFunctionInfo( - const std::string& func_name) const { +const FunctionInfo& OpenGLModuleNode::GetFunctionInfo(const std::string& func_name) const { auto it = fmap_.find(func_name); if (it == fmap_.end()) { LOG(FATAL) << "Cannot find shader"; @@ -188,22 +185,20 @@ const FunctionInfo& OpenGLModuleNode::GetFunctionInfo( return it->second; } -OpenGLWrappedFunc::OpenGLWrappedFunc( - OpenGLModuleNode* m, - ObjectPtr sptr, - std::string func_name, - std::vector arg_size, - const std::vector& thread_axis_tags) - : m_(m), sptr_(std::move(sptr)), func_name_(std::move(func_name)), +OpenGLWrappedFunc::OpenGLWrappedFunc(OpenGLModuleNode* m, ObjectPtr sptr, + std::string func_name, std::vector arg_size, + const std::vector& thread_axis_tags) + : m_(m), + sptr_(std::move(sptr)), + func_name_(std::move(func_name)), arg_size_(std::move(arg_size)) { thread_axis_cfg_.Init(arg_size_.size(), thread_axis_tags); } -void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, - void** void_args) const { - auto &shader = m_->GetShader(func_name_); - auto &program = m_->GetProgram(func_name_); - auto &func_info = m_->GetFunctionInfo(func_name_); +void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { + auto& shader = m_->GetShader(func_name_); + auto& program = m_->GetProgram(func_name_); + auto& func_info = m_->GetFunctionInfo(func_name_); size_t nargs = shader.arg_kinds.size(); // Must call this function before setting uniforms & input textures. @@ -213,7 +208,7 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, GLuint texture_unit = 0; gl::Texture* output = nullptr; for (size_t i = 0; i != nargs; ++i) { - auto &name = shader.arg_names.at(i); + auto& name = shader.arg_names.at(i); auto kind = shader.arg_kinds.at(i); auto type = func_info.arg_types.at(i); switch (kind) { @@ -240,24 +235,19 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, // Set "thread_extent" uniform. ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); std::unique_ptr thread_extent(new GLint(wl.block_dim(0))); - m_->workspace().SetUniform(program, shader.thread_extent_var, - DLDataType{kDLInt, 32, 1}, + m_->workspace().SetUniform(program, shader.thread_extent_var, DLDataType{kDLInt, 32, 1}, static_cast(thread_extent.get())); m_->workspace().Render(output); } -Module OpenGLModuleCreate(std::unordered_map shaders, - std::string fmt, +Module OpenGLModuleCreate(std::unordered_map shaders, std::string fmt, std::unordered_map fmap) { - auto n = make_object(std::move(shaders), - std::move(fmt), - std::move(fmap)); + auto n = make_object(std::move(shaders), std::move(fmt), std::move(fmap)); return Module(n); } -Module OpenGLModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module OpenGLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -278,20 +268,17 @@ Module OpenGLModuleLoadBinary(void* strm) { return OpenGLModuleCreate(FromJSON(data), fmt, fmap); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_gl") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenGLModuleLoadFile(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_gl").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = OpenGLModuleLoadFile(args[0], args[1]); +}); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_glbin") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenGLModuleLoadFile(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_glbin").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = OpenGLModuleLoadFile(args[0], args[1]); +}); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opengl") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenGLModuleLoadBinary(args[0]); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opengl").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = OpenGLModuleLoadBinary(args[0]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opengl/opengl_module.h b/src/runtime/opengl/opengl_module.h index 4d2d1c8..27841a8 100644 --- a/src/runtime/opengl/opengl_module.h +++ b/src/runtime/opengl/opengl_module.h @@ -25,12 +25,14 @@ #define TVM_RUNTIME_OPENGL_OPENGL_MODULE_H_ #include + #include #include #include -#include -#include #include +#include +#include + #include "../meta_data.h" namespace tvm { @@ -67,11 +69,10 @@ OpenGLArgKind String2OpenGLArgKind(const std::string& str); */ struct OpenGLShader { OpenGLShader() = default; - OpenGLShader(std::string source, - std::vector arg_names, - std::vector arg_kinds, - std::string thread_extent_var) - : source(std::move(source)), arg_names(std::move(arg_names)), + OpenGLShader(std::string source, std::vector arg_names, + std::vector arg_kinds, std::string thread_extent_var) + : source(std::move(source)), + arg_names(std::move(arg_names)), arg_kinds(std::move(arg_kinds)), thread_extent_var(std::move(thread_extent_var)) { CHECK_EQ(this->arg_names.size(), this->arg_kinds.size()) << "Invalid input"; @@ -96,8 +97,7 @@ std::unordered_map FromJSON(const std::string& str); * \param fmt The format of the data, * \param fmap The map function information map of each function. */ -Module OpenGLModuleCreate(std::unordered_map shaders, - std::string fmt, +Module OpenGLModuleCreate(std::unordered_map shaders, std::string fmt, std::unordered_map fmap); inline std::string OpenGLArgKind2String(OpenGLArgKind kind) { @@ -156,8 +156,7 @@ inline void OpenGLShader::Load(dmlc::JSONReader* reader) { } } -inline std::string ToJSON( - const std::unordered_map& shaders) { +inline std::string ToJSON(const std::unordered_map& shaders) { std::ostringstream os; dmlc::JSONWriter writer(&os); writer.BeginObject(); @@ -166,8 +165,7 @@ inline std::string ToJSON( return os.str(); } -inline std::unordered_map FromJSON( - const std::string& str) { +inline std::unordered_map FromJSON(const std::string& str) { std::unordered_map shaders; std::istringstream is(str); dmlc::JSONReader reader(&is); diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 9d24ca9..ae97716 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -32,8 +32,9 @@ #define TVM_RUNTIME_PACK_ARGS_H_ #include -#include + #include +#include namespace tvm { namespace runtime { @@ -55,7 +56,7 @@ union ArgUnion { * * \return The wrapped packed function. */ -template +template inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types); /*! * \brief Create a packed function that from function only packs buffer arguments. @@ -66,7 +67,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types * * \return The wrapped packed function. */ -template +template inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types); /*! * \brief Create a packed function that from function that takes a packed arguments. @@ -77,7 +78,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_t * * \return The wrapped packed function. */ -template +template inline PackedFunc PackFuncPackedArg(F f, const std::vector& arg_types); /*! * \brief Extract number of buffer argument from the argument types. @@ -88,23 +89,21 @@ inline size_t NumBufferArgs(const std::vector& arg_types); // implementations details namespace detail { -template +template class TempArray { public: explicit TempArray(int size) {} - T* data() { - return data_; - } + T* data() { return data_; } + private: T data_[kSize]; }; -template +template class TempArray { public: explicit TempArray(int size) : data_(size) {} - T* data() { - return data_.data(); - } + T* data() { return data_.data(); } + private: std::vector data_; }; @@ -120,8 +119,7 @@ enum ArgConvertCode { }; inline ArgConvertCode GetArgConvertCode(DLDataType t) { - CHECK_EQ(t.lanes, 1U) - << "Cannot pass vector type argument to devic function for now"; + CHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now"; if (t.code == kDLInt) { if (t.bits == 64U) return INT64_TO_INT64; if (t.bits == 32U) return INT64_TO_INT32; @@ -137,7 +135,7 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { return HANDLE_TO_HANDLE; } -template +template inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { @@ -158,7 +156,7 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code addr[i] = &(holder[i]); break; } - case INT64_TO_UINT32 : { + case INT64_TO_UINT32: { holder[i].v_uint32 = static_cast(args.values[i].v_int64); addr[i] = &(holder[i]); break; @@ -175,9 +173,8 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code return PackedFunc(ret); } -template -inline PackedFunc PackFuncNonBufferArg_( - F f, int base, const std::vector& codes) { +template +inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) { TempArray holder_(num_args); @@ -186,13 +183,14 @@ inline PackedFunc PackFuncNonBufferArg_( switch (codes[i]) { case INT64_TO_INT64: case FLOAT64_TO_FLOAT64: { - LOG(FATAL) << "Do not support 64bit argument to device function"; break; + LOG(FATAL) << "Do not support 64bit argument to device function"; + break; } case INT64_TO_INT32: { holder[i].v_int32 = static_cast(args.values[base + i].v_int64); break; } - case INT64_TO_UINT32 : { + case INT64_TO_UINT32: { holder[i].v_uint32 = static_cast(args.values[base + i].v_int64); break; } @@ -201,7 +199,8 @@ inline PackedFunc PackFuncNonBufferArg_( break; } case HANDLE_TO_HANDLE: { - LOG(FATAL) << "not reached"; break; + LOG(FATAL) << "not reached"; + break; } } } @@ -210,9 +209,8 @@ inline PackedFunc PackFuncNonBufferArg_( return PackedFunc(ret); } -template -inline PackedFunc PackFuncPackedArg_( - F f, const std::vector& codes) { +template +inline PackedFunc PackFuncPackedArg_(F f, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { TempArray pack_(num_args); @@ -238,20 +236,19 @@ inline PackedFunc PackFuncPackedArg_( ++ptr; break; } - case INT64_TO_UINT32 : { - *reinterpret_cast(ptr) = - static_cast(args.values[i].v_int64); + case INT64_TO_UINT32: { + *reinterpret_cast(ptr) = static_cast(args.values[i].v_int64); ++ptr; break; } case FLOAT64_TO_FLOAT32: { - *reinterpret_cast(ptr) = - static_cast(args.values[i].v_float64); + *reinterpret_cast(ptr) = static_cast(args.values[i].v_float64); ++ptr; break; } default: { - LOG(FATAL) << "not reached"; break; + LOG(FATAL) << "not reached"; + break; } } } @@ -261,7 +258,7 @@ inline PackedFunc PackFuncPackedArg_( } } // namespace detail -template +template inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types) { std::vector codes(arg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { @@ -282,17 +279,17 @@ inline size_t NumBufferArgs(const std::vector& arg_types) { size_t base = arg_types.size(); for (size_t i = 0; i < arg_types.size(); ++i) { if (arg_types[i].code != kTVMOpaqueHandle) { - base = i; break; + base = i; + break; } } for (size_t i = base; i < arg_types.size(); ++i) { - CHECK(arg_types[i].code != kTVMOpaqueHandle) - << "Device function need to be organized"; + CHECK(arg_types[i].code != kTVMOpaqueHandle) << "Device function need to be organized"; } return base; } -template +template inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types) { size_t num_buffer = NumBufferArgs(arg_types); std::vector codes; @@ -309,7 +306,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_t } } -template +template inline PackedFunc PackFuncPackedArg(F f, const std::vector& arg_types) { std::vector codes; for (size_t i = 0; i < arg_types.size(); ++i) { diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 855a342..641532a 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -24,10 +24,12 @@ #include #include #include -#include -#include -#include + #include +#include +#include +#include + #include "runtime_base.h" namespace tvm { @@ -43,8 +45,7 @@ struct Registry::Manager { // mutex std::mutex mutex; - Manager() { - } + Manager() {} static Manager* Global() { // We deliberately leak the Manager instance, to avoid leak sanitizers @@ -64,8 +65,7 @@ Registry& Registry::Register(const std::string& name, bool can_override) { // N Manager* m = Manager::Global(); std::lock_guard lock(m->mutex); if (m->fmap.count(name)) { - CHECK(can_override) - << "Global PackedFunc " << name << " is already registered"; + CHECK(can_override) << "Global PackedFunc " << name << " is already registered"; } Registry* r = new Registry(); @@ -96,7 +96,7 @@ std::vector Registry::ListNames() { std::lock_guard lock(m->mutex); std::vector keys; keys.reserve(m->fmap.size()); - for (const auto &kv : m->fmap) { + for (const auto& kv : m->fmap) { keys.push_back(kv.first); } return keys; @@ -110,14 +110,13 @@ struct TVMFuncThreadLocalEntry { /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; + std::vector ret_vec_charp; }; /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; -int TVMFuncRegisterGlobal( - const char* name, TVMFunctionHandle f, int override) { +int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { API_BEGIN(); tvm::runtime::Registry::Register(name, override != 0) .set_body(*static_cast(f)); @@ -126,8 +125,7 @@ int TVMFuncRegisterGlobal( int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { API_BEGIN(); - const tvm::runtime::PackedFunc* fp = - tvm::runtime::Registry::Get(name); + const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name); if (fp != nullptr) { *out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*) } else { @@ -136,10 +134,9 @@ int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { API_END(); } -int TVMFuncListGlobalNames(int *out_size, - const char*** out_array) { +int TVMFuncListGlobalNames(int* out_size, const char*** out_array) { API_BEGIN(); - TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get(); + TVMFuncThreadLocalEntry* ret = TVMFuncThreadLocalStore::Get(); ret->ret_vec_str = tvm::runtime::Registry::ListNames(); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { diff --git a/src/runtime/rocm/rocm_common.h b/src/runtime/rocm/rocm_common.h index 5d0d5c9..2e637f5 100644 --- a/src/runtime/rocm/rocm_common.h +++ b/src/runtime/rocm/rocm_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,28 +24,28 @@ #ifndef TVM_RUNTIME_ROCM_ROCM_COMMON_H_ #define TVM_RUNTIME_ROCM_ROCM_COMMON_H_ -#include #include +#include + #include + #include "../workspace_pool.h" namespace tvm { namespace runtime { -#define ROCM_DRIVER_CALL(x) \ - { \ - hipError_t result = x; \ - if (result != hipSuccess && result != hipErrorDeinitialized) { \ - LOG(FATAL) \ - << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \ - } \ +#define ROCM_DRIVER_CALL(x) \ + { \ + hipError_t result = x; \ + if (result != hipSuccess && result != hipErrorDeinitialized) { \ + LOG(FATAL) << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \ + } \ } -#define ROCM_CALL(func) \ - { \ - hipError_t e = (func); \ - CHECK(e == hipSuccess) \ - << "ROCM HIP: " << hipGetErrorString(e); \ +#define ROCM_CALL(func) \ + { \ + hipError_t e = (func); \ + CHECK(e == hipSuccess) << "ROCM HIP: " << hipGetErrorString(e); \ } /*! \brief Thread local workspace */ diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 25e1ac7..475c4fb 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -35,9 +35,7 @@ namespace runtime { class ROCMDeviceAPI final : public DeviceAPI { public: - void SetDevice(TVMContext ctx) final { - ROCM_CALL(hipSetDevice(ctx.device_id)); - } + void SetDevice(TVMContext ctx) final { ROCM_CALL(hipSetDevice(ctx.device_id)); } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { int value = 0; switch (kind) { @@ -53,27 +51,26 @@ class ROCMDeviceAPI final : public DeviceAPI { break; } case kMaxThreadsPerBlock: { - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { - ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, - ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, ctx.device_id)); break; } case kMaxSharedMemoryPerBlock: { - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeMaxSharedMemoryPerBlock, + ctx.device_id)); break; } case kComputeVersion: { std::ostringstream os; - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id)); os << value << "."; - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id)); os << value; *rv = os.str(); return; @@ -86,23 +83,19 @@ class ROCMDeviceAPI final : public DeviceAPI { return; } case kMaxClockRate: { - ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, - ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, ctx.device_id)); break; } case kMultiProcessorCount: { - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeMultiprocessorCount, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id)); break; } case kMaxThreadDimensions: { int dims[3]; - ROCM_CALL(hipDeviceGetAttribute( - &dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id)); - ROCM_CALL(hipDeviceGetAttribute( - &dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id)); - ROCM_CALL(hipDeviceGetAttribute( - &dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id)); std::stringstream ss; ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; @@ -132,9 +125,8 @@ class ROCMDeviceAPI final : public DeviceAPI { ROCM_CALL(hipFree(ptr)); } - void CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t size, TVMContext ctx_from, - TVMContext ctx_to, DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { hipStream_t hip_stream = static_cast(stream); from = static_cast(from) + from_offset; @@ -144,15 +136,12 @@ class ROCMDeviceAPI final : public DeviceAPI { if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream); } else { - hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, - hip_stream); + hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, hip_stream); } - } else if (ctx_from.device_type == kDLROCM && - ctx_to.device_type == kDLCPU) { + } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) { ROCM_CALL(hipSetDevice(ctx_from.device_id)); GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream); - } else if (ctx_from.device_type == kDLCPU && - ctx_to.device_type == kDLROCM) { + } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) { ROCM_CALL(hipSetDevice(ctx_to.device_id)); GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream); } else { @@ -178,14 +167,13 @@ class ROCMDeviceAPI final : public DeviceAPI { } static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } private: - static void GPUCopy(const void* from, void* to, size_t size, - hipMemcpyKind kind, hipStream_t stream) { + static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind, + hipStream_t stream) { if (stream != 0) { ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); } else { @@ -198,14 +186,11 @@ typedef dmlc::ThreadLocalStore ROCMThreadStore; ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} -ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { - return ROCMThreadStore::Get(); -} +ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.rocm") - .set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 1f4b830..79958d2 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -20,19 +20,22 @@ /*! * \file rocm_module.cc */ -#include +#include "rocm_module.h" + #include -#include +#include + #include -#include #include +#include #include -#include "rocm_module.h" -#include "rocm_common.h" +#include + +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "rocm_common.h" namespace tvm { namespace runtime { @@ -43,12 +46,10 @@ namespace runtime { // The modules will be lazily loaded class ROCMModuleNode : public runtime::ModuleNode { public: - explicit ROCMModuleNode(std::string data, - std::string fmt, + explicit ROCMModuleNode(std::string data, std::string fmt, std::unordered_map fmap, - std::string hip_source, - std::string assembly) - : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) { + std::string hip_source, std::string assembly) + : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) { std::fill(module_.begin(), module_.end(), nullptr); } // destructor @@ -61,17 +62,11 @@ class ROCMModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { - return "hip"; - } - - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; + const char* type_key() const final { return "hip"; } + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); // note: llvm and asm formats are not laodable, so we don't save them @@ -87,9 +82,15 @@ class ROCMModuleNode : public runtime::ModuleNode { } std::string GetSource(const std::string& format) final { - if (format == fmt_) { return data_; } - if (format == "llvm" || format == "") { return hip_source_; } - if (format == "asm") { return assembly_; } + if (format == fmt_) { + return data_; + } + if (format == "llvm" || format == "") { + return hip_source_; + } + if (format == "asm") { + return assembly_; + } return ""; } @@ -104,16 +105,13 @@ class ROCMModuleNode : public runtime::ModuleNode { hipFunction_t func; hipError_t result = hipModuleGetFunction(&func, module_[device_id], func_name.c_str()); if (result != hipSuccess) { - LOG(FATAL) - << "ROCMError: hipModuleGetFunction " << func_name - << " failed with error: " << hipGetErrorString(result); + LOG(FATAL) << "ROCMError: hipModuleGetFunction " << func_name + << " failed with error: " << hipGetErrorString(result); } return func; } // get a global var from primary context in device_id - hipDeviceptr_t GetGlobal(int device_id, - const std::string& global_name, - size_t expect_nbytes) { + hipDeviceptr_t GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) { std::lock_guard lock(mutex_); // must recheck under the lock scope if (module_[device_id] == nullptr) { @@ -122,8 +120,7 @@ class ROCMModuleNode : public runtime::ModuleNode { hipDeviceptr_t global = nullptr; size_t nbytes = 0; - ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes, - module_[device_id], global_name.c_str())); + ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str())); CHECK_EQ(nbytes, expect_nbytes); return global; } @@ -149,11 +146,8 @@ class ROCMModuleNode : public runtime::ModuleNode { class ROCMWrappedFunc { public: // initialize the ROCM function. - void Init(ROCMModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, - size_t num_void_args, - const std::vector& thread_axis_tags) { + void Init(ROCMModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_void_args, const std::vector& thread_axis_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; @@ -161,10 +155,7 @@ class ROCMWrappedFunc { thread_axis_cfg_.Init(num_void_args, thread_axis_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - void* packed_args, - size_t packed_nbytes) const { + void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { int device_id; ROCM_CALL(hipGetDevice(&device_id)); if (fcache_[device_id] == nullptr) { @@ -174,22 +165,12 @@ class ROCMWrappedFunc { hipStream_t strm = static_cast(ROCMThreadEntry::ThreadLocal()->stream); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - void* config[] = { - HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes, - HIP_LAUNCH_PARAM_END - }; + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, + &packed_nbytes, HIP_LAUNCH_PARAM_END}; // HIP supports only extra_args. ROCM_DRIVER_CALL(hipModuleLaunchKernel( - fcache_[device_id], - wl.grid_dim(0), - wl.grid_dim(1), - wl.grid_dim(2), - wl.block_dim(0), - wl.block_dim(1), - wl.block_dim(2), - 0, strm, nullptr, - reinterpret_cast(&config))); + fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), + wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast(&config))); } private: @@ -206,13 +187,10 @@ class ROCMWrappedFunc { ThreadAxisConfig thread_axis_cfg_; }; - -PackedFunc ROCMModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc ROCMModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; @@ -221,18 +199,14 @@ PackedFunc ROCMModuleNode::GetFunction( return PackFuncPackedArg(f, info.arg_types); } -Module ROCMModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string hip_source, - std::string assembly) { +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string hip_source, + std::string assembly) { auto n = make_object(data, fmt, fmap, hip_source, assembly); return Module(n); } -Module ROCMModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module ROCMModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -253,19 +227,12 @@ Module ROCMModuleLoadBinary(void* strm) { return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco") -.set_body_typed(ROCMModuleLoadBinary); - - -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip") -.set_body_typed(ROCMModuleLoadBinary); - +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco") -.set_body_typed(ROCMModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip") -.set_body_typed(ROCMModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.h b/src/runtime/rocm/rocm_module.h index 7f2a0ce..c17e123 100644 --- a/src/runtime/rocm/rocm_module.h +++ b/src/runtime/rocm/rocm_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,10 +25,12 @@ #define TVM_RUNTIME_ROCM_ROCM_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -45,12 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param rocm_source Optional, rocm source file */ -Module ROCMModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string rocm_source, - std::string assembly); +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string rocm_source, + std::string assembly); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_ROCM_ROCM_MODULE_H_ diff --git a/src/runtime/rpc/minrpc/minrpc_server.h b/src/runtime/rpc/minrpc/minrpc_server.h index a84042e..91a900a 100644 --- a/src/runtime/rpc/minrpc/minrpc_server.h +++ b/src/runtime/rpc/minrpc/minrpc_server.h @@ -30,8 +30,9 @@ #include #include -#include "../rpc_protocol.h" + #include "../../../support/arena.h" +#include "../rpc_protocol.h" /*! \brief Whether or not to enable glog style DLOG */ #ifndef TVM_MINRPC_ENABLE_LOGGING @@ -39,7 +40,7 @@ #endif #ifndef MINRPC_CHECK -#define MINRPC_CHECK(cond) \ +#define MINRPC_CHECK(cond) \ if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError); #endif @@ -47,7 +48,6 @@ #include #endif - namespace tvm { namespace runtime { @@ -61,15 +61,14 @@ namespace runtime { * - PosixWrite, PosixRead, Close: posix style, read, write, close API. * - Exit: exit with status code. */ -template +template class MinRPCServer { public: /*! * \brief Constructor. * \param io The IO handler. */ - explicit MinRPCServer(TIOHandler io) - : io_(io), arena_(PageAllocator(io)) {} + explicit MinRPCServer(TIOHandler io) : io_(io), arena_(PageAllocator(io)) {} /*! \brief Run the server loop until shutdown signal is received. */ void ServerLoop() { @@ -135,10 +134,8 @@ class MinRPCServer { this->Read(&call_handle); RecvPackedSeq(&values, &tcodes, &num_args); - int call_ecode = TVMFuncCall( - reinterpret_cast(call_handle), - values, tcodes, num_args, - &(ret_value[1]), &(ret_tcode[1])); + int call_ecode = TVMFuncCall(reinterpret_cast(call_handle), values, tcodes, num_args, + &(ret_value[1]), &(ret_tcode[1])); if (call_ecode == 0) { // Return value encoding as in LocalSession @@ -150,8 +147,7 @@ class MinRPCServer { ret_value[2].v_handle = ret_value[1].v_handle; ret_tcode[2] = kTVMOpaqueHandle; this->ReturnPackedSeq(ret_value, ret_tcode, 3); - } else if (rv_tcode == kTVMPackedFuncHandle || - rv_tcode == kTVMModuleHandle) { + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { ret_tcode[1] = kTVMOpaqueHandle; this->ReturnPackedSeq(ret_value, ret_tcode, 2); } else { @@ -179,15 +175,12 @@ class MinRPCServer { data_ptr = reinterpret_cast(handle) + offset; } else { data_ptr = this->ArenaAlloc(num_bytes); - call_ecode = TVMDeviceCopyDataFromTo( - reinterpret_cast(handle), offset, - data_ptr, 0, num_bytes, - ctx, DLContext{kDLCPU, 0}, - type_hint, nullptr); + call_ecode = + TVMDeviceCopyDataFromTo(reinterpret_cast(handle), offset, data_ptr, 0, num_bytes, + ctx, DLContext{kDLCPU, 0}, type_hint, nullptr); // need sync to make sure that the copy is completed. if (call_ecode == 0) { - call_ecode = TVMSynchronize( - ctx.device_type, ctx.device_id, nullptr); + call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); } } @@ -222,16 +215,12 @@ class MinRPCServer { uint8_t* temp_data = this->ArenaAlloc(num_bytes); this->ReadArray(temp_data, num_bytes); - call_ecode = TVMDeviceCopyDataFromTo( - temp_data, 0, - reinterpret_cast(handle), offset, - num_bytes, - DLContext{kDLCPU, 0}, ctx, - type_hint, nullptr); + call_ecode = + TVMDeviceCopyDataFromTo(temp_data, 0, reinterpret_cast(handle), offset, num_bytes, + DLContext{kDLCPU, 0}, ctx, type_hint, nullptr); // need sync to make sure that the copy is completed. if (call_ecode == 0) { - call_ecode = TVMSynchronize( - ctx.device_type, ctx.device_id, nullptr); + call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); } } @@ -367,10 +356,8 @@ class MinRPCServer { DLDataType type_hint = values[7].v_type; TVMStreamHandle stream = values[8].v_handle; - int call_ecode = TVMDeviceCopyDataFromTo( - from, from_offset, - to, to_offset, size, - ctx_from, ctx_to, type_hint, stream); + int call_ecode = TVMDeviceCopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, + ctx_to, type_hint, stream); if (call_ecode == 0) { this->ReturnVoid(); @@ -392,8 +379,7 @@ class MinRPCServer { DLDataType type_hint = values[3].v_type; void* handle; - int call_ecode = TVMDeviceAllocDataSpace( - ctx, nbytes, alignment, type_hint, &handle); + int call_ecode = TVMDeviceAllocDataSpace(ctx, nbytes, alignment, type_hint, &handle); if (call_ecode == 0) { this->ReturnHandle(handle); @@ -440,31 +426,31 @@ class MinRPCServer { io_.Exit(static_cast(code)); } - template + template T* ArenaAlloc(int count) { static_assert(std::is_pod::value, "need to be trival"); return arena_.template allocate_(count); } - template + template void Read(T* data) { static_assert(std::is_pod::value, "need to be trival"); this->ReadRawBytes(data, sizeof(T)); } - template + template void ReadArray(T* data, size_t count) { static_assert(std::is_pod::value, "need to be trival"); return this->ReadRawBytes(data, sizeof(T) * count); } - template + template void Write(const T& data) { static_assert(std::is_pod::value, "need to be trival"); return this->WriteRawBytes(&data, sizeof(T)); } - template + template void WriteArray(T* data, size_t count) { static_assert(std::is_pod::value, "need to be trival"); return this->WriteRawBytes(data, sizeof(T) * count); @@ -476,16 +462,14 @@ class MinRPCServer { public: using ArenaPageHeader = tvm::support::ArenaPageHeader; - explicit PageAllocator(TIOHandler io) - : io_(io) {} + explicit PageAllocator(TIOHandler io) : io_(io) {} ArenaPageHeader* allocate(size_t min_size) { size_t npages = ((min_size + kPageSize - 1) / kPageSize); void* data; - if (TVMDeviceAllocDataSpace( - DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign, - DLDataType{kDLInt, 1, 1}, &data) != 0) { + if (TVMDeviceAllocDataSpace(DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign, + DLDataType{kDLInt, 1, 1}, &data) != 0) { io_.Exit(static_cast(RPCServerStatus::kAllocError)); } @@ -508,11 +492,8 @@ class MinRPCServer { TIOHandler io_; }; - void RecvPackedSeq(TVMValue** out_values, - int** out_tcodes, - int* out_num_args) { - RPCReference::RecvPackedSeq( - out_values, out_tcodes, out_num_args, this); + void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) { + RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this); } void ReturnVoid() { @@ -520,8 +501,7 @@ class MinRPCServer { int32_t tcode = kTVMNullptr; RPCCode code = RPCCode::kReturn; - uint64_t packet_nbytes = - sizeof(code) + sizeof(num_args) + sizeof(tcode); + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); this->Write(packet_nbytes); this->Write(code); @@ -536,8 +516,7 @@ class MinRPCServer { uint64_t encode_handle = reinterpret_cast(handle); uint64_t packet_nbytes = - sizeof(code) + sizeof(num_args) + - sizeof(tcode) + sizeof(encode_handle); + sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle); this->Write(packet_nbytes); this->Write(code); @@ -546,24 +525,18 @@ class MinRPCServer { this->Write(encode_handle); } - void ReturnException(const char* msg) { - RPCReference::ReturnException(msg, this); - } + void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); } - void ReturnPackedSeq(const TVMValue* arg_values, - const int* type_codes, - int num_args) { + void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) { RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this); } - void ReturnLastTVMError() { - this->ReturnException(TVMGetLastError()); - } + void ReturnLastTVMError() { this->ReturnException(TVMGetLastError()); } void ReadRawBytes(void* data, size_t size) { uint8_t* buf = reinterpret_cast(data); size_t ndone = 0; - while (ndone < size) { + while (ndone < size) { ssize_t ret = io_.PosixRead(buf, size - ndone); if (ret == 0) { if (allow_clean_shutdown_) { @@ -582,9 +555,9 @@ class MinRPCServer { } void WriteRawBytes(const void* data, size_t size) { - const uint8_t *buf = reinterpret_cast(data); + const uint8_t* buf = reinterpret_cast(data); size_t ndone = 0; - while (ndone < size) { + while (ndone < size) { ssize_t ret = io_.PosixWrite(buf, size - ndone); if (ret == 0 || ret == -1) { this->ThrowError(RPCServerStatus::kWriteError); diff --git a/src/runtime/rpc/minrpc/posix_popen_server.cc b/src/runtime/rpc/minrpc/posix_popen_server.cc index fdc5711..9784780 100644 --- a/src/runtime/rpc/minrpc/posix_popen_server.cc +++ b/src/runtime/rpc/minrpc/posix_popen_server.cc @@ -21,7 +21,9 @@ #define TVM_ARENA_HAS_DESTRUCTOR 0 #include + #include + #include "minrpc_server.h" namespace tvm { @@ -33,20 +35,13 @@ namespace runtime { class PosixIOHandler { public: explicit PosixIOHandler(int read_fd = 0, int write_fd = 1) - : read_fd_(read_fd), write_fd_(write_fd) { - } + : read_fd_(read_fd), write_fd_(write_fd) {} - ssize_t PosixRead(void* data, size_t size) { - return read(read_fd_, data, size); - } + ssize_t PosixRead(void* data, size_t size) { return read(read_fd_, data, size); } - ssize_t PosixWrite(const void* data, size_t size) { - return write(write_fd_, data, size); - } + ssize_t PosixWrite(const void* data, size_t size) { return write(write_fd_, data, size); } - void Exit(int code) { - exit(code); - } + void Exit(int code) { exit(code); } void Close() { if (read_fd_ != 0) close(read_fd_); diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc index f8dc6e6..eaa64e3 100644 --- a/src/runtime/rpc/rpc_channel.cc +++ b/src/runtime/rpc/rpc_channel.cc @@ -20,9 +20,10 @@ /*! * \file rpc_channel.cc */ -#include #include "rpc_channel.h" +#include + namespace tvm { namespace runtime { diff --git a/src/runtime/rpc/rpc_channel.h b/src/runtime/rpc/rpc_channel.h index be34a8b..114bc0a 100644 --- a/src/runtime/rpc/rpc_channel.h +++ b/src/runtime/rpc/rpc_channel.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_RPC_RPC_CHANNEL_H_ #include + #include namespace tvm { diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 93af4e2..196a97e 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -21,9 +21,11 @@ * \file rpc_device_api.cc */ #include -#include #include +#include + #include + #include "rpc_session.h" namespace tvm { @@ -41,14 +43,12 @@ class RPCDeviceAPI final : public DeviceAPI { GetSess(ctx)->GetDeviceAPI(remote_ctx)->GetAttr(remote_ctx, kind, rv); } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { auto sess = GetSess(ctx); auto remote_ctx = RemoveSessMask(ctx); - void *data = sess->GetDeviceAPI(remote_ctx)->AllocDataSpace( - remote_ctx, nbytes, alignment, type_hint); + void* data = + sess->GetDeviceAPI(remote_ctx)->AllocDataSpace(remote_ctx, nbytes, alignment, type_hint); RemoteSpace* space = new RemoteSpace(); space->data = data; @@ -59,49 +59,38 @@ class RPCDeviceAPI final : public DeviceAPI { RemoteSpace* space = static_cast(ptr); auto remote_ctx = RemoveSessMask(ctx); try { - GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace( - remote_ctx, space->data); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace(remote_ctx, space->data); } catch (const dmlc::Error& e) { // fault tolerance to remote close. } delete space; } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { int from_dev_type = ctx_from.device_type; int to_dev_type = ctx_to.device_type; - if (from_dev_type > kRPCSessMask && - to_dev_type > kRPCSessMask) { + if (from_dev_type > kRPCSessMask && to_dev_type > kRPCSessMask) { CHECK(ctx_from.device_type == ctx_to.device_type) << "Cannot copy across two different remote session"; auto remote_ctx_from = RemoveSessMask(ctx_from); auto remote_ctx_to = RemoveSessMask(ctx_to); auto remote_ctx = remote_ctx_from; if (remote_ctx.device_type == kDLCPU) remote_ctx = remote_ctx_to; - GetSess(ctx_from)->GetDeviceAPI(remote_ctx) + GetSess(ctx_from) + ->GetDeviceAPI(remote_ctx) ->CopyDataFromTo(static_cast(from)->data, from_offset, - static_cast(to)->data, to_offset, - size, remote_ctx_from, remote_ctx_to, type_hint, stream); - } else if (from_dev_type > kRPCSessMask && - to_dev_type == kDLCPU) { + static_cast(to)->data, to_offset, size, + remote_ctx_from, remote_ctx_to, type_hint, stream); + } else if (from_dev_type > kRPCSessMask && to_dev_type == kDLCPU) { auto remote_ctx_from = RemoveSessMask(ctx_from); - GetSess(ctx_from)->CopyFromRemote( - static_cast(from)->data, from_offset, - to, to_offset, size, remote_ctx_from, type_hint); - } else if (from_dev_type == kDLCPU && - to_dev_type > kRPCSessMask) { + GetSess(ctx_from)->CopyFromRemote(static_cast(from)->data, from_offset, + to, to_offset, size, remote_ctx_from, type_hint); + } else if (from_dev_type == kDLCPU && to_dev_type > kRPCSessMask) { auto remote_ctx_to = RemoveSessMask(ctx_to); - GetSess(ctx_to)->CopyToRemote( - const_cast(from), from_offset, - static_cast(to)->data, to_offset, - size, remote_ctx_to, type_hint); + GetSess(ctx_to)->CopyToRemote(const_cast(from), from_offset, + static_cast(to)->data, to_offset, size, + remote_ctx_to, type_hint); } else { LOG(FATAL) << "expect copy from/to remote or between remote"; } @@ -116,7 +105,7 @@ class RPCDeviceAPI final : public DeviceAPI { std::shared_ptr GetSess(TVMContext ctx) { int dev_type = ctx.device_type; CHECK_GE(dev_type, kRPCSessMask); - int tbl_index = dev_type / kRPCSessMask - 1; + int tbl_index = dev_type / kRPCSessMask - 1; return RPCSession::Get(tbl_index); } @@ -126,11 +115,10 @@ class RPCDeviceAPI final : public DeviceAPI { } }; -TVM_REGISTER_GLOBAL("device_api.rpc") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCDeviceAPI inst; - DeviceAPI* ptr = &inst; - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.rpc").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCDeviceAPI inst; + DeviceAPI* ptr = &inst; + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 26f24c9..bf85dc5 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -21,25 +21,27 @@ * \file rpc_session.cc * \brief RPC session for remote function call. */ +#include "rpc_endpoint.h" + #include -#include #include +#include #include #include -#include + +#include #include -#include #include -#include -#include #include -#include +#include +#include +#include +#include -#include "rpc_endpoint.h" -#include "rpc_local_session.h" -#include "../object_internal.h" -#include "../../support/ring_buffer.h" #include "../../support/arena.h" +#include "../../support/ring_buffer.h" +#include "../object_internal.h" +#include "rpc_local_session.h" namespace tvm { namespace runtime { @@ -54,11 +56,8 @@ namespace runtime { */ class RPCEndpoint::EventHandler : public dmlc::Stream { public: - EventHandler(support::RingBuffer* reader, - support::RingBuffer* writer, - std::string name, - std::string* remote_key, - std::function flush_writer) + EventHandler(support::RingBuffer* reader, support::RingBuffer* writer, std::string name, + std::string* remote_key, std::function flush_writer) : reader_(reader), writer_(writer), name_(name), @@ -94,19 +93,13 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } /*! \return Whether we are ready to handle next request. */ - bool Ready() const { - return reader_->bytes_available() >= pending_request_bytes_; - } + bool Ready() const { return reader_->bytes_available() >= pending_request_bytes_; } /*! \return Whether we can perform a clean shutdown */ - bool CanCleanShutdown() const { - return state_ == kRecvPacketNumBytes; - } + bool CanCleanShutdown() const { return state_ == kRecvPacketNumBytes; } /*! \brief Finish the copy ack stage. */ - void FinishCopyAck() { - this->SwitchToState(kRecvPacketNumBytes); - } + void FinishCopyAck() { this->SwitchToState(kRecvPacketNumBytes); } /*! * \brief Enter the io loop until the next event. @@ -115,19 +108,18 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { * \param setreturn The function to set the return value encoding. * \return The function to set return values when there is a return event. */ - RPCCode HandleNextEvent(bool client_mode, - bool async_server_mode, + RPCCode HandleNextEvent(bool client_mode, bool async_server_mode, RPCSession::FEncodeReturn setreturn) { std::swap(client_mode_, client_mode); std::swap(async_server_mode_, async_server_mode); RPCCode status = RPCCode::kNone; - while (status == RPCCode::kNone && - state_ != kWaitForAsyncCallback && - this->Ready()) { + while (status == RPCCode::kNone && state_ != kWaitForAsyncCallback && this->Ready()) { switch (state_) { - case kInitHeader: HandleInitHeader(); break; + case kInitHeader: + HandleInitHeader(); + break; case kRecvPacketNumBytes: { uint64_t packet_nbytes; CHECK(this->Read(&packet_nbytes)); @@ -177,16 +169,13 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { * \param arg_values The argument values. * \param type_codes The type codes. */ - void ValidateArguments(const TVMValue* arg_values, - const int* type_codes, - int num_args) { + void ValidateArguments(const TVMValue* arg_values, const int* type_codes, int num_args) { TVMArgs args(arg_values, type_codes, num_args); for (int i = 0; i < num_args; ++i) { int tcode = type_codes[i]; if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) { - LOG(FATAL) << "ValueError: Cannot pass argument " << i - << ", type " << args[i].AsObjectRef()->GetTypeKey() - << " is not supported by RPC"; + LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " + << args[i].AsObjectRef()->GetTypeKey() << " is not supported by RPC"; } else if (tcode == kTVMContext) { DLContext ctx = args[i]; CHECK_LT(static_cast(ctx.device_type), kRPCSessMask) @@ -199,26 +188,20 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code); } - uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, - const int* type_codes, - int num_args, + uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, int num_args, bool client_mode) { - return RPCReference::PackedSeqGetNumBytes( - arg_values, type_codes, num_args, client_mode, this); + return RPCReference::PackedSeqGetNumBytes(arg_values, type_codes, num_args, client_mode, this); } - void SendPackedSeq(const TVMValue* arg_values, - const int* type_codes, - int num_args, + void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, bool client_mode) { - RPCReference::SendPackedSeq( - arg_values, type_codes, num_args, client_mode, this); + RPCReference::SendPackedSeq(arg_values, type_codes, num_args, client_mode, this); } // Endian aware IO handling using Stream::Read; - using Stream::Write; using Stream::ReadArray; + using Stream::Write; using Stream::WriteArray; bool Read(RPCCode* code) { @@ -232,7 +215,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->Write(cdata); } - template + template T* ArenaAlloc(int count) { static_assert(std::is_pod::value, "need to be trival"); return arena_.template allocate_(count); @@ -263,8 +246,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { void SwitchToState(State state) { // invariant if (state != kCopyAckReceived) { - CHECK_EQ(pending_request_bytes_, 0U) - << "state=" << state; + CHECK_EQ(pending_request_bytes_, 0U) << "state=" << state; } // need to actively flush the writer // so the data get pushed out. @@ -272,8 +254,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { flush_writer_(); } state_ = state; - CHECK(state != kInitHeader) - << "cannot switch to init header"; + CHECK(state != kInitHeader) << "cannot switch to init header"; if (state == kRecvPacketNumBytes) { this->RequestBytes(sizeof(uint64_t)); // recycle arena for the next session. @@ -305,38 +286,39 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { if (code >= RPCCode::kSyscallCodeStart) { this->HandleSyscall(code); } else { - switch (code) { - case RPCCode::kInitServer: { - this->HandleInitServer(); - break; - } - case RPCCode::kCallFunc: { - this->HandleNormalCallFunc(); - break; - } - case RPCCode::kCopyFromRemote: { - this->HandleCopyFromRemote(); - break; - } - case RPCCode::kCopyToRemote: { - this->HandleCopyToRemote(); - break; - } - case RPCCode::kException: - case RPCCode::kReturn: { - this->HandleReturn(code, setreturn); - break; - } - case RPCCode::kCopyAck: { - this->SwitchToState(kCopyAckReceived); - break; - } - case RPCCode::kShutdown: { - this->SwitchToState(kShutdownReceived); - break; - } - default: LOG(FATAL) << "Unknown event " << static_cast(code); + switch (code) { + case RPCCode::kInitServer: { + this->HandleInitServer(); + break; } + case RPCCode::kCallFunc: { + this->HandleNormalCallFunc(); + break; + } + case RPCCode::kCopyFromRemote: { + this->HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + this->HandleCopyToRemote(); + break; + } + case RPCCode::kException: + case RPCCode::kReturn: { + this->HandleReturn(code, setreturn); + break; + } + case RPCCode::kCopyAck: { + this->SwitchToState(kCopyAckReceived); + break; + } + case RPCCode::kShutdown: { + this->SwitchToState(kShutdownReceived); + break; + } + default: + LOG(FATAL) << "Unknown event " << static_cast(code); + } } } @@ -357,17 +339,13 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { * \brief Return exception to the remote. * \param err_msg The error message. */ - void ReturnException(const char* err_msg) { - RPCReference::ReturnException(err_msg, this); - } + void ReturnException(const char* err_msg) { RPCReference::ReturnException(err_msg, this); } /*! * \brief Return nullptr to the remote. * \param err_msg The error message. */ - void ReturnVoid() { - RPCReference::ReturnVoid(this); - } + void ReturnVoid() { RPCReference::ReturnVoid(this); } /*! * \brief Return a packed sequence to the remote. @@ -389,7 +367,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // switch to the state before sending exception. this->SwitchToState(kRecvPacketNumBytes); std::string msg = args[0]; - LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg; + LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg; } CHECK(setreturn != nullptr) << "fsetreturn not available"; @@ -426,16 +404,14 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // When session is local, we can directly treat handle // as the cpu pointer without allocating a temp space. - if (ctx.device_type == kDLCPU && - sess->IsLocalSession() && - DMLC_IO_NO_ENDIAN_SWAP) { + if (ctx.device_type == kDLCPU && sess->IsLocalSession() && DMLC_IO_NO_ENDIAN_SWAP) { char* data_ptr = reinterpret_cast(handle) + offset; fcopyack(data_ptr, num_bytes); } else { char* data_ptr = this->ArenaAlloc(num_bytes); - auto on_copy_complete = [this, elem_bytes, num_bytes, data_ptr, fcopyack]( - RPCCode status, TVMArgs args) { + auto on_copy_complete = [this, elem_bytes, num_bytes, data_ptr, fcopyack](RPCCode status, + TVMArgs args) { if (status == RPCCode::kException) { this->ReturnException(args.values[0].v_str); this->SwitchToState(kRecvPacketNumBytes); @@ -449,11 +425,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { }; this->SwitchToState(kWaitForAsyncCallback); - sess->AsyncCopyFromRemote( - reinterpret_cast(handle), offset, - data_ptr, 0, - num_bytes, ctx, type_hint, - on_copy_complete); + sess->AsyncCopyFromRemote(reinterpret_cast(handle), offset, data_ptr, 0, num_bytes, + ctx, type_hint, on_copy_complete); } } @@ -474,14 +447,14 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // When session is local, we can directly treat handle // as the cpu pointer without allocating a temp space. if (ctx.device_type == kDLCPU && sess->IsLocalSession()) { - char* dptr = reinterpret_cast(handle) + offset; - this->ReadArray(dptr, num_bytes); - - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes); - } - this->ReturnVoid(); - this->SwitchToState(kRecvPacketNumBytes); + char* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); + + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes); + } + this->ReturnVoid(); + this->SwitchToState(kRecvPacketNumBytes); } else { char* temp_data = this->ArenaAlloc(num_bytes); this->ReadArray(temp_data, num_bytes); @@ -501,11 +474,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { }; this->SwitchToState(kWaitForAsyncCallback); - sess->AsyncCopyToRemote( - temp_data, 0, - reinterpret_cast(handle), offset, - num_bytes, ctx, type_hint, - on_copy_complete); + sess->AsyncCopyToRemote(temp_data, 0, reinterpret_cast(handle), offset, num_bytes, ctx, + type_hint, on_copy_complete); } } @@ -517,17 +487,16 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { TVMArgs args = RecvPackedSeq(); this->SwitchToState(kWaitForAsyncCallback); - GetServingSession()->AsyncCallFunc( - reinterpret_cast(call_handle), - args.values, args.type_codes, args.size(), - [this](RPCCode status, TVMArgs args) { - if (status == RPCCode::kException) { - this->ReturnException(args.values[0].v_str); - } else { - this->ReturnPackedSeq(args); - } - this->SwitchToState(kRecvPacketNumBytes); - }); + GetServingSession()->AsyncCallFunc(reinterpret_cast(call_handle), args.values, + args.type_codes, args.size(), + [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + } else { + this->ReturnPackedSeq(args); + } + this->SwitchToState(kRecvPacketNumBytes); + }); } void HandleInitServer() { @@ -541,8 +510,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { TVMArgs args = RecvPackedSeq(); try { - CHECK(serving_session_ == nullptr) - << "Server has already been initialized"; + CHECK(serving_session_ == nullptr) << "Server has already been initialized"; std::string server_protocol_ver = kRPCProtocolVer; CHECK_EQ(client_protocol_ver, server_protocol_ver) @@ -562,29 +530,26 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } auto* fconstructor = Registry::Get(constructor_name); - CHECK(fconstructor != nullptr) - << " Cannot find session constructor " << constructor_name; + CHECK(fconstructor != nullptr) << " Cannot find session constructor " << constructor_name; TVMRetValue con_ret; try { fconstructor->CallPacked(constructor_args, &con_ret); } catch (const dmlc::Error& e) { LOG(FATAL) << "Server[" << name_ << "]:" - << " Error caught from session constructor " << constructor_name - << ":\n" << e.what(); + << " Error caught from session constructor " << constructor_name << ":\n" + << e.what(); } CHECK_EQ(con_ret.type_code(), kTVMModuleHandle) << "Server[" << name_ << "]:" - << " Constructor " << constructor_name - << " need to return an RPCModule"; + << " Constructor " << constructor_name << " need to return an RPCModule"; Module mod = con_ret; std::string tkey = mod->type_key(); - CHECK_EQ(tkey, "rpc") - << "Constructor " << constructor_name << " to return an RPCModule"; + CHECK_EQ(tkey, "rpc") << "Constructor " << constructor_name << " to return an RPCModule"; serving_session_ = RPCModuleGetSession(mod); this->ReturnVoid(); - } catch (const std::runtime_error &e) { + } catch (const std::runtime_error& e) { this->ReturnException(e.what()); } @@ -598,15 +563,14 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { TVMStreamHandle handle = args[1]; this->SwitchToState(kWaitForAsyncCallback); - GetServingSession()->AsyncStreamWait( - ctx, handle, [this](RPCCode status, TVMArgs args) { - if (status == RPCCode::kException) { - this->ReturnException(args.values[0].v_str); - } else { - this->ReturnVoid(); - } - this->SwitchToState(kRecvPacketNumBytes); - }); + GetServingSession()->AsyncStreamWait(ctx, handle, [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + } else { + this->ReturnVoid(); + } + this->SwitchToState(kRecvPacketNumBytes); + }); } catch (const std::runtime_error& e) { this->ReturnException(e.what()); this->SwitchToState(kRecvPacketNumBytes); @@ -614,7 +578,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } // Handler for special syscalls that have a specific RPCCode. - template + template void SysCallHandler(F f) { TVMArgs args = RecvPackedSeq(); try { @@ -650,9 +614,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { return size; } // wriite the data to the channel. - void Write(const void* data, size_t size) final { - writer_->Write(data, size); - } + void Write(const void* data, size_t size) final { writer_->Write(data, size); } // Number of pending bytes requests size_t pending_request_bytes_{0}; // The ring buffer to read data from. @@ -669,23 +631,18 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { std::function flush_writer_; }; -RPCCode RPCEndpoint::HandleUntilReturnEvent( - bool client_mode, - RPCSession::FEncodeReturn setreturn) { +RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) { RPCCode code = RPCCode::kCallFunc; - while (code != RPCCode::kReturn && - code != RPCCode::kShutdown && - code != RPCCode::kCopyAck) { + while (code != RPCCode::kReturn && code != RPCCode::kShutdown && code != RPCCode::kCopyAck) { while (writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); + writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); } size_t bytes_needed = handler_->BytesNeeded(); if (bytes_needed != 0) { - size_t n = reader_.WriteWithCallback([this](void* data, size_t size) { - return channel_->Recv(data, size); - }, bytes_needed); + size_t n = reader_.WriteWithCallback( + [this](void* data, size_t size) { return channel_->Recv(data, size); }, bytes_needed); if (n == 0) { if (handler_->CanCleanShutdown()) { return RPCCode::kShutdown; @@ -703,27 +660,24 @@ void RPCEndpoint::Init() { // callback to flush the writer. auto flush_writer = [this]() { while (writer_.bytes_available() != 0) { - size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); + size_t n = writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); if (n == 0) break; } }; // Event handler - handler_ = std::make_shared( - &reader_, &writer_, name_, &remote_key_, flush_writer); + handler_ = std::make_shared(&reader_, &writer_, name_, &remote_key_, flush_writer); // Quick function to for syscall remote. syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) { std::lock_guard lock(mutex_); RPCCode code = static_cast(all_args[0].operator int()); - TVMArgs args(all_args.values + 1, all_args.type_codes +1, all_args.num_args -1); + TVMArgs args(all_args.values + 1, all_args.type_codes + 1, all_args.num_args - 1); - uint64_t packet_nbytes = - sizeof(code) + - handler_->PackedSeqGetNumBytes( - args.values, args.type_codes, args.num_args, true); + uint64_t packet_nbytes = sizeof(code) + handler_->PackedSeqGetNumBytes( + args.values, args.type_codes, args.num_args, true); // All packet begins with packet nbytes handler_->Write(packet_nbytes); @@ -738,10 +692,8 @@ void RPCEndpoint::Init() { }); } -std::shared_ptr RPCEndpoint::Create( - std::unique_ptr channel, - std::string name, - std::string remote_key) { +std::shared_ptr RPCEndpoint::Create(std::unique_ptr channel, + std::string name, std::string remote_key) { std::shared_ptr endpt = std::make_shared(); endpt->channel_ = std::move(channel); endpt->name_ = std::move(name); @@ -750,9 +702,7 @@ std::shared_ptr RPCEndpoint::Create( return endpt; } -RPCEndpoint::~RPCEndpoint() { - this->Shutdown(); -} +RPCEndpoint::~RPCEndpoint() { this->Shutdown(); } void RPCEndpoint::Shutdown() { if (channel_ != nullptr) { @@ -765,9 +715,9 @@ void RPCEndpoint::Shutdown() { // flush all writing buffer to output channel. try { while (writer_.bytes_available() != 0) { - size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); + size_t n = writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); if (n == 0) break; } } catch (const dmlc::Error& e) { @@ -795,9 +745,9 @@ int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int even code = handler_->HandleNextEvent(false, true, [](TVMArgs) {}); } if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); + writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); } CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); if (code == RPCCode::kShutdown) return 0; @@ -812,11 +762,8 @@ void RPCEndpoint::InitRemoteSession(TVMArgs args) { uint64_t length = protocol_ver.length(); uint64_t packet_nbytes = - sizeof(code) + - sizeof(length) + - length + - handler_->PackedSeqGetNumBytes( - args.values, args.type_codes, args.num_args, true); + sizeof(code) + sizeof(length) + length + + handler_->PackedSeqGetNumBytes(args.values, args.type_codes, args.num_args, true); // All packet begins with packet nbytes handler_->Write(packet_nbytes); @@ -830,10 +777,8 @@ void RPCEndpoint::InitRemoteSession(TVMArgs args) { } // Get remote function with name -void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, +void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, RPCSession::FEncodeReturn encode_return) { std::lock_guard lock(mutex_); @@ -842,42 +787,28 @@ void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, uint64_t handle = reinterpret_cast(h); uint64_t packet_nbytes = - sizeof(code) + - sizeof(handle) + - handler_->PackedSeqGetNumBytes( - arg_values, arg_type_codes, num_args, true); + sizeof(code) + sizeof(handle) + + handler_->PackedSeqGetNumBytes(arg_values, arg_type_codes, num_args, true); handler_->Write(packet_nbytes); handler_->Write(code); handler_->Write(handle); - handler_->SendPackedSeq( - arg_values, arg_type_codes, num_args, true); + handler_->SendPackedSeq(arg_values, arg_type_codes, num_args, true); code = HandleUntilReturnEvent(true, encode_return); CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); } -void RPCEndpoint::CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_to, - DLDataType type_hint) { +void RPCEndpoint::CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t data_size, TVMContext ctx_to, DLDataType type_hint) { std::lock_guard lock(mutex_); RPCCode code = RPCCode::kCopyToRemote; uint64_t handle = reinterpret_cast(to); uint64_t offset = static_cast(to_offset); uint64_t size = static_cast(data_size); - uint64_t packet_nbytes = - sizeof(code) + - sizeof(handle) + - sizeof(offset) + - sizeof(size) + - sizeof(ctx_to) + - sizeof(type_hint) + - data_size; + uint64_t packet_nbytes = sizeof(code) + sizeof(handle) + sizeof(offset) + sizeof(size) + + sizeof(ctx_to) + sizeof(type_hint) + data_size; handler_->Write(packet_nbytes); handler_->Write(code); @@ -888,29 +819,19 @@ void RPCEndpoint::CopyToRemote(void* from, handler_->Write(type_hint); handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); - CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kReturn); + CHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kReturn); } -void RPCEndpoint::CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_from, - DLDataType type_hint) { +void RPCEndpoint::CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t data_size, TVMContext ctx_from, DLDataType type_hint) { std::lock_guard lock(mutex_); RPCCode code = RPCCode::kCopyFromRemote; uint64_t handle = reinterpret_cast(from); uint64_t offset = static_cast(from_offset); uint64_t size = static_cast(data_size); - uint64_t packet_nbytes = - sizeof(code) + - sizeof(handle) + - sizeof(offset) + - sizeof(size) + - sizeof(ctx_from) + - sizeof(type_hint); + uint64_t packet_nbytes = sizeof(code) + sizeof(handle) + sizeof(offset) + sizeof(size) + + sizeof(ctx_from) + sizeof(type_hint); handler_->Write(packet_nbytes); handler_->Write(code); @@ -921,7 +842,7 @@ void RPCEndpoint::CopyFromRemote(void* from, handler_->Write(type_hint); TVMRetValue rv; - CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kCopyAck); + CHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kCopyAck); handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); handler_->FinishCopyAck(); } @@ -932,18 +853,18 @@ void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { *rv = handler->GetFunction(name); } -void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { +void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { void* handle = args[0]; int type_code = args[1]; handler->FreeHandle(handle, type_code); } -void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { +void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { TVMContext ctx = args[0]; handler->GetDeviceAPI(ctx)->SetDevice(ctx); } -void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { +void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { TVMContext ctx = args[0]; DeviceAttrKind kind = static_cast(args[1].operator int()); if (kind == kExist) { @@ -954,28 +875,26 @@ void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { *rv = 0; } } else { - handler->GetDeviceAPI(ctx)->GetAttr( - ctx, static_cast(kind), rv); + handler->GetDeviceAPI(ctx)->GetAttr(ctx, static_cast(kind), rv); } } -void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { +void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { TVMContext ctx = args[0]; uint64_t nbytes = args[1]; uint64_t alignment = args[2]; DLDataType type_hint = args[3]; - void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace( - ctx, nbytes, alignment, type_hint); + void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace(ctx, nbytes, alignment, type_hint); *rv = data; } -void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { +void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { TVMContext ctx = args[0]; void* ptr = args[1]; handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr); } -void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { +void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { void* from = args[0]; uint64_t from_offset = args[1]; void* to = args[2]; @@ -990,29 +909,43 @@ void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { if (ctx.device_type == kDLCPU) { ctx = ctx_to; } else { - CHECK(ctx_to.device_type == kDLCPU || - ctx_to.device_type == ctx_from.device_type) + CHECK(ctx_to.device_type == kDLCPU || ctx_to.device_type == ctx_from.device_type) << "Can not copy across different ctx types directly"; } - handler->GetDeviceAPI(ctx)->CopyDataFromTo( - from, from_offset, - to, to_offset, - size, ctx_from, ctx_to, type_hint, stream); + handler->GetDeviceAPI(ctx)->CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, + ctx_to, type_hint, stream); } void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { // Event handler sit at clean state at this point. switch (code) { // system functions - case RPCCode::kFreeHandle: SysCallHandler(RPCFreeHandle); break; - case RPCCode::kGetGlobalFunc: SysCallHandler(RPCGetGlobalFunc); break; - case RPCCode::kDevSetDevice: SysCallHandler(RPCDevSetDevice); break; - case RPCCode::kDevGetAttr: SysCallHandler(RPCDevGetAttr); break; - case RPCCode::kDevAllocData: SysCallHandler(RPCDevAllocData); break; - case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break; - case RPCCode::kDevStreamSync: this->HandleSyscallStreamSync(); break; - case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break; - default: LOG(FATAL) << "Unknown event " << static_cast(code); + case RPCCode::kFreeHandle: + SysCallHandler(RPCFreeHandle); + break; + case RPCCode::kGetGlobalFunc: + SysCallHandler(RPCGetGlobalFunc); + break; + case RPCCode::kDevSetDevice: + SysCallHandler(RPCDevSetDevice); + break; + case RPCCode::kDevGetAttr: + SysCallHandler(RPCDevGetAttr); + break; + case RPCCode::kDevAllocData: + SysCallHandler(RPCDevAllocData); + break; + case RPCCode::kDevFreeData: + SysCallHandler(RPCDevFreeData); + break; + case RPCCode::kDevStreamSync: + this->HandleSyscallStreamSync(); + break; + case RPCCode::kCopyAmongRemote: + SysCallHandler(RPCCopyAmongRemote); + break; + default: + LOG(FATAL) << "Unknown event " << static_cast(code); } if (state_ != kWaitForAsyncCallback) { @@ -1023,59 +956,38 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { /*! * \brief RPC client session that proxies all calls to an endpoint. */ -class RPCClientSession : public RPCSession, - public DeviceAPI { +class RPCClientSession : public RPCSession, public DeviceAPI { public: /*! * \brief param endpoint The client endpoint of the session. */ - explicit RPCClientSession(std::shared_ptr endpoint) - : endpoint_(endpoint) {} + explicit RPCClientSession(std::shared_ptr endpoint) : endpoint_(endpoint) {} // function overrides PackedFuncHandle GetFunction(const std::string& name) final { return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name); } - void CallFunc(PackedFuncHandle func, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, - const FEncodeReturn& fencode_return) final { - endpoint_->CallFunc( - func, arg_values, arg_type_codes, num_args, fencode_return); + void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, + int num_args, const FEncodeReturn& fencode_return) final { + endpoint_->CallFunc(func, arg_values, arg_type_codes, num_args, fencode_return); } - void CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_to, - DLDataType type_hint) final { - endpoint_->CopyToRemote( - from, from_offset, to, to_offset, nbytes, ctx_to, type_hint); + void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_to, DLDataType type_hint) final { + endpoint_->CopyToRemote(from, from_offset, to, to_offset, nbytes, ctx_to, type_hint); } - void CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_from, - DLDataType type_hint) final { - endpoint_->CopyFromRemote( - from, from_offset, to, to_offset, nbytes, ctx_from, type_hint); + void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_from, DLDataType type_hint) final { + endpoint_->CopyFromRemote(from, from_offset, to, to_offset, nbytes, ctx_from, type_hint); } void FreeHandle(void* handle, int type_code) final { endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code); } - - void SetDevice(TVMContext ctx) final { - endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx); - } + void SetDevice(TVMContext ctx) final { endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx); } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { if (ctx.device_type == kDLCPU && kind == kExist) { @@ -1086,54 +998,35 @@ class RPCClientSession : public RPCSession, } } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { - return endpoint_->SysCallRemote( - RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + return endpoint_->SysCallRemote(RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); } void FreeDataSpace(TVMContext ctx, void* ptr) final { endpoint_->SysCallRemote(RPCCode::kDevFreeData, ctx, ptr); } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { - endpoint_->SysCallRemote( - RPCCode::kCopyAmongRemote, - const_cast(from), from_offset, - to, to_offset, - size, - ctx_from, ctx_to, - type_hint, stream); + endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, const_cast(from), from_offset, to, + to_offset, size, ctx_from, ctx_to, type_hint, stream); } void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { endpoint_->SysCallRemote(RPCCode::kDevStreamSync, ctx, stream); } - DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final { - return this; - } + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final { return this; } - bool IsLocalSession() const final { - return false; - } + bool IsLocalSession() const final { return false; } private: std::shared_ptr endpoint_; }; -std::shared_ptr -CreateClientSession(std::shared_ptr endpoint) { +std::shared_ptr CreateClientSession(std::shared_ptr endpoint) { return std::make_shared(endpoint); } diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index 9a6afcd..2b88cee 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -25,14 +25,16 @@ #define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ #include + +#include #include #include -#include #include -#include "rpc_session.h" + +#include "../../support/ring_buffer.h" #include "rpc_channel.h" #include "rpc_protocol.h" -#include "../../support/ring_buffer.h" +#include "rpc_session.h" namespace tvm { namespace runtime { @@ -59,7 +61,6 @@ enum class TrackerCode : int { kGetPendingMatchKeys = 7 }; - /*! * \brief Communication endpoints to connect local and remote RPC sessions. * An endpoint can either be a client or a server. @@ -122,11 +123,8 @@ class RPCEndpoint { * \param num_args Number of arguments. * \param fencode_return The function to receive return value encodings. */ - void CallFunc(RPCSession::PackedFuncHandle handle, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, - RPCSession::FEncodeReturn encode_return); + void CallFunc(RPCSession::PackedFuncHandle handle, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, RPCSession::FEncodeReturn encode_return); /*! * \brief Copy bytes into remote array content. * \param from The source host data. @@ -137,13 +135,8 @@ class RPCEndpoint { * \param ctx_to The target context. * \param type_hint Hint of content data type. */ - void CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_to, - DLDataType type_hint); + void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_to, DLDataType type_hint); /*! * \brief Copy bytes from remote array content. * \param from The source host data. @@ -154,13 +147,8 @@ class RPCEndpoint { * \param ctx_from The source context. * \param type_hint Hint of content data type. */ - void CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_from, - DLDataType type_hint); + void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_from, DLDataType type_hint); /*! * \brief Call a remote defined system function with arguments. @@ -168,8 +156,8 @@ class RPCEndpoint { * \param args The arguments * \return The returned remote value. */ - template - inline TVMRetValue SysCallRemote(RPCCode fcode, Args&& ...args); + template + inline TVMRetValue SysCallRemote(RPCCode fcode, Args&&... args); /*! * \brief Create a RPC session with given channel. * \param channel The communication channel. @@ -178,10 +166,8 @@ class RPCEndpoint { * if remote_key equals "%toinit", we need to re-intialize * it by event handler. */ - static std::shared_ptr Create( - std::unique_ptr channel, - std::string name, - std::string remote_key); + static std::shared_ptr Create(std::unique_ptr channel, std::string name, + std::string remote_key); private: class EventHandler; @@ -213,12 +199,11 @@ class RPCEndpoint { * \param endpoint The endpoint. * \return The created session. */ -std::shared_ptr -CreateClientSession(std::shared_ptr endpoint); +std::shared_ptr CreateClientSession(std::shared_ptr endpoint); // implementation of inline functions -template -inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&& ...args) { +template +inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) { return syscall_remote_(static_cast(code), std::forward(args)...); } } // namespace runtime diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 284dca5..f5b933f 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -22,31 +22,29 @@ * \brief Event driven RPC server implementation. */ #include + #include + #include "rpc_endpoint.h" #include "rpc_local_session.h" namespace tvm { namespace runtime { -PackedFunc CreateEventDrivenServer(PackedFunc fsend, - std::string name, - std::string remote_key) { +PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name, std::string remote_key) { static PackedFunc frecv([](TVMArgs args, TVMRetValue* rv) { LOG(FATAL) << "Do not allow explicit receive"; return 0; }); std::unique_ptr ch(new CallbackChannel(fsend, frecv)); - std::shared_ptr sess = - RPCEndpoint::Create(std::move(ch), name, remote_key); + std::shared_ptr sess = RPCEndpoint::Create(std::move(ch), name, remote_key); return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]); - *rv = ret; - }); + int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]); + *rv = ret; + }); } -TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer") -.set_body_typed(CreateEventDrivenServer); +TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer").set_body_typed(CreateEventDrivenServer); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 9d1fb72..b35c62d 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -21,16 +21,17 @@ * \file local_session.cc * \brief Local session that directs requests to local API. */ -#include +#include "rpc_local_session.h" + #include +#include + #include -#include "rpc_local_session.h" namespace tvm { namespace runtime { -RPCSession::PackedFuncHandle -LocalSession::GetFunction(const std::string& name) { +RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) { if (auto* fp = tvm::runtime::Registry::Get(name)) { // return raw handle because the remote need to explicitly manage it. return new PackedFunc(*fp); @@ -58,8 +59,7 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_retu ret_value_pack[2].v_handle = ret_value_pack[1].v_handle; ret_tcode_pack[2] = kTVMOpaqueHandle; encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3)); - } else if (rv_tcode == kTVMPackedFuncHandle || - rv_tcode == kTVMModuleHandle) { + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { // MoveToCHost means rv no longer manages the object. // return handle instead. rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); @@ -78,10 +78,8 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_retu } } -void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, +void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, const FEncodeReturn& encode_return) { auto* pf = static_cast(func); TVMRetValue rv; @@ -89,40 +87,26 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, this->EncodeReturn(std::move(rv), encode_return); } -void LocalSession::CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_to, - DLDataType type_hint) { +void LocalSession::CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t nbytes, TVMContext ctx_to, DLDataType type_hint) { TVMContext cpu_ctx; cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; - this->GetDeviceAPI(ctx_to)->CopyDataFromTo( - from, from_offset, - to, to_offset, - nbytes, cpu_ctx, ctx_to, type_hint, nullptr); + this->GetDeviceAPI(ctx_to)->CopyDataFromTo(from, from_offset, to, to_offset, nbytes, cpu_ctx, + ctx_to, type_hint, nullptr); // Copy can happen asynchrously // synchronize to make sure that copy is completed this->GetDeviceAPI(ctx_to)->StreamSync(ctx_to, nullptr); } -void LocalSession::CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_from, - DLDataType type_hint) { +void LocalSession::CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t nbytes, TVMContext ctx_from, DLDataType type_hint) { TVMContext cpu_ctx; cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; - this->GetDeviceAPI(ctx_from)->CopyDataFromTo( - from, from_offset, - to, to_offset, - nbytes, ctx_from, cpu_ctx, type_hint, nullptr); + this->GetDeviceAPI(ctx_from)->CopyDataFromTo(from, from_offset, to, to_offset, nbytes, ctx_from, + cpu_ctx, type_hint, nullptr); // Copy can happen asynchrously // synchronize to make sure that copy is completed this->GetDeviceAPI(ctx_from)->StreamSync(ctx_from, nullptr); @@ -139,8 +123,7 @@ DeviceAPI* LocalSession::GetDeviceAPI(TVMContext ctx, bool allow_missing) { return DeviceAPI::Get(ctx, allow_missing); } -TVM_REGISTER_GLOBAL("rpc.LocalSession") -.set_body_typed([]() { +TVM_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() { return CreateRPCSessionModule(std::make_shared()); }); diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h index ff0caa4..7a67ce8 100644 --- a/src/runtime/rpc/rpc_local_session.h +++ b/src/runtime/rpc/rpc_local_session.h @@ -24,11 +24,13 @@ #ifndef TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ -#include #include +#include + #include #include #include + #include "rpc_session.h" namespace tvm { @@ -43,35 +45,20 @@ class LocalSession : public RPCSession { // function overrides PackedFuncHandle GetFunction(const std::string& name) override; - void CallFunc(PackedFuncHandle func, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, - const FEncodeReturn& fencode_return) override; + void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, + int num_args, const FEncodeReturn& fencode_return) override; - void CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_to, - DLDataType type_hint) override; + void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_to, DLDataType type_hint) override; - void CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_from, - DLDataType type_hint) override; + void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_from, DLDataType type_hint) override; void FreeHandle(void* handle, int type_code) override; DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) override; - bool IsLocalSession() const override { - return true; - } + bool IsLocalSession() const override { return true; } protected: /*! diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 1062304..8c46269 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -21,10 +21,12 @@ * \file rpc_module.cc * \brief RPC runtime module. */ -#include #include -#include +#include + #include +#include + #include "rpc_endpoint.h" #include "rpc_session.h" @@ -36,10 +38,7 @@ namespace runtime { */ class RPCWrappedFunc : public Object { public: - RPCWrappedFunc(void* handle, - std::shared_ptr sess) - : handle_(handle), sess_(sess) { - } + RPCWrappedFunc(void* handle, std::shared_ptr sess) : handle_(handle), sess_(sess) {} void operator()(TVMArgs args, TVMRetValue* rv) const { std::vector values(args.values, args.values + args.size()); @@ -58,8 +57,7 @@ class RPCWrappedFunc : public Object { // are compatible to each other, just need to change the index. type_codes[i] = kTVMDLTensorHandle; // translate to a remote view of DLTensor - auto dptr = std::make_unique( - *static_cast(values[i].v_handle)); + auto dptr = std::make_unique(*static_cast(values[i].v_handle)); dptr->ctx = RemoveSessMask(dptr->ctx); dptr->data = static_cast(dptr->data)->data; values[i].v_handle = dptr.get(); @@ -72,17 +70,13 @@ class RPCWrappedFunc : public Object { } case kTVMPackedFuncHandle: case kTVMModuleHandle: { - values[i].v_handle = UnwrapRemoteValueToHandle( - TVMArgValue(values[i], tcode)); + values[i].v_handle = UnwrapRemoteValueToHandle(TVMArgValue(values[i], tcode)); break; } } } - auto set_return = [this, rv](TVMArgs args) { - this->WrapRemoteReturnToValue(args, rv); - }; - sess_->CallFunc(handle_, values.data(), type_codes.data(), - args.size(), set_return); + auto set_return = [this, rv](TVMArgs args) { this->WrapRemoteReturnToValue(args, rv); }; + sess_->CallFunc(handle_, values.data(), type_codes.data(), args.size(), set_return); } ~RPCWrappedFunc() { @@ -133,8 +127,7 @@ class RPCWrappedFunc : public Object { data->dl_tensor.data = space; NDArray ret(GetObjectPtr(data)); // RAII now in effect - data->shape_ = std::vector( - tensor->shape, tensor->shape + tensor->ndim); + data->shape_ = std::vector(tensor->shape, tensor->shape + tensor->ndim); data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); data->dl_tensor.ndim = static_cast(data->shape_.size()); // setup dtype @@ -142,8 +135,7 @@ class RPCWrappedFunc : public Object { // setup ctx, encode as remote session data->dl_tensor.ctx.device_id = tensor->ctx.device_id; data->dl_tensor.ctx.device_type = static_cast( - static_cast(tensor->ctx.device_type) + - kRPCSessMask * (sess_->table_index() + 1)); + static_cast(tensor->ctx.device_type) + kRPCSessMask * (sess_->table_index() + 1)); // check strides. CHECK(tensor->strides == nullptr); // setup byteoffset @@ -156,8 +148,7 @@ class RPCWrappedFunc : public Object { class RPCModuleNode final : public ModuleNode { public: RPCModuleNode(void* module_handle, std::shared_ptr sess) - : module_handle_(module_handle), sess_(sess) { - } + : module_handle_(module_handle), sess_(sess) {} ~RPCModuleNode() { if (module_handle_ != nullptr) { @@ -170,13 +161,9 @@ class RPCModuleNode final : public ModuleNode { } } - const char* type_key() const final { - return "rpc"; - } + const char* type_key() const final { return "rpc"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (module_handle_ == nullptr) { return WrapRemoteFunc(sess_->GetFunction(name)); } else { @@ -190,10 +177,7 @@ class RPCModuleNode final : public ModuleNode { return ""; } - PackedFunc GetTimeEvaluator(const std::string& name, - TVMContext ctx, - int number, - int repeat, + PackedFunc GetTimeEvaluator(const std::string& name, TVMContext ctx, int number, int repeat, int min_repeat_ms) { InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator"); // Remove session mask because we pass ctx by parts. @@ -203,15 +187,13 @@ class RPCModuleNode final : public ModuleNode { ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); if (module_handle_ != nullptr) { - return remote_get_time_evaluator_( - GetRef(this), name, - static_cast(ctx.device_type), ctx.device_id, - number, repeat, min_repeat_ms); + return remote_get_time_evaluator_(GetRef(this), name, + static_cast(ctx.device_type), ctx.device_id, number, + repeat, min_repeat_ms); } else { - return remote_get_time_evaluator_( - Optional(nullptr), name, - static_cast(ctx.device_type), ctx.device_id, - number, repeat, min_repeat_ms); + return remote_get_time_evaluator_(Optional(nullptr), name, + static_cast(ctx.device_type), ctx.device_id, number, + repeat, min_repeat_ms); } } @@ -225,16 +207,12 @@ class RPCModuleNode final : public ModuleNode { remote_import_module_(GetRef(this), other); } - const std::shared_ptr& sess() { - return sess_; - } + const std::shared_ptr& sess() { return sess_; } - void* module_handle() const { - return module_handle_; - } + void* module_handle() const { return module_handle_; } private: - template + template void InitRemoteFunc(FType* func, const std::string& name) { if (*func != nullptr) return; RPCSession::PackedFuncHandle handle = sess_->GetFunction(name); @@ -245,9 +223,7 @@ class RPCModuleNode final : public ModuleNode { PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) { if (handle == nullptr) return PackedFunc(); auto wf = std::make_shared(handle, sess_); - return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); } // The module handle @@ -256,7 +232,7 @@ class RPCModuleNode final : public ModuleNode { std::shared_ptr sess_; // remote function to get time evaluator TypedPackedFunc, std::string, int, int, int, int, int)> - remote_get_time_evaluator_; + remote_get_time_evaluator_; // remote function getter for modules. TypedPackedFunc remote_mod_get_function_; // remote function getter for load module @@ -265,28 +241,23 @@ class RPCModuleNode final : public ModuleNode { TypedPackedFunc remote_import_module_; }; - void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const { if (arg.type_code() == kTVMModuleHandle) { Module mod = arg; std::string tkey = mod->type_key(); - CHECK_EQ(tkey, "rpc") - << "ValueError: Cannot pass a non-RPC module to remote"; + CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); CHECK(rmod->sess() == sess_) << "ValueError: Cannot pass in module into a different remote session"; return rmod->module_handle(); } else { - LOG(FATAL) << "ValueError: Cannot pass type " - << runtime::TypeCode2Str(arg.type_code()) + LOG(FATAL) << "ValueError: Cannot pass type " << runtime::TypeCode2Str(arg.type_code()) << " as an argument to the remote"; return nullptr; } } -void RPCWrappedFunc::WrapRemoteReturnToValue( - TVMArgs args, - TVMRetValue *rv) const { +void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const { int tcode = args[0]; if (tcode == kTVMNullptr) return; @@ -294,9 +265,7 @@ void RPCWrappedFunc::WrapRemoteReturnToValue( CHECK_EQ(args.size(), 2); void* handle = args[1]; auto wf = std::make_shared(handle, sess_); - *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); } else if (tcode == kTVMModuleHandle) { CHECK_EQ(args.size(), 2); void* handle = args[1]; @@ -321,16 +290,12 @@ Module CreateRPCSessionModule(std::shared_ptr sess) { std::shared_ptr RPCModuleGetSession(Module mod) { std::string tkey = mod->type_key(); - CHECK_EQ(tkey, "rpc") - << "ValueError: Cannot pass a non-RPC module to remote"; + CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); return rmod->sess(); } -PackedFunc WrapTimeEvaluator(PackedFunc pf, - TVMContext ctx, - int number, - int repeat, +PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, int min_repeat_ms) { CHECK(pf != nullptr); @@ -340,8 +305,7 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, return (*get_micro_time_evaluator)(pf, ctx, number, repeat); } - auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) - mutable { + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) mutable { TVMRetValue temp; std::ostringstream os; // skip first time call, to activate lazy compilation components. @@ -350,15 +314,14 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); for (int i = 0; i < repeat; ++i) { - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; + std::chrono::time_point tbegin, + tend; double duration_ms = 0.0; do { if (duration_ms > 0.0) { - number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random + number = static_cast(std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random } tbegin = std::chrono::high_resolution_clock::now(); @@ -369,12 +332,12 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); tend = std::chrono::high_resolution_clock::now(); - duration_ms = std::chrono::duration_cast > - (tend - tbegin).count() * 1000; + duration_ms = + std::chrono::duration_cast>(tend - tbegin).count() * 1000; } while (duration_ms < min_repeat_ms); - double speed = std::chrono::duration_cast >( - tend - tbegin).count() / number; + double speed = + std::chrono::duration_cast>(tend - tbegin).count() / number; os.write(reinterpret_cast(&speed), sizeof(speed)); } @@ -388,64 +351,52 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, return PackedFunc(ftimer); } - TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") -.set_body_typed([](Optional opt_mod, - std::string name, - int device_type, - int device_id, - int number, - int repeat, - int min_repeat_ms) { - TVMContext ctx; - ctx.device_type = static_cast(device_type); - ctx.device_id = device_id; - if (opt_mod.defined()) { - Module m = opt_mod.value(); - std::string tkey = m->type_key(); - if (tkey == "rpc") { - return static_cast(m.operator->()) - ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms); - } else { - return WrapTimeEvaluator( - m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); - } - } else { - auto* pf = runtime::Registry::Get(name); - CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; - return WrapTimeEvaluator( - *pf, ctx, number, repeat, min_repeat_ms); - } -}); + .set_body_typed([](Optional opt_mod, std::string name, int device_type, int device_id, + int number, int repeat, int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + if (tkey == "rpc") { + return static_cast(m.operator->()) + ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms); + } else { + return WrapTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + } + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms); + } + }); // server function registration. -TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule") -.set_body_typed([](Module parent, Module child) { +TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule").set_body_typed([](Module parent, Module child) { parent->Import(child); }); TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") -.set_body_typed([](Module parent, std::string name, bool query_imports) { - return parent->GetFunction(name, query_imports); -}); + .set_body_typed([](Module parent, std::string name, bool query_imports) { + return parent->GetFunction(name, query_imports); + }); // functions to access an RPC module. -TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule") -.set_body_typed([](Module sess, std::string name) { +TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) { std::string tkey = sess->type_key(); CHECK_EQ(tkey, "rpc"); return static_cast(sess.operator->())->LoadModule(name); }); -TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule") -.set_body_typed([](Module parent, Module child) { +TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) { std::string tkey = parent->type_key(); CHECK_EQ(tkey, "rpc"); static_cast(parent.operator->())->ImportModule(child); }); -TVM_REGISTER_GLOBAL("rpc.SessTableIndex") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body([](TVMArgs args, TVMRetValue* rv) { Module m = args[0]; std::string tkey = m->type_key(); CHECK_EQ(tkey, "rpc"); diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc index 376b8b5..2f42435 100644 --- a/src/runtime/rpc/rpc_pipe_impl.cc +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -24,18 +24,18 @@ // Linux only for now, as linux is the most common usecase. #if defined(__linux__) || defined(__ANDROID__) -#include -#include #include #include - +#include #include -#include +#include + #include +#include +#include "../../support/pipe.h" #include "rpc_endpoint.h" #include "rpc_local_session.h" -#include "../../support/pipe.h" namespace tvm { namespace runtime { @@ -43,12 +43,9 @@ namespace runtime { class PipeChannel final : public RPCChannel { public: explicit PipeChannel(int readfd, int writefd, pid_t child_pid) - : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) { - } + : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) {} - ~PipeChannel() { - Close(); - } + ~PipeChannel() { Close(); } size_t Send(const void* data, size_t size) final { ssize_t n = write(writefd_, data, size); @@ -78,7 +75,6 @@ class PipeChannel final : public RPCChannel { pid_t child_pid_; }; - Module CreatePipeClient(std::vector cmd) { int parent2child[2]; int child2parent[2]; @@ -111,15 +107,13 @@ Module CreatePipeClient(std::vector cmd) { close(child_write); auto endpt = RPCEndpoint::Create( - std::unique_ptr( - new PipeChannel(parent_read, parent_write, pid)), - "pipe", "pipe"); + std::unique_ptr(new PipeChannel(parent_read, parent_write, pid)), "pipe", + "pipe"); endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0)); return CreateRPCSessionModule(CreateClientSession(endpt)); } -TVM_REGISTER_GLOBAL("rpc.CreatePipeClient") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body([](TVMArgs args, TVMRetValue* rv) { std::vector cmd; for (int i = 0; i < args.size(); ++i) { cmd.push_back(args[i].operator std::string()); @@ -127,7 +121,6 @@ TVM_REGISTER_GLOBAL("rpc.CreatePipeClient") *rv = CreatePipeClient(cmd); }); - } // namespace runtime } // namespace tvm #endif diff --git a/src/runtime/rpc/rpc_protocol.h b/src/runtime/rpc/rpc_protocol.h index 6221bfb..3a0555d 100644 --- a/src/runtime/rpc/rpc_protocol.h +++ b/src/runtime/rpc/rpc_protocol.h @@ -79,22 +79,35 @@ enum class RPCServerStatus : int { */ inline const char* RPCServerStatusToString(RPCServerStatus status) { switch (status) { - case RPCServerStatus::kSuccess: return "kSuccess"; - case RPCServerStatus::kInvalidTypeCodeObject: return "kInvalidTypeCodeObject"; - case RPCServerStatus::kInvalidTypeCodeNDArray: return "kInvalidTypeCodeNDArray"; - case RPCServerStatus::kInvalidDLTensorFieldStride: return "kInvalidDLTensorFieldStride"; + case RPCServerStatus::kSuccess: + return "kSuccess"; + case RPCServerStatus::kInvalidTypeCodeObject: + return "kInvalidTypeCodeObject"; + case RPCServerStatus::kInvalidTypeCodeNDArray: + return "kInvalidTypeCodeNDArray"; + case RPCServerStatus::kInvalidDLTensorFieldStride: + return "kInvalidDLTensorFieldStride"; case RPCServerStatus::kInvalidDLTensorFieldByteOffset: { return "kInvalidDLTensorFieldByteOffset"; } - case RPCServerStatus::kUnknownTypeCode: return "kUnknownTypeCode"; - case RPCServerStatus::kUnknownRPCCode: return "kUnknownRPCCode"; - case RPCServerStatus::kRPCCodeNotSupported: return "RPCCodeNotSupported"; - case RPCServerStatus::kUnknownRPCSyscall: return "kUnknownRPCSyscall"; - case RPCServerStatus::kCheckError: return "kCheckError"; - case RPCServerStatus::kReadError: return "kReadError"; - case RPCServerStatus::kWriteError: return "kWriteError"; - case RPCServerStatus::kAllocError: return "kAllocError"; - default: return ""; + case RPCServerStatus::kUnknownTypeCode: + return "kUnknownTypeCode"; + case RPCServerStatus::kUnknownRPCCode: + return "kUnknownRPCCode"; + case RPCServerStatus::kRPCCodeNotSupported: + return "RPCCodeNotSupported"; + case RPCServerStatus::kUnknownRPCSyscall: + return "kUnknownRPCSyscall"; + case RPCServerStatus::kCheckError: + return "kCheckError"; + case RPCServerStatus::kReadError: + return "kReadError"; + case RPCServerStatus::kWriteError: + return "kWriteError"; + case RPCServerStatus::kAllocError: + return "kAllocError"; + default: + return ""; } } @@ -111,11 +124,10 @@ struct RPCReference { * \brief Auxiliary class to get the packed sequence. * \tparam TChannel The channel to throw errror. */ - template + template struct PackedSeqNumBytesGetter { public: - explicit PackedSeqNumBytesGetter(TChannel* channel) - : channel_(channel) {} + explicit PackedSeqNumBytesGetter(TChannel* channel) : channel_(channel) {} template void Write(const T& value) { @@ -127,13 +139,9 @@ struct RPCReference { num_bytes_ += sizeof(T) * num; } - void ThrowError(RPCServerStatus status) { - channel_->ThrowError(status); - } + void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); } - uint64_t num_bytes() const { - return num_bytes_; - } + uint64_t num_bytes() const { return num_bytes_; } private: TChannel* channel_; @@ -162,12 +170,9 @@ struct RPCReference { * \tparam TChannel The type of the communication channel. * \return The total number of bytes. */ - template - static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, - const int* type_codes, - int num_args, - bool client_mode, - TChannel* channel) { + template + static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, + int num_args, bool client_mode, TChannel* channel) { PackedSeqNumBytesGetter getter(channel); SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter); return getter.num_bytes(); @@ -196,12 +201,9 @@ struct RPCReference { * \param channel The communication channel handler. * \tparam TChannel The type of the communication channel. */ - template - static void SendPackedSeq(const TVMValue* arg_values, - const int* type_codes, - int num_args, - bool client_mode, - TChannel* channel) { + template + static void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, + bool client_mode, TChannel* channel) { channel->Write(num_args); channel->WriteArray(type_codes, num_args); @@ -270,7 +272,8 @@ struct RPCReference { } break; } - case kTVMNullptr: break; + case kTVMNullptr: + break; case kTVMStr: { const char* s = value.v_str; uint64_t len = StrLength(s); @@ -303,10 +306,8 @@ struct RPCReference { * \tparam TChannel The type of the communication channel. * \note The temporary space are populated via an arena inside channel. */ - template - static void RecvPackedSeq(TVMValue** out_values, - int** out_tcodes, - int* out_num_args, + template + static void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args, TChannel* channel) { // receive number of args int num_args; @@ -411,19 +412,14 @@ struct RPCReference { * \param channel The communication channel handler. * \tparam TChannel The type of the communication channel. */ - template + template static void ReturnException(const char* msg, TChannel* channel) { RPCCode code = RPCCode::kException; int32_t num_args = 1; int32_t tcode = kTVMStr; uint64_t len = StrLength(msg); - uint64_t packet_nbytes = - sizeof(code) + - sizeof(num_args) + - sizeof(tcode) + - sizeof(len) + - len; + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(len) + len; channel->Write(packet_nbytes); channel->Write(code); @@ -440,22 +436,17 @@ struct RPCReference { * \param channel The communication channel handler. * \tparam TChannel The type of the communication channel. */ - template - static void ReturnPackedSeq(const TVMValue* arg_values, - const int* type_codes, - int num_args, + template + static void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, TChannel* channel) { RPCCode code = RPCCode::kReturn; uint64_t packet_nbytes = - sizeof(code) + - PackedSeqGetNumBytes( - arg_values, type_codes, num_args, false, channel); + sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args, false, channel); channel->Write(packet_nbytes); channel->Write(code); - SendPackedSeq( - arg_values, type_codes, num_args, false, channel); + SendPackedSeq(arg_values, type_codes, num_args, false, channel); } /*! @@ -464,16 +455,13 @@ struct RPCReference { * \param channel The communication channel handler. * \tparam TChannel The type of the communication channel. */ - template + template static void ReturnVoid(TChannel* channel) { int32_t num_args = 1; int32_t tcode = kTVMNullptr; RPCCode code = RPCCode::kReturn; - uint64_t packet_nbytes = - sizeof(code) + - sizeof(num_args) + - sizeof(tcode); + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); channel->Write(packet_nbytes); channel->Write(code); diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index 612ca41..b999a48 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -22,6 +22,7 @@ * \brief Server environment of the RPC. */ #include + #include "../file_util.h" namespace tvm { @@ -29,36 +30,32 @@ namespace runtime { std::string RPCGetPath(const std::string& name) { // do live lookup everytime as workpath can change. - const PackedFunc* f = - runtime::Registry::Get("tvm.rpc.server.workpath"); + const PackedFunc* f = runtime::Registry::Get("tvm.rpc.server.workpath"); CHECK(f != nullptr) << "require tvm.rpc.server.workpath"; return (*f)(name); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload"). -set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = RPCGetPath(args[0]); - std::string data = args[1]; - SaveBinaryToFile(file_name, data); - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.upload").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string file_name = RPCGetPath(args[0]); + std::string data = args[1]; + SaveBinaryToFile(file_name, data); +}); -TVM_REGISTER_GLOBAL("tvm.rpc.server.download") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = RPCGetPath(args[0]); - std::string data; - LoadBinaryFromFile(file_name, &data); - TVMByteArray arr; - arr.data = data.c_str(); - arr.size = data.length(); - LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size; - *rv = arr; - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.download").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string file_name = RPCGetPath(args[0]); + std::string data; + LoadBinaryFromFile(file_name, &data); + TVMByteArray arr; + arr.data = data.c_str(); + arr.size = data.length(); + LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size; + *rv = arr; +}); -TVM_REGISTER_GLOBAL("tvm.rpc.server.remove") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = RPCGetPath(args[0]); - RemoveFile(file_name); - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.remove").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string file_name = RPCGetPath(args[0]); + RemoveFile(file_name); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index d07aa74..9e05e5d 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -21,18 +21,18 @@ * \file rpc_session.cc * \brief RPC session for remote function call. */ -#include +#include "rpc_session.h" + #include -#include +#include + #include -#include "rpc_session.h" +#include namespace tvm { namespace runtime { -bool RPCSession::IsAsync() const { - return false; -} +bool RPCSession::IsAsync() const { return false; } void RPCSession::SendException(FAsyncCallback callback, const char* msg) { TVMValue value; @@ -41,68 +41,50 @@ void RPCSession::SendException(FAsyncCallback callback, const char* msg) { callback(RPCCode::kException, TVMArgs(&value, &tcode, 1)); } -void RPCSession::AsyncCallFunc(PackedFuncHandle func, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, - FAsyncCallback callback) { +void RPCSession::AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, FAsyncCallback callback) { try { this->CallFunc(func, arg_values, arg_type_codes, num_args, - [&callback](TVMArgs args) { - callback(RPCCode::kReturn, args); - }); + [&callback](TVMArgs args) { callback(RPCCode::kReturn, args); }); } catch (const std::runtime_error& e) { this->SendException(callback, e.what()); } } - -void RPCSession::AsyncCopyToRemote(void* local_from, - size_t local_from_offset, - void* remote_to, - size_t remote_to_offset, - size_t nbytes, - TVMContext remote_ctx_to, - DLDataType type_hint, - RPCSession::FAsyncCallback callback) { +void RPCSession::AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint, RPCSession::FAsyncCallback callback) { TVMValue value; int32_t tcode = kTVMNullptr; value.v_handle = nullptr; try { - this->CopyToRemote(local_from, local_from_offset, - remote_to, remote_to_offset, - nbytes, remote_ctx_to, type_hint); + this->CopyToRemote(local_from, local_from_offset, remote_to, remote_to_offset, nbytes, + remote_ctx_to, type_hint); callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); } catch (const std::runtime_error& e) { this->SendException(callback, e.what()); } } -void RPCSession::AsyncCopyFromRemote(void* remote_from, - size_t remote_from_offset, - void* local_to, - size_t local_to_offset, - size_t nbytes, - TVMContext remote_ctx_from, - DLDataType type_hint, +void RPCSession::AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, + TVMContext remote_ctx_from, DLDataType type_hint, RPCSession::FAsyncCallback callback) { TVMValue value; int32_t tcode = kTVMNullptr; value.v_handle = nullptr; try { - this->CopyFromRemote(remote_from, remote_from_offset, - local_to, local_to_offset, - nbytes, remote_ctx_from, type_hint); + this->CopyFromRemote(remote_from, remote_from_offset, local_to, local_to_offset, nbytes, + remote_ctx_from, type_hint); callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); } catch (const std::runtime_error& e) { this->SendException(callback, e.what()); } } -void RPCSession::AsyncStreamWait(TVMContext ctx, - TVMStreamHandle stream, +void RPCSession::AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, RPCSession::FAsyncCallback callback) { TVMValue value; int32_t tcode = kTVMNullptr; @@ -116,7 +98,6 @@ void RPCSession::AsyncStreamWait(TVMContext ctx, } } - class RPCSessTable { public: static constexpr int kMaxRPCSession = 32; @@ -135,7 +116,8 @@ class RPCSessTable { std::lock_guard lock(mutex_); for (int i = 0; i < kMaxRPCSession; ++i) { if (tbl_[i].lock() == nullptr) { - tbl_[i] = ptr; return i; + tbl_[i] = ptr; + return i; } } LOG(FATAL) << "maximum number of RPC session reached"; diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 7ea1eb9..6a7e6d6 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -24,12 +24,13 @@ #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_SESSION_H_ - -#include #include +#include + #include #include #include + #include "rpc_protocol.h" namespace tvm { @@ -120,10 +121,8 @@ class RPCSession { * \param fencode_return The function to set the return value, * if not called, return value is null. */ - virtual void CallFunc(PackedFuncHandle func, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, + virtual void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, const FEncodeReturn& fencode_return) = 0; /*! @@ -136,12 +135,8 @@ class RPCSession { * \param remote_ctx_to The target context. * \param type_hint Hint of content data type. */ - virtual void CopyToRemote(void* local_from, - size_t local_from_offset, - void* remote_to, - size_t remote_to_offset, - size_t nbytes, - TVMContext remote_ctx_to, + virtual void CopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, DLDataType type_hint) = 0; /*! * \brief Copy bytes from remote array content. @@ -153,12 +148,8 @@ class RPCSession { * \param remote_ctx_from The source context in the remote. * \param type_hint Hint of content data type. */ - virtual void CopyFromRemote(void* remote_from, - size_t remote_from_offset, - void* local_to, - size_t local_to_offset, - size_t nbytes, - TVMContext remote_ctx_from, + virtual void CopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from, DLDataType type_hint) = 0; /*! @@ -226,11 +217,8 @@ class RPCSession { * * \param callback The callback to pass the return value or exception. */ - virtual void AsyncCallFunc(PackedFuncHandle func, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, - FAsyncCallback callback); + virtual void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, FAsyncCallback callback); /*! * \brief Asynchrous version of CopyToRemote. @@ -247,14 +235,9 @@ class RPCSession { * \note All the allocated memory in local_from, and remote_to * must stay alive until on_compelete is called. */ - virtual void AsyncCopyToRemote(void* local_from, - size_t local_from_offset, - void* remote_to, - size_t remote_to_offset, - size_t nbytes, - TVMContext remote_ctx_to, - DLDataType type_hint, - FAsyncCallback on_complete); + virtual void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint, FAsyncCallback on_complete); /*! * \brief Asynchrous version of CopyFromRemote. @@ -271,13 +254,9 @@ class RPCSession { * \note All the allocated memory in remote_from, and local_to * must stay alive until on_compelete is called. */ - virtual void AsyncCopyFromRemote(void* remote_from, - size_t remote_from_offset, - void* local_to, - size_t local_to_offset, - size_t nbytes, - TVMContext remote_ctx_from, - DLDataType type_hint, + virtual void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, + TVMContext remote_ctx_from, DLDataType type_hint, FAsyncCallback on_complete); /*! * \brief Asynchrously wait for all events in ctx, stream compeletes. @@ -285,16 +264,12 @@ class RPCSession { * \param stream The stream to wait on. * \param on_complete The callback to signal copy complete. */ - virtual void AsyncStreamWait(TVMContext ctx, - TVMStreamHandle stream, - FAsyncCallback on_compelte); + virtual void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_compelte); /*! * \return The session table index of the session. */ - int table_index() const { - return table_index_; - } + int table_index() const { return table_index_; } /*! * \brief Try get session from the global session table by table index. @@ -351,10 +326,7 @@ struct RemoteSpace { * the `number` parameter will be automatically increased. * \return f_timer A timer function. */ -PackedFunc WrapTimeEvaluator(PackedFunc f, - TVMContext ctx, - int number, - int repeat, +PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int number, int repeat, int min_repeat_ms); /*! diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index f3a30dd..77a743b 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -21,21 +21,22 @@ * \file rpc_socket_impl.cc * \brief Socket based RPC implementation. */ -#include #include +#include + #include + +#include "../../support/socket.h" #include "rpc_endpoint.h" -#include "rpc_session.h" #include "rpc_local_session.h" -#include "../../support/socket.h" +#include "rpc_session.h" namespace tvm { namespace runtime { class SockChannel final : public RPCChannel { public: - explicit SockChannel(support::TCPSocket sock) - : sock_(sock) {} + explicit SockChannel(support::TCPSocket sock) : sock_(sock) {} ~SockChannel() { try { // BadSocket can throw @@ -64,13 +65,12 @@ class SockChannel final : public RPCChannel { support::TCPSocket sock_; }; -std::shared_ptr -RPCConnect(std::string url, int port, std::string key, TVMArgs init_seq) { +std::shared_ptr RPCConnect(std::string url, int port, std::string key, + TVMArgs init_seq) { support::TCPSocket sock; support::SockAddr addr(url.c_str(), port); sock.Create(addr.ss_family()); - CHECK(sock.Connect(addr)) - << "Connect to " << addr.AsString() << " failed"; + CHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed"; // hand shake std::ostringstream os; int code = kRPCMagic; @@ -83,12 +83,10 @@ RPCConnect(std::string url, int port, std::string key, TVMArgs init_seq) { CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code)); if (code == kRPCMagic + 2) { sock.Close(); - LOG(FATAL) << "URL " << url << ":" << port - << " cannot find server that matches key=" << key; + LOG(FATAL) << "URL " << url << ":" << port << " cannot find server that matches key=" << key; } else if (code == kRPCMagic + 1) { sock.Close(); - LOG(FATAL) << "URL " << url << ":" << port - << " server already have key=" << key; + LOG(FATAL) << "URL " << url << ":" << port << " server already have key=" << key; } else if (code != kRPCMagic) { sock.Close(); LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server"; @@ -99,54 +97,44 @@ RPCConnect(std::string url, int port, std::string key, TVMArgs init_seq) { remote_key.resize(keylen); CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); } - auto endpt = RPCEndpoint::Create( - std::unique_ptr(new SockChannel(sock)), key, remote_key); + auto endpt = + RPCEndpoint::Create(std::unique_ptr(new SockChannel(sock)), key, remote_key); endpt->InitRemoteSession(init_seq); return endpt; } -Module RPCClientConnect(std::string url, - int port, - std::string key, - TVMArgs init_seq) { +Module RPCClientConnect(std::string url, int port, std::string key, TVMArgs init_seq) { auto endpt = RPCConnect(url, port, "client:" + key, init_seq); return CreateRPCSessionModule(CreateClientSession(endpt)); } // TVM_DLL needed for MSVC TVM_DLL void RPCServerLoop(int sockfd) { - support::TCPSocket sock( - static_cast(sockfd)); - RPCEndpoint::Create( - std::unique_ptr(new SockChannel(sock)), - "SockServerLoop", "")->ServerLoop(); + support::TCPSocket sock(static_cast(sockfd)); + RPCEndpoint::Create(std::unique_ptr(new SockChannel(sock)), "SockServerLoop", "") + ->ServerLoop(); } -void RPCServerLoop(PackedFunc fsend, - PackedFunc frecv) { - RPCEndpoint::Create( - std::unique_ptr(new CallbackChannel(fsend, frecv)), - "SockServerLoop", "")->ServerLoop(); +void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) { + RPCEndpoint::Create(std::unique_ptr(new CallbackChannel(fsend, frecv)), + "SockServerLoop", "") + ->ServerLoop(); } -TVM_REGISTER_GLOBAL("rpc.Connect") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("rpc.Connect").set_body([](TVMArgs args, TVMRetValue* rv) { std::string url = args[0]; int port = args[1]; std::string key = args[2]; - *rv = RPCClientConnect( - url, port, key, - TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); + *rv = RPCClientConnect(url, port, key, + TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); }); -TVM_REGISTER_GLOBAL("rpc.ServerLoop") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[0].type_code() == kDLInt) { RPCServerLoop(args[0]); } else { - RPCServerLoop( - args[0].operator tvm::runtime::PackedFunc(), - args[1].operator tvm::runtime::PackedFunc()); + RPCServerLoop(args[0].operator tvm::runtime::PackedFunc(), + args[1].operator tvm::runtime::PackedFunc()); } }); diff --git a/src/runtime/runtime_base.h b/src/runtime/runtime_base.h index 84fc3c4..21601df 100644 --- a/src/runtime/runtime_base.h +++ b/src/runtime/runtime_base.h @@ -25,25 +25,37 @@ #define TVM_RUNTIME_RUNTIME_BASE_H_ #include + #include /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { /*! \brief every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR */ -#define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) +#define API_END() \ + } \ + catch (std::runtime_error & _except_) { \ + return TVMAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! * \brief every function starts with API_BEGIN(); * and finishes with API_END() or API_END_HANDLE_ERROR * The finally clause contains procedure to cleanup states when an error happens. */ -#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) +#define API_END_HANDLE_ERROR(Finalize) \ + } \ + catch (std::runtime_error & _except_) { \ + Finalize; \ + return TVMAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! * \brief handle exception throwed out * \param e the exception * \return the return value of API after exception is handled */ -int TVMAPIHandleException(const std::runtime_error &e); +int TVMAPIHandleException(const std::runtime_error& e); #endif // TVM_RUNTIME_RUNTIME_BASE_H_ diff --git a/src/runtime/stackvm/stackvm.cc b/src/runtime/stackvm/stackvm.cc index 0f17f9e..042815b 100644 --- a/src/runtime/stackvm/stackvm.cc +++ b/src/runtime/stackvm/stackvm.cc @@ -21,87 +21,88 @@ * Implementation stack VM. * \file stackvm.cc */ +#include "stackvm.h" + #include #include + #include -#include "stackvm.h" namespace tvm { namespace runtime { typedef dmlc::ThreadLocalStore StackVMStateStore; -StackVM::State* StackVM::ThreadLocalState() { - return StackVMStateStore::Get(); -} +StackVM::State* StackVM::ThreadLocalState() { return StackVMStateStore::Get(); } #define STACK_VM_BINOP(OP, FIELD) \ { \ stack[sp - 1].FIELD = stack[sp - 1].FIELD OP stack[sp].FIELD; \ - sp -= 1; pc += 1; \ + sp -= 1; \ + pc += 1; \ } #define STACK_VM_CMPOP(OP, FIELD) \ { \ stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \ - sp -= 1; pc += 1; \ + sp -= 1; \ + pc += 1; \ } -#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \ - { \ - int index = code[pc + 1].v_int; \ - stack[sp]FIELD = static_cast( \ - static_cast(stack[sp].v_handle)[index]); \ - pc += 2; \ +#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \ + { \ + int index = code[pc + 1].v_int; \ + stack[sp] FIELD = static_cast(static_cast(stack[sp].v_handle)[index]); \ + pc += 2; \ } -#define STACK_VM_STORE(FIELD, DST_TYPE) \ - { \ - int index = code[pc + 1].v_int; \ - static_cast(stack[sp - 1].v_handle)[index] = \ - static_cast(stack[sp]FIELD); \ - sp -= 2; pc += 2; \ +#define STACK_VM_STORE(FIELD, DST_TYPE) \ + { \ + int index = code[pc + 1].v_int; \ + static_cast(stack[sp - 1].v_handle)[index] = \ + static_cast(stack[sp] FIELD); \ + sp -= 2; \ + pc += 2; \ } -#define STACK_VM_PRINT_CODE0(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE << std::endl; return pc + 1; \ +#define STACK_VM_PRINT_CODE0(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << std::endl; \ + return pc + 1; \ } -#define STACK_VM_PRINT_CODE1(CODE) \ - case CODE: { \ +#define STACK_VM_PRINT_CODE1(CODE) \ + case CODE: { \ os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \ - << "[" << pc + 1 << "]" << std::endl; \ - return pc + 2; \ + << "[" << pc + 1 << "]" << std::endl; \ + return pc + 2; \ } -#define STACK_VM_PRINT_CODE2(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE \ - << " " << code[pc + 1].v_int \ - << " " << code[pc + 2].v_int << "\n" \ - << "[" << pc + 1 << "]" << std::endl \ - << "[" << pc + 2 << "]" << std::endl; \ - return pc + 3; \ +#define STACK_VM_PRINT_CODE2(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " << code[pc + 2].v_int \ + << "\n" \ + << "[" << pc + 1 << "]" << std::endl \ + << "[" << pc + 2 << "]" << std::endl; \ + return pc + 3; \ } -#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int \ - << " " << heap_id_name[code[pc + 1].v_int] << "\n" \ - << "[" << pc + 1 << "]" << std::endl; \ - return pc + 2; \ +#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " \ + << heap_id_name[code[pc + 1].v_int] << "\n" \ + << "[" << pc + 1 << "]" << std::endl; \ + return pc + 2; \ } -#define STACK_VM_PRINT_JUMP(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int \ - << " to " << pc + code[pc + 1].v_int << '\n' \ - << "[" << pc + 1 << "]" << std::endl; \ - return pc + 2; \ +#define STACK_VM_PRINT_JUMP(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int << " to " \ + << pc + code[pc + 1].v_int << '\n' \ + << "[" << pc + 1 << "]" << std::endl; \ + return pc + 2; \ } - int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { switch (code[pc].op_code) { // int @@ -164,9 +165,7 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { int begin = code[pc + 2].v_int; int end = code[pc + 3].v_int; os << "[" << pc << "]\tCALL_PACKED_FUNC " - << " fid=" << call_fid - << " begin=" << begin - << " end=" << end; + << " fid=" << call_fid << " begin=" << begin << " end=" << end; os << '\n'; for (int i = 0; i < 3; ++i) { os << "[" << pc + 1 + i << "]" << std::endl; @@ -181,8 +180,7 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*) int64_t pc = 0; const int64_t code_size = static_cast(vm.code.size()); - os << "Program dump: code-size=" << code_size << '\n' - << "----------begin-----------------\n"; + os << "Program dump: code-size=" << code_size << '\n' << "----------begin-----------------\n"; while (pc < code_size) { pc = vm.PrintCode(os, pc); } @@ -190,8 +188,7 @@ std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*) return os; } -void StackVM::Run(const runtime::TVMArgs& args, - runtime::ModuleNode* mod_ctx) const { +void StackVM::Run(const runtime::TVMArgs& args, runtime::ModuleNode* mod_ctx) const { StackVM::State* s = StackVM::ThreadLocalState(); if (s->heap.size() < heap_size) { s->heap.resize(heap_size); @@ -199,7 +196,7 @@ void StackVM::Run(const runtime::TVMArgs& args, s->sp = 0; s->pc = 0; s->mod_ctx = mod_ctx; - s->heap[0].v_handle = (void*)args.values; // NOLINT(*) + s->heap[0].v_handle = (void*)args.values; // NOLINT(*) s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*) s->heap[2].v_int64 = args.num_args; this->Run(s); @@ -207,16 +204,13 @@ void StackVM::Run(const runtime::TVMArgs& args, void StackVM::InitCache() { extern_func_cache_.clear(); - extern_func_cache_.resize( - extern_func_name.size(), PackedFunc(nullptr)); + extern_func_cache_.resize(extern_func_name.size(), PackedFunc(nullptr)); } void StackVM::Save(dmlc::Stream* strm) const { // to be endian invariant. std::vector code_copy(code.size()); - std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) { - return c.v_int; - }); + std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) { return c.v_int; }); strm->Write(code_copy); strm->Write(str_data); strm->Write(extern_func_name); @@ -225,14 +219,16 @@ void StackVM::Save(dmlc::Stream* strm) const { strm->Write(stack_size); } -bool StackVM::Load(dmlc::Stream* strm) { +bool StackVM::Load(dmlc::Stream* strm) { // to be endian invariant. std::vector code_copy; if (!strm->Read(&code_copy)) return false; code.resize(code_copy.size()); std::transform(code_copy.begin(), code_copy.end(), code.begin(), [](int v) { - Code code; code.v_int = v; return code; - }); + Code code; + code.v_int = v; + return code; + }); if (!strm->Read(&str_data)) return false; if (!strm->Read(&extern_func_name)) return false; if (!strm->Read(&heap_id_name)) return false; @@ -258,36 +254,92 @@ void StackVM::Run(State* s) const { const int64_t code_size = static_cast(code.size()); while (pc < code_size) { switch (code[pc].op_code) { - case ADD_I64: STACK_VM_BINOP(+, v_int64); break; - case SUB_I64: STACK_VM_BINOP(-, v_int64); break; - case MUL_I64: STACK_VM_BINOP(*, v_int64); break; - case DIV_I64: STACK_VM_BINOP(/, v_int64); break; - case MOD_I64: STACK_VM_BINOP(%, v_int64); break; - case EQ_I64: STACK_VM_CMPOP(==, v_int64); break; - case LT_I64: STACK_VM_CMPOP(<, v_int64); break; - case LE_I64: STACK_VM_CMPOP(<=, v_int64); break; - case ADD_F64: STACK_VM_BINOP(+, v_float64); break; - case SUB_F64: STACK_VM_BINOP(-, v_float64); break; - case MUL_F64: STACK_VM_BINOP(*, v_float64); break; - case DIV_F64: STACK_VM_BINOP(/, v_float64); break; - case EQ_F64: STACK_VM_CMPOP(==, v_float64); break; - case LT_F64: STACK_VM_CMPOP(<, v_float64); break; - case LE_F64: STACK_VM_CMPOP(<=, v_float64); break; - case EQ_HANDLE: STACK_VM_CMPOP(==, v_handle); break; + case ADD_I64: + STACK_VM_BINOP(+, v_int64); + break; + case SUB_I64: + STACK_VM_BINOP(-, v_int64); + break; + case MUL_I64: + STACK_VM_BINOP(*, v_int64); + break; + case DIV_I64: + STACK_VM_BINOP(/, v_int64); + break; + case MOD_I64: + STACK_VM_BINOP(%, v_int64); + break; + case EQ_I64: + STACK_VM_CMPOP(==, v_int64); + break; + case LT_I64: + STACK_VM_CMPOP(<, v_int64); + break; + case LE_I64: + STACK_VM_CMPOP(<=, v_int64); + break; + case ADD_F64: + STACK_VM_BINOP(+, v_float64); + break; + case SUB_F64: + STACK_VM_BINOP(-, v_float64); + break; + case MUL_F64: + STACK_VM_BINOP(*, v_float64); + break; + case DIV_F64: + STACK_VM_BINOP(/, v_float64); + break; + case EQ_F64: + STACK_VM_CMPOP(==, v_float64); + break; + case LT_F64: + STACK_VM_CMPOP(<, v_float64); + break; + case LE_F64: + STACK_VM_CMPOP(<=, v_float64); + break; + case EQ_HANDLE: + STACK_VM_CMPOP(==, v_handle); + break; // addressing - case ARRAY_LOAD_UINT32: STACK_VM_LOAD(.v_int64, int64_t, uint32_t); break; - case ARRAY_LOAD_INT32: STACK_VM_LOAD(.v_int64, int64_t, int32_t); break; - case ARRAY_LOAD_INT64: STACK_VM_LOAD(.v_int64, int64_t, int64_t); break; - case ARRAY_LOAD_FP64: STACK_VM_LOAD(.v_float64, double, double); break; - case ARRAY_LOAD_HANDLE: STACK_VM_LOAD(.v_handle, void*, void*); break; - case ARRAY_LOAD_TVMVALUE: STACK_VM_LOAD(, TVMValue, TVMValue); break; + case ARRAY_LOAD_UINT32: + STACK_VM_LOAD(.v_int64, int64_t, uint32_t); + break; + case ARRAY_LOAD_INT32: + STACK_VM_LOAD(.v_int64, int64_t, int32_t); + break; + case ARRAY_LOAD_INT64: + STACK_VM_LOAD(.v_int64, int64_t, int64_t); + break; + case ARRAY_LOAD_FP64: + STACK_VM_LOAD(.v_float64, double, double); + break; + case ARRAY_LOAD_HANDLE: + STACK_VM_LOAD(.v_handle, void*, void*); + break; + case ARRAY_LOAD_TVMVALUE: + STACK_VM_LOAD(, TVMValue, TVMValue); + break; // store - case ARRAY_STORE_UINT32: STACK_VM_STORE(.v_int64, uint32_t); break; - case ARRAY_STORE_INT32: STACK_VM_STORE(.v_int64, int32_t); break; - case ARRAY_STORE_INT64: STACK_VM_STORE(.v_int64, int64_t); break; - case ARRAY_STORE_FP64: STACK_VM_STORE(.v_float64, double); break; - case ARRAY_STORE_HANDLE: STACK_VM_STORE(.v_handle, void*); break; - case ARRAY_STORE_TVMVALUE: STACK_VM_STORE(, TVMValue); break; + case ARRAY_STORE_UINT32: + STACK_VM_STORE(.v_int64, uint32_t); + break; + case ARRAY_STORE_INT32: + STACK_VM_STORE(.v_int64, int32_t); + break; + case ARRAY_STORE_INT64: + STACK_VM_STORE(.v_int64, int64_t); + break; + case ARRAY_STORE_FP64: + STACK_VM_STORE(.v_float64, double); + break; + case ARRAY_STORE_HANDLE: + STACK_VM_STORE(.v_handle, void*); + break; + case ARRAY_STORE_TVMVALUE: + STACK_VM_STORE(, TVMValue); + break; // add case ADDR_ADD: { stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*) @@ -365,9 +417,8 @@ void StackVM::Run(State* s) const { } case ASSERT_SP: { int64_t expected = code[pc + 1].v_int; - CHECK_EQ(sp, expected) - << "sp assertion failed, expected=" - << expected << " now=" << sp << ", pc=" << pc; + CHECK_EQ(sp, expected) << "sp assertion failed, expected=" << expected << " now=" << sp + << ", pc=" << pc; pc += 2; break; } @@ -379,11 +430,10 @@ void StackVM::Run(State* s) const { int begin = code[pc + 2].v_int; int end = code[pc + 3].v_int; int num_args = end - begin; - static_assert(sizeof(Code) == sizeof(int) && - alignof(Code) == alignof(int), "asusmption"); + static_assert(sizeof(Code) == sizeof(int) && alignof(Code) == alignof(int), "asusmption"); runtime::TVMRetValue rv; - GetExtern(s, call_fid).CallPacked( - runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv); + GetExtern(s, call_fid) + .CallPacked(runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv); sp = sp - 1; stack[sp] = rv.value(); pc += 4; @@ -396,47 +446,55 @@ void StackVM::Run(State* s) const { DLTensor* arr = static_cast(stack[sp].v_handle); switch (kind) { case StackVM::kArrData: { - stack[sp].v_handle = arr[index].data; break; + stack[sp].v_handle = arr[index].data; + break; } case StackVM::kArrShape: { - stack[sp].v_handle = arr[index].shape; break; + stack[sp].v_handle = arr[index].shape; + break; } case StackVM::kArrStrides: { - stack[sp].v_handle = arr[index].strides; break; + stack[sp].v_handle = arr[index].strides; + break; } case StackVM::kArrNDim: { - stack[sp].v_int64 = arr[index].ndim; break; + stack[sp].v_int64 = arr[index].ndim; + break; } case StackVM::kArrTypeCode: { - stack[sp].v_int64 = static_cast( - arr[index].dtype.code); break; + stack[sp].v_int64 = static_cast(arr[index].dtype.code); + break; } case StackVM::kArrTypeBits: { - stack[sp].v_int64 = static_cast( - arr[index].dtype.bits); break; + stack[sp].v_int64 = static_cast(arr[index].dtype.bits); + break; } case StackVM::kArrTypeLanes: { - stack[sp].v_int64 = static_cast( - arr[index].dtype.lanes); break; + stack[sp].v_int64 = static_cast(arr[index].dtype.lanes); + break; } case StackVM::kArrByteOffset: { - stack[sp].v_int64 = static_cast( - arr[index].byte_offset); break; + stack[sp].v_int64 = static_cast(arr[index].byte_offset); + break; } case StackVM::kArrDeviceId: { - stack[sp].v_int64 = arr[index].ctx.device_id; break; + stack[sp].v_int64 = arr[index].ctx.device_id; + break; } case StackVM::kArrDeviceType: { - stack[sp].v_int64 = static_cast( - arr[index].ctx.device_type); break; + stack[sp].v_int64 = static_cast(arr[index].ctx.device_type); + break; } case StackVM::kArrAddr: { - stack[sp].v_handle = arr + index; break; + stack[sp].v_handle = arr + index; + break; } case StackVM::kTVMValueContent: { - stack[sp] = static_cast(stack[sp].v_handle)[index]; break; + stack[sp] = static_cast(stack[sp].v_handle)[index]; + break; } - default: LOG(FATAL) << "unhandled get " << kind; + default: + LOG(FATAL) << "unhandled get " << kind; } pc = pc + 3; break; @@ -447,7 +505,8 @@ void StackVM::Run(State* s) const { DLTensor* arr = static_cast(stack[sp - 1].v_handle); switch (kind) { case StackVM::kArrData: { - arr[index].data = stack[sp].v_handle; break; + arr[index].data = stack[sp].v_handle; + break; } case StackVM::kArrShape: { arr[index].shape = static_cast(stack[sp].v_handle); @@ -486,9 +545,11 @@ void StackVM::Run(State* s) const { break; } case StackVM::kTVMValueContent: { - static_cast(stack[sp - 1].v_handle)[index] = stack[sp]; break; + static_cast(stack[sp - 1].v_handle)[index] = stack[sp]; + break; } - default: LOG(FATAL) << "unhandled tvm_struct_set " << kind; + default: + LOG(FATAL) << "unhandled tvm_struct_set " << kind; } sp -= 2; pc += 3; @@ -511,8 +572,8 @@ void StackVM::Run(State* s) const { size_t nbytes = static_cast(stack[sp - 2].v_int64); int dtype_code_hint = static_cast(stack[sp - 1].v_int64); int dtype_bits_hint = static_cast(stack[sp].v_int64); - void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes, - dtype_code_hint, dtype_bits_hint); + void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, + dtype_bits_hint); stack[sp - 4].v_handle = ptr; sp = sp - 4; pc = pc + 1; @@ -543,8 +604,7 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const { // allow race write in this, since write is idempotent PackedFunc& f = extern_func_cache_[fid]; if (f == nullptr) { - CHECK(s->mod_ctx != nullptr) - << "No local context is set in stackvm"; + CHECK(s->mod_ctx != nullptr) << "No local context is set in stackvm"; const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]); CHECK(pf != nullptr); f = *pf; diff --git a/src/runtime/stackvm/stackvm.h b/src/runtime/stackvm/stackvm.h index f36e171..09581a6 100644 --- a/src/runtime/stackvm/stackvm.h +++ b/src/runtime/stackvm/stackvm.h @@ -29,8 +29,9 @@ #define TVM_RUNTIME_STACKVM_STACKVM_H_ #include -#include #include +#include + #include #include @@ -339,7 +340,7 @@ class StackVM { * \param pc The pc * \return the pc to next instruction. */ - int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*) + int64_t PrintCode(std::ostream& os, int64_t pc) const; // NOLINT(*) /*! \brief Get thread local state of the stack VM */ static State* ThreadLocalState(); // The code below are programs @@ -362,15 +363,26 @@ class StackVM { */ static OpCode CodeI64ToF64(OpCode code) { switch (code) { - case ADD_I64: return ADD_F64; - case SUB_I64: return SUB_F64; - case MUL_I64: return MUL_F64; - case DIV_I64: return DIV_F64; - case EQ_I64: return EQ_F64; - case LT_I64: return LT_F64; - case LE_I64: return LE_F64; - case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64; - default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64; + case ADD_I64: + return ADD_F64; + case SUB_I64: + return SUB_F64; + case MUL_I64: + return MUL_F64; + case DIV_I64: + return DIV_F64; + case EQ_I64: + return EQ_F64; + case LT_I64: + return LT_F64; + case LE_I64: + return LE_F64; + case MOD_I64: + LOG(FATAL) << "cannot handle mod for float"; + return ADD_F64; + default: + LOG(FATAL) << "cannot handle op " << code; + return ADD_F64; } } /*! @@ -383,16 +395,20 @@ class StackVM { if (t.code == kTVMOpaqueHandle) return ARRAY_LOAD_HANDLE; if (t.code == kDLInt) { switch (t.bits) { - case 32 : return ARRAY_LOAD_INT32; - case 64 : return ARRAY_LOAD_INT64; + case 32: + return ARRAY_LOAD_INT32; + case 64: + return ARRAY_LOAD_INT64; } } else if (t.code == kDLUInt) { switch (t.bits) { - case 32 : return ARRAY_LOAD_UINT32; + case 32: + return ARRAY_LOAD_UINT32; } } else if (t.code == kDLFloat) { switch (t.bits) { - case 64 : return ARRAY_LOAD_FP64; + case 64: + return ARRAY_LOAD_FP64; } } LOG(FATAL) << "Cannot load type " << t; @@ -408,16 +424,20 @@ class StackVM { if (t.code == kTVMOpaqueHandle) return ARRAY_STORE_HANDLE; if (t.code == kDLInt) { switch (t.bits) { - case 32 : return ARRAY_STORE_INT32; - case 64 : return ARRAY_STORE_INT64; + case 32: + return ARRAY_STORE_INT32; + case 64: + return ARRAY_STORE_INT64; } } else if (t.code == kDLUInt) { switch (t.bits) { - case 32 : return ARRAY_STORE_UINT32; + case 32: + return ARRAY_STORE_UINT32; } } else if (t.code == kDLFloat) { switch (t.bits) { - case 64 : return ARRAY_STORE_FP64; + case 64: + return ARRAY_STORE_FP64; } } LOG(FATAL) << "Cannot store type " << t; diff --git a/src/runtime/stackvm/stackvm_module.cc b/src/runtime/stackvm/stackvm_module.cc index 8b30b75..9e1f1f5 100644 --- a/src/runtime/stackvm/stackvm_module.cc +++ b/src/runtime/stackvm/stackvm_module.cc @@ -20,13 +20,16 @@ /*! * \file stackvm_module.cc */ -#include -#include +#include "stackvm_module.h" + #include +#include +#include + #include -#include #include -#include "stackvm_module.h" +#include + #include "../file_util.h" namespace tvm { @@ -34,13 +37,9 @@ namespace runtime { class StackVMModuleNode : public runtime::ModuleNode { public: - const char* type_key() const { - return "stackvm"; - } + const char* type_key() const { return "stackvm"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == runtime::symbol::tvm_module_main) { return GetFunction(entry_func_, sptr_to_self); } @@ -48,9 +47,8 @@ class StackVMModuleNode : public runtime::ModuleNode { if (it == fmap_.end()) return PackedFunc(); const StackVM& vm = it->second; // capture sptr_to_self to keep module node alive. - return PackedFunc([vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - vm.Run(args, this); - }); + return PackedFunc( + [vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { vm.Run(args, this); }); } std::string GetSource(const std::string& format) final { @@ -62,8 +60,7 @@ class StackVMModuleNode : public runtime::ModuleNode { return os.str(); } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string data, mblob; dmlc::MemoryStringStream writer(&data); dmlc::Stream* strm = &writer; @@ -74,8 +71,7 @@ class StackVMModuleNode : public runtime::ModuleNode { strm->Write(num_imports); for (runtime::Module im : imports_) { - CHECK_EQ(im->imports().size(), 0U) - << "Only support simply one-level hierarchy"; + CHECK_EQ(im->imports().size(), 0U) << "Only support simply one-level hierarchy"; std::string tkey = im->type_key(); strm->Write(tkey); LOG(INFO) << "save " << tkey; @@ -85,8 +81,7 @@ class StackVMModuleNode : public runtime::ModuleNode { SaveBinaryToFile(file_name, data); } - static Module Create(std::unordered_map fmap, - std::string entry_func) { + static Module Create(std::unordered_map fmap, std::string entry_func) { auto n = make_object(); n->fmap_ = std::move(fmap); n->entry_func_ = std::move(entry_func); @@ -108,17 +103,14 @@ class StackVMModuleNode : public runtime::ModuleNode { CHECK(strm->Read(&tkey)); std::string fkey = "runtime.module.loadbinary_" + tkey; const PackedFunc* f = Registry::Get(fkey); - CHECK(f != nullptr) - << "Loader of " << tkey << "(" - << fkey << ") is not presented."; + CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented."; Module m = (*f)(static_cast(strm)); n->imports_.emplace_back(std::move(m)); } return Module(n); } - static Module LoadFromFile(std::string file_name, - std::string format) { + static Module LoadFromFile(std::string file_name, std::string format) { std::string data; LoadBinaryFromFile(file_name, &data); dmlc::MemoryStringStream reader(&data); @@ -132,13 +124,12 @@ class StackVMModuleNode : public runtime::ModuleNode { std::string entry_func_; }; -Module StackVMModuleCreate(std::unordered_map fmap, - std::string entry_func) { +Module StackVMModuleCreate(std::unordered_map fmap, std::string entry_func) { return StackVMModuleNode::Create(fmap, entry_func); } TVM_REGISTER_GLOBAL("runtime.module.loadfile_stackvm") -.set_body_typed(StackVMModuleNode::LoadFromFile); + .set_body_typed(StackVMModuleNode::LoadFromFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/stackvm/stackvm_module.h b/src/runtime/stackvm/stackvm_module.h index c84eb6f..6ae4ae4 100644 --- a/src/runtime/stackvm/stackvm_module.h +++ b/src/runtime/stackvm/stackvm_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,10 @@ #define TVM_RUNTIME_STACKVM_STACKVM_MODULE_H_ #include + #include #include + #include "stackvm.h" namespace tvm { @@ -38,8 +40,7 @@ namespace runtime { * \param entry_func The entry function name. * \return The created module */ -Module StackVMModuleCreate(std::unordered_map fmap, - std::string entry_func); +Module StackVMModuleCreate(std::unordered_map fmap, std::string entry_func); } // namespace runtime } // namespace tvm diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc index 3eb7b1c..fe29146 100644 --- a/src/runtime/system_library.cc +++ b/src/runtime/system_library.cc @@ -21,10 +21,12 @@ * \file system_library.cc * \brief Create library module that directly get symbol from the system lib. */ -#include -#include #include +#include +#include + #include + #include "library_module.h" namespace tvm { @@ -48,10 +50,8 @@ class SystemLibrary : public Library { std::lock_guard lock(mutex_); auto it = tbl_.find(name); if (it != tbl_.end() && ptr != it->second) { - LOG(WARNING) - << "SystemLib symbol " << name - << " get overriden to a different address " - << ptr << "->" << it->second; + LOG(WARNING) << "SystemLib symbol " << name << " get overriden to a different address " << ptr + << "->" << it->second; } tbl_[name] = ptr; } @@ -68,11 +68,9 @@ class SystemLibrary : public Library { std::unordered_map tbl_; }; -TVM_REGISTER_GLOBAL("runtime.SystemLib") -.set_body_typed([]() { - static auto mod = CreateModuleFromLibrary( - SystemLibrary::Global()); - return mod; +TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_typed([]() { + static auto mod = CreateModuleFromLibrary(SystemLibrary::Global()); + return mod; }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 00f089b..0cc881c 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -21,26 +21,26 @@ * \file thread_pool.cc * \brief Threadpool for multi-threading runtime. */ -#include +#include +#include #include -#include +#include #include +#include #include -#include -#include #if TVM_THREADPOOL_USE_OPENMP #include #endif -#include -#include -#include -#include #include -#include -#include +#include +#include #include #include +#include #include +#include +#include +#include const constexpr int kL1CacheBytes = 64; @@ -69,10 +69,7 @@ constexpr int kSyncStride = 64 / sizeof(std::atomic); class ParallelLauncher { public: // Reset the the task request. - void Init(FTVMParallelLambda flambda, - void* cdata, - int num_task, - bool need_sync) { + void Init(FTVMParallelLambda flambda, void* cdata, int num_task, bool need_sync) { num_pending_.store(num_task); this->cdata = cdata; this->flambda = flambda; @@ -88,17 +85,14 @@ class ParallelLauncher { } if (need_sync) { for (int i = 0; i < num_task; ++i) { - sync_counter_[i * kSyncStride].store( - 0, std::memory_order_relaxed); + sync_counter_[i * kSyncStride].store(0, std::memory_order_relaxed); } this->env.sync_handle = sync_counter_; } else { this->env.sync_handle = nullptr; } } - ~ParallelLauncher() { - delete[] sync_counter_; - } + ~ParallelLauncher() { delete[] sync_counter_; } // Wait n jobs to finish int WaitForJobs() { while (num_pending_.load() != 0) { @@ -122,13 +116,9 @@ class ParallelLauncher { has_error_.store(true); } // Signal that one job has finished. - void SignalJobFinish() { - num_pending_.fetch_sub(1); - } + void SignalJobFinish() { num_pending_.fetch_sub(1); } // Get thread local version of the store. - static ParallelLauncher* ThreadLocal() { - return dmlc::ThreadLocalStore::Get(); - } + static ParallelLauncher* ThreadLocal() { return dmlc::ThreadLocalStore::Get(); } // The parallel lambda FTVMParallelLambda flambda; // The closure data @@ -159,15 +149,9 @@ class SpscTaskQueue { int32_t task_id; }; - SpscTaskQueue() : - buffer_(new Task[kRingSize]), - head_(0), - tail_(0) { - } + SpscTaskQueue() : buffer_(new Task[kRingSize]), head_(0), tail_(0) {} - ~SpscTaskQueue() { - delete[] buffer_; - } + ~SpscTaskQueue() { delete[] buffer_; } /*! * \brief Push a task into the queue and notify the comsumer if it is on wait. @@ -198,9 +182,7 @@ class SpscTaskQueue { } if (pending_.fetch_sub(1) == 0) { std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { - return pending_.load() >= 0 || exit_now_.load(); - }); + cv_.wait(lock, [this] { return pending_.load() >= 0 || exit_now_.load(); }); } if (exit_now_.load(std::memory_order_relaxed)) { return false; @@ -275,7 +257,7 @@ class SpscTaskQueue { // The thread pool class ThreadPool { public: - ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) { + ThreadPool() : num_workers_(tvm::runtime::threading::MaxConcurrency()) { for (int i = 0; i < num_workers_; ++i) { // The SpscTaskQueue only hosts ONE item at a time queues_.emplace_back(std::unique_ptr(new SpscTaskQueue())); @@ -286,8 +268,8 @@ class ThreadPool { } threads_ = std::unique_ptr( new tvm::runtime::threading::ThreadGroup( - num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, - exclude_worker0_ /* include_main_thread */)); + num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, + exclude_worker0_ /* include_main_thread */)); num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_); } ~ThreadPool() { @@ -296,10 +278,7 @@ class ThreadPool { } threads_.reset(); } - int Launch(FTVMParallelLambda flambda, - void* cdata, - int num_task, - int need_sync) { + int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, int need_sync) { ParallelLauncher* launcher = ParallelLauncher::ThreadLocal(); CHECK(!launcher->is_worker) << "Cannot launch parallel job inside worker, consider fuse then parallel"; @@ -332,15 +311,12 @@ class ThreadPool { return res; } - static ThreadPool* ThreadLocal() { - return dmlc::ThreadLocalStore::Get(); - } + static ThreadPool* ThreadLocal() { return dmlc::ThreadLocalStore::Get(); } void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) { // this will also reset the affinity of the ThreadGroup // may use less than the MaxConcurrency number of workers - num_workers_used_ = threads_->Configure(mode, nthreads, - exclude_worker0_); + num_workers_used_ = threads_->Configure(mode, nthreads, exclude_worker0_); // if MaxConcurrency restricted the number of workers (e.g., due to // hyperthreading), respect the restriction num_workers_used_ = std::min(num_workers_, num_workers_used_); @@ -376,33 +352,25 @@ class ThreadPool { std::unique_ptr threads_; }; -TVM_REGISTER_GLOBAL("runtime.config_threadpool") -.set_body([](TVMArgs args, TVMRetValue* rv) { - threading::ThreadGroup::AffinityMode mode =\ - static_cast(\ - static_cast(args[0])); - int nthreads = args[1]; - ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads); +TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRetValue* rv) { + threading::ThreadGroup::AffinityMode mode = + static_cast(static_cast(args[0])); + int nthreads = args[1]; + ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads); }); - } // namespace runtime } // namespace tvm - -int TVMBackendParallelLaunch( - FTVMParallelLambda flambda, - void* cdata, - int num_task) { +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { #if !TVM_THREADPOOL_USE_OPENMP - int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch( - flambda, cdata, num_task, 1); + int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1); return res; #else int num_workers = tvm::runtime::threading::MaxConcurrency(); if (num_task == 0) num_task = num_workers; omp_set_num_threads(num_workers); - #pragma omp parallel num_threads(num_workers) +#pragma omp parallel num_threads(num_workers) { TVMParallelGroupEnv env; env.num_task = num_task; @@ -414,18 +382,15 @@ int TVMBackendParallelLaunch( int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { #if TVM_THREADPOOL_USE_OPENMP - #pragma omp barrier +#pragma omp barrier #else using tvm::runtime::kSyncStride; int num_task = penv->num_task; - std::atomic* sync_counter = - reinterpret_cast*>(penv->sync_handle); - int old_counter = sync_counter[task_id * kSyncStride].fetch_add( - 1, std::memory_order_release); + std::atomic* sync_counter = reinterpret_cast*>(penv->sync_handle); + int old_counter = sync_counter[task_id * kSyncStride].fetch_add(1, std::memory_order_release); for (int i = 0; i < num_task; ++i) { if (i != task_id) { - while (sync_counter[i * kSyncStride].load( - std::memory_order_relaxed) <= old_counter) { + while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <= old_counter) { tvm::runtime::threading::Yield(); } } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 3e6fd78..92e12b5 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,6 +25,7 @@ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #include + #include #include @@ -64,9 +65,12 @@ enum class StorageRank { */ inline StorageRank DefaultStorageRank(int thread_scope_rank) { switch (thread_scope_rank) { - case -1: return StorageRank::kGlobal; - case 0: return StorageRank::kShared; - case 1: return StorageRank::kLocal; + case -1: + return StorageRank::kGlobal; + case 0: + return StorageRank::kShared; + case 1: + return StorageRank::kLocal; default: { LOG(FATAL) << "unknown rank"; return StorageRank::kGlobal; @@ -84,20 +88,27 @@ struct StorageScope { inline bool operator==(const StorageScope& other) const { return rank == other.rank && tag == other.tag; } - inline bool operator!=(const StorageScope& other) const { - return !(*this == other); - } + inline bool operator!=(const StorageScope& other) const { return !(*this == other); } inline std::string to_string() const { std::string ret; switch (rank) { - case StorageRank::kGlobal: return "global" + tag; - case StorageRank::kShared: return "shared" + tag; - case StorageRank::kWarp: return "warp" + tag; - case StorageRank::kLocal: return "local" + tag; - case StorageRank::kWMMAMatrixA: return "wmma.matrix_a" + tag; - case StorageRank::kWMMAMatrixB: return "wmma.matrix_b" + tag; - case StorageRank::kWMMAAccumulator: return "wmma.accumulator" + tag; - default: LOG(FATAL) << "unknown storage scope"; return ""; + case StorageRank::kGlobal: + return "global" + tag; + case StorageRank::kShared: + return "shared" + tag; + case StorageRank::kWarp: + return "warp" + tag; + case StorageRank::kLocal: + return "local" + tag; + case StorageRank::kWMMAMatrixA: + return "wmma.matrix_a" + tag; + case StorageRank::kWMMAMatrixB: + return "wmma.matrix_b" + tag; + case StorageRank::kWMMAAccumulator: + return "wmma.accumulator" + tag; + default: + LOG(FATAL) << "unknown storage scope"; + return ""; } } /*! @@ -107,7 +118,7 @@ struct StorageScope { */ static StorageScope make(const std::string& s) { StorageScope r; - if (s.compare(0, 6, "global") == 0) { + if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { @@ -165,7 +176,6 @@ struct ThreadScope { } }; - /*! \brief workload specification */ struct ThreadWorkLoad { // array, first three are thread configuration. @@ -174,22 +184,17 @@ struct ThreadWorkLoad { * \param i The block dimension. * \return i-th block dim */ - inline size_t block_dim(size_t i) const { - return work_size[i + 3]; - } + inline size_t block_dim(size_t i) const { return work_size[i + 3]; } /*! * \param i The grid dimension. * \return i-th grid dim */ - inline size_t grid_dim(size_t i) const { - return work_size[i]; - } + inline size_t grid_dim(size_t i) const { return work_size[i]; } }; /*! \brief Thread axis configuration */ class ThreadAxisConfig { public: - void Init(size_t base, - const std::vector& thread_axis_tags) { + void Init(size_t base, const std::vector& thread_axis_tags) { base_ = base; std::vector filled(6, false); for (size_t i = 0; i < thread_axis_tags.size(); ++i) { @@ -210,15 +215,12 @@ class ThreadAxisConfig { ThreadWorkLoad w; std::fill(w.work_size, w.work_size + 6, 1); for (size_t i = 0; i < arg_index_map_.size(); ++i) { - w.work_size[arg_index_map_[i]] = - static_cast(x.values[base_ + i].v_int64); + w.work_size[arg_index_map_[i]] = static_cast(x.values[base_ + i].v_int64); } return w; } // return the work dim - size_t work_dim() const { - return work_dim_; - } + size_t work_dim() const { return work_dim_; } private: /*! \brief base axis */ diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 0a2a60c..2e781ea 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,10 +21,11 @@ * \file threading_backend.cc * \brief Native threading backend */ -#include #include -#include +#include + #include +#include #if defined(__linux__) || defined(__ANDROID__) #include #include @@ -40,12 +41,9 @@ namespace threading { class ThreadGroup::Impl { public: - Impl(int num_workers, - std::function worker_callback, - bool exclude_worker0) + Impl(int num_workers, std::function worker_callback, bool exclude_worker0) : num_workers_(num_workers) { - CHECK_GE(num_workers, 1) - << "Requested a non-positive number of worker threads."; + CHECK_GE(num_workers, 1) << "Requested a non-positive number of worker threads."; for (int i = exclude_worker0; i < num_workers_; ++i) { threads_.emplace_back([worker_callback, i] { worker_callback(i); }); } @@ -79,15 +77,14 @@ class ThreadGroup::Impl { // ones. num_workers_used = std::min(num_workers_, num_workers_used); - const char *val = getenv("TVM_BIND_THREADS"); + const char* val = getenv("TVM_BIND_THREADS"); if (val == nullptr || atoi(val) == 1) { // Do not set affinity if there are more workers than found cores if (sorted_order_.size() >= static_cast(num_workers_)) { - SetAffinity(exclude_worker0, mode == kLittle); + SetAffinity(exclude_worker0, mode == kLittle); } else { - LOG(WARNING) - << "The thread affinity cannot be set when the number of workers" - << "is larger than the number of available cores in the system."; + LOG(WARNING) << "The thread affinity cannot be set when the number of workers" + << "is larger than the number of available cores in the system."; } } return num_workers_used; @@ -101,15 +98,14 @@ class ThreadGroup::Impl { #if defined(__ANDROID__) #ifndef CPU_SET #define CPU_SETSIZE 1024 -#define __NCPUBITS (8 * sizeof (uint64_t)) +#define __NCPUBITS (8 * sizeof(uint64_t)) typedef struct { uint64_t __bits[CPU_SETSIZE / __NCPUBITS]; } cpu_set_t; #define CPU_SET(cpu, cpusetp) \ - ((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) -#define CPU_ZERO(cpusetp) \ - memset((cpusetp), 0, sizeof(cpu_set_t)) + ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) +#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t)) #endif #endif #if defined(__linux__) || defined(__ANDROID__) @@ -128,8 +124,7 @@ class ThreadGroup::Impl { #if defined(__ANDROID__) sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset); #else - pthread_setaffinity_np(threads_[i].native_handle(), - sizeof(cpu_set_t), &cpuset); + pthread_setaffinity_np(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset); #endif } if (exclude_worker0) { // master thread run task @@ -182,27 +177,27 @@ class ThreadGroup::Impl { void InitSortedOrder() { unsigned int threads = std::thread::hardware_concurrency(); - std::vector > max_freqs; + std::vector > max_freqs; for (unsigned int i = 0; i < threads; ++i) { int64_t cur_freq = 0; - #if defined(__linux__) || defined(__ANDROID__) - std::ostringstream filepath; - filepath << "/sys/devices/system/cpu/cpu" << i << "/cpufreq/cpuinfo_max_freq"; - std::ifstream ifs(filepath.str()); - if (!ifs.fail()) { - if (!(ifs >> cur_freq)) { - cur_freq = -1; - } - ifs.close(); +#if defined(__linux__) || defined(__ANDROID__) + std::ostringstream filepath; + filepath << "/sys/devices/system/cpu/cpu" << i << "/cpufreq/cpuinfo_max_freq"; + std::ifstream ifs(filepath.str()); + if (!ifs.fail()) { + if (!(ifs >> cur_freq)) { + cur_freq = -1; } - #endif + ifs.close(); + } +#endif max_freqs.push_back(std::make_pair(i, cur_freq)); } - auto fcmpbyfreq = [] (const std::pair &a, - const std::pair &b) { - return a.second == b.second ? a.first < b.first : a.second > b.second; + auto fcmpbyfreq = [](const std::pair& a, + const std::pair& b) { + return a.second == b.second ? a.first < b.first : a.second > b.second; }; std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq); int64_t big_freq = max_freqs.begin()->second; @@ -228,10 +223,9 @@ class ThreadGroup::Impl { int little_count_ = 0; }; -ThreadGroup::ThreadGroup(int num_workers, - std::function worker_callback, +ThreadGroup::ThreadGroup(int num_workers, std::function worker_callback, bool exclude_worker0) - : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} + : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} ThreadGroup::~ThreadGroup() { delete impl_; } void ThreadGroup::Join() { impl_->Join(); } @@ -239,13 +233,11 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0 return impl_->Configure(mode, nthreads, exclude_worker0); } -void Yield() { - std::this_thread::yield(); -} +void Yield() { std::this_thread::yield(); } int MaxConcurrency() { int max_concurrency = 1; - const char *val = getenv("TVM_NUM_THREADS"); + const char* val = getenv("TVM_NUM_THREADS"); if (val == nullptr) { val = getenv("OMP_NUM_THREADS"); } @@ -271,7 +263,6 @@ int MaxConcurrency() { return std::max(max_concurrency, 1); } - } // namespace threading } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index c2036da..c72e70f 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -28,9 +28,9 @@ #include #include -#include -#include #include +#include +#include #include #include #include @@ -50,24 +50,17 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr); // Helper to deserialize a serialized vm instruction. Instruction DeserializeInstruction(const VMInstructionSerializer& instr); -PackedFunc Executable::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "get_lib") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLib(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); } else if (name == "get_bytecode") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetBytecode(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetBytecode(); }); } else if (name == "get_stats") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Stats(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); } else if (name == "save") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Save(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Save(); }); } else if (name == "get_function_arity") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0]; @@ -172,7 +165,8 @@ std::string Executable::Stats() const { // Get the number of globals and the name of each of them. oss << " Globals (#" << global_map.size() << "): ["; for (const auto& it : global_map) { - oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; + oss << "(\"" << it.first << "\", " << it.second << ")" + << ", "; } if (!global_map.empty()) oss.seekp(-2, oss.cur); oss << "]" << std::endl; @@ -232,8 +226,7 @@ TVMByteArray Executable::Save() { void Executable::SaveGlobalSection(dmlc::Stream* strm) { std::vector > globals(this->global_map.begin(), this->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { + auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; std::sort(globals.begin(), globals.end(), comp); @@ -364,8 +357,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); // Save the fields. - fields.insert(fields.end(), instr.datatype_fields, - instr.datatype_fields + instr.num_fields); + fields.insert(fields.end(), instr.datatype_fields, instr.datatype_fields + instr.num_fields); break; } case Opcode::AllocClosure: { @@ -373,15 +365,12 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); // Save the free vars. - fields.insert(fields.end(), instr.free_vars, - instr.free_vars + instr.num_freevar); + fields.insert(fields.end(), instr.free_vars, instr.free_vars + instr.num_freevar); break; } case Opcode::If: { // Number of fields = 4 - fields.assign({instr.if_op.test, - instr.if_op.target, - instr.if_op.true_offset, + fields.assign({instr.if_op.test, instr.if_op.target, instr.if_op.true_offset, instr.if_op.false_offset}); break; } @@ -399,8 +388,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.closure, instr.num_closure_args, instr.dst}); // Save the args. - fields.insert(fields.end(), instr.closure_args, - instr.closure_args + instr.num_closure_args); + fields.insert(fields.end(), instr.closure_args, instr.closure_args + instr.num_closure_args); break; } case Opcode::LoadConst: { @@ -441,9 +429,7 @@ void Executable::SaveCodeSection(dmlc::Stream* strm) { strm->Write(static_cast(this->functions.size())); for (const auto& func : this->functions) { // Save the function info. - VMFunctionSerializer func_format(func.name, - func.register_file_size, - func.instructions.size(), + VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), func.params); func_format.Save(strm); @@ -523,8 +509,7 @@ void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { // Extract the `cnt` number of fields started at `start` from the list // `instr_fields`. -inline std::vector ExtractFields(const std::vector& instr_fields, - Index start, +inline std::vector ExtractFields(const std::vector& instr_fields, Index start, Index cnt) { CHECK_LE(static_cast(start + cnt), instr_fields.size()); std::vector ret; @@ -634,11 +619,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { RegName dst = instr.fields[5]; - return Instruction::AllocStorage( - allocation_size, - alignment, - dtype, - dst); + return Instruction::AllocStorage(allocation_size, alignment, dtype, dst); } case Opcode::If: { // Number of fields = 4 @@ -727,9 +708,7 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } // Create the VM function. - VMFunction vm_func = VMFunction(loaded_func.name, - loaded_func.params, - instructions, + VMFunction vm_func = VMFunction(loaded_func.name, loaded_func.params, instructions, loaded_func.register_file_size); auto it = this->global_map.find(loaded_func.name); CHECK(it != this->global_map.end()); @@ -738,24 +717,21 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } } -TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); *rv = static_cast(exec->global_map.size()); }); -TVM_REGISTER_GLOBAL("runtime.GetGlobalFields") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetGlobalFields").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); int idx = args[1]; std::vector > globals(exec->global_map.begin(), exec->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { + auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; std::sort(globals.begin(), globals.end(), comp); @@ -763,17 +739,14 @@ TVM_REGISTER_GLOBAL("runtime.GetGlobalFields") *rv = globals[idx].first; }); -TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); *rv = static_cast(exec->primitive_map.size()); }); - -TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); @@ -790,11 +763,9 @@ TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields") }); TVM_REGISTER_GLOBAL("runtime.Load_Executable") -.set_body_typed([]( - std::string code, - runtime::Module lib) { - return Executable::Load(code, lib); -}); + .set_body_typed([](std::string code, runtime::Module lib) { + return Executable::Load(code, lib); + }); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index 3e6140e..c0fd441 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -21,9 +21,11 @@ * \file tvm/runtime/vm/memory_manager.cc * \brief Allocate and manage memory for the runtime. */ -#include -#include #include "memory_manager.h" + +#include +#include + #include "naive_allocator.h" #include "pooled_allocator.h" @@ -35,8 +37,7 @@ static void BufferDeleter(Object* obj) { auto* ptr = static_cast(obj); CHECK(ptr->manager_ctx != nullptr); Buffer* buffer = reinterpret_cast(ptr->manager_ctx); - MemoryManager::Global()->GetAllocator(buffer->ctx)-> - Free(*(buffer)); + MemoryManager::Global()->GetAllocator(buffer->ctx)->Free(*(buffer)); delete buffer; delete ptr; } @@ -93,7 +94,7 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDa // RAII in effect, now run the check. // TODO(@jroesch): generalize later to non-overlapping allocations. CHECK(needed_size == this->buffer.size) - << "size mistmatch required " << needed_size << " found " << this->buffer.size; + << "size mistmatch required " << needed_size << " found " << this->buffer.size; return ret; } @@ -106,8 +107,8 @@ MemoryManager* MemoryManager::Global() { Allocator* MemoryManager::GetAllocator(TVMContext ctx) { std::lock_guard lock(mu_); if (allocators_.find(ctx) == allocators_.end()) { - DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" - << ctx.device_id << ")"; + DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" << ctx.device_id + << ")"; std::unique_ptr alloc(new NaiveAllocator(ctx)); allocators_.emplace(ctx, std::move(alloc)); } @@ -120,7 +121,7 @@ NDArray Allocator::Empty(std::vector shape, DLDataType dtype, DLContext container->SetDeleter(BufferDeleter); size_t size = GetDataSize(container->dl_tensor); size_t alignment = GetDataAlignment(container->dl_tensor); - Buffer *buffer = new Buffer; + Buffer* buffer = new Buffer; *buffer = this->Alloc(size, alignment, dtype); container->manager_ctx = reinterpret_cast(buffer); container->dl_tensor.data = buffer->data; diff --git a/src/runtime/vm/memory_manager.h b/src/runtime/vm/memory_manager.h index b445352..f59d584 100644 --- a/src/runtime/vm/memory_manager.h +++ b/src/runtime/vm/memory_manager.h @@ -27,6 +27,7 @@ #include #include #include + #include #include #include @@ -73,15 +74,13 @@ class Allocator { * \param ctx The context where the array is allocated. * \return The empty NDArray. */ - NDArray Empty(std::vector shape, - DLDataType dtype, - DLContext ctx); + NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); /*! \brief Allocate a buffer given a size, alignment and type. * \param nbytes The size of the buffer. * \param alignment The alignment of the buffer. * \param type_hint A type hint to the allocator. * \return A sized allocation in the form of a buffer. - */ + */ virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0; /*! \brief Free a buffer allocated by the allocator. * \param buffer The buffer to free. @@ -115,9 +114,7 @@ class StorageObj : public Object { Buffer buffer; /*! \brief Allocate an NDArray from a given piece of storage. */ - NDArray AllocNDArray(size_t offset, - std::vector shape, - DLDataType dtype); + NDArray AllocNDArray(size_t offset, std::vector shape, DLDataType dtype); /*! \brief The deleter for an NDArray when allocated from underlying storage. */ static void Deleter(Object* ptr); diff --git a/src/runtime/vm/naive_allocator.h b/src/runtime/vm/naive_allocator.h index db47a62..5ac2ca6 100644 --- a/src/runtime/vm/naive_allocator.h +++ b/src/runtime/vm/naive_allocator.h @@ -24,6 +24,7 @@ #define TVM_RUNTIME_VM_NAIVE_ALLOCATOR_H_ #include + #include #include "memory_manager.h" @@ -52,9 +53,7 @@ class NaiveAllocator final : public Allocator { DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B"; } - size_t UsedMemory() const override { - return used_memory_.load(std::memory_order_relaxed); - } + size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); } private: std::atomic used_memory_; diff --git a/src/runtime/vm/pooled_allocator.h b/src/runtime/vm/pooled_allocator.h index 5965a4e..e09628f 100644 --- a/src/runtime/vm/pooled_allocator.h +++ b/src/runtime/vm/pooled_allocator.h @@ -24,6 +24,7 @@ #define TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_ #include + #include #include #include diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 4dac66e..6e4682d 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -22,6 +22,8 @@ * \brief The Relay debug virtual machine. */ +#include "vm.h" + #include #include @@ -34,27 +36,24 @@ #include #include -#include "vm.h" - namespace tvm { namespace runtime { namespace vm { -PackedFunc VirtualMachineDebug::GetFunction( - const std::string& name, const ObjectPtr& sptr_to_self) { +PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { if (name == "get_stat") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.size(), 1U); std::vector> op_acc_time; for (auto kv : op_durations_) { - auto val = std::make_pair( - kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0)); + auto val = + std::make_pair(kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0)); op_acc_time.push_back(val); } bool sort_by_time = args[0]; if (sort_by_time) { - auto comp = [](const std::pair& lhs, - const std::pair& rhs) { + auto comp = [](const std::pair& lhs, const std::pair& rhs) { return lhs.second > rhs.second; }; std::sort(op_acc_time.begin(), op_acc_time.end(), comp); @@ -74,9 +73,9 @@ PackedFunc VirtualMachineDebug::GetFunction( auto min_value = *std::min_element(vals.begin(), vals.end()); auto max_value = *std::max_element(vals.begin(), vals.end()); - os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t" - << std::setw(10) << std::left << op_invokes_[kv.first] << "\t" - << sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl; + os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t" << std::setw(10) + << std::left << op_invokes_[kv.first] << "\t" << sum << "/" << mean << "/" << min_value + << "/" << max_value << std::endl; total_duration += sum; total_packed_funcs += op_invokes_[kv.first]; @@ -104,10 +103,8 @@ void VirtualMachineDebug::LoadExecutable(const Executable* exec) { } } -void VirtualMachineDebug::InvokePacked(Index packed_index, - const PackedFunc& func, Index arg_count, - Index output_size, - const std::vector& args) { +void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, + Index output_size, const std::vector& args) { CHECK(exec_); auto ctx = this->GetParamsContext(); // warmup @@ -119,9 +116,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_end = std::chrono::high_resolution_clock::now(); double op_duration = - std::chrono::duration_cast >(op_end - - op_begin) - .count(); + std::chrono::duration_cast>(op_end - op_begin).count(); op_durations_[packed_index].push_back(op_duration * 1e6); op_invokes_[packed_index] += 1; @@ -133,8 +128,7 @@ runtime::Module CreateVirtualMachineDebug(const Executable* exec) { return runtime::Module(vm); } -TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec) << "Virtual machine has not been defined yet." diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index f0a407f..c286828 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -40,16 +40,15 @@ class VirtualMachineDebug : public VirtualMachine { public: VirtualMachineDebug() : VirtualMachine() {} - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; void LoadExecutable(const Executable* exec) final; ~VirtualMachineDebug() {} private: - void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, - Index output_size, const std::vector& args) final; + void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, + const std::vector& args) final; std::unordered_map packed_index_map_; std::unordered_map> op_durations_; diff --git a/src/runtime/vm/serialize_util.h b/src/runtime/vm/serialize_util.h index 3423f7a..8bd1f86 100644 --- a/src/runtime/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -60,9 +60,7 @@ struct VMFunctionSerializer { VMFunctionSerializer() = default; - VMFunctionSerializer(const std::string& name, - Index register_file_size, - size_t num_instructions, + VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, const std::vector& params) : name(name), register_file_size(register_file_size), @@ -87,7 +85,7 @@ struct VMFunctionSerializer { } /*! - * \brief Save the VM function header into the serialized form. + * \brief Save the VM function header into the serialized form. * \param strm The stream used to save data. */ void Save(dmlc::Stream* strm) const { @@ -108,11 +106,11 @@ struct VMInstructionSerializer { VMInstructionSerializer() = default; - VMInstructionSerializer(Index opcode, const std::vector& fields) : - opcode(opcode), fields(fields) {} + VMInstructionSerializer(Index opcode, const std::vector& fields) + : opcode(opcode), fields(fields) {} /*! - * \brief Compute the hash of the serialized instruction. + * \brief Compute the hash of the serialized instruction. * \return The hash that combines the opcode and all fields of the VM * instruction. */ @@ -139,13 +137,12 @@ struct VMInstructionSerializer { } Index hash = Hash(); - CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " - << opcode << "\n"; + CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " << opcode << "\n"; return true; } /*! - * \brief Save the instruction into the serialized form. + * \brief Save the instruction into the serialized form. * \param strm The stream used to save data. */ void Save(dmlc::Stream* strm) const { diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index fedbbe9..0714709 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -23,11 +23,11 @@ */ #include -#include #include -#include #include #include +#include +#include #include #include @@ -56,8 +56,7 @@ inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint // We could put cache in here, from ctx to storage allocator. auto storage_obj = SimpleObjAllocator().make_object(); auto alloc = MemoryManager::Global()->GetAllocator(ctx); - DCHECK(alloc != nullptr) - << "allocator must not null"; + DCHECK(alloc != nullptr) << "allocator must not null"; storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint); return Storage(storage_obj); } @@ -87,8 +86,8 @@ Instruction::Instruction(const Instruction& instr) { case Opcode::AllocTensor: this->alloc_tensor.storage = instr.alloc_tensor.storage; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; - this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, - instr.alloc_tensor.ndim); + this->alloc_tensor.shape = + Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); this->alloc_tensor.dtype = instr.alloc_tensor.dtype; return; case Opcode::AllocTensorReg: @@ -151,7 +150,7 @@ Instruction::Instruction(const Instruction& instr) { } } -template +template static inline void FreeIf(T* t) { if (t != nullptr) { delete t; @@ -177,8 +176,8 @@ Instruction& Instruction::operator=(const Instruction& instr) { case Opcode::AllocTensor: this->alloc_tensor.storage = instr.alloc_tensor.storage; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; - this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, - instr.alloc_tensor.ndim); + this->alloc_tensor.shape = + Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); this->alloc_tensor.dtype = instr.alloc_tensor.dtype; return *this; case Opcode::AllocTensorReg: @@ -294,9 +293,7 @@ Instruction Instruction::Fatal() { return instr; } -Instruction Instruction::InvokePacked(Index packed_index, - Index arity, - Index output_size, +Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size, const std::vector& args) { Instruction instr; instr.op = Opcode::InvokePacked; @@ -310,10 +307,8 @@ Instruction Instruction::InvokePacked(Index packed_index, return instr; } -Instruction Instruction::AllocTensor( - RegName storage, - const std::vector& shape, - DLDataType dtype, Index dst) { +Instruction Instruction::AllocTensor(RegName storage, const std::vector& shape, + DLDataType dtype, Index dst) { Instruction instr; instr.op = Opcode::AllocTensor; instr.dst = dst; @@ -327,10 +322,8 @@ Instruction Instruction::AllocTensor( return instr; } -Instruction Instruction::AllocTensorReg( - RegName storage, - RegName shape_register, - DLDataType dtype, Index dst) { +Instruction Instruction::AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, + Index dst) { Instruction instr; instr.op = Opcode::AllocTensorReg; instr.dst = dst; @@ -340,9 +333,7 @@ Instruction Instruction::AllocTensorReg( return instr; } -Instruction Instruction::AllocStorage(RegName size, - Index alignment, - DLDataType dtype_hint, +Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, Index dst) { Instruction instr; instr.op = Opcode::AllocStorage; @@ -354,7 +345,7 @@ Instruction Instruction::AllocStorage(RegName size, } Instruction Instruction::AllocADT(Index tag, Index num_fields, - const std::vector& datatype_fields, Index dst) { + const std::vector& datatype_fields, Index dst) { Instruction instr; instr.op = Opcode::AllocADT; instr.dst = dst; @@ -486,7 +477,7 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) { } } -template +template std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") { if (cnt == 0) { return ""; @@ -515,26 +506,21 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { } case Opcode::InvokePacked: { os << "invoke_packed PackedFunc[" << instr.packed_index << "] (in: $" - << StrJoin(instr.packed_args, 0, - instr.arity - instr.output_size, ", $") + << StrJoin(instr.packed_args, 0, instr.arity - instr.output_size, ", $") << ", out: $" - << StrJoin(instr.packed_args, instr.arity - instr.output_size, - instr.output_size, ", $") + << StrJoin(instr.packed_args, instr.arity - instr.output_size, instr.output_size, + ", $") << ")"; break; } case Opcode::AllocTensor: { - os << "alloc_tensor $" << instr.dst << " $" - << instr.alloc_tensor.storage << " [" - << StrJoin(instr.alloc_tensor.shape, 0, - instr.alloc_tensor.ndim) - << "] "; + os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " [" + << StrJoin(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] "; DLDatatypePrint(os, instr.alloc_tensor.dtype); break; } case Opcode::AllocTensorReg: { - os << "alloc_tensor_reg $" << instr.dst << " $" - << instr.alloc_tensor_reg.storage << " $" + os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $" << instr.alloc_tensor_reg.shape_register << " "; DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; @@ -545,26 +531,24 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::AllocClosure: { - os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index - << "]($" << StrJoin(instr.free_vars, 0, instr.num_freevar, ",$") - << ")"; + os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index << "]($" + << StrJoin(instr.free_vars, 0, instr.num_freevar, ",$") << ")"; break; } case Opcode::If: { - os << "if " << "$" << instr.if_op.test << " $" << instr.if_op.target << " " - << instr.if_op.true_offset << " " << instr.if_op.false_offset; + os << "if " + << "$" << instr.if_op.test << " $" << instr.if_op.target << " " << instr.if_op.true_offset + << " " << instr.if_op.false_offset; break; } case Opcode::Invoke: { os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($" - << StrJoin(instr.invoke_args_registers, 0, instr.num_args, ",$") - << ")"; + << StrJoin(instr.invoke_args_registers, 0, instr.num_args, ",$") << ")"; break; } case Opcode::InvokeClosure: { os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($" - << StrJoin(instr.closure_args, 0, instr.num_closure_args, ",$") - << ")"; + << StrJoin(instr.closure_args, 0, instr.num_closure_args, ",$") << ")"; break; } case Opcode::LoadConst: { @@ -576,8 +560,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::GetField: { - os << "get_field $" << instr.dst << " $" << instr.object << "[" - << instr.field_index << "]"; + os << "get_field $" << instr.dst << " $" << instr.object << "[" << instr.field_index << "]"; break; } case Opcode::GetTag: { @@ -589,11 +572,9 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::AllocStorage: { - os << "alloc_storage $" << - instr.dst << " $" << - instr.alloc_storage.allocation_size << " $" << - instr.alloc_storage.alignment << " " << - DLDataType2String(instr.alloc_storage.dtype_hint); + os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " $" + << instr.alloc_storage.alignment << " " + << DLDataType2String(instr.alloc_storage.dtype_hint); break; } default: @@ -637,14 +618,14 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, std::string func_name = args[0]; auto git = exec_->global_map.find(func_name); CHECK(git != exec_->global_map.end()) - << "Cannot find function " << func_name << " in the executable"; + << "Cannot find function " << func_name << " in the executable"; auto func = exec_->functions[git->second]; if (func.params.empty()) { *rv = Invoke(func, {}); } else { auto it = inputs_.find(func_name); CHECK(it != inputs_.end()) << "Input has not been set for function " << func_name; - const std::vector &func_args = it->second; + const std::vector& func_args = it->second; *rv = Invoke(func, func_args); } }); @@ -672,8 +653,8 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, const auto& param_names = vm_func.params; // TODO(icemelon9): For heterogeneous execution, get input device information TVMContext ctx = ctxs_[0]; - CHECK_EQ(args.size() - 1, param_names.size()) << - "The number of provided parameters doesn't match the number of arguments"; + CHECK_EQ(args.size() - 1, param_names.size()) + << "The number of provided parameters doesn't match the number of arguments"; std::vector func_args(param_names.size()); for (int i = 1; i < args.size(); ++i) { ObjectRef obj = CopyTo(args[i], ctx); @@ -745,16 +726,14 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector& args) { CHECK(exec_) << "The executable has not been created yet."; auto it = exec_->global_map.find(name); - CHECK(it != exec_->global_map.end()) - << "Cannot find function " << name << " in the executable"; + CHECK(it != exec_->global_map.end()) << "Cannot find function " << name << " in the executable"; auto func_index_ = it->second; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index_; return Invoke(exec_->functions[func_index_], args); } -void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, - Index arg_count, Index output_size, - const std::vector& args) { +void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, + Index output_size, const std::vector& args) { size_t arity = 0; for (Index i = 0; i < arg_count; i++) { if (const auto* obj = args[i].as()) { @@ -806,10 +785,7 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { } } - -void VirtualMachine::Init(const std::vector& ctxs) { - ctxs_ = ctxs; -} +void VirtualMachine::Init(const std::vector& ctxs) { ctxs_ = ctxs; } inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames_.back().register_file[r] = val; @@ -893,13 +869,13 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::InvokePacked: { - DLOG(INFO) << "InvokedPacked " << "arity=" << instr.arity; + DLOG(INFO) << "InvokedPacked " + << "arity=" << instr.arity; const auto& func = packed_funcs_[instr.packed_index]; const auto& arity = instr.arity; std::vector args; for (Index i = 0; i < arity; ++i) { - DLOG(INFO) << - "arg" << i << " $" << instr.packed_args[i]; + DLOG(INFO) << "arg" << i << " $" << instr.packed_args[i]; auto arg = ReadRegister(instr.packed_args[i]); args.push_back(arg); } @@ -1022,10 +998,8 @@ void VirtualMachine::RunLoop() { auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = LoadScalarInt(instr.alloc_storage.alignment); - DLOG(INFO) << - "AllocStorage: allocation_size=" << size << - "alignment=" << alignment << - "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint); + DLOG(INFO) << "AllocStorage: allocation_size=" << size << "alignment=" << alignment + << "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint); auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]); WriteRegister(instr.dst, storage); @@ -1057,8 +1031,7 @@ runtime::Module CreateVirtualMachine(const Executable* exec) { return runtime::Module(vm); } -TVM_REGISTER_GLOBAL("runtime._VirtualMachine") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime._VirtualMachine").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec) << "The virtual machine executable has not been defined yet."; diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 48fbdc7..207a86a 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -17,21 +17,19 @@ * under the License. */ -#include #include #include #include #include +#include #include #include - #include "../file_util.h" #include "../pack_args.h" #include "../thread_storage_scope.h" #include "../workspace_pool.h" - #include "vulkan_common.h" #include "vulkan_module.h" #include "vulkan_shader.h" @@ -117,9 +115,7 @@ class VulkanDeviceAPI final : public DeviceAPI { } void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { const auto& vctx = context(ctx.device_id); VkBufferCreateInfo info; @@ -628,9 +624,8 @@ VulkanDeviceAPI::VulkanDeviceAPI() { #ifdef USE_VULKAN_IMMEDIATE_MODE if (has_extension("VK_KHR_push_descriptor") && has_extension("VK_KHR_descriptor_update_template")) { - ctx.descriptor_template_khr_functions = - std::unique_ptr( - new VulkanDescriptorTemplateKHRFunctions()); + ctx.descriptor_template_khr_functions = std::unique_ptr( + new VulkanDescriptorTemplateKHRFunctions()); ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR = CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr( ctx.device, "vkCreateDescriptorUpdateTemplateKHR")); @@ -672,9 +667,7 @@ class VulkanModuleNode; // a wrapped function class to get packed func. class VulkanWrappedFunc { public: - void Init(VulkanModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, + void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, const std::vector& thread_axis_tags) { m_ = m; @@ -710,13 +703,12 @@ class VulkanWrappedFunc { class VulkanModuleNode final : public runtime::ModuleNode { public: explicit VulkanModuleNode(std::unordered_map smap, - std::unordered_map fmap, std::string source) + std::unordered_map fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) {} const char* type_key() const final { return "vulkan"; } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); @@ -750,10 +742,8 @@ class VulkanModuleNode final : public runtime::ModuleNode { } } - std::shared_ptr GetPipeline( - size_t device_id, - const std::string& func_name, - size_t num_pack_args) { + std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, + size_t num_pack_args) { const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); std::lock_guard lock(mutex_); const auto& cp = ecache_[device_id][func_name]; @@ -1022,8 +1012,7 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { return streams_[device_id].get(); } -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, - const ArgUnion* pack_args) const { +void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id; CHECK_LT(device_id, kVulkanMaxNumDevice); const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 9242d3d..780b111 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -22,8 +22,8 @@ #include #include #include - #include + #include #include #include @@ -140,7 +140,6 @@ struct VulkanContext { bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; } }; - } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_shader.h b/src/runtime/vulkan/vulkan_shader.h index 1b2e454..d56ca61 100644 --- a/src/runtime/vulkan/vulkan_shader.h +++ b/src/runtime/vulkan/vulkan_shader.h @@ -18,7 +18,6 @@ */ #pragma once - #include #include #include diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index 1a24d28..388cacc 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -20,12 +20,11 @@ #include #include -#include #include +#include #include "vulkan_common.h" - namespace tvm { namespace runtime { namespace vulkan { @@ -44,8 +43,7 @@ struct VulkanStreamToken { class VulkanStream { public: - explicit VulkanStream(const VulkanContext* vctx) - : vctx_(vctx), state_(new VulkanStreamState()) { + explicit VulkanStream(const VulkanContext* vctx) : vctx_(vctx), state_(new VulkanStreamState()) { // create command pool VkCommandPoolCreateInfo cmd_pool_cinfo; cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; diff --git a/src/runtime/workspace_pool.cc b/src/runtime/workspace_pool.cc index fc316cd..8ee905e 100644 --- a/src/runtime/workspace_pool.cc +++ b/src/runtime/workspace_pool.cc @@ -21,9 +21,10 @@ * \file workspace_pool.h * \brief Workspace pool utility. */ -#include #include "workspace_pool.h" +#include + namespace tvm { namespace runtime { @@ -67,7 +68,8 @@ class WorkspacePool::Pool { if (free_list_.back().size >= nbytes) { // find smallest fit auto it = free_list_.end() - 2; - for (; it->size >= nbytes; --it) {} + for (; it->size >= nbytes; --it) { + } e = *(it + 1); free_list_.erase(it + 1); } else { @@ -91,7 +93,8 @@ class WorkspacePool::Pool { allocated_.pop_back(); } else { int index = static_cast(allocated_.size()) - 2; - for (; index > 0 && allocated_[index].data != data; --index) {} + for (; index > 0 && allocated_[index].data != data; --index) { + } CHECK_GT(index, 0) << "trying to free things that has not been allocated"; e = allocated_[index]; allocated_.erase(allocated_.begin() + index); @@ -132,8 +135,7 @@ class WorkspacePool::Pool { }; WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr device) - : device_type_(device_type), device_(device) { -} + : device_type_(device_type), device_(device) {} WorkspacePool::~WorkspacePool() { for (size_t i = 0; i < array_.size(); ++i) { @@ -158,8 +160,7 @@ void* WorkspacePool::AllocWorkspace(TVMContext ctx, size_t size) { } void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) { - CHECK(static_cast(ctx.device_id) < array_.size() && - array_[ctx.device_id] != nullptr); + CHECK(static_cast(ctx.device_id) < array_.size() && array_[ctx.device_id] != nullptr); array_[ctx.device_id]->Free(ptr); } diff --git a/src/runtime/workspace_pool.h b/src/runtime/workspace_pool.h index 72613ca..288da7d 100644 --- a/src/runtime/workspace_pool.h +++ b/src/runtime/workspace_pool.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,9 @@ #define TVM_RUNTIME_WORKSPACE_POOL_H_ #include -#include + #include +#include namespace tvm { namespace runtime { diff --git a/src/support/arena.h b/src/support/arena.h index b062276..cb08db9 100644 --- a/src/support/arena.h +++ b/src/support/arena.h @@ -31,9 +31,8 @@ #endif #include -#include #include - +#include namespace tvm { namespace support { @@ -74,9 +73,7 @@ class SimplePageAllocator { * \brief De-allocate an allocate page. * \param page The page to be de-allocated. */ - void deallocate(ArenaPageHeader* page) { - delete [] reinterpret_cast(page); - } + void deallocate(ArenaPageHeader* page) { delete[] reinterpret_cast(page); } static const constexpr int kPageSize = 16 << 10; static const constexpr int kPageAlign = 1024; @@ -91,20 +88,17 @@ class SimplePageAllocator { * \brief Arena allocator that allocates memory from continuous * chunk and frees them all only during destruction. */ -template +template class GenericArena { public: - explicit GenericArena(PageAllocator alloc = PageAllocator()) - : alloc_(alloc) { + explicit GenericArena(PageAllocator alloc = PageAllocator()) : alloc_(alloc) { // eagerly allocate the first page. head_ = tail_ = alloc_.allocate(1); head_->next = nullptr; } #if TVM_ARENA_HAS_DESTRUCTOR - ~GenericArena() { - this->FreeAll(); - } + ~GenericArena() { this->FreeAll(); } #endif /*! \brief Free all pages. */ @@ -129,10 +123,9 @@ class GenericArena { * \param count Numberof elements * \note The space of T is not initialized. */ - template + template T* allocate_(int count = 1) { - static_assert(PageAllocator::kPageAlign % alignof(T) == 0, - "To large alignment"); + static_assert(PageAllocator::kPageAlign % alignof(T) == 0, "To large alignment"); return static_cast(Alloc(sizeof(T) * count, alignof(T))); } /*! @@ -146,7 +139,7 @@ class GenericArena { * memory allocated from the same arena. * Otherwise the destructor needs to be called explicitly. */ - template + template T* make(Args&&... args) { T* ptr = allocate_(); new (ptr) T(std::forward(args)...); @@ -183,7 +176,7 @@ class GenericArena { } else { ArenaPageHeader* new_head; offset = UpperAlign(sizeof(ArenaPageHeader), align); - if (free_list_ != nullptr && offset + size <= free_list_-> size) { + if (free_list_ != nullptr && offset + size <= free_list_->size) { new_head = free_list_; free_list_ = free_list_->next; } else { @@ -215,7 +208,7 @@ using Arena = GenericArena; * \brief Link list node * \tparam T the content data type */ -template +template struct LinkNode { /*! \brief The content value */ T value; @@ -228,7 +221,7 @@ struct LinkNode { * \note This is a simple data structure that can be used together with the arena. * \sa LinkNode */ -template +template struct LinkedList { /*! \brief Head pointer */ LinkNode* head{nullptr}; diff --git a/src/support/base64.h b/src/support/base64.h index c85b268..9849542 100644 --- a/src/support/base64.h +++ b/src/support/base64.h @@ -27,7 +27,7 @@ #define TVM_SUPPORT_BASE64_H_ #include -#include + #include #include #include @@ -38,18 +38,16 @@ namespace support { namespace base64 { // decoding table const char DecodeTable[] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 62, // '+' - 0, 0, 0, - 63, // '/' - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' - 0, 0, 0, 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' - 0, 0, 0, 0, 0, 0, - 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, - 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' }; // encoding table static const char EncodeTable[] = @@ -62,14 +60,12 @@ static const char EncodeTable[] = */ class StreamBufferReader { public: - explicit StreamBufferReader(size_t buffer_size) { - buffer_.resize(buffer_size); - } + explicit StreamBufferReader(size_t buffer_size) { buffer_.resize(buffer_size); } /*! * \brief set input stream * \param stream The stream to be set */ - void set_stream(dmlc::Stream *stream) { + void set_stream(dmlc::Stream* stream) { stream_ = stream; read_len_ = read_ptr_ = 1; } @@ -88,13 +84,11 @@ class StreamBufferReader { } } /*! \return whether we are reaching the end of file */ - bool AtEnd() const { - return read_len_ == 0; - } + bool AtEnd() const { return read_len_ == 0; } private: /*! \brief the underlying stream */ - dmlc::Stream *stream_{nullptr}; + dmlc::Stream* stream_{nullptr}; /*! \brief buffer to hold data */ std::string buffer_; /*! \brief length of valid data in buffer */ @@ -106,11 +100,9 @@ class StreamBufferReader { /*! * \brief Input stream from base64 encoding */ -class Base64InStream: public dmlc::Stream { +class Base64InStream : public dmlc::Stream { public: - explicit Base64InStream(dmlc::Stream *fs) : reader_(256) { - reader_.set_stream(fs); - } + explicit Base64InStream(dmlc::Stream* fs) : reader_(256) { reader_.set_stream(fs); } /*! * \brief initialize the stream position to beginning of next base64 stream * \note call this function before actually start read @@ -122,16 +114,14 @@ class Base64InStream: public dmlc::Stream { } while (isspace(temp_ch_)); } /*! \brief whether current position is end of a base64 stream */ - bool IsEOF(void) const { - return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); - } + bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); } // override read function. - virtual size_t Read(void *ptr, size_t size) { + virtual size_t Read(void* ptr, size_t size) { using base64::DecodeTable; if (size == 0) return 0; // use tlen to record left size size_t tlen = size; - unsigned char *cptr = static_cast(ptr); + unsigned char* cptr = static_cast(ptr); // if anything left, load from previous buffered result if (num_prev_ != 0) { if (num_prev_ == 2) { @@ -142,13 +132,16 @@ class Base64InStream: public dmlc::Stream { num_prev_ = 0; } else { // assert tlen == 1 - *cptr++ = buf_prev[0]; --tlen; + *cptr++ = buf_prev[0]; + --tlen; buf_prev[0] = buf_prev[1]; num_prev_ = 1; } } else { // assert num_prev_ == 1 - *cptr++ = buf_prev[0]; --tlen; num_prev_ = 0; + *cptr++ = buf_prev[0]; + --tlen; + num_prev_ = 0; } } if (tlen == 0) return size; @@ -163,8 +156,9 @@ class Base64InStream: public dmlc::Stream { temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; nvalue |= DecodeTable[temp_ch_] << 12; - *cptr++ = (nvalue >> 16) & 0xFF; --tlen; - } + *cptr++ = (nvalue >> 16) & 0xFF; + --tlen; + } { // third byte temp_ch_ = reader_.GetChar(); @@ -174,13 +168,13 @@ class Base64InStream: public dmlc::Stream { temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ == '=') << "invalid base64 format"; temp_ch_ = reader_.GetChar(); - CHECK(temp_ch_ == EOF || isspace(temp_ch_)) - << "invalid base64 format"; + CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_] << 6; if (tlen) { - *cptr++ = (nvalue >> 8) & 0xFF; --tlen; + *cptr++ = (nvalue >> 8) & 0xFF; + --tlen; } else { buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF; } @@ -188,19 +182,18 @@ class Base64InStream: public dmlc::Stream { { // fourth byte temp_ch_ = reader_.GetChar(); - CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) - << "invalid base64 format"; + CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; if (temp_ch_ == '=') { temp_ch_ = reader_.GetChar(); - CHECK(temp_ch_ == EOF || isspace(temp_ch_)) - << "invalid base64 format"; + CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_]; if (tlen) { - *cptr++ = nvalue & 0xFF; --tlen; + *cptr++ = nvalue & 0xFF; + --tlen; } else { - buf_prev[num_prev_ ++] = nvalue & 0xFF; + buf_prev[num_prev_++] = nvalue & 0xFF; } } // get next char @@ -211,7 +204,7 @@ class Base64InStream: public dmlc::Stream { } return size - tlen; } - virtual void Write(const void *ptr, size_t size) { + virtual void Write(const void* ptr, size_t size) { LOG(FATAL) << "Base64InStream do not support write"; } @@ -228,17 +221,17 @@ class Base64InStream: public dmlc::Stream { /*! * \brief Stream to write to base64 format. */ -class Base64OutStream: public dmlc::Stream { +class Base64OutStream : public dmlc::Stream { public: - explicit Base64OutStream(dmlc::Stream *fp) : fp_(fp) { - } - virtual void Write(const void *ptr, size_t size) { + explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {} + virtual void Write(const void* ptr, size_t size) { using base64::EncodeTable; size_t tlen = size; - const unsigned char *cptr = static_cast(ptr); + const unsigned char* cptr = static_cast(ptr); while (tlen) { - while (buf__top_ < 3 && tlen != 0) { - buf_[++buf__top_] = *cptr++; --tlen; + while (buf__top_ < 3 && tlen != 0) { + buf_[++buf__top_] = *cptr++; + --tlen; } if (buf__top_ == 3) { // flush 4 bytes out @@ -250,7 +243,7 @@ class Base64OutStream: public dmlc::Stream { } } } - virtual size_t Read(void *ptr, size_t size) { + virtual size_t Read(void* ptr, size_t size) { LOG(FATAL) << "Base64OutStream do not support read"; return 0; } @@ -280,12 +273,11 @@ class Base64OutStream: public dmlc::Stream { private: static constexpr size_t kBufferSize = 256; - dmlc::Stream *fp_{nullptr}; + dmlc::Stream* fp_{nullptr}; int buf__top_{0}; unsigned char buf_[4]; std::string out_buf_; - void PutChar(char ch) { out_buf_ += ch; if (out_buf_.length() >= kBufferSize) Flush(); diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 622e28e..2b944cb 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -17,15 +17,15 @@ * under the License. */ - /*! +/*! * FFI registration code used for frontend testing purposes. * \file ffi_testing.cc */ -#include -#include -#include #include #include +#include +#include +#include namespace tvm { // Attrs used to python API @@ -36,16 +36,10 @@ struct TestAttrs : public AttrsNode { TypedEnvFunc func; TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") { - TVM_ATTR_FIELD(axis) - .set_default(10) - .set_lower_bound(1) - .set_upper_bound(10) - .describe("axis field"); - TVM_ATTR_FIELD(name) - .describe("name"); - TVM_ATTR_FIELD(padding) - .describe("padding of input") - .set_default(Array({0, 0})); + TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe( + "axis field"); + TVM_ATTR_FIELD(name).describe("name"); + TVM_ATTR_FIELD(padding).describe("padding of input").set_default(Array({0, 0})); TVM_ATTR_FIELD(func) .describe("some random env function") .set_default(TypedEnvFunc(nullptr)); @@ -54,49 +48,37 @@ struct TestAttrs : public AttrsNode { TVM_REGISTER_NODE_TYPE(TestAttrs); -TVM_REGISTER_GLOBAL("testing.nop") -.set_body([](TVMArgs args, TVMRetValue *ret) { - }); +TVM_REGISTER_GLOBAL("testing.nop").set_body([](TVMArgs args, TVMRetValue* ret) {}); -TVM_REGISTER_GLOBAL("testing.echo") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0]; - }); +}); -TVM_REGISTER_GLOBAL("testing.test_wrap_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PackedFunc pf = args[0]; - *ret = runtime::TypedPackedFunc([pf](){ - pf(); - }); - }); +TVM_REGISTER_GLOBAL("testing.test_wrap_callback").set_body([](TVMArgs args, TVMRetValue* ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf]() { pf(); }); +}); TVM_REGISTER_GLOBAL("testing.test_raise_error_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string msg = args[0]; - *ret = runtime::TypedPackedFunc([msg](){ - LOG(FATAL) << msg; - }); - }); - -TVM_REGISTER_GLOBAL("testing.test_check_eq_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string msg = args[0]; - *ret = runtime::TypedPackedFunc([msg](int x, int y){ - CHECK_EQ(x, y) << msg; - }); - }); + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::string msg = args[0]; + *ret = runtime::TypedPackedFunc([msg]() { LOG(FATAL) << msg; }); + }); -TVM_REGISTER_GLOBAL("testing.context_test") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLContext ctx = args[0]; - int dtype = args[1]; - int did = args[2]; - CHECK_EQ(static_cast(ctx.device_type), dtype); - CHECK_EQ(static_cast(ctx.device_id), did); - *ret = ctx; - }); +TVM_REGISTER_GLOBAL("testing.test_check_eq_callback").set_body([](TVMArgs args, TVMRetValue* ret) { + std::string msg = args[0]; + *ret = + runtime::TypedPackedFunc([msg](int x, int y) { CHECK_EQ(x, y) << msg; }); +}); +TVM_REGISTER_GLOBAL("testing.context_test").set_body([](TVMArgs args, TVMRetValue* ret) { + DLContext ctx = args[0]; + int dtype = args[1]; + int did = args[2]; + CHECK_EQ(static_cast(ctx.device_type), dtype); + CHECK_EQ(static_cast(ctx.device_id), did); + *ret = ctx; +}); // in src/api_test.cc void ErrorTest(int x, int y) { @@ -108,15 +90,13 @@ void ErrorTest(int x, int y) { } } -TVM_REGISTER_GLOBAL("testing.ErrorTest") -.set_body_typed(ErrorTest); +TVM_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest); // internal function used for debug and testing purposes -TVM_REGISTER_GLOBAL("testing.object_use_count") -.set_body([](TVMArgs args, TVMRetValue *ret) { - runtime::ObjectRef obj = args[0]; - // substract the current one because we always copy - // and get another value. - *ret = (obj.use_count() - 1); - }); +TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRetValue* ret) { + runtime::ObjectRef obj = args[0]; + // substract the current one because we always copy + // and get another value. + *ret = (obj.use_count() - 1); +}); } // namespace tvm diff --git a/src/support/pipe.h b/src/support/pipe.h index 120bbdb..dcebd0d 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -24,16 +24,17 @@ #ifndef TVM_SUPPORT_PIPE_H_ #define TVM_SUPPORT_PIPE_H_ -#include #include +#include #ifdef _WIN32 #include #else -#include #include -#include +#include + #include +#include #endif namespace tvm { @@ -48,12 +49,9 @@ class Pipe : public dmlc::Stream { using PipeHandle = int; #endif /*! \brief Construct a pipe from system handle. */ - explicit Pipe(int64_t handle) - : handle_(static_cast(handle)) {} + explicit Pipe(int64_t handle) : handle_(static_cast(handle)) {} /*! \brief destructor */ - ~Pipe() { - Flush(); - } + ~Pipe() { Flush(); } using Stream::Read; using Stream::Write; /*! @@ -62,18 +60,16 @@ class Pipe : public dmlc::Stream { * \param size block size * \return the size of data read */ - size_t Read(void *ptr, size_t size) final { + size_t Read(void* ptr, size_t size) final { if (size == 0) return 0; #ifdef _WIN32 DWORD nread; - CHECK(ReadFile(handle_, static_cast(ptr), - &nread, nullptr)) + CHECK(ReadFile(handle_, static_cast(ptr), &nread, nullptr)) << "Read Error: " << GetLastError(); #else ssize_t nread; nread = read(handle_, ptr, size); - CHECK_GE(nread, 0) - << "Write Error: " << strerror(errno); + CHECK_GE(nread, 0) << "Write Error: " << strerror(errno); #endif return static_cast(nread); } @@ -83,19 +79,17 @@ class Pipe : public dmlc::Stream { * \param size block size * \return the size of data read */ - void Write(const void *ptr, size_t size) final { + void Write(const void* ptr, size_t size) final { if (size == 0) return; #ifdef _WIN32 DWORD nwrite; - CHECK(WriteFile(handle_, static_cast(ptr), - &nwrite, nullptr) && + CHECK(WriteFile(handle_, static_cast(ptr), &nwrite, nullptr) && static_cast(nwrite) == size) << "Write Error: " << GetLastError(); #else ssize_t nwrite; nwrite = write(handle_, ptr, size); - CHECK_EQ(static_cast(nwrite), size) - << "Write Error: " << strerror(errno); + CHECK_EQ(static_cast(nwrite), size) << "Write Error: " << strerror(errno); #endif } /*! diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index d3227ad..a393849 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -24,9 +24,9 @@ #ifndef TVM_SUPPORT_RING_BUFFER_H_ #define TVM_SUPPORT_RING_BUFFER_H_ -#include -#include #include +#include +#include namespace tvm { namespace support { @@ -41,13 +41,9 @@ class RingBuffer { /*! \brief constructor */ RingBuffer() : ring_(kInitCapacity) {} /*! \return number of bytes available in buffer. */ - size_t bytes_available() const { - return bytes_available_; - } + size_t bytes_available() const { return bytes_available_; } /*! \return Current capacity of buffer. */ - size_t capacity() const { - return ring_.size(); - } + size_t capacity() const { return ring_.size(); } /*! * Reserve capacity to be at least n. * Will only increase capacity if n is bigger than current capacity. @@ -59,16 +55,15 @@ class RingBuffer { */ void Reserve(size_t n) { if (ring_.size() < n) { - size_t old_size = ring_.size(); - size_t new_size = static_cast(n * 1.2); - ring_.resize(new_size); - if (head_ptr_ + bytes_available_ > old_size) { - // copy the ring overflow part into the tail. - size_t ncopy = head_ptr_ + bytes_available_ - old_size; - memcpy(&ring_[0] + old_size, &ring_[0], ncopy); - } - } else if (ring_.size() > n * 8 && - ring_.size() > kInitCapacity) { + size_t old_size = ring_.size(); + size_t new_size = static_cast(n * 1.2); + ring_.resize(new_size); + if (head_ptr_ + bytes_available_ > old_size) { + // copy the ring overflow part into the tail. + size_t ncopy = head_ptr_ + bytes_available_ - old_size; + memcpy(&ring_[0] + old_size, &ring_[0], ncopy); + } + } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) { // shrink too large temporary buffer to // avoid out of memory on some embedded devices if (bytes_available_ != 0) { @@ -81,7 +76,7 @@ class RingBuffer { bytes_available_ = old_bytes; } // shrink the ring. - size_t new_size = kInitCapacity; + size_t new_size = kInitCapacity; new_size = std::max(new_size, n); new_size = std::max(new_size, bytes_available_); @@ -102,8 +97,7 @@ class RingBuffer { size_t ncopy = std::min(size, ring_.size() - head_ptr_); memcpy(data, &ring_[0] + head_ptr_, ncopy); if (ncopy < size) { - memcpy(reinterpret_cast(data) + ncopy, - &ring_[0], size - ncopy); + memcpy(reinterpret_cast(data) + ncopy, &ring_[0], size - ncopy); } head_ptr_ = (head_ptr_ + size) % ring_.size(); bytes_available_ -= size; @@ -115,7 +109,7 @@ class RingBuffer { * \param max_nbytes Maximum number of bytes can to read. * \tparam FSend A non-blocking function with signature size_t (const void* data, size_t size); */ - template + template size_t ReadWithCallback(FSend fsend, size_t max_nbytes) { size_t size = std::min(max_nbytes, bytes_available_); CHECK_NE(size, 0U); @@ -155,7 +149,7 @@ class RingBuffer { * \param max_nbytes Maximum number of bytes can write. * \tparam FRecv A non-blocking function with signature size_t (void* data, size_t size); */ - template + template size_t WriteWithCallback(FRecv frecv, size_t max_nbytes) { this->Reserve(bytes_available_ + max_nbytes); size_t nbytes = max_nbytes; diff --git a/src/support/socket.h b/src/support/socket.h index aeb4626..3ccfaaa 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -35,26 +35,27 @@ using ssize_t = int; #pragma comment(lib, "Ws2_32.lib") #endif #else +#include +#include #include #include -#include -#include -#include #include -#include -#include #include +#include +#include +#include #endif #include -#include + #include -#include +#include #include +#include + #include "../support/util.h" #if defined(_WIN32) -static inline int poll(struct pollfd *pfd, int nfds, - int timeout) { +static inline int poll(struct pollfd* pfd, int nfds, int timeout) { return WSAPoll(pfd, nfds, timeout); } #else @@ -68,7 +69,8 @@ namespace support { * \return The hostname. */ inline std::string GetHostName() { - std::string buf; buf.resize(256); + std::string buf; + buf.resize(256); CHECK_NE(gethostname(&buf[0], 256), -1); return std::string(buf.c_str()); } @@ -100,16 +102,14 @@ struct SockAddr { * \param url The url of the address * \param port The port of the address. */ - SockAddr(const char *url, int port) { - this->Set(url, port); - } + SockAddr(const char* url, int port) { this->Set(url, port); } /*! - * \brief SockAddr Get the socket address from tracker. - * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) - * \return SockAddr parsed from url. - */ - explicit SockAddr(const std::string &url) { + * \brief SockAddr Get the socket address from tracker. + * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) + * \return SockAddr parsed from url. + */ + explicit SockAddr(const std::string& url) { size_t sep = url.find(","); std::string host = url.substr(2, sep - 3); std::string port = url.substr(sep + 1, url.length() - 1); @@ -125,31 +125,28 @@ struct SockAddr { * \param host the url of the address * \param port the port of address */ - void Set(const char *host, int port) { + void Set(const char* host, int port) { addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = PF_UNSPEC; hints.ai_flags = AI_PASSIVE; hints.ai_socktype = SOCK_STREAM; - addrinfo *res = NULL; + addrinfo* res = NULL; int sig = getaddrinfo(host, NULL, &hints, &res); - CHECK(sig == 0 && res != NULL) - << "cannot obtain address of " << host; + CHECK(sig == 0 && res != NULL) << "cannot obtain address of " << host; switch (res->ai_family) { case AF_INET: { - sockaddr_in *addr4 = reinterpret_cast(&addr); - memcpy(addr4, res->ai_addr, res->ai_addrlen); - addr4->sin_port = htons(port); - addr4->sin_family = AF_INET; - } - break; + sockaddr_in* addr4 = reinterpret_cast(&addr); + memcpy(addr4, res->ai_addr, res->ai_addrlen); + addr4->sin_port = htons(port); + addr4->sin_family = AF_INET; + } break; case AF_INET6: { - sockaddr_in6 *addr6 = reinterpret_cast(&addr); - memcpy(addr6, res->ai_addr, res->ai_addrlen); - addr6->sin6_port = htons(port); - addr6->sin6_family = AF_INET6; - } - break; + sockaddr_in6* addr6 = reinterpret_cast(&addr); + memcpy(addr6, res->ai_addr, res->ai_addrlen); + addr6->sin6_port = htons(port); + addr6->sin6_family = AF_INET6; + } break; default: CHECK(false) << "cannot decode address"; } @@ -157,35 +154,34 @@ struct SockAddr { } /*! \brief return port of the address */ int port() const { - return ntohs((addr.ss_family == AF_INET6)? \ - reinterpret_cast(&addr)->sin6_port : \ - reinterpret_cast(&addr)->sin_port); + return ntohs((addr.ss_family == AF_INET6) + ? reinterpret_cast(&addr)->sin6_port + : reinterpret_cast(&addr)->sin_port); } /*! \brief return the ip address family */ - int ss_family() const { - return addr.ss_family; - } + int ss_family() const { return addr.ss_family; } /*! \return a string representation of the address */ std::string AsString() const { - std::string buf; buf.resize(256); + std::string buf; + buf.resize(256); - const void *sinx_addr = nullptr; - if (addr.ss_family == AF_INET6) { - const in6_addr& addr6 = reinterpret_cast(&addr)->sin6_addr; - sinx_addr = reinterpret_cast(&addr6); - } else if (addr.ss_family == AF_INET) { - const in_addr& addr4 = reinterpret_cast(&addr)->sin_addr; - sinx_addr = reinterpret_cast(&addr4); - } else { - CHECK(false) << "illegal address"; - } + const void* sinx_addr = nullptr; + if (addr.ss_family == AF_INET6) { + const in6_addr& addr6 = reinterpret_cast(&addr)->sin6_addr; + sinx_addr = reinterpret_cast(&addr6); + } else if (addr.ss_family == AF_INET) { + const in_addr& addr4 = reinterpret_cast(&addr)->sin_addr; + sinx_addr = reinterpret_cast(&addr4); + } else { + CHECK(false) << "illegal address"; + } #ifdef _WIN32 - const char *s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*) + const char* s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*) &buf[0], buf.length()); #else - const char *s = inet_ntop(addr.ss_family, sinx_addr, - &buf[0], static_cast(buf.length())); + const char* s = + inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast(buf.length())); #endif CHECK(s != nullptr) << "cannot decode address"; std::ostringstream os; @@ -238,10 +234,10 @@ class Socket { * \brief bind the socket to an address * \param addr The address to be binded */ - void Bind(const SockAddr &addr) { + void Bind(const SockAddr& addr) { if (bind(sockfd, reinterpret_cast(&addr.addr), - (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : - sizeof(sockaddr_in))) == -1) { + (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == + -1) { Socket::Error("Bind"); } } @@ -256,8 +252,8 @@ class Socket { for (int port = start_port; port < end_port; ++port) { SockAddr addr(host.c_str(), port); if (bind(sockfd, reinterpret_cast(&addr.addr), - (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : - sizeof(sockaddr_in))) == 0) { + (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == + 0) { return port; } else { LOG(WARNING) << "Bind failed to " << host << ":" << port; @@ -278,7 +274,7 @@ class Socket { int GetSockError() const { int error = 0; socklen_t len = sizeof(error); - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { Error("GetSockError"); } return error; @@ -291,9 +287,7 @@ class Socket { return false; } /*! \brief check if socket is already closed */ - bool IsClosed() const { - return sockfd == INVALID_SOCKET; - } + bool IsClosed() const { return sockfd == INVALID_SOCKET; } /*! \brief close the socket */ void Close() { if (sockfd != INVALID_SOCKET) { @@ -354,7 +348,7 @@ class Socket { * \brief Report an socket error. * \param msg The error message. */ - static void Error(const char *msg) { + static void Error(const char* msg) { int errsv = GetLastError(); #ifdef _WIN32 LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv; @@ -364,8 +358,7 @@ class Socket { } protected: - explicit Socket(SockType sockfd) : sockfd(sockfd) { - } + explicit Socket(SockType sockfd) : sockfd(sockfd) {} }; /*! @@ -373,22 +366,20 @@ class Socket { */ class TCPSocket : public Socket { public: - TCPSocket() : Socket(INVALID_SOCKET) { - } + TCPSocket() : Socket(INVALID_SOCKET) {} /*! * \brief construct a TCP socket from existing descriptor * \param sockfd The descriptor */ - explicit TCPSocket(SockType sockfd) : Socket(sockfd) { - } + explicit TCPSocket(SockType sockfd) : Socket(sockfd) {} /*! * \brief enable/disable TCP keepalive * \param keepalive whether to set the keep alive option on */ void SetKeepAlive(bool keepalive) { int opt = static_cast(keepalive); - if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, - reinterpret_cast(&opt), sizeof(opt)) < 0) { + if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&opt), sizeof(opt)) < + 0) { Socket::Error("SetKeepAlive"); } } @@ -406,9 +397,7 @@ class TCPSocket : public Socket { * \brief perform listen of the socket * \param backlog backlog parameter */ - void Listen(int backlog = 16) { - listen(sockfd, backlog); - } + void Listen(int backlog = 16) { listen(sockfd, backlog); } /*! * \brief get a new connection * \return The accepted socket connection. @@ -421,14 +410,13 @@ class TCPSocket : public Socket { return TCPSocket(newfd); } /*! - * \brief get a new connection - * \param addr client address from which connection accepted - * \return The accepted socket connection. - */ - TCPSocket Accept(SockAddr *addr) { + * \brief get a new connection + * \param addr client address from which connection accepted + * \return The accepted socket connection. + */ + TCPSocket Accept(SockAddr* addr) { socklen_t addrlen = sizeof(addr->addr); - SockType newfd = accept(sockfd, reinterpret_cast(&addr->addr), - &addrlen); + SockType newfd = accept(sockfd, reinterpret_cast(&addr->addr), &addrlen); if (newfd == INVALID_SOCKET) { Socket::Error("Accept"); } @@ -453,10 +441,10 @@ class TCPSocket : public Socket { * \param addr the address to connect to * \return whether connect is successful */ - bool Connect(const SockAddr &addr) { - return connect(sockfd, reinterpret_cast(&addr.addr), - (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : - sizeof(sockaddr_in))) == 0; + bool Connect(const SockAddr& addr) { + return connect( + sockfd, reinterpret_cast(&addr.addr), + (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0; } /*! * \brief send data using the socket @@ -466,8 +454,8 @@ class TCPSocket : public Socket { * \return size of data actually sent * return -1 if error occurs */ - ssize_t Send(const void *buf_, size_t len, int flag = 0) { - const char *buf = reinterpret_cast(buf_); + ssize_t Send(const void* buf_, size_t len, int flag = 0) { + const char* buf = reinterpret_cast(buf_); return send(sockfd, buf, static_cast(len), flag); } /*! @@ -478,8 +466,8 @@ class TCPSocket : public Socket { * \return size of data actually received * return -1 if error occurs */ - ssize_t Recv(void *buf_, size_t len, int flags = 0) { - char *buf = reinterpret_cast(buf_); + ssize_t Recv(void* buf_, size_t len, int flags = 0) { + char* buf = reinterpret_cast(buf_); return recv(sockfd, buf, static_cast(len), flags); } /*! @@ -489,10 +477,10 @@ class TCPSocket : public Socket { * \param len the size of the buffer * \return size of data actually sent */ - size_t SendAll(const void *buf_, size_t len) { - const char *buf = reinterpret_cast(buf_); + size_t SendAll(const void* buf_, size_t len) { + const char* buf = reinterpret_cast(buf_); size_t ndone = 0; - while (ndone < len) { + while (ndone < len) { ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); if (ret == -1) { if (LastErrorWouldBlock()) return ndone; @@ -510,14 +498,13 @@ class TCPSocket : public Socket { * \param len length of data to recv * \return size of data actually sent */ - size_t RecvAll(void *buf_, size_t len) { - char *buf = reinterpret_cast(buf_); + size_t RecvAll(void* buf_, size_t len) { + char* buf = reinterpret_cast(buf_); size_t ndone = 0; - while (ndone < len) { - ssize_t ret = recv(sockfd, buf, - static_cast(len - ndone), MSG_WAITALL); + while (ndone < len) { + ssize_t ret = recv(sockfd, buf, static_cast(len - ndone), MSG_WAITALL); if (ret == -1) { - if (LastErrorWouldBlock()) { + if (LastErrorWouldBlock()) { LOG(FATAL) << "would block"; return ndone; } @@ -612,7 +599,7 @@ struct PollHelper { * \param timeout the timeout counter, can be negative, which means wait until the event happen * \return 1 if success, 0 if timeout, and -1 if error occurs */ - inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) + inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) pollfd pfd; pfd.fd = fd; pfd.events = POLLPRI; diff --git a/src/support/str_escape.h b/src/support/str_escape.h index fd25c01..65eec68 100644 --- a/src/support/str_escape.h +++ b/src/support/str_escape.h @@ -25,8 +25,8 @@ #ifndef TVM_SUPPORT_STR_ESCAPE_H_ #define TVM_SUPPORT_STR_ESCAPE_H_ -#include #include +#include namespace tvm { namespace support { @@ -76,9 +76,7 @@ inline std::string StrEscape(const char* data, size_t size) { * \param size The size of the string. * \return the Result string. */ -inline std::string StrEscape(const std::string& val) { - return StrEscape(val.data(), val.length()); -} +inline std::string StrEscape(const std::string& val) { return StrEscape(val.data(), val.length()); } } // namespace support } // namespace tvm diff --git a/src/support/util.h b/src/support/util.h index 9a477e6..859b372 100644 --- a/src/support/util.h +++ b/src/support/util.h @@ -26,16 +26,16 @@ #include #ifndef _WIN32 -#include #include +#include #endif -#include -#include -#include #include #include #include #include +#include +#include +#include namespace tvm { namespace support { @@ -92,15 +92,14 @@ inline int TVMWexitstatus(int status) { #endif } - /*! * \brief IsNumber check whether string is a number. * \param str input string * \return result of operation. */ inline bool IsNumber(const std::string& str) { - return !str.empty() && std::find_if(str.begin(), - str.end(), [](char c) { return !std::isdigit(c); }) == str.end(); + return !str.empty() && + std::find_if(str.begin(), str.end(), [](char c) { return !std::isdigit(c); }) == str.end(); } /*! diff --git a/src/target/build_common.h b/src/target/build_common.h index 93687c2..ec5b522 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -24,27 +24,27 @@ #ifndef TVM_TARGET_BUILD_COMMON_H_ #define TVM_TARGET_BUILD_COMMON_H_ -#include -#include -#include #include -#include +#include +#include +#include #include +#include #include -#include + #include +#include + #include "../runtime/meta_data.h" namespace tvm { namespace codegen { -inline std::unordered_map -ExtractFuncInfo(const IRModule& mod) { +inline std::unordered_map ExtractFuncInfo(const IRModule& mod) { std::unordered_map fmap; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); runtime::FunctionInfo info; diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 848d27f..e3890ca 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -21,23 +21,22 @@ * \file codegen.cc * \brief Common utilities to generated C style code. */ +#include +#include +#include +#include +#include +#include #include #include - -#include -#include #include +#include -#include -#include -#include -#include -#include -#include -#include #include -#include #include +#include +#include +#include namespace tvm { namespace codegen { @@ -50,17 +49,14 @@ runtime::Module Build(IRModule mod, const Target& target) { std::string build_f_name = "target.build." + target->target_name; // the build function. const PackedFunc* bf = runtime::Registry::Get(build_f_name); - CHECK(bf != nullptr) - << "target.build." << target << " is not enabled"; + CHECK(bf != nullptr) << "target.build." << target << " is not enabled"; return (*bf)(mod, target->str()); } /*! \brief Helper class to serialize module */ class ModuleSerializer { public: - explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { - Init(); - } + explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); } void SerializeModule(dmlc::Stream* stream) { // Only have one DSO module and it is in the root, then @@ -109,8 +105,8 @@ class ModuleSerializer { // invariance: root module is always at location 0. // The module order is collected via DFS void CreateModuleIndex() { - std::unordered_set visited {mod_.operator->()}; - std::vector stack {mod_.operator->()}; + std::unordered_set visited{mod_.operator->()}; + std::vector stack{mod_.operator->()}; uint64_t module_index = 0; while (!stack.empty()) { @@ -139,8 +135,7 @@ class ModuleSerializer { } bool DSOExportable(const runtime::ModuleNode* mod) { - return !std::strcmp(mod->type_key(), "llvm") || - !std::strcmp(mod->type_key(), "c"); + return !std::strcmp(mod->type_key(), "llvm") || !std::strcmp(mod->type_key(), "c"); } runtime::Module mod_; @@ -148,21 +143,21 @@ class ModuleSerializer { std::unordered_map mod2index_; // index -> module std::vector mod_vec_; - std::vector import_tree_row_ptr_ {0}; + std::vector import_tree_row_ptr_{0}; std::vector import_tree_child_indices_; }; namespace { - std::string SerializeModule(const runtime::Module& mod) { - std::string bin; - dmlc::MemoryStringStream ms(&bin); - dmlc::Stream* stream = &ms; +std::string SerializeModule(const runtime::Module& mod) { + std::string bin; + dmlc::MemoryStringStream ms(&bin); + dmlc::Stream* stream = &ms; - ModuleSerializer module_serializer(mod); - module_serializer.SerializeModule(stream); + ModuleSerializer module_serializer(mod); + module_serializer.SerializeModule(stream); - return bin; - } + return bin; +} } // namespace std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { @@ -180,8 +175,8 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { << "#endif\n"; os << "TVM_EXPORT extern const unsigned char " << runtime::symbol::tvm_dev_mblob << "[];\n"; uint64_t nbytes = bin.length(); - os << "const unsigned char " << runtime::symbol::tvm_dev_mblob - << "[" << bin.length() + sizeof(nbytes) << "] = {\n "; + os << "const unsigned char " << runtime::symbol::tvm_dev_mblob << "[" + << bin.length() + sizeof(nbytes) << "] = {\n "; os << std::hex; size_t nunit = 80 / 4; for (size_t i = 0; i < sizeof(nbytes); ++i) { @@ -214,8 +209,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { return os.str(); } -runtime::Module PackImportsToLLVM(const runtime::Module& mod, - bool system_lib, +runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, const std::string& target_triple) { std::string bin = SerializeModule(mod); @@ -233,19 +227,16 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, std::string codegen_f_name = "codegen.codegen_blob"; // the codegen function. const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name); - CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented."; + CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented."; return (*codegen_f)(blob_byte_array, system_lib, target_triple); } -TVM_REGISTER_GLOBAL("target.Build") -.set_body_typed(Build); +TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build); // Export two auxiliary function to the runtime namespace. -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC") -.set_body_typed(PackImportsToC); +TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM") -.set_body_typed(PackImportsToLLVM); +TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM); } // namespace codegen } // namespace tvm diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index c16182d..99d6bee 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -16,34 +16,32 @@ * specific language governing permissions and limitations * under the License. */ -#include #include "registry.h" +#include + namespace tvm { namespace datatype { using runtime::TVMArgs; using runtime::TVMRetValue; -TVM_REGISTER_GLOBAL("runtime._datatype_register") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("runtime._datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) { datatype::Registry::Global()->Register(args[0], static_cast(args[1].operator int())); }); -TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = datatype::Registry::Global()->GetTypeCode(args[0]); }); -TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = Registry::Global()->GetTypeName(args[0].operator int()); }); TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); + }); Registry* Registry::Global() { static Registry inst; diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h index 919409f..c043592 100644 --- a/src/target/datatype/registry.h +++ b/src/target/datatype/registry.h @@ -22,6 +22,7 @@ #include #include + #include #include @@ -69,7 +70,7 @@ class Registry { * \param type_name The type name * \return The type code */ - uint8_t GetTypeCode(const std::string &type_name); + uint8_t GetTypeCode(const std::string& type_name); /*! * \brief Get type name from type code diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 44d017f..9ad9f56 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -20,14 +20,12 @@ * \file src/target/generic_func.cc */ #include - -#include -#include #include #include -#include -#include +#include #include +#include +#include #include #include @@ -43,8 +41,7 @@ struct GenericFunc::Manager { // mutex std::mutex mutex; - Manager() { - } + Manager() {} static Manager* Global() { static Manager inst; @@ -76,25 +73,23 @@ void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) m->fmap[name] = func; } -GenericFunc& GenericFunc::set_default(const PackedFunc value, - bool allow_override) { +GenericFunc& GenericFunc::set_default(const PackedFunc value, bool allow_override) { auto node = static_cast(operator->()); if (!allow_override) { CHECK(node->generic_func_ == nullptr) - << "Generic function already registered for " << node->name_; + << "Generic function already registered for " << node->name_; } node->generic_func_ = value; return *this; } GenericFunc& GenericFunc::register_func(const std::vector& tags, - const PackedFunc value, - bool allow_override) { - for (auto &t : tags) { + const PackedFunc value, bool allow_override) { + for (auto& t : tags) { if (!allow_override) { auto iter = (*this)->dispatch_dict_.find(t); CHECK(iter == (*this)->dispatch_dict_.end()) - << "Tag " << t << " already registered for schedule factory " << (*this)->name_; + << "Tag " << t << " already registered for schedule factory " << (*this)->name_; } (*this)->dispatch_dict_[t] = value; } @@ -107,7 +102,7 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { PackedFunc func; if (target.defined()) { - for (auto &k : target->keys()) { + for (auto& k : target->keys()) { auto iter = node->dispatch_dict_.find(k); if (iter != node->dispatch_dict_.end()) { func = iter->second; @@ -124,30 +119,25 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { func.CallPacked(args, ret); } -TVM_REGISTER_GLOBAL("target.GenericFuncCreate") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncCreate").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = GenericFunc(make_object()); - }); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal").set_body([](TVMArgs args, TVMRetValue* ret) { std::string func_name = args[0]; *ret = GenericFunc::Get(func_name); - }); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); bool allow_override = args[2]; - generic_func - .set_default(*func, allow_override); - }); + generic_func.set_default(*func, allow_override); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); @@ -159,17 +149,14 @@ TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") tags_vector.push_back(tag); } - generic_func - .register_func(tags_vector, *func, allow_override); - }); + generic_func.register_func(tags_vector, *func, allow_override); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1); - generic_func - .CallPacked(func_args, ret); - }); + generic_func.CallPacked(func_args, ret); +}); } // namespace tvm diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index b95974f..37855fb 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -21,123 +21,99 @@ * \file intrin_rule_default.cc * \brief Default intrinsic rules. */ -#include #include "intrin_rule.h" +#include + namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - *rv = one / sqrt(call->args[0]); - }); + auto one = make_const(call->args[0].dtype(), 1); + *rv = one / sqrt(call->args[0]); + }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - *rv = one / (one + exp(-call->args[0])); - }); + auto one = make_const(call->args[0].dtype(), 1); + *rv = one / (one + exp(-call->args[0])); + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); - *rv = isfinite(call->args[0]); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + *rv = isfinite(call->args[0]); + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); - *rv = isinf(call->args[0]); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + *rv = isinf(call->args[0]); + }); } // namespace intrin } // namespace codegen diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 0914742..8a5a440 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -24,9 +24,9 @@ #ifndef TVM_TARGET_INTRIN_RULE_H_ #define TVM_TARGET_INTRIN_RULE_H_ -#include -#include #include +#include + #include namespace tvm { @@ -49,21 +49,18 @@ struct FloatSuffix { // Return the intrinsic name struct Direct { - std::string operator()(DataType t, std::string name) const { - return name; - } + std::string operator()(DataType t, std::string name) const { return name; } }; // Call pure extern function. -template +template inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); std::string name = T()(call->dtype, call->name); if (name.length() != 0) { - *rv = CallNode::make( - call->dtype, name, call->args, CallNode::PureExtern); + *rv = CallNode::make(call->dtype, name, call->args, CallNode::PureExtern); } else { *rv = e; } diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 8935809..280c999 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -23,12 +23,13 @@ */ #ifdef TVM_LLVM_VERSION -#include #include +#include #include -#include "codegen_llvm.h" -#include "../build_common.h" + #include "../../runtime/rocm/rocm_module.h" +#include "../build_common.h" +#include "codegen_llvm.h" namespace tvm { namespace codegen { @@ -45,8 +46,8 @@ static inline int DetectROCMmaxThreadsPerBlock() { TVMRetValue val; api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { - tvm::runtime::DeviceAPI::Get(tvm_ctx)-> - GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, + &val); return val.operator int(); } } @@ -73,8 +74,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* buf = nullptr; int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { @@ -88,9 +88,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 alloca->setAlignment(llvm::Align(info.alignment)); @@ -104,12 +103,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get( - DTypeToLLVMType(op->dtype), constant_size); + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", - nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); #else @@ -119,8 +117,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } buf = builder_->CreatePointerCast( - buf, DTypeToLLVMType(op->dtype)->getPointerTo( - buf->getType()->getPointerAddressSpace())); + buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -132,18 +129,32 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; if (ts.rank == 1) { switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break; - case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break; - case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break; - default: LOG(FATAL) << "unknown workitem idx"; + case 0: + intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; + break; + default: + LOG(FATAL) << "unknown workitem idx"; } } else { CHECK_EQ(ts.rank, 0); switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break; - case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break; - case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break; - default: LOG(FATAL) << "unknown workgroup idx"; + case 0: + intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; + break; + default: + LOG(FATAL) << "unknown workgroup idx"; } } llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); @@ -155,9 +166,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { if (sync == "warp") { return nullptr; } else if (sync == "shared") { - llvm::Function* f = llvm::Intrinsic::getDeclaration( - module_.get(), - ::llvm::Intrinsic::amdgcn_s_barrier); + llvm::Function* f = + llvm::Intrinsic::getDeclaration(module_.get(), ::llvm::Intrinsic::amdgcn_s_barrier); return builder_->CreateCall(f, {}); } else { LOG(FATAL) << "Do not support sync " << sync; @@ -169,9 +179,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { // Additional optimization hook to tweak the builder. } - unsigned GetGlobalAddressSpace() const final { - return 1; - } + unsigned GetGlobalAddressSpace() const final { return 1; } protected: void InitTarget(llvm::TargetMachine* tm) final { @@ -211,13 +219,10 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { // issue #4087 for a discussion #endif InitializeLLVM(); - CHECK(target.length() >= 4 && - target.substr(0, 4) == "rocm"); + CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm"); std::ostringstream config; - config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" - << DetectROCMComputeVersion(target) - << " -mattr=-code-object-v3 " - << target.substr(4, target.length() - 4); + config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target) + << " -mattr=-code-object-v3 " << target.substr(4, target.length() - 4); std::unique_ptr tm = GetLLVMTargetMachine(config.str()); std::unique_ptr ctx(new llvm::LLVMContext()); // careful: cg will hold a naked pointer reference to ctx, so it should @@ -226,18 +231,16 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); cg->AddFunction(f); } - const auto *find_rocm_bitcodes = - tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); + const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); Array bitcode_files = (*find_rocm_bitcodes)(); - for (auto &bitcode_path : bitcode_files) { + for (auto& bitcode_path : bitcode_files) { std::string path = bitcode_path; llvm::SMDiagnostic err; std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); @@ -248,7 +251,7 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { } mlib->setTargetTriple(tm->getTargetTriple().str()); mlib->setDataLayout(tm->createDataLayout()); - for (llvm::Function &f : mlib->functions()) { + for (llvm::Function& f : mlib->functions()) { f.addFnAttr(llvm::Attribute::AlwaysInline); } cg->AddLinkModule(std::move(mlib)); @@ -271,33 +274,28 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 - CHECK(tm->addPassesToEmitFile( - pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + CHECK(tm->addPassesToEmitFile(pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm->addPassesToEmitFile( - pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + CHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #else - CHECK(tm->addPassesToEmitFile( - pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + CHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*mObj); std::string obj(dataObj.begin(), dataObj.end()); llvm::legacy::PassManager passAsm; #if TVM_LLVM_VERSION <= 60 - CHECK(tm->addPassesToEmitFile(passAsm, destAsm, - llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(passAsm, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #else - CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, - llvm::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif passAsm.run(*mAsm); @@ -315,8 +313,7 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly); } -TVM_REGISTER_GLOBAL("target.build.rocm") -.set_body_typed(BuildAMDGPU); +TVM_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 73d849a..ba45115 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -47,8 +47,7 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { - llvm::Intrinsic::ID id = static_cast( - Downcast(op->args[0])->value); + llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); @@ -57,21 +56,21 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { return CodeGenCPU::CreateIntrinsic(op); } -PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { +PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { using namespace tir; const PrimExpr& e = call->args[2]; ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop; ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu; // Fallback to default llvm lowering rule if input type not a full vector or half vector length - int total_size = call->dtype.bits() * call->dtype.lanes(); + int total_size = call->dtype.bits() * call->dtype.lanes(); if (!call->dtype.is_vector() || call->dtype.bits() == 8 || - (total_size != 128 && total_size != 64)) { + (total_size != 128 && total_size != 64)) { Array vcnt_args; vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); + return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } // Popcount lowering rule: @@ -80,12 +79,11 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // to return back to original input type // Dvisions are always divisible (number of bits = 64 or 128) - DataType uint8_type = DataType( - e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8); - DataType uint16_type = DataType( - uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); - DataType uint32_type = DataType( - uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); + DataType uint8_type = DataType(e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8); + DataType uint16_type = + DataType(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); + DataType uint32_type = + DataType(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); // Interpret input as vector of 8bit values PrimExpr input8 = reinterpret(uint8_type, e); @@ -96,16 +94,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = tir::CallNode::make( - uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); + PrimExpr vcnt8 = + tir::CallNode::make(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = tir::CallNode::make( - uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); + PrimExpr vcnt16 = + tir::CallNode::make(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 16) { return vcnt16; } @@ -115,8 +113,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = tir::CallNode::make( - uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); + PrimExpr vcnt32 = + tir::CallNode::make(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 32) { return vcnt32; } @@ -126,15 +124,14 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return tir::CallNode::make( - call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); + return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - CodeGenLLVM* cg = new CodeGenARM(); - *rv = static_cast(cg); - }); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenARM(); + *rv = static_cast(cg); + }); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index be8ef92..b7c48c7 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -21,17 +21,17 @@ * \file codegen_blob.cc */ #ifdef TVM_LLVM_VERSION +#include "codegen_blob.h" + #include + #include -#include "codegen_blob.h" namespace tvm { namespace codegen { -std::pair, - std::shared_ptr> CodeGenBlob(const std::string& data, - bool system_lib, - const std::string& target_triple) { +std::pair, std::shared_ptr> CodeGenBlob( + const std::string& data, bool system_lib, const std::string& target_triple) { InitializeLLVM(); auto tm = GetLLVMTargetMachine(std::string("-target ") + target_triple); auto triple = tm->getTargetTriple(); @@ -41,10 +41,9 @@ std::pair, module->setTargetTriple(triple.str()); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); - auto* tvm_dev_mblob = new llvm::GlobalVariable(*module, blob_value->getType(), true, - llvm::GlobalValue::ExternalLinkage, blob_value, - runtime::symbol::tvm_dev_mblob, nullptr, - llvm::GlobalVariable::NotThreadLocal, 0); + auto* tvm_dev_mblob = new llvm::GlobalVariable( + *module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value, + runtime::symbol::tvm_dev_mblob, nullptr, llvm::GlobalVariable::NotThreadLocal, 0); #if TVM_LLVM_VERSION >= 100 tvm_dev_mblob->setAlignment(llvm::Align(1)); @@ -64,11 +63,9 @@ std::pair, auto int8_ptr_ty = int8_ty->getPointerTo(0); llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty); - auto* tvm_dev_mblob_reg = - new llvm::GlobalVariable(*module, int32_ty, - false, llvm::GlobalValue::InternalLinkage, - constant_zero, - std::string(runtime::symbol::tvm_dev_mblob) + "_reg_"); + auto* tvm_dev_mblob_reg = new llvm::GlobalVariable( + *module, int32_ty, false, llvm::GlobalValue::InternalLinkage, constant_zero, + std::string(runtime::symbol::tvm_dev_mblob) + "_reg_"); auto tvm_dev_mblob_reg_alignment = module->getDataLayout().getABITypeAlignment(int32_ty); #if TVM_LLVM_VERSION >= 100 tvm_dev_mblob_reg->setAlignment(llvm::Align(tvm_dev_mblob_reg_alignment)); @@ -80,11 +77,9 @@ std::pair, llvm::ArrayType::get(int8_ty, std::strlen(runtime::symbol::tvm_dev_mblob) + 1); auto* tvm_dev_mblob_string_value = llvm::ConstantDataArray::getString(*ctx, runtime::symbol::tvm_dev_mblob, true); - auto* tvm_dev_mblob_string = - new llvm::GlobalVariable(*module, tvm_dev_mblob_string_ty, - true, llvm::GlobalValue::PrivateLinkage, - tvm_dev_mblob_string_value, - std::string(runtime::symbol::tvm_dev_mblob) + ".str"); + auto* tvm_dev_mblob_string = new llvm::GlobalVariable( + *module, tvm_dev_mblob_string_ty, true, llvm::GlobalValue::PrivateLinkage, + tvm_dev_mblob_string_value, std::string(runtime::symbol::tvm_dev_mblob) + ".str"); #if TVM_LLVM_VERSION >= 100 tvm_dev_mblob_string->setAlignment(llvm::Align(1)); #else @@ -92,33 +87,30 @@ std::pair, #endif // Global init function - llvm::Function* init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false), - llvm::GlobalValue::InternalLinkage, - llvm::Twine("_GLOBAL__sub_I_", module_name), - module.get()); + llvm::Function* init_fn = llvm::Function::Create( + llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage, + llvm::Twine("_GLOBAL__sub_I_", module_name), module.get()); // Create variable initialization function. - llvm::Function* var_init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false), - llvm::GlobalValue::InternalLinkage, - llvm::Twine("__cxx_global_var_init"), - module.get()); + llvm::Function* var_init_fn = llvm::Function::Create( + llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage, + llvm::Twine("__cxx_global_var_init"), module.get()); // Create TVMBackendRegisterSystemLibSymbol function llvm::Function* tvm_backend_fn = llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false), llvm::GlobalValue::ExternalLinkage, - llvm::Twine("TVMBackendRegisterSystemLibSymbol"), - module.get()); + llvm::Twine("TVMBackendRegisterSystemLibSymbol"), module.get()); // Set necessary fn sections auto get_static_init_section_specifier = [&triple]() -> std::string { - if (triple.isOSLinux()) { - return ".text.startup"; - } else if (triple.isOSDarwin()) { - return "__TEXT,__StaticInit,regular,pure_instructions"; - } else { - return ""; - } + if (triple.isOSLinux()) { + return ".text.startup"; + } else if (triple.isOSDarwin()) { + return "__TEXT,__StaticInit,regular,pure_instructions"; + } else { + return ""; + } }; auto static_init_section_specifier = get_static_init_section_specifier(); @@ -144,11 +136,9 @@ std::pair, llvm::Constant* indices[] = {constant_zero, constant_zero}; llvm::SmallVector args; args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_dev_mblob_string_ty, - tvm_dev_mblob_string, - indices)); - args.push_back(llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), - tvm_dev_mblob, - indices)); + tvm_dev_mblob_string, indices)); + args.push_back( + llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), tvm_dev_mblob, indices)); auto* tvm_backend_fn_ret_value = ir_builder.CreateCall(tvm_backend_fn, args); ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_dev_mblob_reg); ir_builder.CreateRetVoid(); diff --git a/src/target/llvm/codegen_blob.h b/src/target/llvm/codegen_blob.h index a394f77..2821f44 100644 --- a/src/target/llvm/codegen_blob.h +++ b/src/target/llvm/codegen_blob.h @@ -24,9 +24,10 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_BLOB_H_ #define TVM_TARGET_LLVM_CODEGEN_BLOB_H_ #ifdef TVM_LLVM_VERSION -#include #include #include +#include + #include "llvm_common.h" namespace tvm { @@ -40,10 +41,8 @@ namespace codegen { * * \return LLVM module and LLVM context */ -std::pair, - std::shared_ptr> CodeGenBlob(const std::string& data, - bool system_lib, - const std::string& target_triple); +std::pair, std::shared_ptr> CodeGenBlob( + const std::string& data, bool system_lib, const std::string& target_triple); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index e474b9c..03b5496 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -22,20 +22,19 @@ */ #ifdef TVM_LLVM_VERSION +#include "codegen_cpu.h" + #include #include + #include #include -#include "codegen_cpu.h" namespace tvm { namespace codegen { -void CodeGenCPU::Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup) { +void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, + llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) { CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup); dbg_info_ = CreateDebugInfo(module_.get()); static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); @@ -46,53 +45,34 @@ void CodeGenCPU::Init(const std::string& module_name, t_tvm_context_ = llvm::StructType::create({t_int_, t_int_}); t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}); t_tvm_func_handle_ = t_void_p_; - t_tvm_array_ = llvm::StructType::create( - {t_void_p_, - t_tvm_context_, - t_int_, - t_tvm_type_, - t_tvm_shape_index_->getPointerTo(), - t_tvm_shape_index_->getPointerTo(), - t_int64_}); + t_tvm_array_ = llvm::StructType::create({t_void_p_, t_tvm_context_, t_int_, t_tvm_type_, + t_tvm_shape_index_->getPointerTo(), + t_tvm_shape_index_->getPointerTo(), t_int64_}); t_tvm_value_ = llvm::StructType::create({t_float64_}); - t_tvm_parallel_group_env_ = llvm::StructType::create({ - t_int32_->getPointerTo(), t_int32_}); + t_tvm_parallel_group_env_ = llvm::StructType::create({t_int32_->getPointerTo(), t_int32_}); ftype_tvm_parallel_lambda_ = llvm::FunctionType::get( - t_int_, - {t_int_, - t_tvm_parallel_group_env_->getPointerTo(), - t_void_p_}, false); + t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo(), t_void_p_}, false); md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_); // Runtime functions. - ftype_tvm_func_call_ = llvm::FunctionType::get(t_int_, { - t_tvm_func_handle_, - t_tvm_value_->getPointerTo(), - t_int_->getPointerTo(), + ftype_tvm_func_call_ = llvm::FunctionType::get( t_int_, - t_tvm_value_->getPointerTo(), - t_int_->getPointerTo()}, false); - ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(t_int_, { - t_void_p_, - t_char_->getPointerTo(), - t_tvm_func_handle_->getPointerTo()}, false); - ftype_tvm_api_set_last_error_ = llvm::FunctionType::get( - t_void_, {t_char_->getPointerTo()}, false); - ftype_tvm_parallel_launch_ = - llvm::FunctionType::get(t_int_, { - ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_} - , false); + {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, + t_tvm_value_->getPointerTo(), t_int_->getPointerTo()}, + false); + ftype_tvm_get_func_from_env_ = llvm::FunctionType::get( + t_int_, {t_void_p_, t_char_->getPointerTo(), t_tvm_func_handle_->getPointerTo()}, false); + ftype_tvm_api_set_last_error_ = + llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false); + ftype_tvm_parallel_launch_ = llvm::FunctionType::get( + t_int_, {ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}, false); ftype_tvm_parallel_barrier_ = - llvm::FunctionType::get(t_int_, { - t_int_, t_tvm_parallel_group_env_->getPointerTo()} - , false); - ftype_tvm_static_init_callback_ = - llvm::FunctionType::get(t_int_, {t_void_p_}, false); + llvm::FunctionType::get(t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo()}, false); + ftype_tvm_static_init_callback_ = llvm::FunctionType::get(t_int_, {t_void_p_}, false); ftype_tvm_static_init_ = - llvm::FunctionType::get(t_int_, { - t_void_p_->getPointerTo(), - ftype_tvm_static_init_callback_->getPointerTo(), - t_void_p_, t_int_} - , false); + llvm::FunctionType::get(t_int_, + {t_void_p_->getPointerTo(), + ftype_tvm_static_init_callback_->getPointerTo(), t_void_p_, t_int_}, + false); // initialize TVM runtime API if (system_lib) { // We will need this in environment for backward registration. @@ -103,21 +83,20 @@ void CodeGenCPU::Init(const std::string& module_name, f_tvm_register_system_symbol_ = nullptr; } if (dynamic_lookup || system_lib) { - f_tvm_func_call_ = llvm::Function::Create( - ftype_tvm_func_call_, - llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get()); - f_tvm_get_func_from_env_ = llvm::Function::Create( - ftype_tvm_get_func_from_env_, - llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get()); - f_tvm_api_set_last_error_ = llvm::Function::Create( - ftype_tvm_api_set_last_error_, - llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); - f_tvm_parallel_launch_ = llvm::Function::Create( - ftype_tvm_parallel_launch_, - llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get()); - f_tvm_parallel_barrier_ = llvm::Function::Create( - ftype_tvm_parallel_barrier_, - llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get()); + f_tvm_func_call_ = llvm::Function::Create(ftype_tvm_func_call_, llvm::Function::ExternalLinkage, + "TVMFuncCall", module_.get()); + f_tvm_get_func_from_env_ = + llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage, + "TVMBackendGetFuncFromEnv", module_.get()); + f_tvm_api_set_last_error_ = + llvm::Function::Create(ftype_tvm_api_set_last_error_, llvm::Function::ExternalLinkage, + "TVMAPISetLastError", module_.get()); + f_tvm_parallel_launch_ = + llvm::Function::Create(ftype_tvm_parallel_launch_, llvm::Function::ExternalLinkage, + "TVMBackendParallelLaunch", module_.get()); + f_tvm_parallel_barrier_ = + llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage, + "TVMBackendParallelBarrier", module_.get()); } this->InitGlobalContext(dynamic_lookup); } @@ -152,22 +131,13 @@ void CodeGenCPU::AddDebugInformation(llvm::Function* function) { #if TVM_LLVM_VERSION >= 80 auto* DIFunction = dbg_info_->di_builder_->createFunction( - dbg_info_->file_, function->getName(), "", - dbg_info_->file_, - 0 /* line number */, - DIFunctionTy, - false /* internal linkage */); + dbg_info_->file_, function->getName(), "", dbg_info_->file_, 0 /* line number */, + DIFunctionTy, false /* internal linkage */); #else auto* DIFunction = dbg_info_->di_builder_->createFunction( - dbg_info_->file_, function->getName(), "", - dbg_info_->file_, - 0 /* line number */, - DIFunctionTy, - false, /* internal linkage */ - true, - 0 /* line number */, - llvm::DINode::FlagPrototyped, - true /* isOptimized */); + dbg_info_->file_, function->getName(), "", dbg_info_->file_, 0 /* line number */, + DIFunctionTy, false, /* internal linkage */ + true, 0 /* line number */, llvm::DINode::FlagPrototyped, true /* isOptimized */); #endif CHECK(DIFunction); @@ -236,9 +206,8 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { llvm::Function* f = module_->getFunction(entry_func_name); CHECK(f) << "Function " << entry_func_name << "does not in module"; llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1); - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0, - runtime::symbol::tvm_module_main); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0, runtime::symbol::tvm_module_main); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else @@ -254,8 +223,8 @@ std::unique_ptr CodeGenCPU::Finish() { } return CodeGenLLVM::Finish(); } -llvm::Value* CodeGenCPU::CreateStructRefPtr( - DataType t, llvm::Value* buf, llvm::Value* index, int kind) { +llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, + int kind) { if (kind < intrinsic::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); @@ -280,27 +249,22 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr( return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)}); } case intrinsic::kArrTypeCode: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(3), ConstInt32(0)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)}); } case intrinsic::kArrTypeBits: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(3), ConstInt32(1)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)}); } case intrinsic::kArrTypeLanes: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(3), ConstInt32(2)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)}); } case intrinsic::kArrByteOffset: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)}); } case intrinsic::kArrDeviceId: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(1), ConstInt32(1)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)}); } case intrinsic::kArrDeviceType: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(1), ConstInt32(0)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)}); } case intrinsic::kTVMValueContent: { CHECK_EQ(t.lanes(), 1); @@ -318,7 +282,9 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr( return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()); } } - default: LOG(FATAL) << "unknown field code"; return nullptr; + default: + LOG(FATAL) << "unknown field code"; + return nullptr; } } @@ -331,8 +297,8 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { for (llvm::Value* v : arg_values) { arg_types.push_back(v->getType()); } - llvm::FunctionType* ftype = llvm::FunctionType::get( - GetLLVMType(GetRef(op)), arg_types, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_types, false); // Check if it is available in global function table as injected function. auto it = gv_func_map_.find(op->name); if (it != gv_func_map_.end()) { @@ -349,8 +315,7 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { } else { llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { - f = llvm::Function::Create( - ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); + f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(f); @@ -361,12 +326,9 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { } } -llvm::GlobalVariable* CodeGenCPU::InitContextPtr( - llvm::Type* p_type, std::string name) { +llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) { llvm::GlobalVariable* gv = new llvm::GlobalVariable( - *module_, p_type, false, - llvm::GlobalValue::LinkOnceAnyLinkage, 0, - name); + *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, 0, name); #if TVM_LLVM_VERSION >= 100 gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(p_type))); #else @@ -384,9 +346,8 @@ llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { #else llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); #endif - faddr->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); + faddr->setMetadata("tbaa", + md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); return faddr; } @@ -399,16 +360,15 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_)); } else { if (!dynamic_lookup) { - gv_tvm_func_call_ = InitContextPtr( - ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall"); - gv_tvm_get_func_from_env_ = InitContextPtr( - ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv"); - gv_tvm_api_set_last_error_ = InitContextPtr( - ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError"); - gv_tvm_parallel_launch_ = InitContextPtr( - ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch"); - gv_tvm_parallel_barrier_ = InitContextPtr( - ftype_tvm_parallel_barrier_->getPointerTo(), "__TVMBackendParallelBarrier"); + gv_tvm_func_call_ = InitContextPtr(ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall"); + gv_tvm_get_func_from_env_ = InitContextPtr(ftype_tvm_get_func_from_env_->getPointerTo(), + "__TVMBackendGetFuncFromEnv"); + gv_tvm_api_set_last_error_ = + InitContextPtr(ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError"); + gv_tvm_parallel_launch_ = + InitContextPtr(ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch"); + gv_tvm_parallel_barrier_ = InitContextPtr(ftype_tvm_parallel_barrier_->getPointerTo(), + "__TVMBackendParallelBarrier"); // Mark as context functions gv_func_map_["TVMBackendAllocWorkspace"] = nullptr; gv_func_map_["TVMBackendFreeWorkspace"] = nullptr; @@ -419,12 +379,9 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) { // create emit codes that checks and load the function. using llvm::BasicBlock; - BasicBlock* fail_block = BasicBlock::Create( - *ctx_, "call_fail", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "call_end", function_); - llvm::Value* succ = builder_->CreateICmpEQ( - retcode, llvm::ConstantInt::get(t_int_, 0)); + BasicBlock* fail_block = BasicBlock::Create(*ctx_, "call_fail", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "call_end", function_); + llvm::Value* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0)); builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); builder_->SetInsertPoint(fail_block); // return the code. @@ -448,20 +405,14 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { arg_values.push_back(value); arg_types.push_back(value->getType()); } - llvm::FunctionType* ftype = - llvm::FunctionType::get(t_int_, arg_types, false); - llvm::Function* fcompute = - llvm::Function::Create(ftype, - llvm::Function::PrivateLinkage, - op->value.as()->value, - module_.get()); - BasicBlock* compute_call_end = CheckCallSuccess( - builder_->CreateCall(fcompute, arg_values)); + llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_types, false); + llvm::Function* fcompute = llvm::Function::Create( + ftype, llvm::Function::PrivateLinkage, op->value.as()->value, module_.get()); + BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values)); // setup compute fuinction. std::unordered_map new_vmap; size_t idx = 0; - for (auto it = fcompute->arg_begin(); - it != fcompute->arg_end(); ++it, ++idx) { + for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); ++it, ++idx) { llvm::Argument* v = &(*it); const Var& var = vargs[idx]; new_vmap[var.get()] = v; @@ -478,7 +429,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } std::swap(function_, fcompute); std::swap(new_vmap, var_map_); - BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_); + BasicBlock* compute_entry = BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(compute_entry); this->VisitStmt(op->body); builder_->CreateRet(ConstInt32(0)); @@ -503,48 +454,41 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array& vfields, uint64_t* nu llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); llvm::Value* zero = ConstInt32(0); for (size_t i = 0; i < vfields.size(); ++i) { - builder_->CreateStore( - var_map_.at(vfields[i].get()), - builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); + builder_->CreateStore(var_map_.at(vfields[i].get()), + builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); } *num_bytes = data_layout_->getTypeAllocSize( llvm::cast(cdata->getType())->getElementType()); return cdata; } -void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, - const Array& vfields, +void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, const Array& vfields, std::unordered_map* vmap) { for (size_t i = 0; i < vfields.size(); ++i) { (*vmap)[vfields[i].get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP( - cdata, {ConstInt32(0), ConstInt32(i)})); + builder_->CreateLoad(builder_->CreateInBoundsGEP(cdata, {ConstInt32(0), ConstInt32(i)})); } } void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { using llvm::BasicBlock; // closure data - llvm::Function* f = llvm::Function::Create( - ftype_tvm_parallel_lambda_, - llvm::Function::PrivateLinkage, - "__tvm_parallel_lambda", module_.get()); + llvm::Function* f = + llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage, + "__tvm_parallel_lambda", module_.get()); // allocate and setup the closure, call the closure. Array vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; llvm::Value* cdata = PackClosureData(vfields, &nbytes); #if TVM_LLVM_VERSION >= 90 - auto launch_callee = llvm::FunctionCallee( - ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); + auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); #else auto launch_callee = RuntimeTVMParallelLaunch(); #endif - BasicBlock* par_launch_end = CheckCallSuccess( - builder_->CreateCall( - launch_callee, - {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)})); + BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall( + launch_callee, {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)})); // Setup the closure function. - BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); + BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); @@ -558,9 +502,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { par_env.task_id = Var("task_id", DataType::Int(32)); par_env.num_task = Var("num_task", DataType::Int(32)); new_vmap[par_env.task_id.get()] = task_id; - new_vmap[par_env.num_task.get()] = builder_->CreateLoad( - builder_->CreateInBoundsGEP( - penv, {ConstInt32(0), ConstInt32(1)})); + new_vmap[par_env.num_task.get()] = + builder_->CreateLoad(builder_->CreateInBoundsGEP(penv, {ConstInt32(0), ConstInt32(1)})); par_env.penv = penv; std::swap(function_, f); std::swap(parallel_env_, par_env); @@ -571,16 +514,13 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { std::swap(var_map_, new_vmap); std::swap(parallel_env_, par_env); std::swap(function_, f); - CHECK_NE(par_env.parallel_loop_count, 0) - << "Cannot find parallel loop within parallel launch"; + CHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch"; builder_->SetInsertPoint(par_launch_end); } llvm::Value* CodeGenCPU::CreateStaticHandle() { llvm::GlobalVariable* gv = new llvm::GlobalVariable( - *module_, t_void_p_, false, - llvm::GlobalValue::PrivateLinkage, 0, - "__tvm_static_handle"); + *module_, t_void_p_, false, llvm::GlobalValue::PrivateLinkage, 0, "__tvm_static_handle"); #if TVM_LLVM_VERSION >= 100 gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(t_void_p_))); #else @@ -593,26 +533,23 @@ llvm::Value* CodeGenCPU::CreateStaticHandle() { void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) { using llvm::BasicBlock; // closure data - llvm::Function* f = llvm::Function::Create( - ftype_tvm_static_init_callback_, - llvm::Function::PrivateLinkage, - "__tvm_static_init_lambda", module_.get()); + llvm::Function* f = + llvm::Function::Create(ftype_tvm_static_init_callback_, llvm::Function::PrivateLinkage, + "__tvm_static_init_lambda", module_.get()); llvm::Value* gv = CreateStaticHandle(); llvm::Function* finit = module_->getFunction(init_fname); if (finit == nullptr) { - finit = llvm::Function::Create( - ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get()); + finit = llvm::Function::Create(ftype_tvm_static_init_, llvm::Function::ExternalLinkage, + init_fname, module_.get()); } // allocate and setup the closure, call the closure. uint64_t nbytes; Array vfields = tir::UndefinedVars(body, {}); llvm::Value* cdata = PackClosureData(vfields, &nbytes); - BasicBlock* init_end = CheckCallSuccess( - builder_->CreateCall( - finit, - {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)})); + BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( + finit, {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)})); // Setup the closure function. - BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); + BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); @@ -642,9 +579,9 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { if (it == func_handle_map_.end()) { // create global location for the handle // create the function handle - hptr = new llvm::GlobalVariable( - *module_, t_tvm_func_handle_, false, - llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname); + hptr = + new llvm::GlobalVariable(*module_, t_tvm_func_handle_, false, + llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname); #if TVM_LLVM_VERSION >= 100 hptr->setAlignment(llvm::Align(align)); #else @@ -657,42 +594,34 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { } // create emit codes that checks and load the function. BasicBlock* pre_block = builder_->GetInsertBlock(); - BasicBlock* init_block = BasicBlock::Create( - *ctx_, "handle_init", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "handle_init_end", function_); + BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align)); #else llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); #endif - llvm::Value* handle_not_null = builder_->CreateICmpNE( - handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); - builder_->CreateCondBr( - handle_not_null, end_block, init_block, md_very_likely_branch_); + llvm::Value* handle_not_null = + builder_->CreateICmpNE(handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); + builder_->CreateCondBr(handle_not_null, end_block, init_block, md_very_likely_branch_); // Initialize the handle if needed. builder_->SetInsertPoint(init_block); - llvm::Value* out = WithFunctionEntry([&]() { - return builder_->CreateAlloca(t_tvm_func_handle_); - }); + llvm::Value* out = + WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = builder_->CreateAlignedLoad( - gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx = + builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); #else - llvm::LoadInst* ctx = builder_->CreateAlignedLoad( - gv_mod_ctx_, gv_mod_ctx_->getAlignment()); + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif - ctx->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); + ctx->setMetadata("tbaa", + md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); #if TVM_LLVM_VERSION >= 90 - auto env_callee = llvm::FunctionCallee( - ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); + auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); #else auto env_callee = RuntimeTVMGetFuncFromEnv(); #endif - llvm::Value* retcode = builder_->CreateCall( - env_callee, {ctx, GetConstString(fname), out}); + llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align)); @@ -710,38 +639,33 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { return phi; } -llvm::BasicBlock * -CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, - llvm::Value **ret_tcode, const DataType &r_type, - const int64_t begin, const int64_t end) { +llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array& args, llvm::Value** rvalue, + llvm::Value** ret_tcode, const DataType& r_type, + const int64_t begin, const int64_t end) { using llvm::BasicBlock; std::string func_name = args[0].as()->value; - llvm::Value *handle = GetPackedFuncHandle(func_name); + llvm::Value* handle = GetPackedFuncHandle(func_name); // call the function int64_t nargs = end - begin; CHECK_GE(nargs, 0); - llvm::Value *stack_value = MakeValue(args[1]); - llvm::Value *stack_tcode = MakeValue(args[2]); - llvm::Value *arg_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), - ConstInt32(begin)); - llvm::Value *arg_tcode = - CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); - llvm::Value *ret_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), - ConstInt32(end)); + llvm::Value* stack_value = MakeValue(args[1]); + llvm::Value* stack_tcode = MakeValue(args[2]); + llvm::Value* arg_value = builder_->CreateInBoundsGEP( + builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); + llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + llvm::Value* ret_value = builder_->CreateInBoundsGEP( + builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); #else auto call_callee = RuntimeTVMFuncCall(); #endif - BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), - ret_value, *ret_tcode})); + BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall( + call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode})); DataType r_api_type = tir::APIType(r_type); - llvm::Value* load_ptr = builder_->CreatePointerCast( - ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); + llvm::Value* load_ptr = + builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); #if TVM_LLVM_VERSION >= 110 *rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); #else @@ -751,47 +675,44 @@ CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, return end_block; } -llvm::Value *CodeGenCPU::CreateCallPacked(const CallNode *op) { +llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { CHECK_EQ(op->args.size(), 5U); - llvm::Value *rvalue = nullptr; - llvm::Value *ret_tcode = nullptr; - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, - op->args[3].as()->value, + llvm::Value* rvalue = nullptr; + llvm::Value* ret_tcode = nullptr; + MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, op->args[4].as()->value); return rvalue; } -llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) { +llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { using llvm::BasicBlock; CHECK_EQ(op->args.size(), 6U); - llvm::Value *rvalue = nullptr; - llvm::Value *ret_tcode = nullptr; - BasicBlock *end_block = MakeCallPacked( - op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + llvm::Value* rvalue = nullptr; + llvm::Value* ret_tcode = nullptr; + BasicBlock* end_block = + MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); // Get traced value. - llvm::Value *traced_value = MakeValue(op->args[5]); + llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - BasicBlock *update_block = - BasicBlock::Create(*ctx_, "update_block", function_); + BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - BasicBlock *continue_block = - BasicBlock::Create(*ctx_, "continue_block", function_); + BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_); #if TVM_LLVM_VERSION >= 110 - llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); + llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); #else - llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); + llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); #endif // Check the ret_type_code and create cmp instruction. - llvm::Value *cmp = builder_->CreateICmpNE( - ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); + llvm::Value* cmp = + builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); builder_->CreateCondBr(cmp, update_block, continue_block); builder_->SetInsertPoint(update_block); builder_->CreateBr(continue_block); builder_->SetInsertPoint(continue_block); // The return value depends on from what bb we come from. - llvm::PHINode *phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); + llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); phi_rvalue->addIncoming(rvalue, update_block); phi_rvalue->addIncoming(traced_value, end_block); return phi_rvalue; @@ -823,17 +744,14 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { void CodeGenCPU::AddStartupFunction() { if (export_system_symbols_.size() != 0) { llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false); - function_ = llvm::Function::Create( - ftype, - llvm::Function::InternalLinkage, - "__tvm_module_startup", module_.get()); + function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, + "__tvm_module_startup", module_.get()); llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(startup_entry); for (const auto& kv : export_system_symbols_) { llvm::Value* name = GetConstString(kv.first); - builder_->CreateCall( - f_tvm_register_system_symbol_, { - name, builder_->CreateBitCast(kv.second, t_void_p_)}); + builder_->CreateCall(f_tvm_register_system_symbol_, + {name, builder_->CreateBitCast(kv.second, t_void_p_)}); } llvm::appendToGlobalCtors(*module_, function_, 65535); builder_->CreateRet(nullptr); @@ -853,9 +771,8 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; - llvm::Value* ref = this->CreateStructRefPtr( - op->dtype, MakeValue(op->args[0]), - MakeValue(op->args[1]), kind); + llvm::Value* ref = + this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == intrinsic::kArrAddr) { return builder_->CreatePointerCast(ref, t_void_p_); } else { @@ -865,13 +782,11 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); - llvm::Value* ref = this->CreateStructRefPtr( - op->args[3].dtype(), MakeValue(op->args[0]), - MakeValue(op->args[1]), kind); + llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), + MakeValue(op->args[1]), kind); CHECK(kind != intrinsic::kArrAddr); if (value->getType()->isPointerTy()) { - value = builder_->CreatePointerCast( - value, ref->getType()->getPointerElementType()); + value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); } builder_->CreateStore(value, ref); return ConstInt32(0); @@ -879,22 +794,22 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; return WithFunctionEntry([&]() -> llvm::AllocaInst* { - const int64_t* pval = as_const_int(op->args[1]); - CHECK(pval) << "require stack alloca to contain constant value"; - llvm::Value* num = ConstInt32(pval[0]); - if (type == "shape") { - return builder_->CreateAlloca(t_tvm_shape_index_, num); - } else if (type == "arg_value") { - return builder_->CreateAlloca(t_tvm_value_, num); - } else if (type == "arg_tcode") { - return builder_->CreateAlloca(t_int_, num); - } else if (type == "array") { - return builder_->CreateAlloca(t_tvm_array_, num); - } else { - LOG(FATAL) << "Unknown stack alloca type " << type; - return nullptr; - } - }); + const int64_t* pval = as_const_int(op->args[1]); + CHECK(pval) << "require stack alloca to contain constant value"; + llvm::Value* num = ConstInt32(pval[0]); + if (type == "shape") { + return builder_->CreateAlloca(t_tvm_shape_index_, num); + } else if (type == "arg_value") { + return builder_->CreateAlloca(t_tvm_value_, num); + } else if (type == "arg_tcode") { + return builder_->CreateAlloca(t_int_, num); + } else if (type == "array") { + return builder_->CreateAlloca(t_tvm_array_, num); + } else { + LOG(FATAL) << "Unknown stack alloca type " << type; + return nullptr; + } + }); } else { return CodeGenLLVM::CreateIntrinsic(op); } @@ -909,16 +824,14 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { os << ", " << op->message.as()->value; } llvm::Value* msg = GetConstString(os.str()); - BasicBlock* fail_block = BasicBlock::Create( - *ctx_, "assert_fail", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "assert_end", function_); + BasicBlock* fail_block = BasicBlock::Create(*ctx_, "assert_fail", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "assert_end", function_); builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); // fail condition. builder_->SetInsertPoint(fail_block); #if TVM_LLVM_VERSION >= 90 - auto err_callee = llvm::FunctionCallee( - ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); + auto err_callee = + llvm::FunctionCallee(ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); #else auto err_callee = RuntimeTVMAPISetLastError(); #endif @@ -932,7 +845,7 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::coproc_uop_scope) { this->CreateStaticInit(op->value.as()->value, op->body); - } else if (op->attr_key == tir::attr::compute_scope) { + } else if (op->attr_key == tir::attr::compute_scope) { this->CreateComputeScope(op); } else if (tir::attr::IsPragmaKey(op->attr_key)) { if (op->attr_key == "pragma_parallel_stride_pattern") { @@ -943,20 +856,18 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == "pragma_parallel_launch_point") { CreateParallelLaunch(op->body, 0); } else if (op->attr_key == "pragma_parallel_barrier_when_finish") { - CHECK(parallel_env_.penv != nullptr) - << "Cannot run barrier without parallel environment"; + CHECK(parallel_env_.penv != nullptr) << "Cannot run barrier without parallel environment"; CHECK(!parallel_env_.in_parallel_loop) << "Cannot not place within parallel loop as the workload may differ, " << " place it between parallel and parallel_launch_point"; this->VisitStmt(op->body); #if TVM_LLVM_VERSION >= 90 - auto bar_callee = llvm::FunctionCallee( - ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier()); + auto bar_callee = + llvm::FunctionCallee(ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier()); #else auto bar_callee = RuntimeTVMParallelBarrier(); #endif - builder_->CreateCall( - bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv}); + builder_->CreateCall(bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv}); } else if (op->attr_key == tir::attr::pragma_import_llvm) { const StringImmNode* value = op->value.as(); CHECK(value != nullptr); @@ -973,15 +884,13 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { void CodeGenCPU::VisitStmt_(const ForNode* op) { CHECK(is_zero(op->min)); - if (op->for_type == ForType::Serial || - op->for_type == ForType::Unrolled) { + if (op->for_type == ForType::Serial || op->for_type == ForType::Unrolled) { CodeGenLLVM::VisitStmt_(op); } else if (op->for_type == ForType::Parallel) { if (parallel_env_.penv == nullptr) { CreateParallelLaunch( - ForNode::make( - op->loop_var, op->min, op->extent, - op->for_type, op->device_api, op->body), 0); + ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), + 0); } else { // already in parallel env. CHECK(parallel_env_.task_id.defined()); @@ -994,20 +903,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; if (parallel_env_.stride_pattern) { - CreateSerialFor(MakeValue(task_id), - MakeValue(op->extent), - MakeValue(num_task), - op->loop_var, - op->body); + CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), + op->loop_var, op->body); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; PrimExpr begin = MinNode::make(task_id * step, op->extent); PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); - CreateSerialFor(MakeValue(begin), - MakeValue(end), - llvm::ConstantInt::getSigned(GetLLVMType(end), 1), - op->loop_var, - op->body); + CreateSerialFor(MakeValue(begin), MakeValue(end), + llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); } parallel_env_.in_parallel_loop = false; ++parallel_env_.parallel_loop_count; diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index aa8371c..7a14b8f 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -24,11 +24,12 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_CPU_H_ #define TVM_TARGET_LLVM_CODEGEN_CPU_H_ -#include -#include #include #include #include +#include +#include + #include "codegen_llvm.h" namespace tvm { @@ -37,11 +38,8 @@ namespace codegen { // CPU host code generation class CodeGenCPU : public CodeGenLLVM { public: - void Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup) override; + void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, + bool system_lib, bool dynamic_lookup) override; void AddFunction(const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; std::unique_ptr Finish() override; @@ -95,20 +93,18 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* CreateStaticHandle(); llvm::Value* GetPackedFuncHandle(const std::string& str); - llvm::Value* PackClosureData(const Array& fields, uint64_t *num_bytes); + llvm::Value* PackClosureData(const Array& fields, uint64_t* num_bytes); llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); - void UnpackClosureData(llvm::Value*cdata, - const Array& fields, + void UnpackClosureData(llvm::Value* cdata, const Array& fields, std::unordered_map* vmap); // Make packed call. - llvm::BasicBlock *MakeCallPacked(const Array &args, - llvm::Value **rvalue, - llvm::Value **ret_tcode, const DataType &r_type, + llvm::BasicBlock* MakeCallPacked(const Array& args, llvm::Value** rvalue, + llvm::Value** ret_tcode, const DataType& r_type, const int64_t begin, const int64_t end); // create call into tvm packed function. llvm::Value* CreateCallPacked(const CallNode* op); // Create trace call into tvm packed function. - llvm::Value* CreateCallTracePacked(const CallNode *op); + llvm::Value* CreateCallTracePacked(const CallNode* op); // Create static initialization void CreateStaticInit(const std::string& init_fname, const Stmt& body); // Create parallel launch diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 74bda71..f664532 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -22,20 +22,21 @@ */ #ifdef TVM_LLVM_VERSION // Part of the code are adapted from Halide's CodeGen_LLVM -#include +#include "codegen_llvm.h" + #include +#include #include #include -#include "codegen_llvm.h" -#include "codegen_cpu.h" #include "../../arith/pattern_match.h" #include "../build_common.h" +#include "codegen_cpu.h" namespace tvm { namespace codegen { -std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine *tm) { +std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { std::string target = tm->getTarget().getName(); std::string factory_name = "tvm.codegen.llvm.target_" + target; const PackedFunc* f = runtime::Registry::Get(factory_name); @@ -47,11 +48,8 @@ std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine *tm) { } } -void CodeGenLLVM::Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup) { +void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, + llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) { InitializeLLVM(); ctx_ = ctx; builder_.reset(new IRBuilder(*ctx_)); @@ -68,7 +66,7 @@ void CodeGenLLVM::Init(const std::string& module_name, t_int64_ = llvm::Type::getInt64Ty(*ctx_); t_float64_ = llvm::Type::getDoubleTy(*ctx_); // meta data - md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1); + md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1); md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa"); md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_); this->InitTarget(tm); @@ -96,9 +94,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { } } -void CodeGenLLVM::AddFunction(const PrimFunc& f) { - this->AddFunctionInternal(f, false); -} +void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); } void CodeGenLLVM::InitFuncState() { var_map_.clear(); @@ -108,7 +104,6 @@ void CodeGenLLVM::InitFuncState() { analyzer_.reset(new arith::Analyzer()); } - void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { this->InitFuncState(); @@ -126,8 +121,8 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { // TODO(tvm-team): // Update the function type to respect the ret_type field of f. // Once we allow more flexibility in the PrimFunc. - llvm::FunctionType* ftype = llvm::FunctionType::get( - ret_void ? t_void_ : t_int_, param_types, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) @@ -135,9 +130,8 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { CHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) << "Function " << global_symbol << " already exist in module"; - function_ = llvm::Function::Create( - ftype, llvm::Function::ExternalLinkage, - global_symbol.value().operator std::string(), module_.get()); + function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + global_symbol.value().operator std::string(), module_.get()); function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); @@ -169,7 +163,6 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } - std::unique_ptr CodeGenLLVM::Finish() { this->AddStartupFunction(); for (size_t i = 0; i < link_modules_.size(); ++i) { @@ -182,13 +175,11 @@ std::unique_ptr CodeGenLLVM::Finish() { return std::move(module_); } - void CodeGenLLVM::HandleImport(const std::string& code) { std::unique_ptr mlib; llvm::SMDiagnostic err; if (code.length() >= 3 && - (code.substr(code.length() - 3) == ".ll" || - code.substr(code.length() - 3) == ".bc")) { + (code.substr(code.length() - 3) == ".ll" || code.substr(code.length() - 3) == ".bc")) { mlib = llvm::parseIRFile(code, err, *ctx_); if (mlib.get() == nullptr) { std::string msg = std::string(err.getMessage()); @@ -196,20 +187,19 @@ void CodeGenLLVM::HandleImport(const std::string& code) { << "line " << err.getLineNo() << ":" << msg; } } else { - std::unique_ptr buf = - llvm::MemoryBuffer::getMemBuffer(code); + std::unique_ptr buf = llvm::MemoryBuffer::getMemBuffer(code); mlib = llvm::parseIR(*buf, err, *ctx_); if (mlib.get() == nullptr) { std::string msg = std::string(err.getMessage()); LOG(FATAL) << "Fail to load llvm ir " - << "line " << err.getLineNo() << ":" << msg - << "\ncontent:\n" << code; + << "line " << err.getLineNo() << ":" << msg << "\ncontent:\n" + << code; } } mlib->setTargetTriple(target_machine_->getTargetTriple().str()); mlib->setDataLayout(target_machine_->createDataLayout()); // mark all the functions as force inline - for (llvm::Function &f : mlib->functions()) { + for (llvm::Function& f : mlib->functions()) { f.removeFnAttr(llvm::Attribute::NoInline); f.addFnAttr(llvm::Attribute::AlwaysInline); f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage); @@ -238,35 +228,27 @@ llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) { class FPassManager : public llvm::legacy::FunctionPassManager { public: - explicit FPassManager(llvm::Module* m) - : llvm::legacy::FunctionPassManager(m) {} + explicit FPassManager(llvm::Module* m) : llvm::legacy::FunctionPassManager(m) {} // override add to allow messaging - void add(llvm::Pass* p) final { - llvm::legacy::FunctionPassManager::add(p); - } + void add(llvm::Pass* p) final { llvm::legacy::FunctionPassManager::add(p); } }; class MPassManager : public llvm::legacy::PassManager { public: // override add to allow messaging - void add(llvm::Pass* p) final { - llvm::legacy::PassManager::add(p); - } + void add(llvm::Pass* p) final { llvm::legacy::PassManager::add(p); } }; -void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) { -} +void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {} void CodeGenLLVM::Optimize() { // pass manager FPassManager fpass(module_.get()); MPassManager mpass; mpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : - llvm::TargetIRAnalysis())); + target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); fpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : - llvm::TargetIRAnalysis())); + target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); // place optimization pass llvm::PassManagerBuilder builder; @@ -300,9 +282,7 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co return native_vector_bits_; } -unsigned CodeGenLLVM::GetGlobalAddressSpace() const { - return 0; -} +unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; } llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { if (dtype.is_handle()) { @@ -317,10 +297,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { etype = llvm::Type::getIntNTy(*ctx_, dtype.bits()); } else if (dtype.is_float()) { switch (dtype.bits()) { - case 16: etype = llvm::Type::getHalfTy(*ctx_); break; - case 32: etype = llvm::Type::getFloatTy(*ctx_); break; - case 64: etype = llvm::Type::getDoubleTy(*ctx_); break; - default: LOG(FATAL) << "do not support " << dtype; + case 16: + etype = llvm::Type::getHalfTy(*ctx_); + break; + case 32: + etype = llvm::Type::getFloatTy(*ctx_); + break; + case 64: + etype = llvm::Type::getDoubleTy(*ctx_); + break; + default: + LOG(FATAL) << "do not support " << dtype; } } if (dtype.lanes() != 1) { @@ -355,16 +342,12 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const { // // This trick comes from Halide's CodeGen_LLVM // -void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, - const VarNode* buffer, - PrimExpr index, +void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, PrimExpr index, DataType type) { if (alias_var_set_.count(buffer) != 0) { // Mark all possibly aliased pointer as same type. llvm::MDNode* meta = md_tbaa_alias_set_; - inst->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(meta, meta, 0)); + inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); return; } @@ -405,16 +388,11 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta); } } - inst->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(meta, meta, 0)); + inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); } -void CodeGenLLVM::GetAlignment(DataType t, - const VarNode* buf_var, - const PrimExpr& index, - int* p_alignment, - int* p_native_bits) { +void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, + int* p_alignment, int* p_native_bits) { int max_align_bits = t.bits(); auto it = alloc_storage_info_.find(buf_var); if (it != alloc_storage_info_.end()) { @@ -430,11 +408,9 @@ void CodeGenLLVM::GetAlignment(DataType t, int64_t coeff = me->coeff; int align_bits = t.bits(); - while (align_bits < max_align_bits && - base % 2 == 0 && - coeff % 2 == 0) { - base = base / 2; - coeff = coeff / 2; + while (align_bits < max_align_bits && base % 2 == 0 && coeff % 2 == 0) { + base = base / 2; + coeff = coeff / 2; align_bits *= 2; } if (align_bits < 8) { @@ -443,8 +419,7 @@ void CodeGenLLVM::GetAlignment(DataType t, *p_alignment = align_bits / 8; } -std::unique_ptr -CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { +std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { #if TVM_LLVM_VERSION >= 100 auto debug_info = std::make_unique(); debug_info->di_builder_ = std::make_unique(*module); @@ -463,8 +438,7 @@ CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { } llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { - llvm::Constant* undef = llvm::UndefValue::get( - llvm::VectorType::get(value->getType(), lanes)); + llvm::Constant* undef = llvm::UndefValue::get(llvm::VectorType::get(value->getType(), lanes)); llvm::Constant* zero = ConstInt32(0); value = builder_->CreateInsertElement(undef, value, zero); #if TVM_LLVM_VERSION >= 110 @@ -506,8 +480,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { } llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { - llvm::Value* mask = llvm::UndefValue::get( - DTypeToLLVMType(DataType::Int(32, target_lanes))); + llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes))); int num_elems = llvm::cast(vec->getType())->getNumElements(); if (num_elems == target_lanes) return vec; CHECK_LT(num_elems, target_lanes); @@ -558,28 +531,21 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { return CreateVecSlice(vecs[0], 0, total_lanes); } - -void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, - llvm::Value* end, - llvm::Value* stride, - const Var& loop_var, - const Stmt& body) { +void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, + const Var& loop_var, const Stmt& body) { using llvm::BasicBlock; BasicBlock* pre_block = builder_->GetInsertBlock(); - BasicBlock* for_begin = BasicBlock::Create( - *ctx_, "for_begin", function_); - BasicBlock* for_body = BasicBlock::Create( - *ctx_, "for_body", function_); - BasicBlock* for_end = BasicBlock::Create( - *ctx_, "for_end", function_); + BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_); + BasicBlock* for_body = BasicBlock::Create(*ctx_, "for_body", function_); + BasicBlock* for_end = BasicBlock::Create(*ctx_, "for_end", function_); builder_->CreateBr(for_begin); builder_->SetInsertPoint(for_begin); llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2); loop_value->addIncoming(begin, pre_block); CHECK(!var_map_.count(loop_var.get())); var_map_[loop_var.get()] = loop_value; - builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), - for_body, for_end, md_very_likely_branch_); + builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end, + md_very_likely_branch_); builder_->SetInsertPoint(for_body); this->VisitStmt(body); var_map_.erase(loop_var.get()); @@ -591,7 +557,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, // cast operatpr llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { - llvm::Type * target = DTypeToLLVMType(to); + llvm::Type* target = DTypeToLLVMType(to); if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); @@ -628,8 +594,8 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else @@ -638,14 +604,12 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); llvm::Constant* zero = ConstInt32(0); llvm::Constant* indices[] = {zero, zero}; - llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr( - type, global, indices); + llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(type, global, indices); str_map_[str] = ptr; return ptr; } -llvm::Value* CodeGenLLVM::CreateBufferPtr( - DataType t, llvm::Value* buffer, llvm::Value* index) { +llvm::Value* CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index) { CHECK_EQ(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); @@ -657,13 +621,11 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr( return builder_->CreateInBoundsGEP(buffer, index); } -llvm::Value* CodeGenLLVM::CreateBufferVecPtr( - DataType t, llvm::Value* buffer, llvm::Value* index) { +llvm::Value* CodeGenLLVM::CreateBufferVecPtr(DataType t, llvm::Value* buffer, llvm::Value* index) { CHECK_GT(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); - llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo( - btype->getAddressSpace()); + llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace()); if (btype != ptype) { buffer = builder_->CreatePointerCast(buffer, ptype); } @@ -683,21 +645,18 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { arg_value.push_back(MakeValue(op->args[i])); arg_type.push_back(arg_value.back()->getType()); } - llvm::FunctionType* ftype = llvm::FunctionType::get( - GetLLVMType(GetRef(op)), arg_type, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_type, false); llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { - f = llvm::Function::Create( - ftype, llvm::Function::ExternalLinkage, - op->name, module_.get()); + f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); } llvm::CallInst* call = builder_->CreateCall(f, arg_value); return call; } -llvm::Function* CodeGenLLVM::GetIntrinsicDecl( - llvm::Intrinsic::ID id, llvm::Type* ret_type, - llvm::ArrayRef arg_types) { +llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, + llvm::ArrayRef arg_types) { llvm::Module* module = module_.get(); if (!llvm::Intrinsic::isOverloaded(id)) { @@ -712,8 +671,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl( auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) { overload_types.clear(); llvm::ArrayRef ref(infos); - auto match = - llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); + auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) { bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref); if (error) { @@ -748,7 +706,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl( // Failed to identify the type. return nullptr; -#else // TVM_LLVM_VERSION +#else // TVM_LLVM_VERSION llvm::ArrayRef ref(infos); // matchIntrinsicType returns true on error. if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) { @@ -766,9 +724,8 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl( llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); - llvm::Intrinsic::ID id = static_cast( - Downcast(op->args[0])->value); - int64_t num_signature = Downcast(op->args[1])->value; + llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); + int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector arg_type; for (size_t i = 2; i < op->args.size(); ++i) { @@ -784,9 +741,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { // mismatch will have to be treated specially here. // TODO(kparzysz-quic): fix this once TVM prefetch uses the same // type as LLVM. - llvm::Type *return_type = (id != llvm::Intrinsic::prefetch) - ? GetLLVMType(GetRef(op)) - : llvm::Type::getVoidTy(*ctx_); + llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) + : llvm::Type::getVoidTy(*ctx_); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " @@ -811,22 +767,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return CreateStorageSync(op); } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); - const RampNode *r = l->index.as(); + const RampNode* r = l->index.as(); llvm::Value* ptr; unsigned addrspace; if (!r) { - ptr = CreateBufferPtr( - l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); - addrspace = llvm::dyn_cast( - ptr->getType())->getAddressSpace(); + ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); + addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); } else { - PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); - ptr = CreateBufferVecPtr( - l->dtype, MakeValue(l->buffer_var), MakeValue(index)); - addrspace = llvm::dyn_cast( - ptr->getType())->getAddressSpace(); + PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); + ptr = CreateBufferVecPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); + addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); } return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace)); } else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) { @@ -840,15 +792,11 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { uint64_t val = (high << 32U) | low; return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { - CHECK_EQ(op->args[0].dtype().lanes(), 1) - << "if_then_else can only take scalar condition"; + CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; using llvm::BasicBlock; - BasicBlock* then_block = BasicBlock::Create( - *ctx_, "if_then", function_); - BasicBlock* else_block = BasicBlock::Create( - *ctx_, "if_else", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "if_end", function_); + BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_); + BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_); builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block); builder_->SetInsertPoint(then_block); llvm::Value* then_value = MakeValue(op->args[1]); @@ -864,23 +812,23 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(else_value, else_value_block); return value; } else if (op->is_intrinsic(CallNode::reinterpret)) { - llvm::Type * target = DTypeToLLVMType(op->dtype); + llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); } else if (op->is_intrinsic(CallNode::isnan)) { // TODO(hgt312): set fast math flag llvm::Value* a = MakeValue(op->args[0]); return builder_->CreateFCmpUNO(a, a); } else if (op->is_intrinsic("vectorlow")) { - llvm::Value *v = MakeValue(op->args[0]); + llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); - return CreateVecSlice(v, 0, l/2); + return CreateVecSlice(v, 0, l / 2); } else if (op->is_intrinsic("vectorhigh")) { - llvm::Value *v = MakeValue(op->args[0]); + llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); - return CreateVecSlice(v, l/2, l/2); + return CreateVecSlice(v, l / 2, l / 2); } else if (op->is_intrinsic("vectorcombine")) { - llvm::Value *v0 = MakeValue(op->args[0]); - llvm::Value *v1 = MakeValue(op->args[1]); + llvm::Value* v0 = MakeValue(op->args[0]); + llvm::Value* v1 = MakeValue(op->args[1]); int num_elems = llvm::cast(v0->getType())->getNumElements() * 2; #if TVM_LLVM_VERSION >= 110 std::vector indices; @@ -897,8 +845,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } } -void CodeGenLLVM::Scalarize(const PrimExpr& e, - std::function f) { +void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { for (int i = 0; i < ramp->dtype.lanes(); ++i) { PrimExpr offset = ramp->base + (ramp->stride * i); @@ -912,11 +859,8 @@ void CodeGenLLVM::Scalarize(const PrimExpr& e, } } - // Visitors -llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { - return GetVarValue(op); -} +llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); } llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) { return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value)); @@ -929,52 +873,48 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { - return GetConstString(op->value); -} - -#define DEFINE_CODEGEN_BINARY_OP(Op) \ - llvm::Value* CodeGenLLVM::Create ## Op( \ - DataType t, llvm::Value* a, llvm::Value *b) { \ - if (t.is_int()) { \ - if (t.bits() >= 32) { \ - return builder_->CreateNSW ## Op (a, b); \ - } else { \ - return builder_->Create ## Op (a, b); \ - } \ - } else if (t.is_uint()) { \ - if (t.bits() >= 32) { \ - return builder_->CreateNUW ## Op (a, b); \ - } else { \ - return builder_->Create ## Op (a, b); \ - } \ - } else { \ - CHECK(t.is_float()); \ - return builder_->CreateF ## Op (a, b); \ - } \ - } \ - llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \ - return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ +llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); } + +#define DEFINE_CODEGEN_BINARY_OP(Op) \ + llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ + if (t.is_int()) { \ + if (t.bits() >= 32) { \ + return builder_->CreateNSW##Op(a, b); \ + } else { \ + return builder_->Create##Op(a, b); \ + } \ + } else if (t.is_uint()) { \ + if (t.bits() >= 32) { \ + return builder_->CreateNUW##Op(a, b); \ + } else { \ + return builder_->Create##Op(a, b); \ + } \ + } else { \ + CHECK(t.is_float()); \ + return builder_->CreateF##Op(a, b); \ + } \ + } \ + llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ + return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ } DEFINE_CODEGEN_BINARY_OP(Add); DEFINE_CODEGEN_BINARY_OP(Sub); DEFINE_CODEGEN_BINARY_OP(Mul); -#define DEFINE_CODEGEN_CMP_OP(Op) \ - llvm::Value* CodeGenLLVM::Create ## Op( \ - DataType t, llvm::Value* a, llvm::Value* b) { \ - if (t.is_int()) { \ - return builder_->CreateICmpS ## Op (a, b); \ - } else if (t.is_uint()) { \ - return builder_->CreateICmpU ## Op (a, b); \ - } else { \ - CHECK(t.is_float()); \ - return builder_->CreateFCmpO ## Op (a, b); \ - } \ -} \ - llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \ - return Create ## Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ +#define DEFINE_CODEGEN_CMP_OP(Op) \ + llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ + if (t.is_int()) { \ + return builder_->CreateICmpS##Op(a, b); \ + } else if (t.is_uint()) { \ + return builder_->CreateICmpU##Op(a, b); \ + } else { \ + CHECK(t.is_float()); \ + return builder_->CreateFCmpO##Op(a, b); \ + } \ + } \ + llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ + return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ } DEFINE_CODEGEN_CMP_OP(LT); @@ -1053,10 +993,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) { - return builder_->CreateSelect( - MakeValue(op->condition), - MakeValue(op->true_value), - MakeValue(op->false_value)); + return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value), + MakeValue(op->false_value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { @@ -1077,8 +1015,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); llvm::Value* ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = - builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); #else llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); #endif @@ -1086,20 +1023,17 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { return load; } else { // vector load - unsigned addrspace = llvm::dyn_cast( - buffer->getType())->getAddressSpace(); + unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); CHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr( - t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast( - ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad( - ptr, llvm::Align(alignment), is_volatile); + llvm::LoadInst* load = + builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); #else llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); #endif @@ -1114,11 +1048,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { auto f = [&](int i, llvm::Value* index) { llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad( - ptr, llvm::Align(basic_align), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(basic_align), is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad( - ptr, basic_align, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile); #endif ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); AddAliasInfo(load, op->buffer_var.get(), PrimExpr(), t); @@ -1128,16 +1060,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { return CreateIntrinsic(op); - } else if (op->call_type == CallNode::Extern || - op->call_type == CallNode::PureExtern) { + } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { return CreateCallExtern(op); } else { - LOG(FATAL) << "Unknown call type " << - "name= " << op->name << - " call_type= " << op->call_type; + LOG(FATAL) << "Unknown call type " + << "name= " << op->name << " call_type= " << op->call_type; return nullptr; } } @@ -1146,14 +1075,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); for (int i = 0; i < op->lanes; ++i) { vec = builder_->CreateInsertElement( - vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), - ConstInt32(i)); + vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i)); } return vec; } llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { - std::vector vecs(op->vectors.size()); + std::vector vecs(op->vectors.size()); int total_lanes = 0; for (int i = 0, e = op->vectors.size(); i < e; ++i) { vecs[i] = VisitExpr(op->vectors[i]); @@ -1162,9 +1090,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { llvm::Value* v0 = CreateVecConcat(vecs); std::vector idx(op->indices.size()); for (int i = 0, e = op->indices.size(); i < e; ++i) { - const int64_t *val = as_const_int(op->indices[i]); - CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " - << "but get " << op->indices[i] << "\n"; + const int64_t* val = as_const_int(op->indices[i]); + CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " + << "but get " << op->indices[i] << "\n"; idx[i] = *val; } llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx); @@ -1198,15 +1126,13 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { return; } else { // vector store - unsigned addrspace = llvm::dyn_cast( - buffer->getType())->getAddressSpace(); + unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); CHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr( - t.element_of(), buffer, MakeValue(ramp->base)); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = @@ -1226,12 +1152,10 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), - ptr, llvm::Align(basic_align), is_volatile); + builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), - ptr, basic_align, is_volatile); + llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), + ptr, basic_align, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), PrimExpr(), op->value.dtype()); }; @@ -1248,21 +1172,16 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { CHECK(op->for_type == ForType::Serial); } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), - op->loop_var, op->body); + llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); } - void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { using llvm::BasicBlock; llvm::Value* cond = MakeValue(op->condition); - BasicBlock* then_block = BasicBlock::Create( - *ctx_, "if_then", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "if_end", function_); + BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_); if (op->else_case.defined()) { - BasicBlock* else_block = BasicBlock::Create( - *ctx_, "if_else", function_); + BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_); builder_->CreateCondBr(cond, then_block, else_block); builder_->SetInsertPoint(then_block); this->VisitStmt(op->then_case); @@ -1279,39 +1198,35 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { builder_->SetInsertPoint(end_block); } - void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } - // maximum necessary alignment in the NV devices - if (info.alignment > 16) { - info.alignment = 16; - } - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + // maximum necessary alignment in the NV devices + if (info.alignment > 16) { + info.alignment = 16; + } + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif - } - info.alignment = alloca->getAlignment(); - buf = alloca; + } + info.alignment = alloca->getAlignment(); + buf = alloca; buf = builder_->CreatePointerCast( - buf, DTypeToLLVMType(op->dtype)->getPointerTo( - buf->getType()->getPointerAddressSpace())); + buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -1334,8 +1249,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); CHECK(v); - alloc_storage_info_[v].alignment = - static_cast(op->value.as()->value); + alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); @@ -1367,9 +1281,7 @@ void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { } } -void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { - MakeValue(op->value); -} +void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e851f37..4522c15 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -25,27 +25,27 @@ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #ifdef TVM_LLVM_VERSION +#include #include #include -#include +#include #include -#include -#include #include +#include +#include #include -#include - #include -#include -#include #include #include #include -#include "llvm_common.h" -#include "../../runtime/thread_storage_scope.h" +#include +#include + #include "../../arith/compute_expr.h" +#include "../../runtime/thread_storage_scope.h" #include "../../tir/transforms/ir_util.h" +#include "llvm_common.h" namespace tvm { namespace codegen { @@ -55,9 +55,8 @@ using namespace tir; /*! * \brief A base class to generate a LLVM. */ -class CodeGenLLVM : - public ExprFunctor, - public StmtFunctor { +class CodeGenLLVM : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Create new code generator based on target machine. @@ -74,11 +73,8 @@ class CodeGenLLVM : * \param dynamic_lookup Whether dynamically lookup runtime function * or use the runtime function table passed by caller. */ - virtual void Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup); + virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, + bool system_lib, bool dynamic_lookup); /*! * \brief Compile and add function f to the current module. * \param f The function to be added. @@ -104,9 +100,7 @@ class CodeGenLLVM : * \param e The expression to be created value for. * \return created value. */ - llvm::Value* MakeValue(const PrimExpr& e) { - return VisitExpr(e); - } + llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); } // Short hande code to get a constant int 32 llvm::Constant* ConstInt32(int64_t value) const { return llvm::ConstantInt::getSigned(t_int32_, value); @@ -170,7 +164,7 @@ class CodeGenLLVM : * \tparam F The function to be executed. * \return The result. */ - template + template llvm::AllocaInst* WithFunctionEntry(F falloca) { llvm::BasicBlock* current = builder_->GetInsertBlock(); llvm::BasicBlock* entry = &(function_->getEntryBlock()); @@ -191,8 +185,7 @@ class CodeGenLLVM : virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder); // Scalarize by iterating elements of e. // f is a callback that takes index and v. - virtual void Scalarize(const PrimExpr& e, - std::function f); + virtual void Scalarize(const PrimExpr& e, std::function f); // Initialize target virtual void InitTarget(llvm::TargetMachine* tm); // Add module startup function if needed. @@ -205,8 +198,7 @@ class CodeGenLLVM : virtual unsigned GetGlobalAddressSpace() const; void AddFunctionInternal(const PrimFunc& f, bool ret_void); // Create extern call - llvm::CallInst* CreateCallExtern(llvm::Type* ret, - const std::string& name, + llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name, const std::vector& value); /*! * \brief Get the LLVM Type for a given runtime type. @@ -243,20 +235,18 @@ class CodeGenLLVM : * could not be generated (e.g. if the argument/return types do not * match). */ - llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, - llvm::Type* ret_type, + llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, llvm::ArrayRef arg_types); // initialize the function state. void InitFuncState(); // Get alignment given index. - void GetAlignment( - DataType t, const VarNode* buf_var, const PrimExpr& index, - int* p_alignment, int* p_native_bits); + void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, + int* p_native_bits); // Get constant string llvm::Value* GetConstString(const std::string& str); // do a scalarize call with f - llvm::Value* CreateScalarizedCall( - const CallNode* op, llvm::Function* f, const std::vector& args); + llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, + const std::vector& args); // handle module import void HandleImport(const std::string& code); // cast operatpr @@ -279,9 +269,7 @@ class CodeGenLLVM : llvm::Value* CreateVecConcat(std::vector vecs); llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); // Create serial for - void CreateSerialFor(llvm::Value* begin, - llvm::Value* end, - llvm::Value* stride, + void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, const Var& loop_var, const Stmt& body); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index, DataType type); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 40dc653..a0687b9 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -24,9 +24,10 @@ #ifdef TVM_LLVM_VERSION #include -#include "codegen_llvm.h" -#include "../build_common.h" + #include "../../runtime/cuda/cuda_module.h" +#include "../build_common.h" +#include "codegen_llvm.h" namespace tvm { namespace codegen { @@ -39,10 +40,9 @@ class CodeGenNVPTX : public CodeGenLLVM { CodeGenLLVM::AddFunctionInternal(f, true); // annotate as kernel function module_->getOrInsertNamedMetadata("nvvm.annotations") - ->addOperand(llvm::MDNode::get(*ctx_, { - llvm::ValueAsMetadata::get(function_), - llvm::MDString::get(*ctx_, "kernel"), - llvm::ValueAsMetadata::get(ConstInt32(1)) })); + ->addOperand(llvm::MDNode::get( + *ctx_, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx_, "kernel"), + llvm::ValueAsMetadata::get(ConstInt32(1))})); } void VisitStmt_(const AllocateNode* op) final { @@ -50,8 +50,7 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::Value* buf = nullptr; int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); @@ -65,9 +64,8 @@ class CodeGenNVPTX : public CodeGenLLVM { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 alloca->setAlignment(llvm::Align(info.alignment)); @@ -81,12 +79,11 @@ class CodeGenNVPTX : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get( - DTypeToLLVMType(op->dtype), constant_size); + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", - nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); #else @@ -96,8 +93,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } buf = builder_->CreatePointerCast( - buf, DTypeToLLVMType(op->dtype)->getPointerTo( - buf->getType()->getPointerAddressSpace())); + buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -109,18 +105,32 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; if (ts.rank == 1) { switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; break; - case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; break; - case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break; - default: LOG(FATAL) << "unknown thread idx"; + case 0: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; + break; + default: + LOG(FATAL) << "unknown thread idx"; } } else { CHECK_EQ(ts.rank, 0); switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break; - case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; break; - case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break; - default: LOG(FATAL) << "unknown thread idx"; + case 0: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; + break; + default: + LOG(FATAL) << "unknown thread idx"; } } llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); @@ -133,9 +143,8 @@ class CodeGenNVPTX : public CodeGenLLVM { // TODO(tqchen) warp sync in CUDA9 return nullptr; } else if (sync == "shared") { - llvm::Function* f = llvm::Intrinsic::getDeclaration( - module_.get(), - ::llvm::Intrinsic::nvvm_barrier0); + llvm::Function* f = + llvm::Intrinsic::getDeclaration(module_.get(), ::llvm::Intrinsic::nvvm_barrier0); return builder_->CreateCall(f, {}); } else { LOG(FATAL) << "Do not support sync " << sync; @@ -174,11 +183,9 @@ inline int DetectCUDAComputeVersion() { tvm_ctx.device_type = kDLGPU; tvm_ctx.device_id = 0; TVMRetValue val; - tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( - tvm_ctx, tvm::runtime::kExist, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { - tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( - tvm_ctx, tvm::runtime::kComputeVersion, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val); std::string version = val; std::istringstream is(version); double ver; @@ -191,12 +198,10 @@ inline int DetectCUDAComputeVersion() { runtime::Module BuildNVPTX(IRModule mod, std::string target) { InitializeLLVM(); - CHECK(target.length() >= 5 && - target.substr(0, 5) == "nvptx"); + CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx"); int compute_ver = DetectCUDAComputeVersion(); std::ostringstream config; - config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" - << compute_ver + config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver << target.substr(5, target.length() - 5); std::unique_ptr tm = GetLLVMTargetMachine(config.str()); std::unique_ptr cg(new CodeGenNVPTX()); @@ -204,15 +209,13 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); cg->AddFunction(f); } - const auto* flibdevice_path = - tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); + const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); if (flibdevice_path != nullptr) { std::string path = (*flibdevice_path)(compute_ver); if (path.length() != 0) { @@ -239,16 +242,14 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { // emit ptx llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 - CHECK(tm->addPassesToEmitFile( - pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm->addPassesToEmitFile( - pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) << "Cannot emit target CGFT_ObjectFile"; #else - CHECK(tm->addPassesToEmitFile( - pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*module); @@ -256,8 +257,7 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll); } -TVM_REGISTER_GLOBAL("target.build.nvptx") -.set_body_typed(BuildNVPTX); +TVM_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 570bb0d..d0038b8 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -24,8 +24,8 @@ #ifdef TVM_LLVM_VERSION #include -#include "codegen_cpu.h" +#include "codegen_cpu.h" #include "llvm/MC/MCSubtargetInfo.h" namespace tvm { @@ -89,14 +89,12 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, DTypeToLLVMType(DataType::Float(32, from.lanes())), { - MakeValue(tir::CallNode::make( - DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, - tir::CallNode::PureIntrinsic)), - MakeValue( - tir::BroadcastNode::make( - FloatImm(DataType::Float(32), 0), from.lanes())), - /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), - /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), + MakeValue(tir::CallNode::make(DataType::Int(16, from.lanes()), + tir::CallNode::reinterpret, {op->value}, + tir::CallNode::PureIntrinsic)), + MakeValue(tir::BroadcastNode::make(FloatImm(DataType::Float(32), 0), from.lanes())), + /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), + /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); } @@ -105,12 +103,11 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto has_f16c = TargetHasFeature(*target_machine_, "f16c"); if (from.lanes() >= 8 && has_f16c) { - return CallVectorIntrin( - ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, - DTypeToLLVMType(DataType::Float(32, from.lanes())), - {MakeValue(tir::CallNode::make( - DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, - tir::CallNode::PureIntrinsic))}); + return CallVectorIntrin(::llvm::Intrinsic::x86_vcvtph2ps_256, 8, + DTypeToLLVMType(DataType::Float(32, from.lanes())), + {MakeValue(tir::CallNode::make( + DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + {op->value}, tir::CallNode::PureIntrinsic))}); } #endif } @@ -150,10 +147,10 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - CodeGenLLVM* cg = new CodeGenX86_64(); - *rv = static_cast(cg); - }); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenX86_64(); + *rv = static_cast(cg); + }); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 58bfb37..d0bef46 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -22,153 +22,148 @@ */ #ifdef TVM_LLVM_VERSION -#include #include "intrin_rule_llvm.h" +#include + namespace tvm { namespace codegen { namespace llvm { TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch") -.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); + .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr ln10 = make_const(x.dtype(), 2.302585093); - PrimExpr ret = tir::CallNode::make( - x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); - *rv = ret; -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr ln10 = make_const(x.dtype(), 2.302585093); + PrimExpr ret = + tir::CallNode::make(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); + *rv = ret; + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1); - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_two = make_const(x.dtype(), -2); - - PrimExpr exp_neg2x = tir::CallNode::make( - x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_pos2x = tir::CallNode::make( - x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); - - PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); - PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - *rv = tir::SelectNode::make( - x >= make_zero(x.dtype()), tanh_pos, tanh_neg); -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1); + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_two = make_const(x.dtype(), -2); + + PrimExpr exp_neg2x = + tir::CallNode::make(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_pos2x = + tir::CallNode::make(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); + + PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); + PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); + *rv = tir::SelectNode::make(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); CHECK(call != nullptr); const PrimExpr& x = call->args[0]; - PrimExpr sin_x = tir::CallNode::make( - x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); - PrimExpr cos_x = tir::CallNode::make( - x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); + PrimExpr sin_x = tir::CallNode::make(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); + PrimExpr cos_x = tir::CallNode::make(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); PrimExpr tan_x = sin_x / cos_x; *rv = tan_x; }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::CallNode::make( - x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::CallNode::make( - x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); - PrimExpr ret = (exp_posx + exp_negx) / two; - *rv = ret; -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_one = make_const(x.dtype(), -1); + PrimExpr exp_negx = + tir::CallNode::make(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::CallNode::make(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr ret = (exp_posx + exp_negx) / two; + *rv = ret; + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::CallNode::make( - x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::CallNode::make( - x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); - PrimExpr ret = (exp_posx - exp_negx) / two; - *rv = ret; -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_one = make_const(x.dtype(), -1); + PrimExpr exp_negx = + tir::CallNode::make(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::CallNode::make(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr ret = (exp_posx - exp_negx) / two; + *rv = ret; + }); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index bb9ff66..8c5053b 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -25,17 +25,18 @@ #define TVM_TARGET_LLVM_INTRIN_RULE_LLVM_H_ #ifdef TVM_LLVM_VERSION -#include #include - #include +#include + #include + #include "llvm_common.h" namespace tvm { namespace codegen { // num_signature means number of arguments used to query signature -template +template inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); @@ -48,11 +49,10 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make( - call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::CallNode::make(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); } -template +template inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); @@ -64,8 +64,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make( - call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); + *rv = tir::CallNode::make(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 0dc1272..ffe35ca 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -22,9 +22,9 @@ */ #ifdef TVM_LLVM_VERSION -#include -#include #include +#include + #include namespace tvm { @@ -39,77 +39,54 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { std::ostringstream intrinsic_name; intrinsic_name << "__nv_" << call->name; if (call->dtype.bits() == 32) intrinsic_name << "f"; - *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, - CallNode::PureExtern); + *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } namespace llvm { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchExternLibDevice); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 3699c9f..52447a1 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -22,9 +22,8 @@ */ #ifdef TVM_LLVM_VERSION -#include -#include #include +#include #include @@ -38,77 +37,54 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { CHECK(call != nullptr); std::ostringstream intrinsic_name; intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits(); - *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, - CallNode::PureExtern); + *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } namespace llvm { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchExternOCML); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index 29e4db3..5534a64 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -22,11 +22,13 @@ */ #ifdef TVM_LLVM_VERSION +#include "llvm_common.h" + #include + #include -#include #include -#include "llvm_common.h" +#include namespace tvm { namespace codegen { @@ -56,15 +58,11 @@ void InitializeLLVM() { } } -void ParseLLVMTargetOptions(const std::string& target_str, - std::string* triple, - std::string* mcpu, - std::string* mattr, - llvm::TargetOptions* options) { +void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu, + std::string* mattr, llvm::TargetOptions* options) { // setup target triple size_t start = 0; - if (target_str.length() >= 4 && - target_str.substr(0, 4) == "llvm") { + if (target_str.length() >= 4 && target_str.substr(0, 4) == "llvm") { start = 4; } // simple parser @@ -82,16 +80,13 @@ void ParseLLVMTargetOptions(const std::string& target_str, } size_t pos = key.find('='); if (pos != std::string::npos) { - CHECK_GE(key.length(), pos + 1) - << "invalid argument " << key; + CHECK_GE(key.length(), pos + 1) << "invalid argument " << key; value = key.substr(pos + 1, key.length() - 1); key = key.substr(0, pos); } else { - CHECK(is >> value) - << "Unspecified value for option " << key; + CHECK(is >> value) << "Unspecified value for option " << key; } - if (key == "-target" || - key == "-mtriple") { + if (key == "-target" || key == "-mtriple") { *triple = value; } else if (key == "-mcpu") { *mcpu = value; @@ -115,16 +110,15 @@ void ParseLLVMTargetOptions(const std::string& target_str, } } - if (triple->length() == 0 || - *triple == "default") { + if (triple->length() == 0 || *triple == "default") { *triple = llvm::sys::getDefaultTargetTriple(); } // set target option llvm::TargetOptions& opt = *options; opt = llvm::TargetOptions(); - #if TVM_LLVM_VERSION < 50 +#if TVM_LLVM_VERSION < 50 opt.LessPreciseFPMADOption = true; - #endif +#endif opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; @@ -136,21 +130,14 @@ void ParseLLVMTargetOptions(const std::string& target_str, } } - -std::unique_ptr -GetLLVMTargetMachine(const std::string& target_str, - bool allow_null) { +std::unique_ptr GetLLVMTargetMachine(const std::string& target_str, + bool allow_null) { std::string target_triple, mcpu, mattr; llvm::TargetOptions opt; - ParseLLVMTargetOptions(target_str, - &target_triple, - &mcpu, - &mattr, - &opt); + ParseLLVMTargetOptions(target_str, &target_triple, &mcpu, &mattr, &opt); - if (target_triple.length() == 0 || - target_triple == "default") { + if (target_triple.length() == 0 || target_triple == "default") { target_triple = llvm::sys::getDefaultTargetTriple(); } if (mcpu.length() == 0) { @@ -158,14 +145,13 @@ GetLLVMTargetMachine(const std::string& target_str, } std::string err; - const llvm::Target* target = - llvm::TargetRegistry::lookupTarget(target_triple, err); + const llvm::Target* target = llvm::TargetRegistry::lookupTarget(target_triple, err); if (target == nullptr) { CHECK(allow_null) << err << " target_triple=" << target_triple; return nullptr; } - llvm::TargetMachine* tm = target->createTargetMachine( - target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); + llvm::TargetMachine* tm = + target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); return std::unique_ptr(tm); } diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 85ee1ee..49389fe 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -25,14 +25,12 @@ #define TVM_TARGET_LLVM_LLVM_COMMON_H_ #ifdef TVM_LLVM_VERSION -#include - #include #include -#include - -#include +#include #include +#include +#include #if TVM_LLVM_VERSION >= 100 #include #include @@ -42,43 +40,41 @@ #include #include #include -#include #include +#include #include #include #include #include +#include +#include #include #include -#include #include - -#include +#include +#include #include #include -#include -#include #if TVM_LLVM_VERSION >= 100 #include #endif +#include +#include +#include +#include #include #include #include -#include -#include #include #include +#include #include #include -#include -#include - -#include -#include -#include #include +#include +#include namespace tvm { namespace codegen { @@ -97,11 +93,8 @@ void InitializeLLVM(); * \param options the options * \param mattr The attributes */ -void ParseLLVMTargetOptions(const std::string& target_str, - std::string* triple, - std::string* mcpu, - std::string* mattr, - llvm::TargetOptions* options); +void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu, + std::string* mattr, llvm::TargetOptions* options); /*! * \brief Get target machine from target_str string. @@ -109,8 +102,8 @@ void ParseLLVMTargetOptions(const std::string& target_str, * \param allow_null Whether allow null to be returned. * \return target machine */ -std::unique_ptr -GetLLVMTargetMachine(const std::string& target_str, bool allow_null = false); +std::unique_ptr GetLLVMTargetMachine(const std::string& target_str, + bool allow_null = false); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index d1a244d..1151b33 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -23,23 +23,25 @@ */ #ifdef TVM_LLVM_VERSION +#include #include #include -#include #include + #include -#include "llvm_common.h" -#include "codegen_llvm.h" -#include "codegen_blob.h" + #include "../../runtime/file_util.h" #include "../../runtime/library_module.h" +#include "codegen_blob.h" +#include "codegen_llvm.h" +#include "llvm_common.h" namespace tvm { namespace codegen { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; class LLVMModuleNode final : public runtime::ModuleNode { public: @@ -51,24 +53,15 @@ class LLVMModuleNode final : public runtime::ModuleNode { } } - const char* type_key() const { - return "llvm"; - } + const char* type_key() const { return "llvm"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "__tvm_is_system_module") { - bool flag = - (mptr_->getFunction("__tvm_module_startup") != nullptr); - return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) { - * rv = flag; - }); + bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr); + return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); } else if (name == "_get_target_triple") { std::string target_triple = tm_->getTargetTriple().str(); - return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) { - *rv = target_triple; - }); + return PackedFunc([target_triple](TVMArgs args, TVMRetValue* rv) { *rv = target_triple; }); } if (ee_ == nullptr) LazyInitJIT(); @@ -76,8 +69,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_module_main)); + const char* entry_name = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_main)); CHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; faddr = reinterpret_cast(GetFunctionAddr(entry_name)); @@ -88,13 +81,11 @@ class LLVMModuleNode final : public runtime::ModuleNode { return WrapPackedFunc(faddr, sptr_to_self); } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = runtime::GetFileFormat(file_name, format); std::error_code ecode; llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); - CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name - << " " << ecode.message(); + CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); if (fmt == "o" || fmt == "obj") { #if TVM_LLVM_VERSION <= 60 std::unique_ptr m = llvm::CloneModule(mptr_); @@ -104,16 +95,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::legacy::PassManager pass; CHECK(tm_); #if TVM_LLVM_VERSION <= 60 - CHECK(tm_->addPassesToEmitFile( - pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == + 0) << "Cannot emit target CGFT_ObjectFile"; #else - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*m); @@ -126,16 +115,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::legacy::PassManager pass; CHECK(tm_); #if TVM_LLVM_VERSION <= 60 - CHECK(tm_->addPassesToEmitFile( - pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) << "Cannot emit target CGFT_AssemblyFile"; #else - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif pass.run(*m); @@ -148,8 +135,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::WriteBitcodeToFile(*mptr_, dest); #endif } else { - LOG(FATAL) << "Do not know how to save file " - << file_name << " with format=\'"<< format << "\'"; + LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format + << "\'"; } dest.close(); } @@ -165,28 +152,26 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::raw_svector_ostream rso(str); if (fmt == "s" || fmt == "asm") { - #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(mptr_); - #else - std::unique_ptr m = llvm::CloneModule(*mptr_); - #endif - llvm::legacy::PassManager pass; - CHECK(tm_); - #if TVM_LLVM_VERSION <= 60 - CHECK(tm_->addPassesToEmitFile( - pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; - #elif TVM_LLVM_VERSION <= 90 - CHECK(tm_->addPassesToEmitFile( - pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; - #else - CHECK(tm_->addPassesToEmitFile( - pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; - #endif - pass.run(*m); - return rso.str().str(); +#if TVM_LLVM_VERSION <= 60 + std::unique_ptr m = llvm::CloneModule(mptr_); +#else + std::unique_ptr m = llvm::CloneModule(*mptr_); +#endif + llvm::legacy::PassManager pass; + CHECK(tm_); +#if TVM_LLVM_VERSION <= 60 + CHECK(tm_->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; +#elif TVM_LLVM_VERSION <= 90 + CHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) + << "Cannot emit target CGFT_AssemblyFile"; +#else + CHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; +#endif + pass.run(*m); + return rso.str().str(); } else if (fmt == "" || fmt == "ll") { std::string type_str; llvm::raw_string_ostream rso(type_str); @@ -194,8 +179,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { mptr_->print(rso, nullptr); return rso.str(); } else { - LOG(FATAL) << "Do not know how to get source code with format: " - << format << "\'"; + LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; } return ""; } @@ -209,9 +193,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::vector funcs; std::string entry_func; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); @@ -251,8 +234,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { mptr_ = module_.get(); } - void Init(std::unique_ptr module, - std::shared_ptr ctx) { + void Init(std::unique_ptr module, std::shared_ptr ctx) { InitializeLLVM(); ctx_ = ctx; llvm::SMDiagnostic err; @@ -319,20 +301,17 @@ class LLVMModuleNode final : public runtime::ModuleNode { CHECK(layout == mptr_->getDataLayout()) << "Data layout mismatch between module(" << mptr_->getDataLayout().getStringRepresentation() << ")" - << " and ExecutionEngine (" - << layout.getStringRepresentation() << ")"; + << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; ee_ = builder.create(tm.release()); - CHECK(ee_ != nullptr) - << "Failed to initialize jit engine for " << mptr_->getTargetTriple(); + CHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << mptr_->getTargetTriple(); ee_->runStaticConstructorsDestructors(false); - if (void** ctx_addr = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { *ctx_addr = this; } - runtime::InitContextFunctions([this](const char *name) { - return reinterpret_cast(GetGlobalAddr(name)); - }); + runtime::InitContextFunctions( + [this](const char* name) { return reinterpret_cast(GetGlobalAddr(name)); }); } // Get global address from execution engine. uint64_t GetGlobalAddr(const std::string& name) const { @@ -357,7 +336,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { // JIT lock std::mutex mutex_; // execution engine - llvm::ExecutionEngine *ee_{nullptr}; + llvm::ExecutionEngine* ee_{nullptr}; // The raw pointer to the module. llvm::Module* mptr_{nullptr}; // The target machine @@ -372,17 +351,13 @@ unsigned LookupLLVMIntrinsic(const std::string& name) { return llvm::Function::lookupIntrinsicID(name); } - -TVM_REGISTER_GLOBAL("target.build.llvm") -.set_body_typed([](IRModule mod, std::string target) { +TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed([](IRModule mod, std::string target) { auto n = make_object(); n->Init(mod, target); return runtime::Module(n); }); - -TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate").set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); auto target = args[0].operator std::string(); auto module_name = args[1].operator std::string(); @@ -403,35 +378,29 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") *rv = runtime::Module(n); }); -TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = static_cast(LookupLLVMIntrinsic(args[0])); - }); - -TVM_REGISTER_GLOBAL("target.llvm_version_major") -.set_body([](TVMArgs args, TVMRetValue* rv) { - int major = TVM_LLVM_VERSION / 10; - *rv = major; - }); - -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->LoadIR(args[0]); - *rv = runtime::Module(n); - }); - -TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") -.set_body([](TVMArgs args, TVMRetValue* rv) { - InitializeLLVM(); - *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); - }); +TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = static_cast(LookupLLVMIntrinsic(args[0])); +}); + +TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body([](TVMArgs args, TVMRetValue* rv) { + int major = TVM_LLVM_VERSION / 10; + *rv = major; +}); + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll").set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + n->LoadIR(args[0]); + *rv = runtime::Module(n); +}); + +TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled").set_body([](TVMArgs args, TVMRetValue* rv) { + InitializeLLVM(); + *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); +}); -TVM_REGISTER_GLOBAL("codegen.codegen_blob") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("codegen.codegen_blob").set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); - auto p = CodeGenBlob(args[0].operator std::string(), - args[1].operator bool(), + auto p = CodeGenBlob(args[0].operator std::string(), args[1].operator bool(), args[2].operator std::string()); n->Init(std::move(p.first), p.second); *rv = runtime::Module(n); diff --git a/src/target/opt/build_aocl_off.cc b/src/target/opt/build_aocl_off.cc index 2585ac2..9f9d098 100644 --- a/src/target/opt/build_aocl_off.cc +++ b/src/target/opt/build_aocl_off.cc @@ -20,17 +20,14 @@ /*! * Optional module when build aocl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opencl/opencl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module AOCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module AOCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { LOG(WARNING) << "AOCL runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "aocl"); } diff --git a/src/target/opt/build_cuda_off.cc b/src/target/opt/build_cuda_off.cc index 4f941a5..893eb67 100644 --- a/src/target/opt/build_cuda_off.cc +++ b/src/target/opt/build_cuda_off.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,11 +24,9 @@ namespace tvm { namespace runtime { -Module CUDAModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { LOG(FATAL) << "CUDA is not enabled"; return Module(); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 99dc5ad..c9471d1 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -27,30 +27,26 @@ #include #endif #include - #include + #include -#include "../build_common.h" -#include "../source/codegen_cuda.h" #include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_module.h" - +#include "../build_common.h" +#include "../source/codegen_cuda.h" namespace tvm { namespace codegen { -#define NVRTC_CALL(x) \ - { \ - nvrtcResult result = x; \ - if (result != NVRTC_SUCCESS) { \ - LOG(FATAL) \ - << "NvrtcError: " #x " failed with error: " \ - << nvrtcGetErrorString(result); \ - } \ +#define NVRTC_CALL(x) \ + { \ + nvrtcResult result = x; \ + if (result != NVRTC_SUCCESS) { \ + LOG(FATAL) << "NvrtcError: " #x " failed with error: " << nvrtcGetErrorString(result); \ + } \ } - std::string FindCUDAIncludePath() { #if defined(_WIN32) const std::string delimiter = "\\"; @@ -78,7 +74,6 @@ std::string FindCUDAIncludePath() { return cuda_include_path; } - std::string NVRTCCompile(const std::string& code, bool include_path = false) { std::vector compile_params; std::vector param_cstrings{}; @@ -104,16 +99,15 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { } for (const auto& string : compile_params) { - param_cstrings.push_back(string.c_str()); + param_cstrings.push_back(string.c_str()); } - NVRTC_CALL(nvrtcCreateProgram( - &prog, code.c_str(), nullptr, 0, nullptr, nullptr)); - nvrtcResult compile_res = - nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); + NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); + nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); size_t log_size; NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size)); - std::string log; log.resize(log_size); + std::string log; + log.resize(log_size); NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0])); CHECK_EQ(compile_res, NVRTC_SUCCESS) << log; size_t ptx_size; @@ -133,9 +127,8 @@ runtime::Module BuildCUDA(IRModule mod, std::string target) { CodeGenCUDA cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenCUDA: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -161,7 +154,6 @@ runtime::Module BuildCUDA(IRModule mod, std::string target) { return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.cuda") -.set_body_typed(BuildCUDA); +TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA); } // namespace codegen } // namespace tvm diff --git a/src/target/opt/build_hexagon_off.cc b/src/target/opt/build_hexagon_off.cc index ce06700..c734eec 100644 --- a/src/target/opt/build_hexagon_off.cc +++ b/src/target/opt/build_hexagon_off.cc @@ -23,9 +23,8 @@ namespace tvm { namespace runtime { Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi) { LOG(WARNING) << "Hexagon runtime is not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hex"); diff --git a/src/target/opt/build_metal_off.cc b/src/target/opt/build_metal_off.cc index ff796d8..3cfe131 100644 --- a/src/target/opt/build_metal_off.cc +++ b/src/target/opt/build_metal_off.cc @@ -20,16 +20,14 @@ /*! * Optional module when build metal is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/metal/metal_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module MetalModuleCreate(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module MetalModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { LOG(WARNING) << "Metal runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "metal"); } diff --git a/src/target/opt/build_opencl_off.cc b/src/target/opt/build_opencl_off.cc index 6e796b1..2367500 100644 --- a/src/target/opt/build_opencl_off.cc +++ b/src/target/opt/build_opencl_off.cc @@ -20,17 +20,14 @@ /*! * Optional module when build opencl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opencl/opencl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module OpenCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl"); } diff --git a/src/target/opt/build_opengl_off.cc b/src/target/opt/build_opengl_off.cc index 781bf51..2e860ce 100644 --- a/src/target/opt/build_opengl_off.cc +++ b/src/target/opt/build_opengl_off.cc @@ -20,14 +20,13 @@ /*! * Optional module when build opencl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opengl/opengl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module OpenGLModuleCreate(std::unordered_map shaders, - std::string fmt, +Module OpenGLModuleCreate(std::unordered_map shaders, std::string fmt, std::unordered_map fmap) { LOG(WARNING) << "OpenGL runtime not enabled, return a source module..."; auto data = ToJSON(shaders); diff --git a/src/target/opt/build_rocm_off.cc b/src/target/opt/build_rocm_off.cc index 64ab759..476e5a8 100644 --- a/src/target/opt/build_rocm_off.cc +++ b/src/target/opt/build_rocm_off.cc @@ -20,19 +20,15 @@ /*! * Optional module when build rocm is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/rocm/rocm_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module ROCMModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string rocm_source, - std::string assembly) { - +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string rocm_source, + std::string assembly) { LOG(WARNING) << "ROCM runtime is not enabled, return a source module..."; auto fget_source = [rocm_source, assembly](const std::string& format) { if (format.length() == 0) return assembly; @@ -40,8 +36,7 @@ Module ROCMModuleCreate( if (format == "asm") return assembly; return std::string(""); }; - return codegen::DeviceSourceModuleCreate( - data, fmt, fmap, "hsaco", fget_source); + return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hsaco", fget_source); } } // namespace runtime diff --git a/src/target/opt/build_sdaccel_off.cc b/src/target/opt/build_sdaccel_off.cc index 8c58c3f..0de305c 100644 --- a/src/target/opt/build_sdaccel_off.cc +++ b/src/target/opt/build_sdaccel_off.cc @@ -20,17 +20,14 @@ /*! * Optional module when build opencl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opencl/opencl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module SDAccelModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module SDAccelModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { LOG(WARNING) << "OpenCL runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "sdaccel"); } diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index 64674e3..2b77869 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -21,28 +21,27 @@ * \file codegen_aocl.cc */ #include -#include + #include -#include "codegen_opencl.h" -#include "../build_common.h" -#include "../../runtime/opencl/aocl/aocl_module.h" +#include + #include "../../runtime/file_util.h" +#include "../../runtime/opencl/aocl/aocl_module.h" +#include "../build_common.h" +#include "codegen_opencl.h" namespace tvm { namespace codegen { -runtime::Module BuildAOCL(IRModule mod, - std::string target_str, - bool emulation) { +runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) { // Get code. using tvm::runtime::Registry; bool output_ssa = false; CodeGenOpenCL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodegenOpenCL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodegenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -80,15 +79,13 @@ runtime::Module BuildAOCL(IRModule mod, return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.aocl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildAOCL(args[0], args[1], false); - }); +TVM_REGISTER_GLOBAL("target.build.aocl").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildAOCL(args[0], args[1], false); +}); -TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildAOCL(args[0], args[1], true); - }); +TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildAOCL(args[0], args[1], true); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index adb84e4..a992851 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -20,20 +20,20 @@ /*! * \file codegen_c.cc */ -#include -#include #include "codegen_c.h" -#include "../../arith/pattern_match.h" + +#include +#include + #include "../../arith/compute_expr.h" +#include "../../arith/pattern_match.h" namespace tvm { namespace codegen { using namespace tir; -void CodeGenC::Init(bool output_ssa) { - print_ssa_form_ = output_ssa; -} +void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; } void CodeGenC::InitFuncState(const PrimFunc& f) { alloc_storage_scope_.clear(); @@ -79,8 +79,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { ReserveKeywordsAsUnique(); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); @@ -124,16 +123,11 @@ void CodeGenC::AddFunction(const PrimFunc& f) { this->stream << "}\n\n"; } -void CodeGenC::PrintFuncPrefix() { - stream << "void"; -} +void CodeGenC::PrintFuncPrefix() { stream << "void"; } -void CodeGenC::PrintFinalReturn() { -} +void CodeGenC::PrintFinalReturn() {} -std::string CodeGenC::Finish() { - return decl_stream.str() + stream.str(); -} +std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) if (print_ssa_form_) { @@ -145,12 +139,10 @@ void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) } } -void CodeGenC::PrintSSAAssign( - const std::string& target, const std::string& src, DataType t) { +void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, DataType t) { PrintType(t, stream); stream << ' ' << target << " = "; - if (src.length() > 3 && - src[0] == '(' && src[src.length() - 1] == ')') { + if (src.length() > 3 && src[0] == '(' && src[src.length() - 1] == ')') { stream << src.substr(1, src.length() - 2); } else { stream << src; @@ -159,8 +151,7 @@ void CodeGenC::PrintSSAAssign( } // Print a reference expression to a buffer. -std::string CodeGenC::GetBufferRef( - DataType t, const VarNode* buffer, PrimExpr index) { +std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) { std::ostringstream os; std::string vid = GetVarID(buffer); std::string scope; @@ -186,8 +177,7 @@ std::string CodeGenC::GetBufferRef( os << "[("; PrintExpr(index, os); os << ")"; - if (t.bits() == 4 || - (t.bits() == 1 && t.is_int())) { + if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { os << " / " << (32 / t.bits()); } os << ']'; @@ -198,8 +188,7 @@ std::string CodeGenC::GetBufferRef( // optimize for constant access if (auto* ptr = index.as()) { int64_t offset = ptr->value; - CHECK_EQ(offset % t.lanes(), 0) - << "Find unaligned vector load to a vector type"; + CHECK_EQ(offset % t.lanes(), 0) << "Find unaligned vector load to a vector type"; os << vid << '[' << (offset / t.lanes()) << ']'; return os.str(); } @@ -224,8 +213,7 @@ std::string CodeGenC::GetBufferRef( os << vid << " + ("; PrintExpr(index, os); os << ")"; - if (t.bits() == 4 || - (t.bits() == 1 && t.is_int())) { + if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { os << " / " << (32 / t.bits()); } os << "))[0]"; @@ -234,8 +222,8 @@ std::string CodeGenC::GetBufferRef( } // Print a reference expression to a buffer. -std::string CodeGenC::GetStructRef( - DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) { +std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, + int kind) { if (kind < intrinsic::kArrKindBound_) { std::ostringstream os; os << "(((DLTensor*)"; @@ -252,17 +240,38 @@ std::string CodeGenC::GetStructRef( os << "]."; // other case: get fields. switch (kind) { - case intrinsic::kArrData: os << "data"; break; - case intrinsic::kArrShape: os << "shape"; break; - case intrinsic::kArrStrides: os << "strides"; break; - case intrinsic::kArrNDim: os << "ndim"; break; - case intrinsic::kArrTypeCode: os << "dtype.code"; break; - case intrinsic::kArrTypeBits: os << "dtype.bits"; break; - case intrinsic::kArrByteOffset: os << "byte_offset"; break; - case intrinsic::kArrTypeLanes: os << "dtype.lanes"; break; - case intrinsic::kArrDeviceId: os << "ctx.device_id"; break; - case intrinsic::kArrDeviceType: os << "ctx.device_type"; break; - default: LOG(FATAL) << "unknown field code"; + case intrinsic::kArrData: + os << "data"; + break; + case intrinsic::kArrShape: + os << "shape"; + break; + case intrinsic::kArrStrides: + os << "strides"; + break; + case intrinsic::kArrNDim: + os << "ndim"; + break; + case intrinsic::kArrTypeCode: + os << "dtype.code"; + break; + case intrinsic::kArrTypeBits: + os << "dtype.bits"; + break; + case intrinsic::kArrByteOffset: + os << "byte_offset"; + break; + case intrinsic::kArrTypeLanes: + os << "dtype.lanes"; + break; + case intrinsic::kArrDeviceId: + os << "ctx.device_id"; + break; + case intrinsic::kArrDeviceType: + os << "ctx.device_type"; + break; + default: + LOG(FATAL) << "unknown field code"; } os << ')'; return os.str(); @@ -297,32 +306,26 @@ void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) { if (it == handle_data_type_.end()) { handle_data_type_[buf_var] = t; } else { - CHECK(it->second == t) - << "conflicting buf var type"; + CHECK(it->second == t) << "conflicting buf var type"; } } -void CodeGenC::PrintVecElemLoad(const std::string& vec, - DataType t, int i, +void CodeGenC::PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << ".s" << std::hex << i << std::dec; } -void CodeGenC::PrintVecElemStore(const std::string& vec, - DataType t, int i, +void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); - stream << vec << ".s" << std::hex << i - << " = " << value << ";\n" << std::dec; + stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec; } -std::string CodeGenC::GetVecLoad( - DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenC::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { return GetBufferRef(t, buffer, base); } -void CodeGenC::PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, +void CodeGenC::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) { std::string ref = GetBufferRef(t, buffer, base); this->PrintIndent(); @@ -338,49 +341,58 @@ std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType targ return os.str(); } -void CodeGenC::BindThreadIndex(const IterVar& iv) { - LOG(FATAL) << "not implemented"; -} +void CodeGenC::BindThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented"; } -void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*) +void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*) } -void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) CHECK_EQ(scope, "global"); } void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*) - CHECK_EQ(t.lanes(), 1) - << "do not yet support vector types"; + CHECK_EQ(t.lanes(), 1) << "do not yet support vector types"; if (t.is_handle()) { - os << "void*"; return; + os << "void*"; + return; } if (t.is_float()) { if (t.bits() == 32) { - os << "float"; return; + os << "float"; + return; } if (t.bits() == 64) { - os << "double"; return; + os << "double"; + return; } } else if (t.is_uint()) { switch (t.bits()) { - case 8: case 16: case 32: case 64: { - os << "uint" << t.bits() << "_t"; return; + case 8: + case 16: + case 32: + case 64: { + os << "uint" << t.bits() << "_t"; + return; } - case 1: os << "int"; return; + case 1: + os << "int"; + return; } } else if (t.is_int()) { switch (t.bits()) { - case 8: case 16: case 32: case 64: { - os << "int" << t.bits() << "_t"; return; + case 8: + case 16: + case 32: + case 64: { + os << "int" << t.bits() << "_t"; + return; } } } LOG(FATAL) << "Cannot convert type " << t << " to C type"; } - -void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) if (auto* ptr = type.as()) { return PrintType(ptr->dtype, os); } else if (auto* ptr = type.as()) { @@ -393,8 +405,7 @@ void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) } } - -inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) +inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) if (op->dtype == DataType::Int(32)) { std::ostringstream temp; temp << op->value; @@ -407,8 +418,8 @@ inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // } } - -inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*) +inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, + CodeGenC* p) { // NOLINT(*) if (dtype == DataType::UInt(32)) { std::ostringstream temp; temp << val << "U"; @@ -421,9 +432,10 @@ inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeG } } -inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) +inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) switch (op->dtype.bits()) { - case 64: case 32: { + case 64: + case 32: { std::ostringstream temp; temp << std::scientific << op->value; if (op->dtype.bits() == 32) temp << 'f'; @@ -434,10 +446,11 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { case 16: { os << '('; p->PrintType(op->dtype, os); - os << ')' << std::scientific <value << 'f'; + os << ')' << std::scientific << op->value << 'f'; break; } - default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; } } @@ -445,16 +458,15 @@ void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(* PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) os << "\"" << op->value << "\""; } -template -inline void PrintBinaryExpr(const T* op, - const char* opstr, +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { if (op->dtype.lanes() == 1) { @@ -476,10 +488,9 @@ inline void PrintBinaryExpr(const T* op, } } -inline void PrintBinaryIntrinsic(const CallNode* op, - const char* opstr, - std::ostream& os, // NOLINT(*) - CodeGenC* p) { +inline void PrintBinaryIntrinsic(const CallNode* op, const char* opstr, + std::ostream& os, // NOLINT(*) + CodeGenC* p) { if (op->dtype.lanes() == 1) { CHECK_EQ(op->args.size(), 2U); os << '('; @@ -550,8 +561,7 @@ void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) } void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->call_type == CallNode::Extern || - op->call_type == CallNode::PureExtern) { + if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { os << op->name << "("; for (size_t i = 0; i < op->args.size(); i++) { this->PrintExpr(op->args[i], os); @@ -590,19 +600,16 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->args[2], os); os << ")"; } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); os << "(("; this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) - << " + "; + os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; this->PrintExpr(l->index, os); os << ')'; } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); - os << GetStructRef( - op->dtype, op->args[0], op->args[1], - op->args[2].as()->value); + os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { CHECK_EQ(op->args.size(), 1U); os << "("; @@ -622,19 +629,16 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) this->PrintExpr(op->args[0], os); os << ")"; } else { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name - << " with return type " << op->dtype; + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { + LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } } } -void CodeGenC::PrintVecBinaryOp( - const std::string& op, DataType t, - PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) { // NOLINT(*) if (isalpha(op[0])) { os << op << "("; this->PrintExpr(lhs, os); @@ -642,7 +646,7 @@ void CodeGenC::PrintVecBinaryOp( this->PrintExpr(rhs, os); os << ")"; } else { - os <<"("; + os << "("; this->PrintExpr(lhs, os); os << ' ' << op << ' '; this->PrintExpr(rhs, os); @@ -657,8 +661,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index); HandleVolatileLoads(ref, op, os); } else { - CHECK(is_one(op->predicate)) - << "predicated load is not supported"; + CHECK(is_one(op->predicate)) << "predicated load is not supported"; arith::PVar base; if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) { @@ -698,12 +701,11 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { DataType t = op->value.dtype(); if (t.lanes() == 1) { std::string value = this->PrintExpr(op->value); - std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); + std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); this->PrintIndent(); stream << ref << " = " << value << ";\n"; } else { - CHECK(is_one(op->predicate)) - << "Predicated store is not supported"; + CHECK(is_one(op->predicate)) << "Predicated store is not supported"; arith::PVar base; if (arith::ramp(base, 1, t.lanes()).Match(op->index)) { std::string value = this->PrintExpr(op->value); @@ -756,9 +758,9 @@ void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) CHECK_EQ(op->base.dtype(), DataType::Int(32)); os << "((int" << op->lanes << ")("; for (int i = 0; i < op->lanes; i++) { - os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")"; - if (i != op->lanes - 1) - os << ", "; + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != op->lanes - 1) os << ", "; } os << "))"; } @@ -767,7 +769,7 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { LOG(FATAL) << "Shuffle: not supported "; } -void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Broadcast: not supported "; } @@ -788,19 +790,14 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) { var_idmap_[op->var.get()] = value; } else { PrintIndent(); - if (op->var.dtype() == DataType::Handle() && - handle_data_type_.count(op->var.get())) { + if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) { PrintType(handle_data_type_.at(op->var.get()), stream); - stream << "* " - << AllocVarID(op->var.get()) - << " = ("; + stream << "* " << AllocVarID(op->var.get()) << " = ("; PrintType(handle_data_type_.at(op->var.get()), stream); - stream << "*)" << value << ";\n"; + stream << "*)" << value << ";\n"; } else { PrintType(op->var.dtype(), this->stream); - this->stream << ' ' - << AllocVarID(op->var.get()) - << " = " << value << ";\n"; + this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n"; } } PrintStmt(op->body); @@ -810,15 +807,14 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); - this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - std::string scope = alloc_storage_scope_.at(buffer); - PrintStorageScope(scope, stream); - PrintType(op->dtype, stream); - stream << ' ' << vid << '[' << constant_size << "];\n"; + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + const VarNode* buffer = op->buffer_var.as(); + std::string scope = alloc_storage_scope_.at(buffer); + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); + stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); @@ -867,9 +863,7 @@ void CodeGenC::VisitStmt_(const ForNode* op) { CHECK(is_zero(op->min)); stream << "for ("; PrintType(op->loop_var.dtype(), stream); - stream << ' ' << vid << " = 0; " - << vid << " < " << extent - << "; ++" << vid << ") {\n"; + stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); @@ -911,15 +905,13 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { const CallNode* call = op->value.as(); if (call) { if (call->is_intrinsic(intrinsic::tvm_storage_sync)) { - this->PrintStorageSync(call); return; + this->PrintStorageSync(call); + return; } else if (call->is_intrinsic(intrinsic::tvm_struct_set)) { CHECK_EQ(call->args.size(), 4); std::string value = PrintExpr(call->args[3]); - std::string ref = GetStructRef( - call->args[3].dtype(), - call->args[0], - call->args[1], - call->args[2].as()->value); + std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], + call->args[2].as()->value); this->PrintIndent(); this->stream << ref << " = " << value << ";\n"; return; @@ -932,8 +924,7 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { } } -void CodeGenC::PrintVecElemLoadExpr( - DataType t, int i, const std::string& value, std::ostream& os) { +void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) { CHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (i != 0) { diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 4fb4b7e..309eb06 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -24,16 +24,18 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_ #define TVM_TARGET_SOURCE_CODEGEN_C_H_ +#include +#include #include -#include #include +#include #include -#include -#include + #include -#include #include #include +#include + #include "codegen_source_base.h" namespace tvm { @@ -50,10 +52,9 @@ using namespace tir; * and OpenCL-C. You might find some odd variant features, e.g., type `int3` for * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`. */ -class CodeGenC : - public ExprFunctor, - public StmtFunctor, - public CodeGenSourceBase { +class CodeGenC : public ExprFunctor, + public StmtFunctor, + public CodeGenSourceBase { public: /*! * \brief Initialize the code generator. @@ -75,9 +76,7 @@ class CodeGenC : * \brief Print the Stmt n to CodeGenC->stream * \param n The statement to be printed. */ - void PrintStmt(const Stmt& n) { - VisitStmt(n); - } + void PrintStmt(const Stmt& n) { VisitStmt(n); } /*! * \brief Print the expression n(or its ssa id if in ssa mode) into os * \param n The expression to be printed. @@ -99,11 +98,11 @@ class CodeGenC : * * Example: stream << "void"; */ - virtual void PrintFuncPrefix(); // NOLINT(*) + virtual void PrintFuncPrefix(); // NOLINT(*) /*! * \brief Print the final return at the end the function. */ - virtual void PrintFinalReturn(); // NOLINT(*) + virtual void PrintFinalReturn(); // NOLINT(*) /*! * \brief Insert statement before function body. * \param f The function to be compiled. @@ -115,33 +114,33 @@ class CodeGenC : */ virtual void InitFuncState(const PrimFunc& f); // expression - void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment void VisitStmt_(const LetStmtNode* op) override; @@ -158,36 +157,34 @@ class CodeGenC : * \param t The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) + virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) /*! * Print Type represetnation of type type. * \param type The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*) + virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*) /*! * \brief Print expr representing the thread tag * \param IterVar iv The thread index to be binded; */ - virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*) - virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) - virtual void PrintStorageSync(const CallNode* op); // NOLINT(*) + virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*) + virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) + virtual void PrintStorageSync(const CallNode* op); // NOLINT(*) // Binary vector op. - virtual void PrintVecBinaryOp( - const std::string&op, DataType op_type, - PrimExpr lhs, PrimExpr rhs, std::ostream& os); // NOLINT(*) + virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs, + std::ostream& os); // NOLINT(*) // print vector load virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base); // print vector store - virtual void PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, + virtual void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value); // NOLINT(*) // print load of single element - virtual void PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os); // NOLINT(*) + virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os); // NOLINT(*) // print store of single element. - virtual void PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value); + virtual void PrintVecElemStore(const std::string& vec, DataType t, int i, + const std::string& value); // Get a cast type from to virtual std::string CastFromTo(std::string value, DataType from, DataType target); // Get load of single element with expression @@ -195,11 +192,9 @@ class CodeGenC : protected: // Print reference to struct location - std::string GetStructRef( - DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); + std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); // Print reference to a buffer as type t in index. - virtual std::string GetBufferRef( - DataType t, const VarNode* buffer, PrimExpr index); + virtual std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index); /*! * \brief Handle volatile loads. @@ -209,8 +204,7 @@ class CodeGenC : * does not implement volatile member functions. CUDA codegen will cast * away volatile qualifier from CUDA __half types. */ - virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, - std::ostream& os) { + virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) { // By default, do nothing but print the loaded value. os << value; } @@ -223,9 +217,7 @@ class CodeGenC : * or "__constant__" is not part of type but a storage class (like * C/C++ static). */ - virtual bool IsScopePartOfType() const { - return true; - } + virtual bool IsScopePartOfType() const { return true; } /*! * \brief If buffer is allocated as type t. @@ -240,15 +232,12 @@ class CodeGenC : */ void RegisterHandleType(const VarNode* buf_var, DataType t); // override - void PrintSSAAssign( - const std::string& target, const std::string& src, DataType t) final; + void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final; /*! \brief reserves common C keywords */ void ReserveKeywordsAsUnique(); /*! \brief Check if buf_var is volatile or not. */ - bool IsVolatile(const VarNode *buf_var) const { - return volatile_buf_.count(buf_var) != 0; - } + bool IsVolatile(const VarNode* buf_var) const { return volatile_buf_.count(buf_var) != 0; } /*! \brief restrict keyword */ std::string restrict_keyword_{""}; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 5e5db82..b11b3d8 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -20,18 +20,19 @@ /*! * \file codegen_c_host.cc */ +#include "codegen_c_host.h" + #include -#include + #include +#include + #include "../build_common.h" -#include "codegen_c_host.h" namespace tvm { namespace codegen { -CodeGenCHost::CodeGenCHost() { - module_name_ = GetUniqueName("__tvm_module_ctx"); -} +CodeGenCHost::CodeGenCHost() { module_name_ = GetUniqueName("__tvm_module_ctx"); } void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) { emit_asserts_ = emit_asserts; @@ -57,12 +58,13 @@ void CodeGenCHost::PrintFinalReturn() { // NOLINT(*) void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "does not support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "does not support vector types"; + os << "void*"; + return; } if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } bool fail = false; if (t.is_float()) { @@ -70,37 +72,55 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) case 16: os << "half"; break; - case 32: os << "float"; break; + case 32: + os << "float"; + break; case 64: os << "double"; break; - default: fail = true; break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { os << 'u'; } switch (t.bits()) { - case 8: os << "int8_t"; break; - case 16: os << "int16_t"; break; - case 32: os << "int32_t"; break; - case 64: os << "int64_t"; break; - case 1: os << "int32_t"; break; - default: fail = true; break; + case 8: + os << "int8_t"; + break; + case 16: + os << "int16_t"; + break; + case 32: + os << "int32_t"; + break; + case 64: + os << "int64_t"; + break; + case 1: + os << "int32_t"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to C type"; } -void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); os << "(("; PrintType(op->dtype, os); @@ -118,9 +138,8 @@ void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name, this->stream << "if (" << packed_func_name << " == NULL) {\n"; int packed_func_if_scope = this->BeginScope(); this->PrintIndent(); - this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ - << ", \"" << func_name << "\"" - << ", &" << packed_func_name << ") != 0) {\n"; + this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" << func_name << "\"" + << ", &" << packed_func_name << ") != 0) {\n"; int get_func_env_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; @@ -141,9 +160,12 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar this->stream << "int " << ret_type_code << ";\n"; this->PrintIndent(); this->stream << "if (TVMFuncCall(" << packed_func_name << ", " - << "(TVMValue*) stack_value" << ", " << "(int*) stack_tcode" << ", " - << num_args << ", " << "&" << ret_val << ", " << "&" - << ret_type_code << ") != 0) {\n"; + << "(TVMValue*) stack_value" + << ", " + << "(int*) stack_tcode" + << ", " << num_args << ", " + << "&" << ret_val << ", " + << "&" << ret_type_code << ") != 0) {\n"; int func_call_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; @@ -152,7 +174,7 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar this->stream << "}\n"; } -void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { std::string stack_name = GetUniqueName("stack"); const std::string& type = op->args[0].as()->value; @@ -188,8 +210,8 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT( std::string packed_func_name = func_name + "_packed"; if (declared_globals_.insert(packed_func_name).second) { // Still reserve the name among unique names. - CHECK(GetUniqueName(packed_func_name) == packed_func_name) << - "Expected name " << packed_func_name << " to not be taken"; + CHECK(GetUniqueName(packed_func_name) == packed_func_name) + << "Expected name " << packed_func_name << " to not be taken"; decl_stream << "static void* " << packed_func_name << " = NULL;\n"; } this->PrintGetFuncFromBackend(func_name, packed_func_name); @@ -202,7 +224,7 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT( } } -void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) +void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) { // NOLINT(*) if (emit_asserts_) { std::string cond = PrintExpr(op->condition); PrintIndent(); @@ -219,18 +241,17 @@ void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) this->PrintStmt(op->body); } -void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) PrintTernaryCondExpr(op, "<", os); } -void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) PrintTernaryCondExpr(op, ">", os); } template -inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, - const char* compare, - std::ostream& os) { // NOLINT(*) +inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, + std::ostream& os) { // NOLINT(*) std::ostringstream temp_a; VisitExpr(op->a, temp_a); std::string a_id = SSAGetID(temp_a.str(), op->a.dtype()); @@ -250,8 +271,7 @@ runtime::Module BuildCHost(IRModule mod) { cg.Init(output_ssa, emit_asserts); for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodegenCHost: Can only take PrimFunc"; + CHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; auto f = Downcast(kv.second); cg.AddFunction(f); } @@ -260,8 +280,7 @@ runtime::Module BuildCHost(IRModule mod) { return CSourceModuleCreate(code, "c"); } -TVM_REGISTER_GLOBAL("target.build.c") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("target.build.c").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildCHost(args[0]); }); } // namespace codegen diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index bec9686..94a76fa 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -26,9 +26,10 @@ #include #include + +#include "codegen_c.h" #include "tvm/target/codegen.h" #include "tvm/tir/expr.h" -#include "codegen_c.h" namespace tvm { namespace codegen { @@ -38,19 +39,19 @@ class CodeGenCHost final : public CodeGenC { CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts); - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void PrintFuncPrefix() final; // NOLINT(*) - void PrintFinalReturn() final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintFuncPrefix() final; // NOLINT(*) + void PrintFinalReturn() final; // NOLINT(*) // overload visitor functions - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const CallNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) // overload min and max to use the ternary operator, so we don't rely on the // standard library implementations - void VisitExpr_(const MinNode *op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const MaxNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) final; // NOLINT(*) - void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*) + void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) private: std::string module_name_; @@ -70,8 +71,7 @@ class CodeGenCHost final : public CodeGenC { * \param os stream reference to print into */ template - inline void PrintTernaryCondExpr(const T* op, - const char* compare, + inline void PrintTernaryCondExpr(const T* op, const char* compare, std::ostream& os); // NOLINT(*) }; diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 591e4d0..cf7a74f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -21,21 +21,21 @@ * \file codegen_cuda.cc */ +#include "codegen_cuda.h" + #include #include +#include #include #include -#include + #include "literal/cuda_half_t.h" -#include "codegen_cuda.h" namespace tvm { namespace codegen { -CodeGenCUDA::CodeGenCUDA() { - restrict_keyword_ = "__restrict__"; -} +CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } void CodeGenCUDA::Init(bool output_ssa) { CodeGenC::Init(output_ssa); @@ -44,10 +44,7 @@ void CodeGenCUDA::Init(bool output_ssa) { CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } - -void CodeGenCUDA::PrintFuncPrefix() { - stream << "extern \"C\" __global__ void"; -} +void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; } std::string CodeGenCUDA::Finish() { if (enable_fp16_) { @@ -96,16 +93,15 @@ void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); - var_idmap_[iv->var.get()] = - CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); } void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "do not yet support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; } bool fail = false; if (t.is_float()) { @@ -130,22 +126,31 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) fail = true; } break; - case 32: os << "float"; break; - case 64: os << "double"; break; - default: fail = true; break; + case 32: + os << "float"; + break; + case 64: + os << "double"; + break; + default: + fail = true; + break; } if (!fail && (lanes == 1 || t.bits() == 16)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } else if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } else if (t.is_vector_bool()) { // CUDA does not support bool vectors. // Use ushort vectors to represent instead. int n = t.lanes(); if (n <= 4) { - os << "ushort" << n; return; + os << "ushort" << n; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -158,31 +163,41 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) switch (t.bits()) { case 1: { if (t.lanes() == 1) { - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 8) { - os << "int8_t"; return; + os << "int8_t"; + return; } else if (t.lanes() == 16) { - os << "int16_t"; return; + os << "int16_t"; + return; } else if (t.lanes() == 32) { - os << "int"; return; + os << "int"; + return; } else { LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; } } case 4: { if (t.lanes() == 1) { - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 4) { - os << "int16_t"; return; + os << "int16_t"; + return; } else if (t.lanes() == 8) { // directly 8 4-bit int in integer. - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 16) { - os << "int2"; return; + os << "int2"; + return; } else if (t.lanes() == 32) { - os << "int4"; return; + os << "int4"; + return; } else if (t.lanes() == 64) { - os << "int8"; return; + os << "int8"; + return; } else { LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; } @@ -195,51 +210,65 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // We use int for int8x4 instead of char4 because using char4 is // likely to produce extra instructions to pack four int8 elements // into 32-bit data. - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 8) { enable_int8_ = true; - os << "int2"; return; + os << "int2"; + return; } else if (t.lanes() == 16) { enable_int8_ = true; - os << "int4"; return; + os << "int4"; + return; } else if (!t.is_uint() && t.lanes() == 1) { - os << "signed char"; break; + os << "signed char"; + break; } else { - os << "char"; break; + os << "char"; + break; } } - case 16: os << "short"; break; - case 32: os << "int"; break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; case 64: { - if (sizeof(long) != 8) { // NOLINT(*) + if (sizeof(long) != 8) { // NOLINT(*) if (t.lanes() == 1) { - os << "long long"; break; + os << "long long"; + break; } else if (t.lanes() == 2) { - os << "longlong"; break; + os << "longlong"; + break; } else { // No longlong3, longlong4 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type on a L32 platform"; break; } } else { - os << "long"; break; + os << "long"; + break; } } - default: fail = true; break; + default: + fail = true; + break; } if (!fail && lanes == 1) { return; } if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } -void CodeGenCUDA::PrintVecBinaryOp( - const std::string& op, DataType t, - PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) { // NOLINT(*) // Delcare the result. std::string sret = GetUniqueName("_"); this->PrintIndent(); @@ -271,8 +300,8 @@ void CodeGenCUDA::PrintVecBinaryOp( os << sret; } -void CodeGenCUDA::PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) { // NOLINT(*) if (t.is_scalar()) { os << vec; return; @@ -293,15 +322,14 @@ void CodeGenCUDA::PrintVecElemLoad( os << "((unsigned char)(" << vec << " >> " << i * 8 << "))"; } } else if (t.is_float16()) { - os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2]; + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else { os << vec << "." << access[i]; } } -void CodeGenCUDA::PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value) { +void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, + const std::string& value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); @@ -318,8 +346,8 @@ void CodeGenCUDA::PrintVecElemStore( stream << "(" << value << " << " << i * 8 << ");\n"; } } else if (t.is_float16()) { - stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2] << " = " << value << ";\n"; + stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " + << value << ";\n"; } else { stream << vec << "." << access[i] << " = " << value << ";\n"; } @@ -335,8 +363,8 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { } else if (sync == "global") { if (!need_global_barrier_) { need_global_barrier_ = true; - this->decl_stream << "extern \"C\" __device__ unsigned " - << vid_global_barrier_state_ << ";\n"; + this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_ + << ";\n"; } // global synchronizer std::string is_load = PrintExpr(op->args[1]); @@ -344,30 +372,28 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { this->PrintIndent(); // In theory only threadfence is needed // but we observed problems with only threadfence - this->stream <<"__threadfence_system();\n"; + this->stream << "__threadfence_system();\n"; this->PrintIndent(); - this->stream <<"if (" << is_load << ") {\n"; + this->stream << "if (" << is_load << ") {\n"; int wb = this->BeginScope(); this->PrintIndent(); this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n"; this->PrintIndent(); std::string ptr = GetUniqueName("pf"); - this->stream << "volatile unsigned* " - << ptr << " = &" << vid_global_barrier_state_<< ";\n"; + this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n"; this->PrintIndent(); this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n"; this->PrintIndent(); - this->stream <<"while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n"; + this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n"; this->EndScope(wb); this->PrintIndent(); - this->stream <<"}\n"; + this->stream << "}\n"; this->PrintIndent(); - this->stream <<"__syncthreads();\n"; + this->stream << "__syncthreads();\n"; } } -void CodeGenCUDA::PrintStorageScope( - const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) CHECK_NE(scope, "global"); if (scope == "shared") { os << "__shared__ "; @@ -380,8 +406,7 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { CHECK_EQ(target_ty.lanes(), from_ty.lanes()); // Emit simple C-style type conversion. - if (from_ty.is_scalar()) - return CodeGenC::VisitExpr_(op, os); + if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); // We could emit make_float4 like calls, but the emitted code looks // too compact to read. Emit this as vectorized unary ops. @@ -407,8 +432,7 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // This is only for backward compatibility with __shfl_{up/down}. // A macro will be used to replace *_sync calls to legacy ones. - if (op->is_intrinsic("__shfl_sync") || - op->is_intrinsic("__shfl_up_sync") || + if (op->is_intrinsic("__shfl_sync") || op->is_intrinsic("__shfl_up_sync") || op->is_intrinsic("__shfl_down_sync")) { enable_warp_shuffle_ = true; } @@ -446,7 +470,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[4], os); os << "], "; this->PrintExpr(op->args[6], os); - if (const StringImmNode *str = op->args[7].as()) { + if (const StringImmNode* str = op->args[7].as()) { os << ", nvcuda::wmma::mem_" << str->value; } else { LOG(FATAL) << "Invalid parameters"; @@ -460,7 +484,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2], os); os << "["; this->PrintExpr(op->args[i * 2 + 1], os); - os << "]" << ((i < 3) ? ", ": ")"); + os << "]" << ((i < 3) ? ", " : ")"); } } else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) { need_mma_h_ = true; @@ -470,7 +494,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2], os); os << "["; this->PrintExpr(op->args[i * 2 + 1], os); - os << "]" << ((i < 3) ? ", ": ")"); + os << "]" << ((i < 3) ? ", " : ")"); } } else if (op->call_type == CallNode::PureExtern && op->dtype.is_vector()) { // @@ -509,8 +533,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::ostringstream scall; scall << op->name << "("; for (size_t j = 0; j < op->args.size(); ++j) { - if (j > 0) - scall << ", "; + if (j > 0) scall << ", "; PrintVecElemLoad(sargs[j], op->args[j].dtype(), i, scall); } scall << ")"; @@ -542,25 +565,20 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; const VarNode* buffer = op->buffer_var.as(); std::string scope = alloc_storage_scope_.at(buffer); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - CHECK(op->dtype == DataType::Float(16) || - op->dtype == DataType::Int(8) || - op->dtype == DataType::UInt(8) || - op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) - << "Matrix_a and matrix_b only support half or char or unsigned char " - << "or uint4 or int4 or int1 type for now"; + CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || + op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || + op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1)) + << "Matrix_a and matrix_b only support half or char or unsigned char " + << "or uint4 or int4 or int1 type for now"; } else { - CHECK(op->dtype == DataType::Float(16) || - op->dtype == DataType::Float(32) || + CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32)) - << "Accumulator only support half, float and int type for now"; + << "Accumulator only support half, float and int type for now"; } constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); @@ -568,19 +586,18 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { PrintStorageScope(scope, stream); PrintType(op->dtype, stream); } - if ((op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) && scope == "shared") { + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { constant_size = constant_size / (32 / op->dtype.bits()); } - stream << ' '<< vid << '[' - << constant_size << "];\n"; + stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); } -void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) { +void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; const CallNode* call = op->value.as(); if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) { @@ -600,17 +617,17 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) { void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { os << "((make_int" << op->lanes << ")("; for (int i = 0; i < op->lanes; i++) { - os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")"; - if (i != op->lanes - 1) - os << ", "; + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != op->lanes - 1) os << ", "; } os << "))"; } -void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) { // make_int8x4 - const int64_t *p = as_const_int(op->value); + const int64_t* p = as_const_int(op->value); CHECK(p); int64_t v = *p & 0xFF; v = (v << 24) | (v << 16) | (v << 8) | v; @@ -629,7 +646,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N os << '('; for (int i = 0; i < op->lanes / 2; ++i) { if (i != 0) os << ", "; - os << "__pack_half2(" << v << ", " << v << ")"; + os << "__pack_half2(" << v << ", " << v << ")"; } os << ')'; return; @@ -646,7 +663,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N os << ')'; } -void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) { +void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) { std::vector to_shuffle(op->vectors.size()); for (int i = 0, e = op->vectors.size(); i < e; ++i) { CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!"; @@ -656,15 +673,15 @@ void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) { PrintType(op->dtype, os); os << '('; for (int i = 0, e = op->indices.size(); i < e; ++i) { - const int64_t *val = as_const_int(op->indices[i]); - CHECK(val && *val >= 0 && (int) *val < (int) to_shuffle.size()); + const int64_t* val = as_const_int(op->indices[i]); + CHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size()); if (i != 0) os << ", "; os << to_shuffle[*val]; } os << ')'; } -void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { +void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { // Non-vector cases. if (!op->dtype.is_vector()) { CodeGenC::VisitExpr_(op, os); @@ -672,8 +689,7 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { } // Codegen vector condition case by serializing the select op. - CHECK(op->false_value->dtype == op->dtype && - op->true_value->dtype == op->dtype && + CHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && op->dtype.lanes() == op->condition.dtype().lanes()); std::string r_var = GetUniqueName("_"); @@ -704,9 +720,10 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { os << r_var; } -inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) +inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) switch (op->dtype.bits()) { - case 64: case 32: { + case 64: + case 32: { std::ostringstream temp; if (std::isinf(op->value)) { if (op->value < 0) { @@ -730,17 +747,17 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) os << '(' << std::scientific << op->value << 'f' << ')'; break; } - default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; } } - -void CodeGenCUDA::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, - const VarNode* variable, std::ostream &os) { +void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + std::ostream& os) { std::stringstream type; PrintType(t, type); std::string shape_str = fragment_shapes[variable]; @@ -765,22 +782,22 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, if (scope == "wmma.matrix_a") { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.matrix_b") { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.accumulator") { need_mma_h_ = true; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } } -int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, - const VarNode* variable, int32_t size) { +int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, + int32_t size) { std::string shape_str = fragment_shapes[variable]; size_t m, n, k; size_t last_pos = 0, pos = 0; @@ -801,8 +818,8 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, return 0; } -void CodeGenCUDA::HandleVolatileLoads(const std::string& value, - const LoadNode* op, std::ostream& os) { +void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* op, + std::ostream& os) { // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // @@ -815,15 +832,15 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, } } -void CodeGenCUDA::PrintVecElemLoadExpr( - DataType t, int i, const std::string& value, std::ostream& os) { +void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + std::ostream& os) { CHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (!(t.lanes() == 2 || t.lanes() == 3)) { if (i != 0) { os << "|"; } - os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; return; } } diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index ed17638..f9ab0ad 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -26,8 +26,10 @@ #include #include + #include #include + #include "codegen_c.h" namespace tvm { @@ -46,37 +48,32 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintVecBinaryOp( - const std::string& op, DataType t, - PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) - void PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) final; // NOLINT(*) + void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; // overload visitor - void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; - void VisitExpr_(const CallNode *op, std::ostream& os) final; + void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; + void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CastNode* op, std::ostream& os) final; - void VisitStmt_(const EvaluateNode *op) final; - void VisitStmt_(const AllocateNode *op) final; - void VisitStmt_(const AttrStmtNode *op) final; + void VisitStmt_(const EvaluateNode* op) final; + void VisitStmt_(const AllocateNode* op) final; + void VisitStmt_(const AttrStmtNode* op) final; private: // Handle volatile loads - void HandleVolatileLoads(const std::string& value, const LoadNode* op, - std::ostream& os) final; + void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final; // Whether scope such as "__shared__" or "__constant__" is part of type. - bool IsScopePartOfType() const final { - return false; - } + bool IsScopePartOfType() const final { return false; } // Whether global barrier is needed. bool need_global_barrier_{false}; @@ -98,10 +95,9 @@ class CodeGenCUDA final : public CodeGenC { std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p); - void PrintWmmaScope( - const std::string& scope, DataType t, const VarNode* variable, std::ostream& os); - int32_t GetWmmaFragmentSize( - const std::string &scope, const VarNode* variable, int32_t size); + void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + std::ostream& os); + int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); }; } // namespace codegen diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 5a50679..e381afb 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -20,13 +20,15 @@ /*! * \file codegen_metal.cc */ -#include -#include -#include #include "codegen_metal.h" -#include "../build_common.h" + +#include +#include +#include + #include "../../runtime/metal/metal_module.h" #include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" namespace tvm { namespace codegen { @@ -57,8 +59,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // add to alloc buffer type. auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; @@ -67,7 +68,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { size_t num_buffer = 0; for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { Var v = f->params[i]; - if (!v.dtype().is_handle()) break; + if (!v.dtype().is_handle()) break; stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); @@ -83,17 +84,15 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { RegisterHandleType(v.get(), prim->dtype); } } - stream << ' ' << vid - << " [[ buffer(" << i << ") ]],\n"; + stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. size_t nargs = f->params.size() - num_buffer; std::string varg = GetUniqueName("arg"); if (nargs != 0) { - std::string arg_buf_type = - static_cast(global_symbol.value()) + "_args_t"; - stream << " constant " << arg_buf_type << "& " << varg - << " [[ buffer(" << num_buffer << ") ]],\n"; + std::string arg_buf_type = static_cast(global_symbol.value()) + "_args_t"; + stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer + << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < f->params.size(); ++i) { @@ -120,8 +119,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>( - tir::attr::kDeviceThreadAxis).value(); + auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); @@ -164,23 +162,31 @@ void CodeGenMetal::BindThreadIndex(const IterVar& iv) { void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "do not yet support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; } if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } bool fail = false; if (t.is_float()) { switch (t.bits()) { - case 16: os << "half"; break; - case 32: os << "float"; break; - default: fail = true; break; + case 16: + os << "half"; + break; + case 32: + os << "float"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -188,18 +194,30 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } if (t.bits() == 8 && t.lanes() == 4) { // directly 4 8 bit int in integer. - os << "int"; return; + os << "int"; + return; } switch (t.bits()) { - case 8: os << "char"; break; - case 16: os << "short"; break; - case 32: os << "int"; break; - case 1: os << "bool"; break; - default: fail = true; break; + case 8: + os << "char"; + break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; + case 1: + os << "bool"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; @@ -218,22 +236,19 @@ void CodeGenMetal::PrintStorageSync(const CallNode* op) { } } -void CodeGenMetal::PrintVecElemLoad(const std::string& vec, - DataType t, int i, +void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << "[" << i << "]"; } -void CodeGenMetal::PrintVecElemStore(const std::string& vec, - DataType t, int i, +void CodeGenMetal::PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << "[" << i << "]" << " = " << value << ";\n"; } -void CodeGenMetal::PrintStorageScope( - const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) if (scope == "global") { os << "device "; } else if (scope == "shared") { @@ -243,7 +258,7 @@ void CodeGenMetal::PrintStorageScope( } } -void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); PrintType(op->dtype, os); os << "("; @@ -273,9 +288,8 @@ runtime::Module BuildMetal(IRModule mod) { CodeGenMetal cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenMetal: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -294,9 +308,8 @@ runtime::Module BuildMetal(IRModule mod) { return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source); } -TVM_REGISTER_GLOBAL("target.build.metal") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildMetal(args[0]); - }); +TVM_REGISTER_GLOBAL("target.build.metal").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildMetal(args[0]); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 644c962..26abe34 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -25,7 +25,9 @@ #define TVM_TARGET_SOURCE_CODEGEN_METAL_H_ #include + #include + #include "codegen_c.h" namespace tvm { @@ -36,22 +38,21 @@ class CodeGenMetal final : public CodeGenC { CodeGenMetal(); // override print thread tag. void PrintArgUnionDecl(); - void AddFunction(const PrimFunc& f); // NOLINT(*) + void AddFunction(const PrimFunc& f); // NOLINT(*) void InitFuncState(const PrimFunc& f) final; - void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // print load of single element - void PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) + void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) final; // NOLINT(*) // print store of single element. - void PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) // overload visitor - void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) // reuse parent's function. using CodeGenC::PrintType; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 7a23abd..746d418 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -20,20 +20,20 @@ /*! * \file codegen_opencl.cc */ +#include "codegen_opencl.h" + #include -#include #include -#include "codegen_opencl.h" -#include "../build_common.h" -#include "../../runtime/thread_storage_scope.h" +#include + #include "../../runtime/opencl/opencl_module.h" +#include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" namespace tvm { namespace codegen { -CodeGenOpenCL::CodeGenOpenCL() { - restrict_keyword_ = "restrict"; -} +CodeGenOpenCL::CodeGenOpenCL() { restrict_keyword_ = "restrict"; } void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); @@ -44,34 +44,30 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { } } -void CodeGenOpenCL::PrintFuncPrefix() { - stream << "__kernel void"; -} +void CodeGenOpenCL::PrintFuncPrefix() { stream << "__kernel void"; } std::string CodeGenOpenCL::Finish() { // inject extension enable pragma for fp16 and fp64 if (enable_fp16_) { - decl_stream - << "#ifdef cl_khr_fp16\n" - "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" - "#elif defined(cl_amd_fp16)\n" - "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n" - "#else\n" - "#error \"Half precision floating point not supported" - "by OpenCL implementation on your device.\" \n" - "#endif\n\n"; + decl_stream << "#ifdef cl_khr_fp16\n" + "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" + "#elif defined(cl_amd_fp16)\n" + "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n" + "#else\n" + "#error \"Half precision floating point not supported" + "by OpenCL implementation on your device.\" \n" + "#endif\n\n"; } if (enable_fp64_) { - decl_stream - << "#ifdef cl_khr_fp64\n" - "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" - "#elif defined(cl_amd_fp64)\n" - "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n" - "#else\n" - "#error \"Double precision floating point not supported" - "by OpenCL implementation on your device.\" \n" - "#endif\n\n"; + decl_stream << "#ifdef cl_khr_fp64\n" + "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" + "#elif defined(cl_amd_fp64)\n" + "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n" + "#else\n" + "#error \"Double precision floating point not supported" + "by OpenCL implementation on your device.\" \n" + "#endif\n\n"; } return CodeGenC::Finish(); @@ -86,19 +82,19 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { } else { os << "get_group_id(" << ts.dim_index << ")"; } - var_idmap_[iv->var.get()] = - CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); } void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "do not yet support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; } if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } bool fail = false; if (t.is_float()) { @@ -107,16 +103,21 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "half"; enable_fp16_ = true; break; - case 32: os << "float"; break; + case 32: + os << "float"; + break; case 64: os << "double"; enable_fp64_ = true; break; - default: fail = true; break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -124,26 +125,40 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } if (t.bits() == 8 && t.lanes() == 4) { // directly 4 8 bit int in integer. - os << "int"; return; + os << "int"; + return; } switch (t.bits()) { - case 8: os << "char"; break; - case 16: os << "short"; break; - case 32: os << "int"; break; - case 64: os << "long"; break; - case 1: os << "int"; break; - default: fail = true; break; + case 8: + os << "char"; + break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; + case 64: + os << "long"; + break; + case 1: + os << "int"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; } -void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, - PrimExpr base, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, + std::ostream& os) { // NOLINT(*) if (!HandleTypeMatch(buffer, t.element_of())) { os << '('; auto it = alloc_storage_scope_.find(buffer); @@ -156,8 +171,7 @@ void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, os << GetVarID(buffer) << " + "; PrintExpr(base, os); } -std::string CodeGenOpenCL::GetVecLoad( - DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); @@ -165,8 +179,7 @@ std::string CodeGenOpenCL::GetVecLoad( return os.str(); } -void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) { this->PrintIndent(); stream << "vstore" << t.lanes() << "(" << value << ", 0, "; @@ -187,8 +200,7 @@ void CodeGenOpenCL::PrintStorageSync(const CallNode* op) { } } -void CodeGenOpenCL::PrintStorageScope( - const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) if (scope == "global") { os << "__global "; } else if (scope == "shared") { @@ -212,7 +224,7 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType return os.str(); } -void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); os << "(("; PrintType(op->dtype, os); @@ -224,7 +236,7 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // os << "))"; } -void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) if (std::isinf(op->value)) { if (op->value < 0) { os << "-"; @@ -243,9 +255,8 @@ runtime::Module BuildOpenCL(IRModule mod, std::string target) { CodeGenOpenCL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenOpenCL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -260,7 +271,6 @@ runtime::Module BuildOpenCL(IRModule mod, std::string target) { return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.opencl") -.set_body_typed(BuildOpenCL); +TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index cc1fe99..32a98e4 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -25,7 +25,9 @@ #define TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_ #include + #include + #include "codegen_c.h" namespace tvm { @@ -38,24 +40,22 @@ class CodeGenOpenCL final : public CodeGenC { // override print thread tag. void InitFuncState(const PrimFunc& f) final; - void PrintFuncPrefix() final; // NOLINT(*) - void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) - void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - std::string GetVecLoad(DataType t, const VarNode* buffer, - PrimExpr base) final; - void PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, + void PrintFuncPrefix() final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) final; + void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) // the address of load/store - void PrintVecAddr(const VarNode* buffer, DataType t, - PrimExpr base, std::ostream& os); // NOLINT(*) - std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) + void PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, + std::ostream& os); // NOLINT(*) + std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) // overload visitor - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) private: // whether enable fp16 and fp64 extension diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc index 0b85e26..fd5c3ba 100644 --- a/src/target/source/codegen_opengl.cc +++ b/src/target/source/codegen_opengl.cc @@ -23,19 +23,20 @@ * We are targeting OpenGL 3.3. The reason of not targeting a recent version * of OpenGL is to have better compatibility of WebGL 2. */ -#include +#include "codegen_opengl.h" + #include -#include #include -#include "codegen_opengl.h" -#include "../build_common.h" +#include +#include + #include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" namespace tvm { namespace codegen { -CodeGenOpenGL::CodeGenOpenGL() - : output_(nullptr), output_iter_var_(nullptr) {} +CodeGenOpenGL::CodeGenOpenGL() : output_(nullptr), output_iter_var_(nullptr) {} void CodeGenOpenGL::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); @@ -160,20 +161,16 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; - shaders_[static_cast(global_symbol.value())] = runtime::OpenGLShader( - this->decl_stream.str() + this->stream.str(), - std::move(arg_names), std::move(arg_kinds), - this->thread_extent_var_); + shaders_[static_cast(global_symbol.value())] = + runtime::OpenGLShader(this->decl_stream.str() + this->stream.str(), std::move(arg_names), + std::move(arg_kinds), this->thread_extent_var_); } -std::unordered_map CodeGenOpenGL::Finish() { - return shaders_; -} +std::unordered_map CodeGenOpenGL::Finish() { return shaders_; } void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) { CHECK_EQ(iv->thread_tag, "threadIdx.x") << "Must be threadIdx.x"; - CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end()) - << "Only support one thread iter var"; + CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end()) << "Only support one thread iter var"; CHECK(output_iter_var_ == nullptr) << "Only support one thread iter var"; var_idmap_[iv->var.get()] = iv->thread_tag; @@ -211,8 +208,7 @@ std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, PrimExpr index) { // Print a reference expression to a buffer. // Format: texelFetch(buffer, index, 0).r -std::string CodeGenOpenGL::GetBufferRef( - DataType t, const VarNode* buffer, PrimExpr index) { +std::string CodeGenOpenGL::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) { CHECK_EQ(t.lanes(), 1) << "Vector type not supported."; CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported."; @@ -274,11 +270,10 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) { // Doesn't support store to vector. auto type = value.dtype(); - CHECK_EQ(type.lanes(), 1) - << "Vectorized store not implemented, type = " << type; + CHECK_EQ(type.lanes(), 1) << "Vectorized store not implemented, type = " << type; CHECK(inputs_.find(buffer) == inputs_.cend()) - << "Texture has been read from before. Must not store to it."; + << "Texture has been read from before. Must not store to it."; if (output_ == nullptr) { output_ = buffer; // Record that this texture is the output. } else { @@ -294,9 +289,8 @@ runtime::Module BuildOpenGL(IRModule mod, std::string target) { CodeGenOpenGL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenOpenGL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenOpenGL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -308,8 +302,7 @@ runtime::Module BuildOpenGL(IRModule mod, std::string target) { return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(mod)); } -TVM_REGISTER_GLOBAL("target.build.opengl") -.set_body_typed(BuildOpenGL); +TVM_REGISTER_GLOBAL("target.build.opengl").set_body_typed(BuildOpenGL); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opengl.h b/src/target/source/codegen_opengl.h index 954806b..2748ae2 100644 --- a/src/target/source/codegen_opengl.h +++ b/src/target/source/codegen_opengl.h @@ -25,11 +25,13 @@ #define TVM_TARGET_SOURCE_CODEGEN_OPENGL_H_ #include + #include -#include #include -#include "codegen_c.h" +#include + #include "../../runtime/opengl/opengl_module.h" +#include "codegen_c.h" namespace tvm { namespace codegen { @@ -45,11 +47,11 @@ class CodeGenOpenGL final : public CodeGenC { void VisitStmt_(const StoreNode* op) final; std::string TexelFetch(const VarNode* buffer, PrimExpr index); std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) final; - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) // Codegen for immediate values - void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*) // Match glsl_texture_store Call. diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 0859428..9b2f034 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -70,8 +70,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { } std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { - CHECK(!var_idmap_.count(v)) - << "Need input to be in SSA form dup " << v->name_hint; + CHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; std::string key = v->name_hint; std::string vid = GetUniqueName(key); var_idmap_[v] = vid; @@ -80,8 +79,7 @@ std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const { auto it = var_idmap_.find(v); - CHECK(it != var_idmap_.end()) - << "Find undefined Variable " << v->name_hint; + CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; return it->second; } diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 6723767..3901659 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -24,13 +24,15 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ +#include #include #include -#include -#include -#include + #include +#include #include +#include + #include "../../runtime/meta_data.h" namespace tvm { @@ -103,8 +105,7 @@ class CodeGenSourceBase { * \param src The source expression. * \param t The type of target. */ - virtual void PrintSSAAssign( - const std::string& target, const std::string& src, DataType t) = 0; + virtual void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) = 0; /*! \brief the declaration stream */ std::ostringstream decl_stream; @@ -147,11 +148,8 @@ runtime::Module CSourceModuleCreate(std::string code, std::string fmt); * \param fget_source a closure to replace default get source behavior. */ runtime::Module DeviceSourceModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string type_key, - std::function fget_source = nullptr); + std::string data, std::string fmt, std::unordered_map fmap, + std::string type_key, std::function fget_source = nullptr); } // namespace codegen } // namespace tvm #endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 71c3626..e60e1f5 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -20,11 +20,13 @@ /*! * \file codegen_vhls.cc */ -#include -#include #include "codegen_vhls.h" -#include "../build_common.h" + +#include +#include + #include "../../runtime/opencl/sdaccel/sdaccel_module.h" +#include "../build_common.h" namespace tvm { namespace codegen { @@ -40,37 +42,45 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) { if (t.is_uint()) { switch (t.bits()) { case 8: - os << "unsigned char"; break; + os << "unsigned char"; + break; case 16: - os << "unsigned short"; break; + os << "unsigned short"; + break; case 32: - os << "unsigned int"; break; + os << "unsigned int"; + break; case 64: - os << "unsigned long long"; break; + os << "unsigned long long"; + break; default: - os << "ap_uint<" << t.bits() << ">"; break; + os << "ap_uint<" << t.bits() << ">"; + break; } } else if (t.is_int()) { switch (t.bits()) { case 8: - os << "char"; break; + os << "char"; + break; case 16: - os << "short"; break; + os << "short"; + break; case 32: - os << "int"; break; + os << "int"; + break; case 64: - os << "long long"; break; + os << "long long"; + break; default: - os << "ap_int<" << t.bits() << ">"; break; + os << "ap_int<" << t.bits() << ">"; + break; } } else { CodeGenC::PrintType(t, os); } } -void CodeGenVivadoHLS::PrintFuncPrefix() { - stream << "extern \"C\" void"; -} +void CodeGenVivadoHLS::PrintFuncPrefix() { stream << "extern \"C\" void"; } void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { for (size_t i = 0; i < f->params.size(); ++i) { @@ -84,9 +94,8 @@ void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n"; } -template -inline void PrintBinaryExpr(const T* op, - const char *opstr, +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenVivadoHLS* p) { os << opstr << '('; @@ -96,35 +105,38 @@ inline void PrintBinaryExpr(const T* op, os << ')'; } -void CodeGenVivadoHLS::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*) - const char *opstr = "std::min"; +void CodeGenVivadoHLS::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) + const char* opstr = "std::min"; if (op->dtype.is_float()) { switch (op->dtype.bits()) { case 32: - opstr = "fminf"; break; + opstr = "fminf"; + break; case 64: - opstr = "fmin"; break; + opstr = "fmin"; + break; } } PrintBinaryExpr(op, opstr, os, this); } -void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*) - const char *opstr = "std::max"; +void CodeGenVivadoHLS::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) + const char* opstr = "std::max"; if (op->dtype.is_float()) { switch (op->dtype.bits()) { case 32: - opstr = "fmaxf"; break; + opstr = "fmaxf"; + break; case 64: - opstr = "fmax"; break; + opstr = "fmax"; + break; } } PrintBinaryExpr(op, opstr, os, this); } - runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { using tvm::runtime::Registry; bool output_ssa = false; @@ -133,9 +145,8 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { // Generate source code for get_source(). cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenVHLS: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenVHLS: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -148,9 +159,8 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { // Generate source code for compilation. Array > kernel_info; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenOpenCL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); CodeGenVivadoHLS cg; cg.Init(output_ssa); @@ -176,8 +186,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(mod), whole_code); } -TVM_REGISTER_GLOBAL("target.build.sdaccel") -.set_body_typed(BuildSDAccel); +TVM_REGISTER_GLOBAL("target.build.sdaccel").set_body_typed(BuildSDAccel); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_vhls.h b/src/target/source/codegen_vhls.h index 10f9ea7..b9bec51 100644 --- a/src/target/source/codegen_vhls.h +++ b/src/target/source/codegen_vhls.h @@ -27,7 +27,9 @@ #include #include #include + #include + #include "codegen_c.h" namespace tvm { @@ -40,8 +42,8 @@ class CodeGenVivadoHLS final : public CodeGenC { void PrintFuncPrefix() final; void PreFunctionBody(const PrimFunc& f) final; - void VisitExpr_(const MinNode *op, std::ostream& os) final; - void VisitExpr_(const MaxNode *op, std::ostream& os) final; + void VisitExpr_(const MinNode* op, std::ostream& os) final; + void VisitExpr_(const MaxNode* op, std::ostream& os) final; }; } // namespace codegen diff --git a/src/target/source/intrin_rule_aocl.cc b/src/target/source/intrin_rule_aocl.cc index 6317a2f..0cafd02 100644 --- a/src/target/source/intrin_rule_aocl.cc +++ b/src/target/source/intrin_rule_aocl.cc @@ -27,73 +27,49 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 47425c3..4e4abd9 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -31,10 +31,14 @@ struct CUDAMath { std::string operator()(DataType t, std::string name) const { if (t.is_float()) { switch (t.bits()) { - case 64: return name; - case 32: return name + 'f'; - case 16: return 'h' + name; - default: return ""; + case 64: + return name; + case 32: + return name + 'f'; + case 16: + return 'h' + name; + default: + return ""; } } return ""; @@ -55,14 +59,18 @@ struct CUDAFastMath : public CUDAMath { struct CUDAFastMathTan : public CUDAMath { std::string operator()(DataType t, std::string name) const { if (t.is_float()) { - switch (t.bits()) { - case 64: return name; - // `__tanf` seems to produce some values too deviant from numpy tan version. - // So, let's use just `tanf` instead. - case 32: return name + 'f'; - case 16: LOG(FATAL) << "cuda tan unsupported for float16"; - default: return ""; - } + switch (t.bits()) { + case 64: + return name; + // `__tanf` seems to produce some values too deviant from numpy tan version. + // So, let's use just `tanf` instead. + case 32: + return name + 'f'; + case 16: + LOG(FATAL) << "cuda tan unsupported for float16"; + default: + return ""; + } } return ""; } @@ -72,16 +80,18 @@ struct CUDAPopcount { std::string operator()(DataType t, std::string name) const { if (t.is_uint()) { switch (t.bits()) { - case 32: return "__popc"; - case 64: return "__popcll"; - default: return ""; + case 32: + return "__popc"; + case 64: + return "__popcll"; + default: + return ""; } } return ""; } }; - struct CUDAWarpIntrinsic { const char* operator()(DataType t, const std::string& name) const { if (name == intrinsic::tvm_warp_shuffle) { @@ -111,86 +121,63 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { *rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern); } -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") -.set_body(DispatchCUDAShuffle); + .set_body(DispatchCUDAShuffle); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_up") -.set_body(DispatchCUDAShuffle); + .set_body(DispatchCUDAShuffle); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down") -.set_body(DispatchCUDAShuffle); + .set_body(DispatchCUDAShuffle); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask") -.set_body(DispatchExtern); + .set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 8bc87d2..00fb9f9 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -27,65 +27,45 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index d7f63a6..60fbde7 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -22,71 +22,52 @@ * \brief OpenCL intrinsic rules. */ #include + #include "../intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchExtern); // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension @@ -97,14 +78,12 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size arith::Analyzer analyzer; CHECK(analyzer.CanProve(call->args[3] == call->args[4])) - << "Intel warp shuffle dose not support width != warp_size"; + << "Intel warp shuffle dose not support width != warp_size"; Array opencl_args{{call->args[1], call->args[2]}}; - *rv = CallNode::make(call->dtype, "intel_sub_group_shuffle", - opencl_args, CallNode::PureExtern); + *rv = CallNode::make(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern); } -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle") -.set_body(DispatchIntelShuffle); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_opengl.cc b/src/target/source/intrin_rule_opengl.cc index 1710d45..1f2a21a 100644 --- a/src/target/source/intrin_rule_opengl.cc +++ b/src/target/source/intrin_rule_opengl.cc @@ -27,53 +27,37 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_vhls.cc b/src/target/source/intrin_rule_vhls.cc index 41e76f2..fb01d65 100644 --- a/src/target/source/intrin_rule_vhls.cc +++ b/src/target/source/intrin_rule_vhls.cc @@ -27,62 +27,43 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 5f13321..ba7f075 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,43 +23,36 @@ */ #include #include -#include "codegen_source_base.h" + #include "../../runtime/file_util.h" #include "../../runtime/meta_data.h" +#include "codegen_source_base.h" namespace tvm { namespace codegen { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; +using runtime::FunctionInfo; using runtime::GetFileFormat; using runtime::GetMetaFilePath; -using runtime::FunctionInfo; using runtime::SaveBinaryToFile; // Simulator function class SourceModuleNode : public runtime::ModuleNode { public: - SourceModuleNode(std::string code, - std::string fmt) - : code_(code), fmt_(fmt) {} - const char* type_key() const { - return "source"; - } + SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} + const char* type_key() const { return "source"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); } - std::string GetSource(const std::string& format) final { - return code_; - } + std::string GetSource(const std::string& format) final { return code_; } protected: std::string code_; @@ -74,35 +67,25 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { // Simulator function class CSourceModuleNode : public runtime::ModuleNode { public: - CSourceModuleNode(std::string code, - std::string fmt) - : code_(code), fmt_(fmt) {} - const char* type_key() const { - return "c"; - } + CSourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} + const char* type_key() const { return "c"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "C Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); } - std::string GetSource(const std::string& format) final { - return code_; - } + std::string GetSource(const std::string& format) final { return code_; } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cc") { CHECK_NE(code_.length(), 0); SaveBinaryToFile(file_name, code_); } else { - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; } } @@ -119,20 +102,12 @@ runtime::Module CSourceModuleCreate(std::string code, std::string fmt) { // supports limited save without cross compile class DeviceSourceModuleNode final : public runtime::ModuleNode { public: - DeviceSourceModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string type_key, + DeviceSourceModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string type_key, std::function fget_source) - : data_(data), - fmt_(fmt), - fmap_(fmap), - type_key_(type_key), - fget_source_(fget_source) {} - - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); @@ -146,15 +121,11 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { } } - const char* type_key() const { - return type_key_.c_str(); - } + const char* type_key() const { return type_key_.c_str(); } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -175,19 +146,14 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { }; runtime::Module DeviceSourceModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string type_key, - std::function fget_source) { + std::string data, std::string fmt, std::unordered_map fmap, + std::string type_key, std::function fget_source) { auto n = make_object(data, fmt, fmap, type_key, fget_source); return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate") -.set_body_typed(SourceModuleCreate); +TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); -TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") -.set_body_typed(CSourceModuleCreate); +TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate").set_body_typed(CSourceModuleCreate); } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 825bdcb..86d1614 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -22,44 +22,37 @@ * \brief Build SPIRV block */ // Use libspirv for parsing and validating code. -#include #include +#include #include -#include "codegen_spirv.h" -#include "../build_common.h" - -#include "../../runtime/vulkan/vulkan_shader.h" #include "../../runtime/vulkan/vulkan_module.h" +#include "../../runtime/vulkan/vulkan_shader.h" +#include "../build_common.h" +#include "codegen_spirv.h" namespace tvm { namespace codegen { class SPIRVTools { public: - SPIRVTools() { - ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); - } - ~SPIRVTools() { - spvContextDestroy(ctx_); - } + SPIRVTools() { ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); } + ~SPIRVTools() { spvContextDestroy(ctx_); } std::string BinaryToText(const std::vector& bin) { spv_text text = nullptr; spv_diagnostic diagnostic; spv_const_binary_t spv_bin{bin.data(), bin.size()}; spv_result_t res; - res = spvBinaryToText( - ctx_, spv_bin.code, spv_bin.wordCount, - SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | - SPV_BINARY_TO_TEXT_OPTION_INDENT, - &text, &diagnostic); + res = + spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount, + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, + &text, &diagnostic); - CHECK_EQ(res, SPV_SUCCESS) - << " line=" << diagnostic->position.line - << " column=" << diagnostic->position.column - << " index=" << diagnostic->position.index - << " error:" << diagnostic->error; + CHECK_EQ(res, SPV_SUCCESS) << " line=" << diagnostic->position.line + << " column=" << diagnostic->position.column + << " index=" << diagnostic->position.index + << " error:" << diagnostic->error; std::string ret(text->str); spvTextDestroy(text); @@ -84,9 +77,8 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restric CodeGenSPIRV cg; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenSPIRV: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -103,8 +95,7 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restric if (webgpu_restriction) { for (auto param : f->params) { - CHECK(param.dtype().is_handle()) - << "WebGPU does not yet support non-buffer arguments"; + CHECK(param.dtype().is_handle()) << "WebGPU does not yet support non-buffer arguments"; } } @@ -122,17 +113,14 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restric smap[f_name] = std::move(shader); } - return runtime::VulkanModuleCreate( - smap, ExtractFuncInfo(mod), code_data.str()); + return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), code_data.str()); } -TVM_REGISTER_GLOBAL("target.build.vulkan") -.set_body_typed([](IRModule mod, std::string target) { +TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, std::string target) { return BuildSPIRV(mod, target, false); }); -TVM_REGISTER_GLOBAL("target.build.webgpu") -.set_body_typed([](IRModule mod, std::string target) { +TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, std::string target) { return BuildSPIRV(mod, target, true); }); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 032a72a..e76e8be 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -21,21 +21,21 @@ * \file codegen_spirv.cc * \brief Generate SPIRV block */ -#include +#include "codegen_spirv.h" + #include +#include + #include -#include "codegen_spirv.h" + #include "../../arith/compute_expr.h" namespace tvm { namespace codegen { -std::vector CodeGenSPIRV::BuildFunction( - const PrimFunc& f, - const std::string& name) { +std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); - CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) - << "SPIRV only takes restricted memory model"; + CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; uint32_t num_buffer = 0; @@ -46,8 +46,8 @@ std::vector CodeGenSPIRV::BuildFunction( auto* prim = ptr->element_type.as(); CHECK(prim); DataType value_type = prim->dtype; - spirv::Value arg_value = builder_->BufferArgument( - builder_->GetSType(value_type), 0, num_buffer); + spirv::Value arg_value = + builder_->BufferArgument(builder_->GetSType(value_type), 0, num_buffer); storage_info_[arg.get()].UpdateContentType(value_type); var_map_[arg.get()] = arg_value; } else { @@ -69,8 +69,7 @@ std::vector CodeGenSPIRV::BuildFunction( } spirv::Value ptr = builder_->DeclarePushConstant(value_types); for (size_t i = 0; i < pod_args.size(); ++i) { - spirv::Value value = builder_->GetPushConstant( - ptr, value_types[i], static_cast(i)); + spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); var_map_[pod_args[i].get()] = value; } } @@ -93,15 +92,14 @@ void CodeGenSPIRV::InitFuncState() { builder_->InitHeader(); } -spirv::Value CodeGenSPIRV::GetThreadIndex( - const IterVar& iv, const PrimExpr& extent) { +spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) { runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); spirv::Value v; if (ts.rank == 1) { v = builder_->GetLocalID(ts.dim_index); auto* sizeptr = extent.as(); - CHECK(sizeptr) - << "SPIRV only allows constant thread group size " << " get " << extent; + CHECK(sizeptr) << "SPIRV only allows constant thread group size " + << " get " << extent; CHECK_LT(ts.dim_index, 3); workgroup_size_[ts.dim_index] = static_cast(sizeptr->value); } else { @@ -118,12 +116,12 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { } else if (sync == "shared") { auto type_int = builder_->GetSType(DataType::Int(32)); builder_->MakeInst( - spv::OpControlBarrier, - builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), - builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), - builder_->IntImm(type_int, static_cast( - spv::MemorySemanticsSequentiallyConsistentMask | - spv::MemorySemanticsWorkgroupMemoryMask))); + spv::OpControlBarrier, + builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), + builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), + builder_->IntImm(type_int, + static_cast(spv::MemorySemanticsSequentiallyConsistentMask | + spv::MemorySemanticsWorkgroupMemoryMask))); } else { LOG(FATAL) << "Do not support sync " << sync; } @@ -227,8 +225,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const NotNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) { - return builder_->Select(MakeValue(op->condition), - MakeValue(op->true_value), + return builder_->Select(MakeValue(op->condition), MakeValue(op->true_value), MakeValue(op->false_value)); } @@ -242,14 +239,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { if (op->is_intrinsic("spirv_glsl450")) { CHECK_GE(op->args.size(), 2U); - uint32_t inst_id = static_cast( - op->args[0].as()->value); + uint32_t inst_id = static_cast(op->args[0].as()->value); std::vector values; for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); } - return builder_->CallGLSL450( - builder_->GetSType(op->dtype), inst_id, values); + return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values); } else if (op->is_intrinsic(CallNode::bitwise_and)) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); @@ -300,10 +295,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Label then_label = builder_->NewLabel(); spirv::Label else_label = builder_->NewLabel(); spirv::Label merge_label = builder_->NewLabel(); - builder_->MakeInst( - spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); - builder_->MakeInst( - spv::OpBranchConditional, cond, then_label, else_label); + builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label); // then block, must get label after we see the value builder_->StartLabel(then_label); spirv::Value then_value = MakeValue(op->args[1]); @@ -321,19 +314,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { phi.SetIncoming(1, else_value, else_value_label); return phi; } else if (op->is_intrinsic("popcount")) { - return builder_->MakeValue( - spv::OpBitCount, - builder_->GetSType(op->dtype), - MakeValue(op->args[0])); + return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), + MakeValue(op->args[0])); } else { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name - << " with return type " << op->dtype; - } else if (op->call_type == CallNode::Extern || - op->call_type == CallNode::PureExtern) { - LOG(FATAL) << "Unresolved extern " << op->name - << " with return type " << op->dtype; + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { + LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; + } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { + LOG(FATAL) << "Unresolved extern " << op->name << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } @@ -347,8 +334,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { for (int i = 0; i < op->lanes; ++i) { spirv::Value v = base; if (i != 0) { - spirv::Value offset = MakeValue( - make_const(op->stride.dtype(), i) * op->stride); + spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride); v = builder_->Add(v, offset); } values.push_back(v); @@ -376,8 +362,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { spirv::SType content_type = builder_->GetSType(info.content_type); spirv::Value buffer = MakeValue(op->buffer_var); - spirv::SType ptr_type = builder_->GetPointerType( - content_type, buffer.stype.storage_class); + spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); uint32_t mask = spv::MemoryAccessMaskNone; if (info.is_volatile) { @@ -387,18 +372,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { CHECK_EQ(info.content_type, op->dtype) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); } else { if (op->dtype.element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. std::vector values; auto f = [&](int i, spirv::Value index) { - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); - values.emplace_back( - builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); + values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); }; this->Scalarize(op->index, f); return builder_->Concat(values); @@ -407,13 +389,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->dtype.lanes()); arith::ModularSet me = analyzer_->modular_set(ramp->base); - CHECK((me->coeff % ramp->lanes) == 0 && - (me->base % ramp->lanes) == 0) + CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = analyzer_->Simplify( - ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, MakeValue(vec_index)); + PrimExpr vec_index = + analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); } } @@ -424,8 +404,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { return spirv::Value(); } -void CodeGenSPIRV::Scalarize(const PrimExpr& e, - std::function f) { +void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { for (int i = 0; i < ramp->dtype.lanes(); ++i) { PrimExpr offset = ramp->base + ramp->stride * i; @@ -435,8 +414,7 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, spirv::SType etype = builder_->GetSType(e.dtype().element_of()); spirv::Value value = MakeValue(e); for (int i = 0; i < e.dtype().lanes(); ++i) { - f(i, builder_->MakeValue( - spv::OpCompositeExtract, etype, value, i)); + f(i, builder_->MakeValue(spv::OpCompositeExtract, etype, value, i)); } } } @@ -454,8 +432,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { spirv::SType content_type = builder_->GetSType(info.content_type); spirv::Value buffer = MakeValue(op->buffer_var); spirv::Value value = MakeValue(op->value); - spirv::SType ptr_type = builder_->GetPointerType( - content_type, buffer.stype.storage_class); + spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); uint32_t mask = spv::MemoryAccessMaskNone; if (info.is_volatile) { @@ -466,17 +443,14 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { CHECK_EQ(info.content_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); } else { if (op->value.dtype().element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. auto f = [&](int i, spirv::Value index) { - spirv::Value elem = builder_->MakeValue( - spv::OpCompositeExtract, content_type, value, i); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); + spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, elem, mask); }; this->Scalarize(op->index, f); @@ -485,13 +459,11 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->value.dtype().lanes()); arith::ModularSet me = analyzer_->modular_set(ramp->base); - CHECK((me->coeff % ramp->lanes) == 0 && - (me->base % ramp->lanes) == 0) + CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = analyzer_->Simplify( - ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, MakeValue(vec_index)); + PrimExpr vec_index = + analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); builder_->MakeInst(spv::OpStore, ptr, value, mask); return; } @@ -519,14 +491,11 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); loop_var.SetIncoming(0, init_value, init_label); spirv::Value loop_cond = builder_->LT(loop_var, extent_value); - uint32_t control = ( - op->for_type == ForType::Unrolled ? - spv::LoopControlUnrollMask : spv::LoopControlMaskNone); - builder_->MakeInst( - spv::OpLoopMerge, merge_label, continue_label, control); - builder_->MakeInst( - spv::OpBranchConditional, loop_cond, body_label, merge_label, - weight_likely_branch_, 1); + uint32_t control = + (op->for_type == ForType::Unrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone); + builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control); + builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label, + weight_likely_branch_, 1); // loop body builder_->StartLabel(body_label); @@ -536,10 +505,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // loop continue builder_->StartLabel(continue_label); - spirv::Value one = - op->loop_var.dtype().is_int() ? - builder_->IntImm(loop_var.stype, 1) : - builder_->UIntImm(loop_var.stype, 1); + spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) + : builder_->UIntImm(loop_var.stype, 1); spirv::Value next_value = builder_->Add(loop_var, one); loop_var.SetIncoming(1, next_value, builder_->CurrentLabel()); builder_->MakeInst(spv::OpBranch, head_label); @@ -553,10 +520,8 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { spirv::Label merge_label = builder_->NewLabel(); if (op->else_case.defined()) { spirv::Label else_label = builder_->NewLabel(); - builder_->MakeInst( - spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); - builder_->MakeInst( - spv::OpBranchConditional, cond, then_label, else_label); + builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label); // then block builder_->StartLabel(then_label); this->VisitStmt(op->then_case); @@ -566,11 +531,9 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { this->VisitStmt(op->else_case); builder_->MakeInst(spv::OpBranch, merge_label); } else { - builder_->MakeInst( - spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); - builder_->MakeInst( - spv::OpBranchConditional, cond, then_label, merge_label, - weight_likely_branch_, 1); + builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + builder_->MakeInst(spv::OpBranchConditional, cond, then_label, merge_label, + weight_likely_branch_, 1); // then block builder_->StartLabel(then_label); this->VisitStmt(op->then_case); @@ -584,23 +547,20 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); CHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; StorageInfo& info = storage_info_[op->buffer_var.get()]; spirv::SType etype = builder_->GetSType(op->dtype); if (info.scope.rank == runtime::StorageRank::kLocal) { - buf = builder_->Allocate( - etype, static_cast(constant_size), - spv::StorageClassFunction); + buf = + builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); } else { // shared memory CHECK(info.scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory - buf = builder_->Allocate( - etype, static_cast(constant_size), - spv::StorageClassWorkgroup); + buf = + builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); } CHECK(!info.content_fixed); info.UpdateContentType(op->dtype); @@ -621,8 +581,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::storage_scope) { const VarNode* v = op->node.as(); CHECK(v); - storage_info_[v].scope = - runtime::StorageScope::make(op->value.as()->value); + storage_info_[v].scope = runtime::StorageScope::make(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); @@ -650,9 +609,7 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { } } -void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { - MakeValue(op->value); -} +void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index adbb59b..a8af29a 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -26,16 +26,16 @@ #include #include -#include #include +#include -#include #include -#include #include +#include +#include -#include "ir_builder.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_builder.h" namespace tvm { namespace codegen { @@ -45,9 +45,8 @@ using namespace tir; /*! * \brief Code generator into SPIRV */ -class CodeGenSPIRV: - public ExprFunctor, - public StmtFunctor { +class CodeGenSPIRV : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Compile and add function f to the current module. @@ -55,16 +54,13 @@ class CodeGenSPIRV: * \param name The name of the target function. * \return The final spirv module. */ - virtual std::vector BuildFunction(const PrimFunc& f, - const std::string& name); + virtual std::vector BuildFunction(const PrimFunc& f, const std::string& name); /*! * \brief Create Value for expression e * \param e The expression to be created value for. * \return created value. */ - spirv::Value MakeValue(const PrimExpr& e) { - return VisitExpr(e); - } + spirv::Value MakeValue(const PrimExpr& e) { return VisitExpr(e); } // override codegen spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const CastNode* op) override; @@ -119,8 +115,7 @@ class CodeGenSPIRV: // Update content type if it hasn't beenupdated. void UpdateContentType(DataType type) { if (content_fixed) { - CHECK_EQ(type, content_type) - << "Cannot use two different content type in GLSL model"; + CHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model"; } else { this->content_type = type; content_fixed = true; @@ -132,8 +127,7 @@ class CodeGenSPIRV: // Get the thread index spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent); spirv::Value CreateStorageSync(const CallNode* op); - void Scalarize(const PrimExpr& e, - std::function f); + void Scalarize(const PrimExpr& e, std::function f); // The builder std::unique_ptr builder_; // Work group size of three @@ -151,5 +145,4 @@ class CodeGenSPIRV: } // namespace codegen } // namespace tvm - #endif // TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_ diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index d8b9e71..6b31bd7 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -20,9 +20,9 @@ /*! * \file intrin_rule_spirv.cc */ +#include #include #include -#include namespace tvm { namespace codegen { @@ -31,7 +31,7 @@ namespace spirv { using namespace runtime; // num_signature means number of arguments used to query signature -template +template inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); @@ -43,71 +43,55 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make( - call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::CallNode::make(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil").set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc") -.set_body(DispatchGLSLPureIntrin); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh").set_body(DispatchGLSLPureIntrin); // WebGPU rules. TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.ceil") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.ceil").set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.round") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.trunc") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.fabs") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.fabs").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.exp") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.exp").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.log") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.log").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.sqrt") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.sqrt").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.pow") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.pow").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.tanh") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.tanh").set_body(DispatchGLSLPureIntrin); } // namespace spirv } // namespace codegen diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 7573b47..305464a 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -49,9 +49,9 @@ void IRBuilder::InitHeader() { // shader ib_.Begin(spv::OpCapability).Add(spv::CapabilityShader).Commit(&header_); // memory model - ib_.Begin(spv::OpMemoryModel).AddSeq( - spv::AddressingModelLogical, - spv::MemoryModelGLSL450).Commit(&entry_); + ib_.Begin(spv::OpMemoryModel) + .AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) + .Commit(&entry_); this->InitPreDefs(); } @@ -66,8 +66,7 @@ void IRBuilder::InitPreDefs() { t_void_.id = id_counter_++; ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_); t_void_func_.id = id_counter_++; - ib_.Begin(spv::OpTypeFunction) - .AddSeq(t_void_func_, t_void_).Commit(&global_); + ib_.Begin(spv::OpTypeFunction).AddSeq(t_void_func_, t_void_).Commit(&global_); } SType IRBuilder::GetSType(const DataType& dtype) { @@ -93,8 +92,7 @@ SType IRBuilder::GetSType(const DataType& dtype) { return t; } -SType IRBuilder::GetPointerType(const SType& value_type, - spv::StorageClass storage_class) { +SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass storage_class) { CHECK_NE(storage_class, spv::StorageClassMax); auto key = std::make_pair(value_type.id, storage_class); auto it = pointer_type_tbl_.find(key); @@ -106,14 +104,12 @@ SType IRBuilder::GetPointerType(const SType& value_type, t.type = DataType::Handle(); t.element_type_id = value_type.id; t.storage_class = storage_class; - ib_.Begin(spv::OpTypePointer) - .AddSeq(t, storage_class, value_type).Commit(&global_); + ib_.Begin(spv::OpTypePointer).AddSeq(t, storage_class, value_type).Commit(&global_); pointer_type_tbl_[key] = t; return t; } -SType IRBuilder::GetStructArrayType(const SType& value_type, - uint32_t num_elems) { +SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems) { auto key = std::make_pair(value_type.id, num_elems); auto it = struct_array_type_tbl_.find(key); if (it != struct_array_type_tbl_.end()) { @@ -127,63 +123,50 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, if (num_elems != 0) { Value length = UIntImm(GetSType(DataType::UInt(32)), num_elems); - ib_.Begin(spv::OpTypeArray) - .AddSeq(arr_type, value_type, length).Commit(&global_); + ib_.Begin(spv::OpTypeArray).AddSeq(arr_type, value_type, length).Commit(&global_); } else { - ib_.Begin(spv::OpTypeRuntimeArray) - .AddSeq(arr_type, value_type).Commit(&global_); + ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_); } int nbits = value_type.type.bits() * value_type.type.lanes(); CHECK_EQ(nbits % 8, 0); uint32_t nbytes = static_cast(nbits) / 8; // decorate the array type. - this->Decorate(spv::OpDecorate, - arr_type, spv::DecorationArrayStride, nbytes); + this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); // declare struct of array SType struct_type; struct_type.id = id_counter_++; struct_type.type = DataType::Handle(); struct_type.element_type_id = value_type.id; - ib_.Begin(spv::OpTypeStruct) - .AddSeq(struct_type, arr_type).Commit(&global_); + ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_); // decorate the array type. ib_.Begin(spv::OpMemberDecorate) .AddSeq(struct_type, 0, spv::DecorationOffset, 0) .Commit(&decorate_); - #if SPV_VERSION < 0x10300 // NOTE: BufferBlock was deprecated in SPIRV 1.3 // use StorageClassStorageBuffer instead. // runtime array are always decorated as BufferBlock(shader storage buffer) if (num_elems == 0) { - this->Decorate(spv::OpDecorate, - struct_type, spv::DecorationBufferBlock); + this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock); } #else - this->Decorate(spv::OpDecorate, - struct_type, spv::DecorationBlock); + this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); #endif struct_array_type_tbl_[key] = struct_type; return struct_type; } -Value IRBuilder::StructArrayAccess(const SType& res_type, - Value buffer, - Value index) { +Value IRBuilder::StructArrayAccess(const SType& res_type, Value buffer, Value index) { CHECK(buffer.flag == kStructArrayPtr); - return MakeValue(spv::OpInBoundsAccessChain, - res_type, buffer, - const_i32_zero_, index); + return MakeValue(spv::OpInBoundsAccessChain, res_type, buffer, const_i32_zero_, index); } Value IRBuilder::IntImm(const SType& dtype, int64_t value) { return GetConst_(dtype, reinterpret_cast(&value)); } -Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) { - return GetConst_(dtype, &value); -} +Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) { return GetConst_(dtype, &value); } Value IRBuilder::FloatImm(const SType& dtype, double value) { if (dtype.type.bits() == 64) { @@ -195,13 +178,11 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { return GetConst_(dtype, &data); } else { CHECK_EQ(dtype.type.bits(), 16); - return Cast(dtype, - FloatImm(GetSType(DataType::Float(32)), value)); + return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value)); } } -Value IRBuilder::BufferArgument(const SType& value_type, - uint32_t descriptor_set, +Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding) { // NOTE: BufferBlock was deprecated in SPIRV 1.3 // use StorageClassStorageBuffer instead. @@ -215,13 +196,10 @@ Value IRBuilder::BufferArgument(const SType& value_type, SType ptr_type = GetPointerType(sarr_type, storage_class); Value val = NewValue(ptr_type, kStructArrayPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, storage_class).Commit(&global_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); - this->Decorate(spv::OpDecorate, - val, spv::DecorationDescriptorSet, descriptor_set); - this->Decorate(spv::OpDecorate, - val, spv::DecorationBinding, binding); + this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set); + this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); return val; } @@ -243,37 +221,30 @@ Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { .Commit(&decorate_); DataType t = value_types[i].type; uint32_t nbits = t.bits() * t.lanes(); - CHECK_EQ(nbits % 8 , 0); + CHECK_EQ(nbits % 8, 0); offset += nbits / 8; } // Decorate push constants as UBO this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); - SType ptr_type = GetPointerType( - struct_type, spv::StorageClassPushConstant); + SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant); Value val = NewValue(ptr_type, kPushConstantPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_); return val; } -Value IRBuilder::GetPushConstant( - Value ptr_push_const, const SType& v_type, uint32_t index) { +Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) { SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant); - Value ptr = this->MakeValue( - spv::OpAccessChain, ptr_vtype, ptr_push_const, - IntImm(t_int32_, static_cast(index))); + Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, + IntImm(t_int32_, static_cast(index))); return this->MakeValue(spv::OpLoad, v_type, ptr); } -Value IRBuilder::NewFunction() { - return NewValue(t_void_func_, kFunction); -} +Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); } void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) { CHECK_EQ(func.flag, kFunction); - ib_.Begin(spv::OpEntryPoint) - .AddSeq(spv::ExecutionModelGLCompute, func, name); + ib_.Begin(spv::OpEntryPoint).AddSeq(spv::ExecutionModelGLCompute, func, name); if (workgroup_id_.id != 0) { ib_.Add(workgroup_id_); } @@ -286,36 +257,30 @@ void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) void IRBuilder::StartFunction(const Value& func) { CHECK_EQ(func.flag, kFunction); // add function declaration to the header. - ib_.Begin(spv::OpFunction).AddSeq( - t_void_, func, 0, t_void_func_).Commit(&func_header_); + ib_.Begin(spv::OpFunction).AddSeq(t_void_, func, 0, t_void_func_).Commit(&func_header_); spirv::Label start_label = this->NewLabel(); ib_.Begin(spv::OpLabel).AddSeq(start_label).Commit(&func_header_); curr_label_ = start_label; } -void IRBuilder::SetLocalSize(const Value& func, - uint32_t local_size[3]) { +void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) { CHECK_EQ(func.flag, kFunction); ib_.Begin(spv::OpExecutionMode) - .AddSeq(func, spv::ExecutionModeLocalSize, - local_size[0], local_size[1], local_size[2]) + .AddSeq(func, spv::ExecutionModeLocalSize, local_size[0], local_size[1], local_size[2]) .Commit(&exec_mode_); } -Value IRBuilder::Allocate(const SType& value_type, - uint32_t num_elems, +Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems, spv::StorageClass storage_class) { CHECK_NE(num_elems, 0U); SType sarr_type = GetStructArrayType(value_type, num_elems); SType ptr_type = GetPointerType(sarr_type, storage_class); Value val = NewValue(ptr_type, kStructArrayPtr); if (storage_class == spv::StorageClassFunction) { - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, storage_class).Commit(&func_header_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&func_header_); } else { - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, storage_class).Commit(&global_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); } return val; } @@ -323,19 +288,16 @@ Value IRBuilder::Allocate(const SType& value_type, Value IRBuilder::GetWorkgroupID(uint32_t dim_index) { if (workgroup_id_.id == 0) { SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); - SType ptr_type = this->GetPointerType( - vec3_type, spv::StorageClassInput); + SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput); workgroup_id_ = NewValue(ptr_type, kVectorPtr); ib_.Begin(spv::OpVariable) .AddSeq(ptr_type, workgroup_id_, spv::StorageClassInput) .Commit(&global_); - this->Decorate(spv::OpDecorate, workgroup_id_, - spv::DecorationBuiltIn, spv::BuiltInWorkgroupId); + this->Decorate(spv::OpDecorate, workgroup_id_, spv::DecorationBuiltIn, spv::BuiltInWorkgroupId); } SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput); - Value ptr = this->MakeValue( - spv::OpAccessChain, pint_type, workgroup_id_, - IntImm(t_int32_, static_cast(dim_index))); + Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, workgroup_id_, + IntImm(t_int32_, static_cast(dim_index))); return this->MakeValue(spv::OpLoad, t_int32_, ptr); } @@ -344,16 +306,13 @@ Value IRBuilder::GetLocalID(uint32_t dim_index) { SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput); local_id_ = NewValue(ptr_type, kVectorPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, local_id_, spv::StorageClassInput) - .Commit(&global_); - this->Decorate(spv::OpDecorate, local_id_, - spv::DecorationBuiltIn, spv::BuiltInLocalInvocationId); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, local_id_, spv::StorageClassInput).Commit(&global_); + this->Decorate(spv::OpDecorate, local_id_, spv::DecorationBuiltIn, + spv::BuiltInLocalInvocationId); } SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput); - Value ptr = this->MakeValue( - spv::OpAccessChain, pint_type, local_id_, - UIntImm(t_int32_, static_cast(dim_index))); + Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, local_id_, + UIntImm(t_int32_, static_cast(dim_index))); return this->MakeValue(spv::OpLoad, t_int32_, ptr); } @@ -380,9 +339,8 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { if (dtype.type.bits() > 32) { if (dtype.type.is_int()) { int64_t sign_mask = 0xFFFFFFFFL; - const int64_t* sign_ptr = - reinterpret_cast(pvalue); - ib_.Add(static_cast((sign_ptr[0] >> 32L) & sign_mask)); + const int64_t* sign_ptr = reinterpret_cast(pvalue); + ib_.Add(static_cast((sign_ptr[0] >> 32L) & sign_mask)); } else { ib_.Add(static_cast((pvalue[0] >> 32UL) & mask)); } @@ -416,8 +374,7 @@ SType IRBuilder::DeclareType(const DataType& dtype) { t.id = id_counter_++; t.type = dtype; SType base_type = GetSType(dtype.element_of()); - ib_.Begin(spv::OpTypeVector).AddSeq( - t, base_type, dtype.lanes()).Commit(&global_); + ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); return t; } } @@ -437,12 +394,10 @@ PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) { return phi; } -Value IRBuilder::CallGLSL450(const SType& ret_type, - uint32_t inst_id, +Value IRBuilder::CallGLSL450(const SType& ret_type, uint32_t inst_id, const std::vector& args) { Value val = NewValue(ret_type, kNormal); - ib_.Begin(spv::OpExtInst) - .AddSeq(ret_type, val, ext_glsl450_, inst_id); + ib_.Begin(spv::OpExtInst).AddSeq(ret_type, val, ext_glsl450_, inst_id); for (const Value& v : args) { ib_.Add(v); } @@ -512,14 +467,12 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { return MakeValue(spv::OpUConvert, dst_type, value); } else if (from.is_uint() && to.is_int()) { if (from.bits() != to.bits()) { - value = MakeValue( - spv::OpUConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpUConvert, GetSType(from.with_bits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); } else if (from.is_int() && to.is_uint()) { if (from.bits() != to.bits()) { - value = MakeValue( - spv::OpSConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpSConvert, GetSType(from.with_bits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); } else if (from.is_float() && to.is_int()) { @@ -533,21 +486,20 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } else if (from.is_float() && to.is_float()) { return MakeValue(spv::OpFConvert, dst_type, value); } else { - LOG(FATAL) << "do not support type cast from " - << from << " to " << to; + LOG(FATAL) << "do not support type cast from " << from << " to " << to; return Value(); } } -#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI ## _Op, a.stype, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpF ## _Op, a.stype, a, b); \ - } \ +#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ + return MakeValue(spv::OpI##_Op, a.stype, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpF##_Op, a.stype, a, b); \ + } \ } #define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ @@ -580,19 +532,19 @@ Value IRBuilder::Mod(Value a, Value b) { } } -#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ +#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS##_Op, bool_type, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU##_Op, bool_type, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ + if (a.stype.type.is_int()) { \ + return MakeValue(spv::OpS##_Op, bool_type, a, b); \ + } else if (a.stype.type.is_uint()) { \ + return MakeValue(spv::OpU##_Op, bool_type, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_OP(LT, LessThan); @@ -600,17 +552,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual); DEFINE_BUILDER_CMP_OP(GT, GreaterThan); DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); -#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ +#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI##_Op, bool_type, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ + if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ + return MakeValue(spv::OpI##_Op, bool_type, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_UOP(EQ, Equal); diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index e9e04e8..c52f92f 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -27,20 +27,20 @@ #include #include +// clang-format off #include -#include -#include -#include #include +#include #include - +#include +#include #include +// clang-format on namespace tvm { namespace codegen { namespace spirv { - /*! \brief Represent the SPIRV Type */ struct SType { /*! \brief The Id to represent type */ @@ -86,9 +86,7 @@ struct Label { class Instr { public: /*! \return the word count */ - uint32_t WordCount() const { - return word_count_; - } + uint32_t WordCount() const { return word_count_; } /*! * \brief Access idx-th word of instruction * \param idx The index @@ -123,9 +121,7 @@ struct PhiValue : public Value { * \param value The value to come * \param parent The parent label. */ - void SetIncoming(uint32_t index, - const Value& value, - const Label& parent) { + void SetIncoming(uint32_t index, const Value& value, const Label& parent) { CHECK_EQ(this->stype.id, value.stype.id); instr[3 + index * 2] = value.id; instr[3 + index * 2 + 1] = parent.id; @@ -204,12 +200,10 @@ class InstrBuilder { */ InstrBuilder& Add(const std::string& v) { const uint32_t kWordSize = sizeof(uint32_t); - uint32_t nwords = - (static_cast(v.length()) + kWordSize) / kWordSize; + uint32_t nwords = (static_cast(v.length()) + kWordSize) / kWordSize; size_t begin = data_.size(); data_.resize(begin + nwords, 0U); - std::copy(v.begin(), v.end(), - reinterpret_cast(&data_[begin])); + std::copy(v.begin(), v.end(), reinterpret_cast(&data_[begin])); return *this; } /*! @@ -218,8 +212,8 @@ class InstrBuilder { * \return reference to self. * \tparams Args The positional arguments */ - template - InstrBuilder& AddSeq(Args&& ...args) { + template + InstrBuilder& AddSeq(Args&&... args) { AddSeqHelper helper; helper.builder = this; runtime::detail::for_each(helper, std::forward(args)...); @@ -253,7 +247,7 @@ class InstrBuilder { // The reference to builder InstrBuilder* builder; // invoke function - template + template void operator()(size_t, const T& v) const { builder->Add(v); } @@ -324,17 +318,15 @@ class IRBuilder { curr_label_ = label; } /*! \return The current label */ - Label CurrentLabel() const { - return curr_label_; - } + Label CurrentLabel() const { return curr_label_; } /*! * \brief Add code to debug segment. * \param op The operator * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void Debug(spv::Op op, Args&& ...args) { + template + void Debug(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&debug_); } /*! @@ -343,10 +335,9 @@ class IRBuilder { * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void ExecutionMode(Value func, Args&& ...args) { - ib_.Begin(spv::OpExecutionMode).AddSeq( - func, std::forward(args)...).Commit(&exec_mode_); + template + void ExecutionMode(Value func, Args&&... args) { + ib_.Begin(spv::OpExecutionMode).AddSeq(func, std::forward(args)...).Commit(&exec_mode_); } /*! * \brief Add code to decorate segment. @@ -354,8 +345,8 @@ class IRBuilder { * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void Decorate(spv::Op op, Args&& ...args) { + template + void Decorate(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&decorate_); } /*! @@ -364,8 +355,8 @@ class IRBuilder { * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void DeclareGlobal(spv::Op op, Args&& ...args) { + template + void DeclareGlobal(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&decorate_); } /*! @@ -376,8 +367,8 @@ class IRBuilder { * \return The result SSA value. * \tparams Args The positional arguments */ - template - Instr MakeInst(spv::Op op, Args&& ...args) { + template + Instr MakeInst(spv::Op op, Args&&... args) { return ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&function_); } /*! @@ -389,8 +380,8 @@ class IRBuilder { * \return The result SSA value. * \tparams Args The positional arguments */ - template - Value MakeValue(spv::Op op, const SType& out_type, Args&& ...args) { + template + Value MakeValue(spv::Op op, const SType& out_type, Args&&... args) { Value val = NewValue(out_type, kNormal); MakeInst(op, out_type, val, std::forward(args)...); return val; @@ -411,9 +402,7 @@ class IRBuilder { * \param args The arguments * \return The result value. */ - Value CallGLSL450(const SType& ret_type, - uint32_t inst_id, - const std::vector& args); + Value CallGLSL450(const SType& ret_type, uint32_t inst_id, const std::vector& args); /*! * \brief Build vector by concatenating components * @@ -433,8 +422,7 @@ class IRBuilder { * \param storage_class The storage class * \return The corresponding spirv type. */ - SType GetPointerType(const SType& value_type, - spv::StorageClass storage_class); + SType GetPointerType(const SType& value_type, spv::StorageClass storage_class); /*! * \brief Get a struct{ value_type[num_elems] } type. * \param value_type the content value type. @@ -443,17 +431,14 @@ class IRBuilder { * * \return The corresponding spirv type. */ - SType GetStructArrayType(const SType& value_type, - uint32_t num_elems); + SType GetStructArrayType(const SType& value_type, uint32_t num_elems); /*! * \brief Get a struct array access with a given index. * \param ptr_type The pointer type. * \param buffer The buffer ptr to struct array * \param index The array index. */ - Value StructArrayAccess(const SType& ptr_type, - Value buffer, - Value index); + Value StructArrayAccess(const SType& ptr_type, Value buffer, Value index); /*! * \brief Create a cast that cast value to dst_type * \param dst_type The target type. @@ -487,9 +472,7 @@ class IRBuilder { * \param binding The binding locaiton in descriptor set. * \param The argument type. */ - Value BufferArgument(const SType& value_type, - uint32_t descriptor_set, - uint32_t binding); + Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding); /*! * \brief Declare POD arguments through push constants. * @@ -535,9 +518,7 @@ class IRBuilder { * \param num_elems Number of elements to allocate. * \param storage_class The storage class we want to store to. */ - Value Allocate(const SType& value_type, - uint32_t num_elems, - spv::StorageClass storage_class); + Value Allocate(const SType& value_type, uint32_t num_elems, spv::StorageClass storage_class); /* * \brief Get the i-th workgroup id. * \return The value representing the workgroup id. @@ -612,7 +593,7 @@ class IRBuilder { std::vector debug_; /*! \brief Annotation segment */ std::vector decorate_; - /*! \brief Global segment: types, variables, types */ + /*! \brief Global segment: types, variables, types */ std::vector global_; /*! \brief Function header segment */ std::vector func_header_; diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index b125f37..6dd2ca0 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -20,14 +20,17 @@ /*! * \file codegen_stackvm.cc */ -#include -#include +#include "codegen_stackvm.h" + #include -#include +#include +#include #include +#include + #include #include -#include "codegen_stackvm.h" + #include "../../runtime/stackvm/stackvm_module.h" namespace tvm { @@ -40,19 +43,32 @@ using namespace tir; StackVM::StructFieldKind MapFieldKind(int64_t kind) { auto val = static_cast(kind); switch (val) { - case intrinsic::kArrData: return StackVM::kArrData; - case intrinsic::kArrShape: return StackVM::kArrShape; - case intrinsic::kArrAddr: return StackVM::kArrAddr; - case intrinsic::kArrStrides: return StackVM::kArrStrides; - case intrinsic::kArrNDim: return StackVM::kArrNDim; - case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode; - case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits; - case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes; - case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset; - case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId; - case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType; - case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent; - default: LOG(FATAL) << "Do not know how to map field " << kind; + case intrinsic::kArrData: + return StackVM::kArrData; + case intrinsic::kArrShape: + return StackVM::kArrShape; + case intrinsic::kArrAddr: + return StackVM::kArrAddr; + case intrinsic::kArrStrides: + return StackVM::kArrStrides; + case intrinsic::kArrNDim: + return StackVM::kArrNDim; + case intrinsic::kArrTypeCode: + return StackVM::kArrTypeCode; + case intrinsic::kArrTypeBits: + return StackVM::kArrTypeBits; + case intrinsic::kArrTypeLanes: + return StackVM::kArrTypeLanes; + case intrinsic::kArrByteOffset: + return StackVM::kArrByteOffset; + case intrinsic::kArrDeviceId: + return StackVM::kArrDeviceId; + case intrinsic::kArrDeviceType: + return StackVM::kArrDeviceType; + case intrinsic::kTVMValueContent: + return StackVM::kTVMValueContent; + default: + LOG(FATAL) << "Do not know how to map field " << kind; } return StackVM::kArrData; } @@ -84,8 +100,7 @@ void CodeGenStackVM::PushOp(StackVM::OpCode opcode) { } void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) { - CHECK(operand >= std::numeric_limits::min() && - operand <= std::numeric_limits::max()); + CHECK(operand >= std::numeric_limits::min() && operand <= std::numeric_limits::max()); vm_.code.at(operand_index).v_int = static_cast(operand); } @@ -120,8 +135,7 @@ int CodeGenStackVM::AllocVarID(const VarNode* v) { int CodeGenStackVM::GetVarID(const VarNode* v) const { auto it = var_idmap_.find(v); - CHECK(it != var_idmap_.end()) - << "Find undefined Variable " << v->name_hint; + CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; return it->second; } @@ -161,7 +175,7 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { void CodeGenStackVM::VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); this->Push(l->index); @@ -261,9 +275,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { } } -void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, - const PrimExpr& a, - const PrimExpr& b) { +void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b) { this->Push(a); this->Push(b); DataType t = a.dtype(); @@ -295,7 +307,7 @@ void CodeGenStackVM::VisitExpr_(const IntImmNode* op) { CHECK(op->value >= std::numeric_limits::min() && op->value <= std::numeric_limits::max()) << "Int constant exceed bound"; - this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); + this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) { @@ -312,25 +324,15 @@ void CodeGenStackVM::VisitExpr_(const CastNode* op) { PushCast(op->dtype, op->value.dtype()); } -void CodeGenStackVM::VisitExpr_(const AddNode* op) { - PushBinary(StackVM::ADD_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const AddNode* op) { PushBinary(StackVM::ADD_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const SubNode* op) { - PushBinary(StackVM::SUB_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const SubNode* op) { PushBinary(StackVM::SUB_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const MulNode* op) { - PushBinary(StackVM::MUL_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const MulNode* op) { PushBinary(StackVM::MUL_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const DivNode* op) { - PushBinary(StackVM::DIV_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const DivNode* op) { PushBinary(StackVM::DIV_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const ModNode* op) { - PushBinary(StackVM::MOD_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const ModNode* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const MinNode* op) { this->Push(op->a); @@ -350,22 +352,16 @@ void CodeGenStackVM::VisitExpr_(const MaxNode* op) { this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitExpr_(const EQNode* op) { - PushBinary(StackVM::EQ_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const EQNode* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const LENode* op) { - PushBinary(StackVM::LE_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const LENode* op) { PushBinary(StackVM::LE_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const NENode* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitExpr_(const LTNode* op) { - PushBinary(StackVM::LT_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const LTNode* op) { PushBinary(StackVM::LT_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const GENode* op) { PushBinary(StackVM::LT_I64, op->a, op->b); @@ -431,7 +427,7 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { } } -void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) { +void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) { if (is_const(ev->value)) return; const CallNode* op = ev->value.as(); if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) { @@ -482,9 +478,7 @@ void CodeGenStackVM::VisitStmt_(const LetStmtNode* op) { this->Push(op->body); } -void CodeGenStackVM::VisitExpr_(const RampNode* op) { - LOG(FATAL) << "Ramp is not supported"; -} +void CodeGenStackVM::VisitExpr_(const RampNode* op) { LOG(FATAL) << "Ramp is not supported"; } void CodeGenStackVM::VisitExpr_(const BroadcastNode* op) { LOG(FATAL) << "Broadcast is not supported"; @@ -506,9 +500,7 @@ void CodeGenStackVM::VisitStmt_(const AssertStmtNode* op) { this->Push(op->body); } -void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) { - this->Push(op->body); -} +void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) { this->Push(op->body); } void CodeGenStackVM::VisitExpr_(const LetNode* op) { this->Push(op->value); @@ -521,17 +513,15 @@ runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) { std::unordered_map fmap; std::string entry_func; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenStackVM: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenStackVM: Can only take PrimFunc"; auto f = Downcast(kv.second); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); StackVM vm = codegen::CodeGenStackVM().Compile(f); - CHECK(!fmap.count(f_name)) - << "Function name " << f_name << "already exist in list"; + CHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list"; fmap[f_name] = std::move(vm); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { @@ -542,7 +532,6 @@ runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) { return runtime::StackVMModuleCreate(fmap, entry_func); } -TVM_REGISTER_GLOBAL("target.build.stackvm") -.set_body_typed(BuildStackVM); +TVM_REGISTER_GLOBAL("target.build.stackvm").set_body_typed(BuildStackVM); } // namespace codegen } // namespace tvm diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index 3103682..b77c406 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -24,12 +24,14 @@ #ifndef TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ #define TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ +#include #include +#include #include -#include + #include -#include #include +#include #include "../../runtime/stackvm/stackvm.h" @@ -44,11 +46,10 @@ using runtime::StackVM; * This module is used to generate host wrapper * into device function when only device JIT is available. */ -class CodeGenStackVM - : public ExprFunctor, - public StmtFunctor { +class CodeGenStackVM : public ExprFunctor, + public StmtFunctor { public: - /*! + /*! * \brief Generate a stack VM representing * \param f The function to be compiled * \param device_funcs The extern device functions to be linked. @@ -59,9 +60,7 @@ class CodeGenStackVM /*! \brief Push stmt to generate new code */ void Push(const Stmt& n); /*! \brief Push expr to generate new code */ - void Push(const PrimExpr& n) { - VisitExpr(n); - } + void Push(const PrimExpr& n) { VisitExpr(n); } /*! * \brief Push the opcode to the code. * \param opcode The code to be pushed. @@ -81,9 +80,7 @@ class CodeGenStackVM */ void SetOperand(int64_t operand_index, int64_t operand); /*! \return The current program pointer */ - int64_t GetPC() const { - return static_cast(vm_.code.size()); - } + int64_t GetPC() const { return static_cast(vm_.code.size()); } /*! * \brief Get string id in vm * \param key The string to get id. @@ -103,9 +100,7 @@ class CodeGenStackVM */ int GetVarID(const VarNode* v) const; // Push binary operator - void PushBinary(StackVM::OpCode op_int64, - const PrimExpr& a, - const PrimExpr& b); + void PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b); // push cast; void PushCast(DataType dst, DataType src); // overloadable functions diff --git a/src/target/target.cc b/src/target/target.cc index c733eae..010a14a 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -21,11 +21,9 @@ * \file src/target/target.cc */ #include - -#include #include +#include #include - #include #include @@ -33,28 +31,27 @@ namespace tvm { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; TVM_REGISTER_NODE_TYPE(TargetNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->str(); - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->str(); + }); /*! -* \brief Construct a Target node from the given name and options. -* \param target_name The major target name. Should be one of -* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm", -* "metal", "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"} -* \param options Additional options appended to the target -* \return The constructed Target -*/ -Target CreateTarget(const std::string& target_name, - const std::vector& options) { + * \brief Construct a Target node from the given name and options. + * \param target_name The major target name. Should be one of + * {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm", + * "metal", "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"} + * \param options Additional options appended to the target + * \return The constructed Target + */ +Target CreateTarget(const std::string& target_name, const std::vector& options) { auto t = make_object(); t->target_name = target_name; @@ -110,9 +107,7 @@ Target CreateTarget(const std::string& target_name, if (t->device_name == "intel_graphics") { t->thread_warp_size = 16; } - } else if (target_name == "metal" || - target_name == "vulkan" || - target_name == "webgpu") { + } else if (target_name == "metal" || target_name == "vulkan" || target_name == "webgpu") { if (target_name == "metal") { t->device_type = kDLMetal; } else if (target_name == "vulkan") { @@ -154,8 +149,7 @@ Target CreateTarget(const std::string& target_name, return Target(t); } -TVM_REGISTER_GLOBAL("target.TargetCreate") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_name = args[0]; std::vector options; for (int i = 1; i < args.num_args; ++i) { @@ -164,13 +158,12 @@ TVM_REGISTER_GLOBAL("target.TargetCreate") } *ret = CreateTarget(target_name, options); - }); +}); -TVM_REGISTER_GLOBAL("target.TargetFromString") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_str = args[0]; *ret = Target::Create(target_str); - }); +}); std::vector TargetNode::keys() const { std::vector result; @@ -200,14 +193,13 @@ const std::string& TargetNode::str() const { if (str_repr_.length() != 0) return str_repr_; std::ostringstream result; result << target_name; - for (const auto &x : options()) { + for (const auto& x : options()) { result << " " << x; } str_repr_ = result.str(); return str_repr_; } - bool StartsWith(const std::string& str, const std::string& pattern) { return str.compare(0, pattern.length(), pattern) == 0; } @@ -257,104 +249,75 @@ struct TVMTargetThreadLocalEntry { typedef dmlc::ThreadLocalStore TVMTargetThreadLocalStore; void Target::EnterWithScope() { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); entry->context_stack.push(*this); } void Target::ExitWithScope() { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } tvm::Target Target::Current(bool allow_not_defined) { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } CHECK(allow_not_defined) - << "Target context required. Please set it by constructing a TargetContext"; + << "Target context required. Please set it by constructing a TargetContext"; return Target(); } -TVM_REGISTER_GLOBAL("target.GetCurrentTarget") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body([](TVMArgs args, TVMRetValue* ret) { bool allow_not_defined = args[0]; *ret = Target::Current(allow_not_defined); - }); +}); class Target::Internal { public: - static void EnterScope(Target target) { - target.EnterWithScope(); - } - static void ExitScope(Target target) { - target.ExitWithScope(); - } + static void EnterScope(Target target) { target.EnterWithScope(); } + static void ExitScope(Target target) { target.ExitWithScope(); } }; -TVM_REGISTER_GLOBAL("target.EnterTargetScope") -.set_body_typed(Target::Internal::EnterScope); +TVM_REGISTER_GLOBAL("target.EnterTargetScope").set_body_typed(Target::Internal::EnterScope); -TVM_REGISTER_GLOBAL("target.ExitTargetScope") -.set_body_typed(Target::Internal::ExitScope); +TVM_REGISTER_GLOBAL("target.ExitTargetScope").set_body_typed(Target::Internal::ExitScope); namespace target { std::vector MergeOptions(std::vector opts, - const std::vector& new_opts) { + const std::vector& new_opts) { opts.insert(opts.end(), new_opts.begin(), new_opts.end()); return opts; } -Target llvm(const std::vector& options) { - return CreateTarget("llvm", options); -} +Target llvm(const std::vector& options) { return CreateTarget("llvm", options); } -Target cuda(const std::vector& options) { - return CreateTarget("cuda", options); -} +Target cuda(const std::vector& options) { return CreateTarget("cuda", options); } -Target rocm(const std::vector& options) { - return CreateTarget("rocm", options); -} +Target rocm(const std::vector& options) { return CreateTarget("rocm", options); } -Target opencl(const std::vector& options) { - return CreateTarget("opencl", options); -} +Target opencl(const std::vector& options) { return CreateTarget("opencl", options); } -Target metal(const std::vector& options) { - return CreateTarget("metal", options); -} +Target metal(const std::vector& options) { return CreateTarget("metal", options); } Target mali(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, { - "-device=mali" - })); + return CreateTarget("opencl", MergeOptions(options, {"-device=mali"})); } Target intel_graphics(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, { - "-device=intel_graphics" - })); + return CreateTarget("opencl", MergeOptions(options, {"-device=intel_graphics"})); } -Target stackvm(const std::vector& options) { - return CreateTarget("stackvm", options); -} +Target stackvm(const std::vector& options) { return CreateTarget("stackvm", options); } -Target ext_dev(const std::vector& options) { - return CreateTarget("ext_dev", options); -} +Target ext_dev(const std::vector& options) { return CreateTarget("ext_dev", options); } -Target hexagon(const std::vector& options) { - return CreateTarget("hexagon", options); -} +Target hexagon(const std::vector& options) { return CreateTarget("hexagon", options); } } // namespace target -BuildConfig BuildConfig::Create() { - return BuildConfig(make_object()); -} +BuildConfig BuildConfig::Create() { return BuildConfig(make_object()); } /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMBuildConfigThreadLocalEntry { @@ -364,28 +327,26 @@ struct TVMBuildConfigThreadLocalEntry { /*! \brief The current build config context */ std::stack context_stack; - TVMBuildConfigThreadLocalEntry() : - default_config(BuildConfig::Create()) { - } + TVMBuildConfigThreadLocalEntry() : default_config(BuildConfig::Create()) {} }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore TVMBuildConfigThreadLocalStore; void BuildConfig::EnterWithScope() { - TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); + TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get(); entry->context_stack.push(*this); } void BuildConfig::ExitWithScope() { - TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); + TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } tvm::BuildConfig BuildConfig::Current() { - TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); + TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } @@ -396,80 +357,73 @@ tvm::BuildConfig BuildConfig::Current() { TVM_REGISTER_NODE_TYPE(BuildConfigNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "build_config("; - p->stream << "data_alignment=" << op->data_alignment << ", "; - p->stream << "offset_factor=" << op->offset_factor << ", "; - p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", "; - p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", "; - p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", "; - p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", "; - p->stream << "unroll_explicit=" << op->unroll_explicit << ", "; - p->stream << "restricted_func=" << op->restricted_func << ", "; - p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", "; - p->stream << "partition_const_loop=" << op->partition_const_loop << ", "; - p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; - p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; - p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; - p->stream << "disable_vectorize=" << op->disable_vectorize; - p->stream << "disable_assert=" << op->disable_assert; - p->stream << ")"; -}); - -TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig") -.set_body([](TVMArgs args, TVMRetValue* ret) { + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "build_config("; + p->stream << "data_alignment=" << op->data_alignment << ", "; + p->stream << "offset_factor=" << op->offset_factor << ", "; + p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", "; + p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", "; + p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", "; + p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", "; + p->stream << "unroll_explicit=" << op->unroll_explicit << ", "; + p->stream << "restricted_func=" << op->restricted_func << ", "; + p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", "; + p->stream << "partition_const_loop=" << op->partition_const_loop << ", "; + p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; + p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; + p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; + p->stream << "disable_vectorize=" << op->disable_vectorize; + p->stream << "disable_assert=" << op->disable_assert; + p->stream << ")"; + }); + +TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = BuildConfig::Current(); - }); +}); class BuildConfig::Internal { public: - static void EnterScope(BuildConfig target) { - target.EnterWithScope(); - } - static void ExitScope(BuildConfig target) { - target.ExitWithScope(); - } + static void EnterScope(BuildConfig target) { target.EnterWithScope(); } + static void ExitScope(BuildConfig target) { target.ExitWithScope(); } }; TVM_REGISTER_GLOBAL("target.EnterBuildConfigScope") -.set_body_typed(BuildConfig::Internal::EnterScope); + .set_body_typed(BuildConfig::Internal::EnterScope); -TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope") -.set_body_typed(BuildConfig::Internal::ExitScope); +TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope").set_body_typed(BuildConfig::Internal::ExitScope); TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass") -.set_body([](TVMArgs args, TVMRetValue* ret) { - BuildConfig cfg = args[0]; - std::vector> add_lower_pass; - CHECK_EQ(args.size() % 2, 1); - for (int i = 1; i < args.size(); i += 2) { - add_lower_pass.push_back(std::make_pair( - args[i].operator int(), - args[i + 1].operator transform::Pass())); - } - cfg->add_lower_pass = add_lower_pass; - }); + .set_body([](TVMArgs args, TVMRetValue* ret) { + BuildConfig cfg = args[0]; + std::vector> add_lower_pass; + CHECK_EQ(args.size() % 2, 1); + for (int i = 1; i < args.size(); i += 2) { + add_lower_pass.push_back( + std::make_pair(args[i].operator int(), args[i + 1].operator transform::Pass())); + } + cfg->add_lower_pass = add_lower_pass; + }); TVM_REGISTER_GLOBAL("target.BuildConfigGetAddLowerPassInfo") -.set_body([](TVMArgs args, TVMRetValue* ret) { - // Return one of the following: - // * Size of add_lower_pass if num_args == 1 - // * Phase index of pass if args are (config, index, true) - // * Function of pass if args are (config, index, false) - BuildConfig cfg = args[0]; - if (args.num_args == 1) { - *ret = static_cast(cfg->add_lower_pass.size()); - } else { - int index = args[1]; - bool get_phase = args[2]; - auto item = cfg->add_lower_pass[index]; - if (get_phase) { - *ret = item.first; - } else { - *ret = item.second; - } - } -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + // Return one of the following: + // * Size of add_lower_pass if num_args == 1 + // * Phase index of pass if args are (config, index, true) + // * Function of pass if args are (config, index, false) + BuildConfig cfg = args[0]; + if (args.num_args == 1) { + *ret = static_cast(cfg->add_lower_pass.size()); + } else { + int index = args[1]; + bool get_phase = args[2]; + auto item = cfg->add_lower_pass[index]; + if (get_phase) { + *ret = item.first; + } else { + *ret = item.second; + } + } + }); } // namespace tvm diff --git a/src/target/target_info.cc b/src/target/target_info.cc index 73fe011..5ebb7ed 100644 --- a/src/target/target_info.cc +++ b/src/target/target_info.cc @@ -20,21 +20,21 @@ /*! * \file target/target_info.cc */ -#include #include +#include #include namespace tvm { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "mem-info(" - << "unit_bits=" << op->unit_bits << ", " - << "max_num_bits=" << op->max_num_bits << ", " - << "max_simd_bits=" << op->max_simd_bits << ", " - << "head_address=" << op->head_address << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "mem-info(" + << "unit_bits=" << op->unit_bits << ", " + << "max_num_bits=" << op->max_num_bits << ", " + << "max_simd_bits=" << op->max_simd_bits << ", " + << "head_address=" << op->head_address << ")"; + }); TVM_REGISTER_NODE_TYPE(MemoryInfoNode); diff --git a/src/te/autodiff/ad_util.cc b/src/te/autodiff/ad_util.cc index b1c97e3..874a512 100644 --- a/src/te/autodiff/ad_util.cc +++ b/src/te/autodiff/ad_util.cc @@ -21,10 +21,12 @@ * \file ad_util.cc * \brief Utility for tensor-level auto-differentiation. */ +#include "ad_util.h" + #include #include + #include -#include "ad_util.h" namespace tvm { namespace te { @@ -34,8 +36,7 @@ std::pair, Map> CloneIterVars(const Array Map vmap; for (const IterVar& iv : vars) { IterVar new_v = - IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), - iv->iter_type, iv->thread_tag); + IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), iv->iter_type, iv->thread_tag); new_vars.push_back(new_v); vmap.Set(iv->var, new_v->var); } @@ -53,8 +54,8 @@ PrimExpr CloneReduction(const PrimExpr& expr) { src_with_newaxis.push_back(tir::Substitute(src, vmap)); } - return ReduceNode::make(red->combiner, src_with_newaxis, - new_axis, tir::Substitute(red->condition, vmap), red->value_index); + return ReduceNode::make(red->combiner, src_with_newaxis, new_axis, + tir::Substitute(red->condition, vmap), red->value_index); } else { return expr; } diff --git a/src/te/autodiff/ad_util.h b/src/te/autodiff/ad_util.h index 7e511b1..56ab6c1 100644 --- a/src/te/autodiff/ad_util.h +++ b/src/te/autodiff/ad_util.h @@ -24,11 +24,12 @@ #ifndef TVM_TE_AUTODIFF_AD_UTIL_H_ #define TVM_TE_AUTODIFF_AD_UTIL_H_ -#include #include -#include +#include + #include #include +#include namespace tvm { namespace te { diff --git a/src/te/autodiff/adjoint.cc b/src/te/autodiff/adjoint.cc index 0c54764..4afca68 100644 --- a/src/te/autodiff/adjoint.cc +++ b/src/te/autodiff/adjoint.cc @@ -30,11 +30,12 @@ * (3) and sum them together to get the adjoint of the input itself. * The three steps are computed recursively. */ +#include +#include #include #include #include -#include -#include + #include #include @@ -47,27 +48,25 @@ Tensor Identity(const Tensor& output) { // add extra dimension for Jacobian shape.push_back(e); } - auto func = - [&output](const Array& input_indices) { - PrimExpr res = const_true(); - for (size_t i = 0; i < output->shape.size(); ++i) { - res = res && (PrimExpr(input_indices[i]) == - PrimExpr(input_indices[output->shape.size() + i])); - } - return CastNode::make(output->dtype, res); - }; + auto func = [&output](const Array& input_indices) { + PrimExpr res = const_true(); + for (size_t i = 0; i < output->shape.size(); ++i) { + res = + res && (PrimExpr(input_indices[i]) == PrimExpr(input_indices[output->shape.size() + i])); + } + return CastNode::make(output->dtype, res); + }; return te::compute(shape, func, "identity"); } -Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head) { +Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Tensor& head) { Tensor jac = Jacobian(output, input); Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(), output->op->name + "." + input->op->name + ".grad"); return result; } -Array Gradient(const Tensor& output, - const Array& inputs, +Array Gradient(const Tensor& output, const Array& inputs, const Tensor& head_or_null) { // Diagonal identity tensor Tensor head = head_or_null.get() ? head_or_null : Identity(output); @@ -95,41 +94,40 @@ Array Gradient(const Tensor& output, // This is a recursive function that does all the work. It computes the adjoint for a given // tensor, adds it to the map, and returns it std::function compute_adjoint; - compute_adjoint = - [&compute_adjoint, &adjoints, &reverse_dependencies, &head, &output] - (const Tensor& tensor) { - if (!adjoints.count(tensor)) { - // Here the adjoint hasn't been computed yet - Tensor res_adjoint; - std::vector direct_consumers = reverse_dependencies[tensor]; - if (direct_consumers.empty()) { - // No reverse dependencies means that the output does not depend on this tensor, - // return a zero tensor of the appropriate shape - // (i.e., output shape + tensor shape, aka shape of Jacobian) - Array result_shape(head->shape.begin(), - head->shape.end() + (-output->shape.size())); - for (auto e : tensor->shape) { - result_shape.push_back(e); - } - res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype)); - } else { - // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied - // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian - // and the multiplication is done in the function VectorJacobianProduct - for (const Tensor& direct_consumer : direct_consumers) { - // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor) - Tensor part = VectorJacobianProduct( - direct_consumer, tensor, compute_adjoint(direct_consumer)); - res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; - } + compute_adjoint = [&compute_adjoint, &adjoints, &reverse_dependencies, &head, + &output](const Tensor& tensor) { + if (!adjoints.count(tensor)) { + // Here the adjoint hasn't been computed yet + Tensor res_adjoint; + std::vector direct_consumers = reverse_dependencies[tensor]; + if (direct_consumers.empty()) { + // No reverse dependencies means that the output does not depend on this tensor, + // return a zero tensor of the appropriate shape + // (i.e., output shape + tensor shape, aka shape of Jacobian) + Array result_shape(head->shape.begin(), + head->shape.end() + (-output->shape.size())); + for (auto e : tensor->shape) { + result_shape.push_back(e); } - - adjoints[tensor] = res_adjoint; - return res_adjoint; + res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype)); } else { - return adjoints[tensor]; + // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied + // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian + // and the multiplication is done in the function VectorJacobianProduct + for (const Tensor& direct_consumer : direct_consumers) { + // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor) + Tensor part = + VectorJacobianProduct(direct_consumer, tensor, compute_adjoint(direct_consumer)); + res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; + } } - }; + + adjoints[tensor] = res_adjoint; + return res_adjoint; + } else { + return adjoints[tensor]; + } + }; // Adjoints corresponding to inputs Array result; @@ -141,15 +139,14 @@ Array Gradient(const Tensor& output, return result; } -TVM_REGISTER_GLOBAL("te.Gradient") -.set_body([](TVMArgs args, TVMRetValue *ret) { - LOG(WARNING) << "te.Gradient is an experimental feature."; - if (args.size() == 2) { - *ret = Gradient(args[0], args[1]); - } else if (args.size() == 3) { - *ret = Gradient(args[0], args[1], args[2]); - } - }); +TVM_REGISTER_GLOBAL("te.Gradient").set_body([](TVMArgs args, TVMRetValue* ret) { + LOG(WARNING) << "te.Gradient is an experimental feature."; + if (args.size() == 2) { + *ret = Gradient(args[0], args[1]); + } else if (args.size() == 3) { + *ret = Gradient(args[0], args[1], args[2]); + } +}); } // namespace te } // namespace tvm diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index d5b6fec..f770169 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -23,19 +23,23 @@ * X must be direct input tensor of Y. * The result Jacobian shape will be (Y.shape, X.shape) */ -#include #include #include +#include #include #include + #include "ad_util.h" namespace tvm { namespace te { -#define NOT_IMPLEMENTED \ - { LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef(op); throw; } +#define NOT_IMPLEMENTED \ + { \ + LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef(op); \ + throw; \ + } /*! \brief Differentiate an expression wrt a variable or a tensor element */ class JacobianMutator : public ExprMutator { @@ -46,7 +50,7 @@ class JacobianMutator : public ExprMutator { * \param indices The indices of the element with respect to which to differentiate. */ explicit JacobianMutator(Tensor input, Array indices) - : input_(input), indices_(indices) {} + : input_(input), indices_(indices) {} /*! * \brief Differentiate wrt the input variable. * \param input The input variable. @@ -71,14 +75,13 @@ class JacobianMutator : public ExprMutator { } } - PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const CallNode* op) { PrimExpr expr = GetRef(op); if (op->call_type == CallNode::CallType::Halide) { - if (input_.get() && op->func.same_as(input_->op) && - op->value_index == input_->value_index) { + if (input_.get() && op->func.same_as(input_->op) && op->value_index == input_->value_index) { // Tensor(indices) CHECK_EQ(indices_.size(), op->args.size()); PrimExpr condition = const_true(); @@ -99,86 +102,71 @@ class JacobianMutator : public ExprMutator { return MulNode::make(Mutate(op->args[0]), MulNode::make(expr, SubNode::make(FloatImm(expr.dtype(), 1.0), expr))); } else if (op->name == "sqrt") { - return DivNode::make(Mutate(op->args[0]), - MulNode::make(expr, FloatImm(expr.dtype(), 2.0))); + return DivNode::make(Mutate(op->args[0]), MulNode::make(expr, FloatImm(expr.dtype(), 2.0))); } else if (op->name == "tanh") { return MulNode::make(Mutate(op->args[0]), SubNode::make(FloatImm(expr.dtype(), 1.0), MulNode::make(expr, expr))); } else if (op->name == "pow") { auto x = op->args[0], y = op->args[1]; - return expr * (Mutate(y)*log(x) + Mutate(x)*y/x); + return expr * (Mutate(y) * log(x) + Mutate(x) * y / x); } else if (op->name == "fabs") { auto type = op->args[0].dtype(); return MulNode::make(Mutate(op->args[0]), SelectNode::make(GENode::make(op->args[0], make_zero(type)), FloatImm(type, 1.0), FloatImm(type, -1.0))); } else if (op->name == intrinsic::tvm_if_then_else) { - Array new_args = {op->args[0], - Mutate(op->args[1]), - Mutate(op->args[2])}; - return CallNode::make(op->dtype, op->name, new_args, - op->call_type, op->func, op->value_index); + Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; + return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func, + op->value_index); } else if (piecewise_const.count(op->name)) { return FloatImm(expr.dtype(), 0.0); } else { throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); } } - NOT_IMPLEMENTED + NOT_IMPLEMENTED; } - PrimExpr VisitExpr_(const AddNode* op) { - return AddNode::make(Mutate(op->a), Mutate(op->b)); - } + PrimExpr VisitExpr_(const AddNode* op) { return AddNode::make(Mutate(op->a), Mutate(op->b)); } - PrimExpr VisitExpr_(const SubNode* op) { - return SubNode::make(Mutate(op->a), Mutate(op->b)); - } + PrimExpr VisitExpr_(const SubNode* op) { return SubNode::make(Mutate(op->a), Mutate(op->b)); } PrimExpr VisitExpr_(const MulNode* op) { - return AddNode::make( - MulNode::make(Mutate(op->a), op->b), - MulNode::make(op->a, Mutate(op->b))); + return AddNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))); } PrimExpr VisitExpr_(const DivNode* op) { return DivNode::make( - SubNode::make( - MulNode::make(Mutate(op->a), op->b), - MulNode::make(op->a, Mutate(op->b))), + SubNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))), MulNode::make(op->b, op->b)); } - PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const FloorDivNode* op) { return FloorDivNode::make( - SubNode::make( - MulNode::make(Mutate(op->a), op->b), - MulNode::make(op->a, Mutate(op->b))), + SubNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))), MulNode::make(op->b, op->b)); } - PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const MinNode* op) { - return SelectNode::make(LENode::make(op->a, op->b), - Mutate(op->a), Mutate(op->b)); + return SelectNode::make(LENode::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); } PrimExpr VisitExpr_(const MaxNode* op) { - return SelectNode::make(GENode::make(op->a, op->b), - Mutate(op->a), Mutate(op->b)); + return SelectNode::make(GENode::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); } - PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const ReduceNode* op) { // This case is relatively difficult because a reduction expression @@ -265,9 +253,8 @@ class JacobianMutator : public ExprMutator { CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); // Also simplify the resulting combiner // (mostly to get rid of unused components, e.g., the original expressions) - return analyzer_.Simplify( - ReduceNode::make(new_combiner, new_source, new_op->axis, - new_op->condition, new_op->value_index)); + return analyzer_.Simplify(ReduceNode::make(new_combiner, new_source, new_op->axis, + new_op->condition, new_op->value_index)); } PrimExpr VisitExpr_(const CastNode* op) { @@ -278,26 +265,21 @@ class JacobianMutator : public ExprMutator { } } - PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const SelectNode* op) { - return SelectNode::make(op->condition, - Mutate(op->true_value), Mutate(op->false_value)); + return SelectNode::make(op->condition, Mutate(op->true_value), Mutate(op->false_value)); } - PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED; - PrimExpr VisitExpr_(const IntImmNode* op) { - return IntImm(op->dtype, 0); - } + PrimExpr VisitExpr_(const IntImmNode* op) { return IntImm(op->dtype, 0); } - PrimExpr VisitExpr_(const FloatImmNode* op) { - return FloatImm(op->dtype, 0); - } + PrimExpr VisitExpr_(const FloatImmNode* op) { return FloatImm(op->dtype, 0); } - PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED; private: Tensor input_; @@ -336,8 +318,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { Array input_indices; size_t i = 0; for (PrimExpr ext : input->shape) { - IterVar new_v = IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i++)), - IterVarType::kDataPar); + IterVar new_v = + IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i++)), IterVarType::kDataPar); // Append jacobian iter to new_axis new_axis.push_back(new_v); // Differentiate wrt input[input_indices] @@ -345,8 +327,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { } arith::Analyzer analzyer; // Compute Jacobian - PrimExpr new_body = Jacobian( - Substitute(op->body[output->value_index], vmap), input, input_indices); + PrimExpr new_body = + Jacobian(Substitute(op->body[output->value_index], vmap), input, input_indices); new_body = analzyer.Simplify(new_body); int value_index = 0; @@ -358,14 +340,14 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { value_index = red->value_index; for (size_t idx = 0; idx < red->source.size(); ++idx) { new_bodies.push_back( - ReduceNode::make(red->combiner, red->source, red->axis, red->condition, idx)); + ReduceNode::make(red->combiner, red->source, red->axis, red->condition, idx)); } } else { new_bodies.push_back(new_body); } - auto new_op = ComputeOpNode::make( - op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); + auto new_op = + ComputeOpNode::make(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); // Jacobian shape = output.shape + input.shape Array new_shape = output->shape; diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 1248547..d8ad839 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -21,46 +21,45 @@ * \brief Compute Op. * \file compute_op.cc */ +#include "compute_op.h" + +#include #include #include -#include -#include #include +#include #include -#include + #include +#include #include -#include "compute_op.h" -#include "op_util.h" -#include "../schedule/message_passing.h" + #include "../../arith/compute_expr.h" #include "../../arith/interval_set.h" +#include "../schedule/message_passing.h" +#include "op_util.h" namespace tvm { namespace te { using namespace tir; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "compute(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "compute(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(ComputeOpNode); /// Verify if ComputeOp is valid with respect to Reduce operations. -static void VerifyComputeOp(const ComputeOpNode *op); +static void VerifyComputeOp(const ComputeOpNode* op); inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - return (a->combiner.same_as(b->combiner)) && - (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && - (a->condition.same_as(b->condition)); + return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)); } -int ComputeOpNode::num_outputs() const { - return body.size(); -} +int ComputeOpNode::num_outputs() const { return body.size(); } Array BaseComputeOpNode::root_iter_vars() const { if (reduce_axis.size() == 0) return axis; @@ -87,10 +86,7 @@ Array BaseComputeOpNode::output_shape(size_t idx) const { return shape; } -Tensor compute(Array shape, - FCompute fcompute, - std::string name, - std::string tag, +Tensor compute(Array shape, FCompute fcompute, std::string name, std::string tag, Map attrs) { auto op_node = make_object(); // compute dimension. @@ -100,20 +96,16 @@ Tensor compute(Array shape, for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back(IterVarNode::make( - Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back( + IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } - return ComputeOpNode::make( - name, tag, attrs, axis, {fcompute(args)}).output(0); + return ComputeOpNode::make(name, tag, attrs, axis, {fcompute(args)}).output(0); } -Array compute(Array shape, - FBatchCompute fcompute, - std::string name, - std::string tag, - Map attrs) { +Array compute(Array shape, FBatchCompute fcompute, std::string name, + std::string tag, Map attrs) { auto op_node = make_object(); // compute dimension. size_t ndim = shape.size(); @@ -122,8 +114,8 @@ Array compute(Array shape, for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back(IterVarNode::make( - Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back( + IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } @@ -135,11 +127,8 @@ Array compute(Array shape, return outputs; } -Operation ComputeOpNode::make(std::string name, - std::string tag, - Map attrs, - Array axis, - Array body) { +Operation ComputeOpNode::make(std::string name, std::string tag, Map attrs, + Array axis, Array body) { if (!attrs.defined()) { attrs = Map(); } @@ -157,9 +146,7 @@ Operation ComputeOpNode::make(std::string name, return Operation(n); } -TVM_REGISTER_GLOBAL("te.ComputeOp") -.set_body_typed(ComputeOpNode::make); - +TVM_REGISTER_GLOBAL("te.ComputeOp").set_body_typed(ComputeOpNode::make); // The schedule related logics Array ComputeOpNode::InputTensors() const { @@ -167,22 +154,21 @@ Array ComputeOpNode::InputTensors() const { std::unordered_set visited; for (auto& e : body) { tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { - const tir::CallNode *call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); - if (!visited.count(t)) { - ret.push_back(t); - visited.insert(t); - } + const tir::CallNode* call = n.as(); + if (call != nullptr && call->func.defined()) { + Tensor t = Downcast(call->func).output(call->value_index); + if (!visited.count(t)) { + ret.push_back(t); + visited.insert(t); } - }); + } + }); } return ret; } -Operation ComputeOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation ComputeOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); VerifyComputeOp(this); Array arr; @@ -202,26 +188,22 @@ Operation ComputeOpNode::ReplaceInputs( arr = this->body; } } else { - arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) { - return te::ReplaceTensor(e, rmap); - }); + arr = + UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); }); } if (!arr.same_as(this->body)) { - return ComputeOpNode::make( - this->name, this->tag, this->attrs, this->axis, arr); + return ComputeOpNode::make(this->name, this->tag, this->attrs, this->axis, arr); } else { return self; } } -void ComputeOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { +void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { - auto *call = n.as(); + auto* call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); if (t->op.defined() && out_dom_map->count(t)) { @@ -260,10 +242,9 @@ void ComputeOpNode::PropBoundToInputs( for (auto& e : body) tir::PostOrderVisit(e, fvisit); } -void BaseComputeOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { +void BaseComputeOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); const TensorDom& tdom = tensor_dom.at(self.output(0)); for (size_t i = 0; i < this->axis.size(); ++i) { @@ -277,10 +258,9 @@ void BaseComputeOpNode::GatherBound( } } -Stmt BaseComputeOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { +Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { CHECK_EQ(stage->op.get(), this); Region bounds; for (IterVar iv : this->axis) { @@ -288,24 +268,22 @@ Stmt BaseComputeOpNode::BuildRealize( } Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { - Tensor t = stage->op.output(i-1); - realize = tir::RealizeNode::make(t->op, t->value_index, - t->dtype, bounds, const_true(), realize); + Tensor t = stage->op.output(i - 1); + realize = + tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); if (it != stage->iter_var_attrs.end()) { IterVarAttr attr = (*it).second; if (attr->dim_align_factor != 0) { - Array tuple = {static_cast(i), - attr->dim_align_factor, - attr->dim_align_offset}; - realize = tir::AttrStmtNode::make( - t, tir::attr::buffer_dim_align, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), - realize); + Array tuple = {static_cast(i), attr->dim_align_factor, + attr->dim_align_offset}; + realize = + tir::AttrStmtNode::make(t, tir::attr::buffer_dim_align, + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, + tuple, CallNode::Intrinsic), + realize); } } } @@ -313,16 +291,12 @@ Stmt BaseComputeOpNode::BuildRealize( return realize; } -size_t ComputeOpNode::num_schedulable_dims() const { - return axis.size(); -} +size_t ComputeOpNode::num_schedulable_dims() const { return axis.size(); } // Build a reduction body. -void MakeReduction(const ComputeOpNode* op, - const Array& tensors, - Stmt* init, +void MakeReduction(const ComputeOpNode* op, const Array& tensors, Stmt* init, Stmt* provide) { - Array args; + Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); } @@ -341,10 +315,8 @@ void MakeReduction(const ComputeOpNode* op, Array update_value = (*combiner)(lhs, reduce->source); for (size_t i = 0; i < size; ++i) { Tensor t = tensors[i]; - inits.emplace_back(ProvideNode::make( - t->op, t->value_index, init_value[i], args)); - provides.emplace_back(ProvideNode::make( - t->op, t->value_index, update_value[i], args)); + inits.emplace_back(ProvideNode::make(t->op, t->value_index, init_value[i], args)); + provides.emplace_back(ProvideNode::make(t->op, t->value_index, update_value[i], args)); } *init = SeqStmt::Flatten(inits); *provide = SeqStmt::Flatten(provides); @@ -354,8 +326,7 @@ void MakeReduction(const ComputeOpNode* op, } // Normal computation. -Stmt MakeProvide(const ComputeOpNode* op, - const Tensor& t) { +Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) { Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); @@ -363,8 +334,7 @@ Stmt MakeProvide(const ComputeOpNode* op, return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args); } -Stmt MakeComputeStmt(const ComputeOpNode* self, - const Stage& stage, +Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { // grab the nest structure @@ -383,10 +353,10 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, init = MergeNest(n.init_nest, init); init = Substitute(init, n.init_vmap); // common nest - std::vector > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector > reduce( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end()); + std::vector > common(n.main_nest.begin(), + n.main_nest.begin() + n.num_common_loop + 1); + std::vector > reduce(n.main_nest.begin() + n.num_common_loop + 1, + n.main_nest.end()); provide = MergeNest(reduce, provide); if (debug_keep_trivial_loop) { provide = MergeNest(common, provide); @@ -409,14 +379,9 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, } } -enum class ComputeType { - kNormal, - kCrossThreadReduction, - kTensorize -}; +enum class ComputeType { kNormal, kCrossThreadReduction, kTensorize }; -ComputeType DetectComputeType(const ComputeOpNode* self, - const Stage& stage) { +ComputeType DetectComputeType(const ComputeOpNode* self, const Stage& stage) { // Verify correctness of leaf nest. int normal_red = 0, thread_red = 0, tensorize = 0; @@ -436,13 +401,11 @@ ComputeType DetectComputeType(const ComputeOpNode* self, ++normal_red; } } else { - CHECK_EQ(thread_red, 0) - << "Cross thread reduce cannot swap with normal data axis"; + CHECK_EQ(thread_red, 0) << "Cross thread reduce cannot swap with normal data axis"; } } if (tensorize != 0) { - CHECK(thread_red == 0) - << "Cannot mix cross thread reduction with Tensorize"; + CHECK(thread_red == 0) << "Cannot mix cross thread reduction with Tensorize"; return ComputeType::kTensorize; } if (thread_red != 0) { @@ -453,10 +416,9 @@ ComputeType DetectComputeType(const ComputeOpNode* self, } // implement the provide utility. -Stmt ComputeOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt ComputeOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); ComputeType ctype = DetectComputeType(this, stage); if (ctype == ComputeType::kCrossThreadReduction) { @@ -469,20 +431,16 @@ Stmt ComputeOpNode::BuildProvide( } } -ComputeLoopNest ComputeLoopNest::make( - const BaseComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { +ComputeLoopNest ComputeLoopNest::make(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { CHECK_EQ(stage->op.operator->(), self); ComputeLoopNest ret; // make main loop nest - ret.main_nest = MakeLoopNest( - stage, dom_map, 0, false, std::unordered_set(), &ret.main_vmap, - debug_keep_trivial_loop); - ret.main_predicates = MakeBoundCheck( - stage, dom_map, ret.main_vmap, false, - std::unordered_set()); + ret.main_nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set(), + &ret.main_vmap, debug_keep_trivial_loop); + ret.main_predicates = + MakeBoundCheck(stage, dom_map, ret.main_vmap, false, std::unordered_set()); for (auto& e : ret.main_predicates) { e = likely(e); } @@ -508,7 +466,8 @@ ComputeLoopNest ComputeLoopNest::make( auto iv = leaf_iter_vars[i]; int flag = update_state.at(iv); if ((flag & 2) != 0) { - begin_loop = i; break; + begin_loop = i; + break; } ret.init_vmap[iv] = ret.main_vmap.at(iv); } @@ -519,11 +478,9 @@ ComputeLoopNest ComputeLoopNest::make( int flag = kv.second; if (flag == 2) skip_iter.insert(kv.first); } - ret.init_nest = MakeLoopNest( - stage, dom_map, begin_loop, true, - skip_iter, &(ret.init_vmap), debug_keep_trivial_loop); - ret.init_predicates = MakeBoundCheck( - stage, dom_map, ret.init_vmap, true, skip_iter); + ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap), + debug_keep_trivial_loop); + ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter); for (auto& e : ret.init_predicates) { e = likely(e); } @@ -563,14 +520,12 @@ class ComputeVerifier final : protected tir::ExprVisitor { for (const PrimExpr e : compute_->body) { // Check for consistency of top level reductions const tir::ReduceNode* reduce = e.as(); - CHECK((reduce && reduce_) || (!reduce && !reduce_)) - << "All ComputeOp should be consistent " - << "with being Reduce operation or not."; + CHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent " + << "with being Reduce operation or not."; if (reduce && reduce_) { - CHECK(ReduceEqual(reduce, reduce_)) - << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; + CHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; } level_ = 0; @@ -589,16 +544,15 @@ class ComputeVerifier final : protected tir::ExprVisitor { void VisitExpr_(const tir::ReduceNode* op) final { // Check for non top level reductions - CHECK(0 == level_) - << "Reductions are only allowed at the top level of compute. " - << "Please create another tensor for further composition."; + CHECK(0 == level_) << "Reductions are only allowed at the top level of compute. " + << "Please create another tensor for further composition."; } //@} private: - const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify - const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation - int level_{0}; ///< Level of op being processed + const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify + const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation + int level_{0}; ///< Level of op being processed }; } // namespace @@ -608,11 +562,8 @@ static void VerifyComputeOp(const ComputeOpNode* op) { v.Run(); } -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - const ComputeLoopNest& n, - Stmt body, - Stmt update) { +Stmt TransformUpdate(const Stage& stage, const std::unordered_map& dom_map, + const ComputeLoopNest& n, Stmt body, Stmt update) { Array conds; std::unordered_set banned; for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { @@ -633,19 +584,17 @@ Stmt TransformUpdate(const Stage& stage, } } - auto fbanned = [&](const VarNode* node) { - return banned.count(node); - }; + auto fbanned = [&](const VarNode* node) { return banned.count(node); }; for (const PrimExpr& pred : n.main_predicates) { if (tir::ExprUseVar(pred, fbanned)) { - LOG(FATAL) << "Tensorize update transform failed, the condition " - << pred << " has a conflict with the reset condition"; + LOG(FATAL) << "Tensorize update transform failed, the condition " << pred + << " has a conflict with the reset condition"; } } - return IfThenElseNode::make(arith::ComputeReduce(conds, const_true(1)), - update, body); + return IfThenElseNode::make(arith::ComputeReduce(conds, const_true(1)), update, + body); } } // namespace te diff --git a/src/te/operation/compute_op.h b/src/te/operation/compute_op.h index 08db74f..610c014 100644 --- a/src/te/operation/compute_op.h +++ b/src/te/operation/compute_op.h @@ -24,10 +24,11 @@ #ifndef TVM_TE_OPERATION_COMPUTE_OP_H_ #define TVM_TE_OPERATION_COMPUTE_OP_H_ -#include #include -#include +#include + #include +#include namespace tvm { namespace te { @@ -58,11 +59,9 @@ struct ComputeLoopNest { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The constructed loop nest */ - static ComputeLoopNest make( - const BaseComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); + static ComputeLoopNest make(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop); }; /*! @@ -73,11 +72,9 @@ struct ComputeLoopNest { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The created statement. */ -Stmt MakeCrossThreadReduction( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); +Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop); /*! * \brief Build body of compute for tensorization. @@ -87,10 +84,8 @@ Stmt MakeCrossThreadReduction( * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The created statement. */ -Stmt MakeTensorize(const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); +Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, bool debug_keep_trivial_loop); /*! * \brief Transform the update part when there is no init func in tensorizing @@ -101,11 +96,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, * \param update The update func in tensorize intrin * \return Transformed result. */ -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - const ComputeLoopNest& n, - Stmt body, - Stmt update); +Stmt TransformUpdate(const Stage& stage, const std::unordered_map& dom_map, + const ComputeLoopNest& n, Stmt body, Stmt update); } // namespace te } // namespace tvm diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 1ec17e9..0905631 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -28,21 +28,17 @@ namespace tvm { namespace te { using namespace tir; -Stmt MakeCrossThreadReduction( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { - Array args; +Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { + Array args; for (IterVar iv : self->axis) { args.push_back(iv->var); } std::unordered_map value_map; - auto nest = MakeLoopNest( - stage, dom_map, 0, false, std::unordered_set(), &value_map, debug_keep_trivial_loop); - auto conds = MakeBoundCheck( - stage, dom_map, value_map, false, - std::unordered_set()); + auto nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set(), &value_map, + debug_keep_trivial_loop); + auto conds = MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set()); size_t size = self->body.size(); CHECK_GT(size, 0); @@ -96,10 +92,10 @@ Stmt MakeCrossThreadReduction( Array update_value = (*combiner)(lhs, reduces[0]->source); for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; - normal_init.emplace_back(StoreNode::make( - normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); - normal_update.emplace_back(StoreNode::make( - normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); + normal_init.emplace_back( + StoreNode::make(normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); + normal_update.emplace_back( + StoreNode::make(normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); } } @@ -108,8 +104,7 @@ Stmt MakeCrossThreadReduction( for (size_t i = 0; i < size; ++i) { if (!normal_red.empty()) { DataType t = reduces[i]->dtype; - freduce_args.push_back(LoadNode::make( - t, normal_res_handles[i], 0, const_true(t.lanes()))); + freduce_args.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes()))); } else { freduce_args.push_back(reduces[0]->source[i]); } @@ -124,8 +119,7 @@ Stmt MakeCrossThreadReduction( for (IterVar iv : stage->leaf_iter_vars) { if (iv->iter_type == kCommReduce) { auto it = stage->iter_var_attrs.find(iv); - if (it != stage->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { + if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { IterVar tv = (*it).second->bind_thread; freduce_args.push_back(tv->var); } @@ -138,14 +132,9 @@ Stmt MakeCrossThreadReduction( } Stmt reduce_body = EvaluateNode::make(CallNode::make( - DataType::Handle(), - tir::intrinsic::tvm_thread_allreduce, - freduce_args, CallNode::Intrinsic)); - reduce_body = AttrStmtNode::make( - reduces[0]->combiner, - tir::attr::reduce_scope, - make_zero(DataType::Handle()), - reduce_body); + DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic)); + reduce_body = AttrStmtNode::make(reduces[0]->combiner, tir::attr::reduce_scope, + make_zero(DataType::Handle()), reduce_body); if (!normal_red.empty()) { Stmt init_body = SeqStmt::Flatten(normal_init); @@ -159,23 +148,22 @@ Stmt MakeCrossThreadReduction( for (size_t idx = 0; idx < size; ++idx) { DataType t = reduces[idx]->dtype; assigns[idx] = ProvideNode::make( - stage->op, idx, - LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args); + stage->op, idx, LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args); } Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(MakeIfNest(thread_head_check), assign_body); assign_body = MergeNest(MakeIfNest(conds), assign_body); Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { - body = AllocateNode::make( - res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmtNode::make( - res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body); + body = + AllocateNode::make(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = AttrStmtNode::make(res_handles[idx - 1], tir::attr::storage_scope, + StringImmNode::make("local"), body); if (!normal_red.empty()) { - body = AllocateNode::make( - normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmtNode::make( - normal_res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body); + body = AllocateNode::make(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, + const_true(), body); + body = AttrStmtNode::make(normal_res_handles[idx - 1], tir::attr::storage_scope, + StringImmNode::make("local"), body); } } body = Substitute(body, value_map); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 9d95e32..59d1ec1 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -21,11 +21,13 @@ * \brief External computation rule. * \file extern_op.cc */ +#include #include #include -#include #include + #include + #include "op_util.h" namespace tvm { @@ -33,37 +35,24 @@ namespace te { using namespace tir; // ExternOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "extern(" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "extern(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(ExternOpNode); -int ExternOpNode::num_outputs() const { - return static_cast(output_placeholders.size()); -} - -Array ExternOpNode::root_iter_vars() const { - return {}; -} +int ExternOpNode::num_outputs() const { return static_cast(output_placeholders.size()); } -DataType ExternOpNode::output_dtype(size_t i) const { - return output_placeholders[i]->dtype; -} +Array ExternOpNode::root_iter_vars() const { return {}; } -Array ExternOpNode::output_shape(size_t i) const { - return output_placeholders[i]->shape; -} +DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders[i]->dtype; } +Array ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; } -Operation ExternOpNode::make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array input_placeholders, - Array output_placeholders, - Stmt body) { +Operation ExternOpNode::make(std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body) { if (!attrs.defined()) { attrs = Map(); } @@ -76,7 +65,7 @@ Operation ExternOpNode::make(std::string name, CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size()); for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) { - CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); + CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); } CHECK_EQ(input_placeholders[i]->strides.size(), 0U); } @@ -87,17 +76,12 @@ Operation ExternOpNode::make(std::string name, return Operation(n); } -TVM_REGISTER_GLOBAL("te.ExternOp") -.set_body_typed(ExternOpNode::make); +TVM_REGISTER_GLOBAL("te.ExternOp").set_body_typed(ExternOpNode::make); +Array ExternOpNode::InputTensors() const { return inputs; } -Array ExternOpNode::InputTensors() const { - return inputs; -} - -Operation ExternOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation ExternOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = ReplaceTensor(this->body, rmap); @@ -108,65 +92,54 @@ Operation ExternOpNode::ReplaceInputs( } } - if (body.same_as(n->body) && - inputs.same_as(n->inputs)) { + if (body.same_as(n->body) && inputs.same_as(n->inputs)) { return self; } else { return Operation(n); } } -void ExternOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { +void ExternOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { for (Tensor t : this->inputs) { auto it = out_dom_map->find(t); if (it == out_dom_map->end()) continue; TensorDom& dom = it->second; for (size_t i = 0; i < t->shape.size(); ++i) { dom.data[i].emplace_back(IntSet::range( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i]))); + Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]))); } } } -void ExternOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { -} +void ExternOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const {} -Stmt ExternOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { +Stmt ExternOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { CHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { Tensor t = stage->op.output(k); Region bounds; for (size_t i = 0; i < t->shape.size(); ++i) { - bounds.push_back( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i])); + bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::RealizeNode::make( - t->op, t->value_index, t->dtype, - bounds, const_true(), realize_body); + realize_body = + tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize_body); } return realize_body; } -Stmt ExternOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt ExternOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmtNode::make( - make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); + Stmt ret = + AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { Array bind_spec; Array tuple; diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 7bb5d61..0022b6f 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -21,54 +21,44 @@ * \brief Hybrid computation rule. * \file hybrid_op.cc */ +#include "hybrid_op.h" + +#include #include #include -#include -#include -#include #include +#include #include -#include +#include + #include +#include #include + #include "op_util.h" -#include "hybrid_op.h" namespace tvm { namespace te { using namespace tir; // HybridOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "hybrid(" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "hybrid(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(HybridOpNode); -int HybridOpNode::num_outputs() const { - return static_cast(outputs.size()); -} - -Array HybridOpNode::root_iter_vars() const { - return this->axis; -} +int HybridOpNode::num_outputs() const { return static_cast(outputs.size()); } -DataType HybridOpNode::output_dtype(size_t i) const { - return outputs[i]->dtype; -} +Array HybridOpNode::root_iter_vars() const { return this->axis; } -Array HybridOpNode::output_shape(size_t i) const { - return outputs[i]->shape; -} +DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } +Array HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } -Operation HybridOpNode::make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array outputs, - Stmt body) { +Operation HybridOpNode::make(std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, Stmt body) { if (!attrs.defined()) { attrs = Map(); } @@ -84,9 +74,7 @@ Operation HybridOpNode::make(std::string name, return res; } -TVM_REGISTER_GLOBAL("te.HybridOp") -.set_body_typed(HybridOpNode::make); - +TVM_REGISTER_GLOBAL("te.HybridOp").set_body_typed(HybridOpNode::make); Array HybridOpNode::InputTensors() const { // Because input tensors could be potentially inlined into hybrid scripts, @@ -98,21 +86,20 @@ Array HybridOpNode::InputTensors() const { std::unordered_set visited; Array curr_inputs; tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { - const tir::CallNode *call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); - if (orig_inputs.count(t) && !visited.count(t)) { - curr_inputs.push_back(t); - visited.insert(t); - } + const tir::CallNode* call = n.as(); + if (call != nullptr && call->func.defined()) { + Tensor t = Downcast(call->func).output(call->value_index); + if (orig_inputs.count(t) && !visited.count(t)) { + curr_inputs.push_back(t); + visited.insert(t); } + } }); return curr_inputs; } -Operation HybridOpNode::ReplaceInputs( - const Operation &self, - const std::unordered_map &rmap) const { +Operation HybridOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = te::ReplaceTensor(this->body, rmap); @@ -123,46 +110,40 @@ Operation HybridOpNode::ReplaceInputs( } } - if (body.same_as(n->body) && - inputs.same_as(n->inputs)) { + if (body.same_as(n->body) && inputs.same_as(n->inputs)) { return self; } else { return Operation(n); } } -void HybridOpNode::PropBoundToInputs( - const Operation &self, - arith::Analyzer* analyzer, - const std::unordered_map &dom_map, - std::unordered_map* out_dom_map) const { +void HybridOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { auto curr_inputs = InputTensors(); for (Tensor t : curr_inputs) { auto it = out_dom_map->find(t); if (it == out_dom_map->end()) continue; - TensorDom &dom = it->second; + TensorDom& dom = it->second; for (size_t i = 0; i < t->shape.size(); ++i) { dom.data[i].emplace_back(IntSet::range( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i]))); + Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]))); } } } -void HybridOpNode::GatherBound( - const Operation &self, - const std::unordered_map &tensor_dom, - std::unordered_map* out_dom_map) const { +void HybridOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { for (auto iter_var : axis) { CHECK(!out_dom_map->count(iter_var)); out_dom_map->operator[](iter_var) = iter_var->dom; } } -Stmt HybridOpNode::BuildRealize( - const Stage &stage, - const std::unordered_map &realize_map, - const Stmt &body) const { +Stmt HybridOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { // TODO(@were): Add attribute inject here and remove it from hybrid parser. CHECK_EQ(stage->op.get(), this); Stmt realize_body = body; @@ -170,24 +151,20 @@ Stmt HybridOpNode::BuildRealize( Tensor t = stage->op.output(k); Region bounds; for (size_t i = 0; i < t->shape.size(); ++i) { - bounds.push_back( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i])); + bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::RealizeNode::make( - t->op, t->value_index, t->dtype, - bounds, const_true(), realize_body); + realize_body = + tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize_body); } return realize_body; } -Stmt HybridOpNode::BuildProvide( - const Stage &stage, - const std::unordered_map &dom_map, - bool debug_keep_trivial_loop) const { +Stmt HybridOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmtNode::make( - make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); + Stmt ret = + AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); std::unordered_map rmap; for (int i = 0; i < this->num_outputs(); ++i) { rmap[outputs[i]] = stage->op.output(i); @@ -223,45 +200,44 @@ Stmt HybridOpNode::BuildProvide( return ret; } -Stmt ApplyLoopShapes(const Stage &stage, - const std::unordered_map &dom_map, Stmt stmt) { +Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt) { class LoopSpliter : public StmtExprMutator { PrimExpr factor; - const VarNode *parent; + const VarNode* parent; IterVar inner, outer; public: bool splitted; - LoopSpliter(const SplitNode *split, - const std::unordered_map &dom_map) : - factor(split->factor), splitted(false) { + LoopSpliter(const SplitNode* split, const std::unordered_map& dom_map) + : factor(split->factor), splitted(false) { parent = split->parent->var.get(); - auto &inner_ = split->inner; + auto& inner_ = split->inner; CHECK(dom_map.count(inner_)); - auto &inner_dom = dom_map.find(inner_)->second; + auto& inner_dom = dom_map.find(inner_)->second; CHECK(is_const_int(inner_dom->min, 0)); - auto &outer_ = split->outer; + auto& outer_ = split->outer; CHECK(dom_map.count(outer_)); - auto &outer_dom = dom_map.find(outer_)->second; + auto& outer_dom = dom_map.find(outer_)->second; CHECK(is_const_int(outer_dom->min, 0)); inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type); outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type); } - Stmt VisitStmt_(const ForNode *op) final { + Stmt VisitStmt_(const ForNode* op) final { if (op->loop_var.get() == parent) { - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = inner + outer * factor; Stmt ret = tir::Substitute(op->body, rmap); PrimExpr cond = likely(outer * factor < (op->extent - inner)); ret = IfThenElseNode::make(cond, ret); ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent, - IterVarTypeToForType(inner->iter_type), op->device_api, ret); + IterVarTypeToForType(inner->iter_type), op->device_api, ret); ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent, - IterVarTypeToForType(outer->iter_type), op->device_api, ret); + IterVarTypeToForType(outer->iter_type), op->device_api, ret); splitted = true; return ret; } @@ -270,24 +246,27 @@ Stmt ApplyLoopShapes(const Stage &stage, }; class LoopFuser : public StmtExprMutator { - const IterVar &parent; - const VarNode *inner; - const VarNode *outer; + const IterVar& parent; + const VarNode* inner; + const VarNode* outer; bool under_outer; PrimExpr extent; public: bool fused; - explicit LoopFuser(const FuseNode *fuse_) - : parent(fuse_->fused), inner(fuse_->inner->var.get()), - outer(fuse_->outer->var.get()), under_outer(false), - extent(0), fused(false) {} + explicit LoopFuser(const FuseNode* fuse_) + : parent(fuse_->fused), + inner(fuse_->inner->var.get()), + outer(fuse_->outer->var.get()), + under_outer(false), + extent(0), + fused(false) {} // TODO(@were): Handle imperfect loops Stmt VisitStmt_(const ForNode* op) final { if (op->loop_var.get() == inner) { CHECK(under_outer); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(parent, op->extent); extent = op->extent; fused = true; @@ -295,15 +274,15 @@ Stmt ApplyLoopShapes(const Stage &stage, } else if (op->loop_var.get() == outer) { under_outer = true; Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexdiv(parent, extent); body = tir::Substitute(body, rmap); under_outer = false; - return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, - op->for_type, op->device_api, body); + return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, op->for_type, + op->device_api, body); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); body = tir::Substitute(body, rmap); extent = extent * op->extent; @@ -313,12 +292,12 @@ Stmt ApplyLoopShapes(const Stage &stage, } }; - for (auto &rel : stage->relations) { - if (const SplitNode *split = rel.as()) { + for (auto& rel : stage->relations) { + if (const SplitNode* split = rel.as()) { LoopSpliter Spliter(split, dom_map); stmt = Spliter(stmt); CHECK(Spliter.splitted); - } else if (const FuseNode *fuse = rel.as()) { + } else if (const FuseNode* fuse = rel.as()) { LoopFuser Fuser(fuse); stmt = Fuser(stmt); CHECK(Fuser.fused); @@ -328,45 +307,45 @@ Stmt ApplyLoopShapes(const Stage &stage, return stmt; } -Stmt ApplyLoopAnnotations(const Stage &stage, - const std::unordered_map &rebased, Stmt stmt) { +Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map& rebased, + Stmt stmt) { class LoopAnnotator : public StmtMutator { - const VarNode *var; - const IterVarAttr &attr; + const VarNode* var; + const IterVarAttr& attr; public: - LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {} + LoopAnnotator(const VarNode* var_, const IterVarAttr& attr_) : var(var_), attr(attr_) {} - Stmt VisitStmt_(const ForNode *op) final { + Stmt VisitStmt_(const ForNode* op) final { tir::ExprDeepEqual expr_equal; if (op->loop_var.get() == var) { if (attr->bind_thread.defined()) { - const auto &iter_var = attr->bind_thread; + const auto& iter_var = attr->bind_thread; if (iter_var->dom.defined()) { CHECK(is_const_int(iter_var->dom->min, 0)); CHECK(expr_equal(iter_var->dom->extent, op->extent)) - << "Thread extent and loop extent mismatch!\n"; + << "Thread extent and loop extent mismatch!\n"; } - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = iter_var; Stmt body = tir::Substitute(op->body, rmap); return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body); } else { return ForNode::make(op->loop_var, op->min, op->extent, - IterVarTypeToForType(attr->iter_type), op->device_api, op->body); + IterVarTypeToForType(attr->iter_type), op->device_api, op->body); } } return StmtMutator::VisitStmt_(op); } }; - for (auto &iter_var : stage->leaf_iter_vars) { + for (auto& iter_var : stage->leaf_iter_vars) { bool need_change = false; int found = 0; - const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; - const VarNode *var = actual->var.get(); + const IterVar& actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; + const VarNode* var = actual->var.get(); ForType expected = IterVarTypeToForType(iter_var->iter_type); IterVarAttr attr; if (stage->iter_var_attrs.count(iter_var)) { @@ -374,9 +353,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage, expected = IterVarTypeToForType(attr->iter_type); } - PostOrderVisit(stmt, - [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { - if (const ForNode *op = node.as()) { + PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { + if (const ForNode* op = node.as()) { if (op->loop_var.get() == var) { ++found; need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined()); @@ -392,23 +370,21 @@ Stmt ApplyLoopAnnotations(const Stage &stage, return stmt; } -Stmt ApplyLoopOrder(const Stage &stage, - const std::unordered_map &dom_map, - const std::unordered_map &rebased, Stmt stmt) { +Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map& dom_map, + const std::unordered_map& rebased, Stmt stmt) { std::vector current_order; PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) { - if (const ForNode *op = node.as()) - current_order.push_back(op->loop_var.get()); + if (const ForNode* op = node.as()) current_order.push_back(op->loop_var.get()); }); std::reverse(current_order.begin(), current_order.end()); - auto &required_ord = stage->leaf_iter_vars; + auto& required_ord = stage->leaf_iter_vars; CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!"; - std::unordered_map reorder; + std::unordered_map reorder; bool need_reorder = false; for (size_t i = 0; i < current_order.size(); ++i) { - auto ¤t = current_order[i]; - const IterVar &iter_var = required_ord[i]; - const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; + auto& current = current_order[i]; + const IterVar& iter_var = required_ord[i]; + const IterVar& required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n"; reorder[current] = required; if (current != required->var.get()) { @@ -417,15 +393,14 @@ Stmt ApplyLoopOrder(const Stage &stage, } class LoopReorder : public StmtMutator { - const Stage &stage; - const std::unordered_map &dom_map; - const std::unordered_map &reorder; + const Stage& stage; + const std::unordered_map& dom_map; + const std::unordered_map& reorder; public: - LoopReorder(const Stage &stage, - const std::unordered_map &dom_map, - const std::unordered_map &reorder) - : stage(stage), dom_map(dom_map), reorder(reorder) {} + LoopReorder(const Stage& stage, const std::unordered_map& dom_map, + const std::unordered_map& reorder) + : stage(stage), dom_map(dom_map), reorder(reorder) {} Stmt VisitStmt_(const ForNode* op) final { // Reorder from in to out @@ -434,25 +409,23 @@ Stmt ApplyLoopOrder(const Stage &stage, auto target = reorder.find(op->loop_var.get())->second; if (body_.same_as(op->body) && op->loop_var.get() == target->var.get()) return GetRef(op); - const Stmt &body = op->body.same_as(body_) ? op->body : body_; + const Stmt& body = op->body.same_as(body_) ? op->body : body_; ForType for_type = IterVarTypeToForType(target->iter_type); if (stage->iter_var_attrs.count(target)) { for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type); } - const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second; - return ForNode::make(target->var, range->min, range->extent, - for_type, DeviceAPI::None, body); + const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; + return ForNode::make(target->var, range->min, range->extent, for_type, DeviceAPI::None, body); } }; - if (need_reorder) - return LoopReorder(stage, dom_map, reorder)(stmt); + if (need_reorder) return LoopReorder(stage, dom_map, reorder)(stmt); return stmt; } -Stmt ApplySchedule(const Stage &stage, - const std::unordered_map &dom_map, Stmt stmt) { +Stmt ApplySchedule(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt) { // TODO(@were): Eliminate loop rebase in script parser and move the burden here // Gather rebased variables std::unordered_map rebased; @@ -473,7 +446,7 @@ std::vector GatherLoopVars(Stmt stmt) { // TODO(@were): Write a comprehensive pass to analyze iter var types std::vector res_; PostOrderVisit(stmt, [&res_](const ObjectRef& node) { - if (const ForNode *op = node.as()) { + if (const ForNode* op = node.as()) { Var loop_var(op->loop_var); Range dom = Range::make_by_min_extent(op->min, op->extent); res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type))); @@ -486,15 +459,14 @@ std::vector GatherLoopVars(Stmt stmt) { // replacer to replace tensors' usage in Provide class ProviderReplacer : public tir::StmtMutator { public: - explicit ProviderReplacer(const std::unordered_map &vmap) - : vmap_(vmap) {} + explicit ProviderReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} Stmt VisitStmt_(const tir::ProvideNode* op) final { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); if (it != vmap_.end()) { - Stmt ret = tir::ProvideNode::make( - it->second->op, it->second->value_index, op->value, op->args); + Stmt ret = + tir::ProvideNode::make(it->second->op, it->second->value_index, op->value, op->args); found = true; return this->VisitStmt(ret); } @@ -505,11 +477,10 @@ class ProviderReplacer : public tir::StmtMutator { bool found{false}; private: - const std::unordered_map &vmap_; + const std::unordered_map& vmap_; }; -Stmt ReplaceProvideTensor(Stmt stmt, - const std::unordered_map &replace) { +Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map& replace) { ProviderReplacer repl(replace); Stmt ret = repl(stmt); return repl.found ? ret : stmt; diff --git a/src/te/operation/hybrid_op.h b/src/te/operation/hybrid_op.h index dadfecd..a11ae89 100644 --- a/src/te/operation/hybrid_op.h +++ b/src/te/operation/hybrid_op.h @@ -24,16 +24,16 @@ #ifndef TVM_TE_OPERATION_HYBRID_OP_H_ #define TVM_TE_OPERATION_HYBRID_OP_H_ -#include #include +#include #include #include #include -#include "../schedule/message_passing.h" -#include "../../tir/transforms/ir_util.h" #include "../../tir/transforms/arg_binder.h" +#include "../../tir/transforms/ir_util.h" +#include "../schedule/message_passing.h" namespace tvm { namespace te { @@ -49,8 +49,7 @@ std::vector GatherLoopVars(Stmt stmt); * \param stmt The statement to be processed. * \param replace The replacement rule. */ -Stmt ReplaceProvideTensor(Stmt stmt, - const std::unordered_map& replace); +Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map& replace); /*! * \brief Apply the schedule manipulation on the function body. @@ -58,8 +57,8 @@ Stmt ReplaceProvideTensor(Stmt stmt, * \param dom_map The extents of the iterative variables may be used. * \param stage The schedule information to be applied. */ -Stmt ApplySchedule(const Stage& stage, - const std::unordered_map& dom_map, Stmt stmt); +Stmt ApplySchedule(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt); /*! * \brief Apply loop splits and fuses in the schedule on the function body. @@ -67,9 +66,8 @@ Stmt ApplySchedule(const Stage& stage, * \param dom_map The extents of the iterative variables may be used. * \param stmt The statement to be processed. */ -Stmt ApplyLoopShapes(const Stage &stage, - const std::unordered_map& dom_map, Stmt stmt); - +Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt); /*! * \brief Apply loop annotation in the schedule on the function body. @@ -77,8 +75,8 @@ Stmt ApplyLoopShapes(const Stage &stage, * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables. * \param stmt The statement to be processed. */ -Stmt ApplyLoopAnnotations(const Stage &stage, - const std::unordered_map& rebased, Stmt stmt); +Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map& rebased, + Stmt stmt); /*! * \brief Apply loop order in the schedule on the function body. @@ -87,9 +85,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage, * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables. * \param stmt The statement to be processed. */ -Stmt ApplyLoopOrder(const Stage &stage, - const std::unordered_map &dom_map, - const std::unordered_map &rebased, Stmt stmt); +Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map& dom_map, + const std::unordered_map& rebased, Stmt stmt); } // namespace te } // namespace tvm diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index bee573e..5b200ac 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -21,14 +21,17 @@ * \brief Utility to make loop nest. * \file op_util.cc */ +#include "op_util.h" + +#include #include #include -#include + #include -#include "op_util.h" -#include "../schedule/message_passing.h" + #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" +#include "../schedule/message_passing.h" namespace tvm { namespace te { @@ -36,14 +39,12 @@ namespace te { using namespace arith; using namespace tir; -std::vector > -MakeLoopNest(const Stage& stage, - const std::unordered_map& dom_map, - size_t begin_iter_pos, - bool new_loop_var, - const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, - bool debug_keep_trivial_loop) { +std::vector > MakeLoopNest(const Stage& stage, + const std::unordered_map& dom_map, + size_t begin_iter_pos, bool new_loop_var, + const std::unordered_set& skip_iter, + std::unordered_map* p_value_map, + bool debug_keep_trivial_loop) { auto leaf_iter_vars = stage->leaf_iter_vars; Stmt no_op = EvaluateNode::make(0); // create the loop nest @@ -84,14 +85,21 @@ MakeLoopNest(const Stage& stage, } if (it_attr.defined()) { switch (it_attr->iter_type) { - case kUnrolled: for_type = ForType::Unrolled; break; - case kVectorized: for_type = ForType::Vectorized; break; - case kParallelized: for_type = ForType::Parallel; break; - case kDataPar: break; - case kTensorized: break; - default: LOG(FATAL) << "Unknown iter type" - << it_attr->iter_type - << " in the iter_var_attrs"; + case kUnrolled: + for_type = ForType::Unrolled; + break; + case kVectorized: + for_type = ForType::Vectorized; + break; + case kParallelized: + for_type = ForType::Parallel; + break; + case kDataPar: + break; + case kTensorized: + break; + default: + LOG(FATAL) << "Unknown iter type" << it_attr->iter_type << " in the iter_var_attrs"; } CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size()); for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) { @@ -105,38 +113,30 @@ MakeLoopNest(const Stage& stage, } } if (!debug_keep_trivial_loop && is_one(dom->extent)) { - nest[i + 1].emplace_back( - LetStmtNode::make(var, dom->min, no_op)); + nest[i + 1].emplace_back(LetStmtNode::make(var, dom->min, no_op)); value_map[iv] = dom->min; } else if (is_zero(dom->min)) { nest[i + 1].emplace_back( - ForNode::make(var, 0, dom->extent, - for_type, DeviceAPI::None, no_op)); + ForNode::make(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); value_map[iv] = var; } else { Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype()); nest[i + 1].emplace_back( - ForNode::make(idx, 0, dom->extent, - for_type, DeviceAPI::None, no_op)); + ForNode::make(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op)); PrimExpr new_value = dom->min + idx; value_map[iv] = new_value; - nest[i + 1].emplace_back( - LetStmtNode::make(var, new_value, no_op)); + nest[i + 1].emplace_back(LetStmtNode::make(var, new_value, no_op)); } if (it_attr.defined() && it_attr->prefetch_data.size() != 0) { - CHECK(!is_one(dom->extent)) - << "Cannot prefetch on trivial loop with extent=1"; - CHECK_EQ(it_attr->prefetch_data.size(), - it_attr->prefetch_offset.size()); + CHECK(!is_one(dom->extent)) << "Cannot prefetch on trivial loop with extent=1"; + CHECK_EQ(it_attr->prefetch_data.size(), it_attr->prefetch_offset.size()); for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) { - nest[i + 1].emplace_back( - AttrStmtNode::make(it_attr->prefetch_data[j], - tir::attr::prefetch_scope, - it_attr->prefetch_offset[j], no_op)); + nest[i + 1].emplace_back(AttrStmtNode::make(it_attr->prefetch_data[j], + tir::attr::prefetch_scope, + it_attr->prefetch_offset[j], no_op)); } } - } else if (bind_iv->thread_tag == "vthread" || - bind_iv->thread_tag == "cthread") { + } else if (bind_iv->thread_tag == "vthread" || bind_iv->thread_tag == "cthread") { // virtual thread // Always restrict threaded IterVar to starts from 0. CHECK(is_zero(dom->min)); @@ -173,9 +173,9 @@ MakeLoopNest(const Stage& stage, value_map[iv] = var; } else { LOG(WARNING) - << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. " - << "TVM assumes only threadIdx.x indicates threads inside a warp, " - << "while threadIdx.y and threadIdx.z indicates different warps."; + << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. " + << "TVM assumes only threadIdx.x indicates threads inside a warp, " + << "while threadIdx.y and threadIdx.z indicates different warps."; value_map[iv] = dom->min; } } else { @@ -185,8 +185,7 @@ MakeLoopNest(const Stage& stage, } // annotate the extent of the IterVar if (!new_loop_var) { - nest[i + 1].emplace_back( - AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op)); + nest[i + 1].emplace_back(AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op)); } } // message passing to get offset of root iter vars. @@ -206,17 +205,15 @@ std::vector MakeIfNest(const std::vector& predicates) { // replacer to replace tensors class TensorReplacer : public tir::StmtExprMutator { public: - explicit TensorReplacer(const std::unordered_map& vmap) - : vmap_(vmap) {} + explicit TensorReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} PrimExpr VisitExpr_(const tir::CallNode* op) final { if (op->call_type == tir::CallNode::Halide) { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); if (it != vmap_.end()) { - PrimExpr ret = tir::CallNode::make( - op->dtype, it->second->op->name, op->args, - op->call_type, it->second->op, it->second->value_index); + PrimExpr ret = tir::CallNode::make(op->dtype, it->second->op->name, op->args, op->call_type, + it->second->op, it->second->value_index); found = true; return this->VisitExpr(ret); } @@ -231,22 +228,18 @@ class TensorReplacer : public tir::StmtExprMutator { const std::unordered_map& vmap_; }; -Stmt ReplaceTensor(Stmt stmt, - const std::unordered_map& replace) { +Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace) { TensorReplacer repl(replace); Stmt ret = repl(stmt); return repl.found ? ret : stmt; } -PrimExpr ReplaceTensor(PrimExpr expr, - const std::unordered_map& replace) { +PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& replace) { TensorReplacer repl(replace); PrimExpr ret = repl(expr); return repl.found ? ret : expr; } - -Stmt Substitute(Stmt s, - const std::unordered_map& value_map) { +Stmt Substitute(Stmt s, const std::unordered_map& value_map) { std::unordered_map init; for (const auto& kv : value_map) { init[kv.first->var.get()] = kv.second; @@ -256,31 +249,31 @@ Stmt Substitute(Stmt s, IterVarType ForTypeToIterVarType(tir::ForType for_type) { switch (for_type) { - case ForType::Serial: - return kDataPar; - case ForType::Parallel: - return kParallelized; - case ForType::Vectorized: - return kVectorized; - case ForType::Unrolled: - return kUnrolled; - default: - return kDataPar; + case ForType::Serial: + return kDataPar; + case ForType::Parallel: + return kParallelized; + case ForType::Vectorized: + return kVectorized; + case ForType::Unrolled: + return kUnrolled; + default: + return kDataPar; } } tir::ForType IterVarTypeToForType(IterVarType iter_type) { switch (iter_type) { - case kDataPar: - return ForType::Serial; - case kParallelized: - return ForType::Parallel; - case kVectorized: - return ForType::Vectorized; - case kUnrolled: - return ForType::Unrolled; - default: - return ForType::Serial; + case kDataPar: + return ForType::Serial; + case kParallelized: + return ForType::Parallel; + case kVectorized: + return ForType::Vectorized; + case kUnrolled: + return ForType::Unrolled; + default: + return ForType::Serial; } } diff --git a/src/te/operation/op_util.h b/src/te/operation/op_util.h index f95f84a..6c864fc 100644 --- a/src/te/operation/op_util.h +++ b/src/te/operation/op_util.h @@ -24,13 +24,15 @@ #ifndef TVM_TE_OPERATION_OP_UTIL_H_ #define TVM_TE_OPERATION_OP_UTIL_H_ -#include #include +#include + #include #include #include -#include "../../tir/transforms/ir_util.h" + #include "../../tir/transforms/arg_binder.h" +#include "../../tir/transforms/ir_util.h" #include "../schedule/message_passing.h" namespace tvm { @@ -49,14 +51,12 @@ using tir::MergeNest; * \param p_value_map The result value of each IterVar. * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 */ -std::vector > -MakeLoopNest(const Stage& stage, - const std::unordered_map& dom_map, - size_t begin_iter_pos, - bool new_loop_var, - const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, - bool debug_keep_trivial_loop); +std::vector > MakeLoopNest(const Stage& stage, + const std::unordered_map& dom_map, + size_t begin_iter_pos, bool new_loop_var, + const std::unordered_set& skip_iter, + std::unordered_map* p_value_map, + bool debug_keep_trivial_loop); /*! * \brief Create a nest of if checking the predicates. @@ -71,15 +71,13 @@ std::vector MakeIfNest(const std::vector& predicates); * \param stmt The statement to be processed. * \param replace The replacement rule. */ -Stmt ReplaceTensor(Stmt stmt, - const std::unordered_map& replace); +Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace); /*! * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. * \param expr The expression to be processed. * \param replace The replacement rule. */ -PrimExpr ReplaceTensor(PrimExpr expr, - const std::unordered_map& replace); +PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& replace); /*! * \brief Substitute the variables of stmt by value map. @@ -87,8 +85,7 @@ PrimExpr ReplaceTensor(PrimExpr expr, * \param value_map The value map. * \return Substituted result. */ -Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map); +Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); /*! * \brief Converts Halide ForType to its corresponding IterVarType diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index d48be4c..9c536eb 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -29,20 +29,16 @@ namespace te { // PlaceholderOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "placeholder(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "placeholder(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(PlaceholderOpNode); -int PlaceholderOpNode::num_outputs() const { - return 1; -} +int PlaceholderOpNode::num_outputs() const { return 1; } -Array PlaceholderOpNode::root_iter_vars() const { - return {}; -} +Array PlaceholderOpNode::root_iter_vars() const { return {}; } DataType PlaceholderOpNode::output_dtype(size_t i) const { CHECK_EQ(i, 0U); @@ -54,9 +50,7 @@ Array PlaceholderOpNode::output_shape(size_t i) const { return shape; } -Operation PlaceholderOpNode::make(std::string name, - Array shape, - DataType dtype) { +Operation PlaceholderOpNode::make(std::string name, Array shape, DataType dtype) { auto n = make_object(); n->name = name; n->shape = shape; @@ -69,44 +63,35 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { } TVM_REGISTER_GLOBAL("te.Placeholder") -.set_body_typed([](Array shape, DataType dtype, std::string name) { - return placeholder(shape, dtype, name); -}); + .set_body_typed([](Array shape, DataType dtype, std::string name) { + return placeholder(shape, dtype, name); + }); -Array PlaceholderOpNode::InputTensors() const { - return {}; -} +Array PlaceholderOpNode::InputTensors() const { return {}; } -Operation PlaceholderOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation PlaceholderOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { return self; } void PlaceholderOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, + const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { -} + std::unordered_map* out_dom_map) const {} -void PlaceholderOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { -} +void PlaceholderOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const {} -Stmt PlaceholderOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { +Stmt PlaceholderOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { return body; } -Stmt PlaceholderOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt PlaceholderOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { return Stmt(); } } // namespace te diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 4992928..582e290 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -24,23 +24,22 @@ #include #include #include -#include "op_util.h" + #include "../schedule/graph.h" +#include "op_util.h" namespace tvm { namespace te { using namespace tir; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "scan(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "scan(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(ScanOpNode); -int ScanOpNode::num_outputs() const { - return static_cast(update.size()); -} +int ScanOpNode::num_outputs() const { return static_cast(update.size()); } Array ScanOpNode::root_iter_vars() const { Array ret{scan_axis}; for (IterVar iv : spatial_axis_) { @@ -49,23 +48,16 @@ Array ScanOpNode::root_iter_vars() const { return ret; } -DataType ScanOpNode::output_dtype(size_t i) const { - return update[i]->dtype; -} +DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; } Array ScanOpNode::output_shape(size_t i) const { CHECK_LT(i, state_placeholder.size()); return state_placeholder[i]->shape; } -Operation ScanOpNode::make(std::string name, - std::string tag, - Map attrs, - IterVar axis, - Array init, - Array update, - Array state_placeholder, - Array inputs) { +Operation ScanOpNode::make(std::string name, std::string tag, Map attrs, + IterVar axis, Array init, Array update, + Array state_placeholder, Array inputs) { if (!attrs.defined()) { attrs = Map(); } @@ -82,31 +74,26 @@ Operation ScanOpNode::make(std::string name, CHECK_EQ(init[i]->dtype, update[i]->dtype); CHECK(prove_equal(init[i]->shape[0], axis->dom->min)) << "init.shape[0] need to match scan_axis.dom.min"; - CHECK(prove_equal( - state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) + CHECK(prove_equal(state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) << "state_placeholder.shape[0] need to match" << " scan_axis.dom.min + scan_axis.dom.extent"; CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim()) << "The dimension of init need to match state_placeholder"; CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim()) << "The update.ndim need to be state_placeholder.ndim - 1"; - for (size_t k = 0; k < update[i].ndim(); ++k) { - CHECK(prove_equal( - update[i]->shape[k], state_placeholder[i]->shape[k])); + for (size_t k = 0; k < update[i].ndim(); ++k) { + CHECK(prove_equal(update[i]->shape[k], state_placeholder[i]->shape[k])); if (k != 0) { // setup spatial axis std::ostringstream spatial_name; spatial_name << name << ".out" << i << ".i" << k; - n->spatial_axis_.push_back( - IterVarNode::make( - Range::make_by_min_extent(0, update[i]->shape[k]), - Var(spatial_name.str()), kOpaque)); + n->spatial_axis_.push_back(IterVarNode::make( + Range::make_by_min_extent(0, update[i]->shape[k]), Var(spatial_name.str()), kOpaque)); } } - for (size_t k = 1; k < init[i].ndim(); ++k) { - CHECK(prove_equal( - init[i]->shape[k], state_placeholder[i]->shape[k])); + for (size_t k = 1; k < init[i].ndim(); ++k) { + CHECK(prove_equal(init[i]->shape[k], state_placeholder[i]->shape[k])); } } n->name = std::move(name); @@ -120,25 +107,16 @@ Operation ScanOpNode::make(std::string name, return Operation(n); } -TVM_REGISTER_GLOBAL("te.ScanOp") -.set_body_typed(ScanOpNode::make); - +TVM_REGISTER_GLOBAL("te.ScanOp").set_body_typed(ScanOpNode::make); -Array scan(Array init, - Array update, - Array state_placeholder, - Array inputs, - std::string name, - std::string tag, +Array scan(Array init, Array update, Array state_placeholder, + Array inputs, std::string name, std::string tag, Map attrs) { - IterVar scan_axis = - IterVarNode::make( - Range::make_by_min_extent( - init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), - Var(name + ".idx"), kOrdered); - Operation op = ScanOpNode::make( - name, tag, attrs, scan_axis, - init, update, state_placeholder, inputs); + IterVar scan_axis = IterVarNode::make( + Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), + Var(name + ".idx"), kOrdered); + Operation op = + ScanOpNode::make(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); Array res; for (int i = 0; i < op->num_outputs(); ++i) { res.push_back(op.output(i)); @@ -157,9 +135,8 @@ Array ScanOpNode::InputTensors() const { return ret; } -Operation ScanOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation ScanOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); for (size_t i = 0; i < n->init.size(); ++i) { @@ -170,19 +147,16 @@ Operation ScanOpNode::ReplaceInputs( n->update.Set(i, rmap.at(n->update[i])); } } - if (!n->init.same_as(init) || - !n->update.same_as(update)) { + if (!n->init.same_as(init) || !n->update.same_as(update)) { return Operation(n); } else { return self; } } -void ScanOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { +void ScanOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) { TensorDom* init_dom = nullptr; @@ -195,8 +169,8 @@ void ScanOpNode::PropBoundToInputs( } // first dimension, always needed. if (init_dom) { - init_dom->data[0].push_back(IntSet::range( - Range::make_by_min_extent(0, this->init[i]->shape[0]))); + init_dom->data[0].push_back( + IntSet::range(Range::make_by_min_extent(0, this->init[i]->shape[0]))); } if (update_dom) { update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get())); @@ -214,10 +188,9 @@ void ScanOpNode::PropBoundToInputs( } } -void ScanOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { +void ScanOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); CHECK(!out_dom_map->count(this->scan_axis)); std::vector output(this->num_outputs()); @@ -234,8 +207,8 @@ void ScanOpNode::GatherBound( arith::Analyzer analyzer; Range sdom = this->scan_axis->dom; Range r = arith::Union(time_dom).cover_range(sdom); - (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent( - sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min)); + (*out_dom_map)[this->scan_axis] = + Range::make_by_min_extent(sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min)); Map fix_pt = ScanFixPointAnalysis(self); // Update for spatial axis. size_t sp_idx = 0; @@ -256,15 +229,12 @@ void ScanOpNode::GatherBound( } } -Stmt ScanOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& dom_map, - const Stmt& body) const { +Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, + const Stmt& body) const { arith::Analyzer analyzer; CHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); - Range tdom = Range::make_by_min_extent( - 0, analyzer.Simplify(sdom->extent + sdom->min)); + Range tdom = Range::make_by_min_extent(0, analyzer.Simplify(sdom->extent + sdom->min)); Stmt ret = body; size_t sp_idx = 0; for (size_t i = 0; i < update.size(); ++i) { @@ -276,25 +246,19 @@ Stmt ScanOpNode::BuildRealize( IterVar sp_ax = this->spatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::RealizeNode::make(t->op, t->value_index, t->dtype, - bounds, const_true(), ret); + ret = tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), ret); } return ret; } -Stmt ScanOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt ScanOpNode::BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt provide = AttrStmtNode::make( - stage->op, tir::attr::scan_update_scope, this->scan_axis->var, - EvaluateNode::make(0)); - Stmt init = AttrStmtNode::make( - stage->op, tir::attr::scan_init_scope, 0, - EvaluateNode::make(0)); + Stmt provide = AttrStmtNode::make(stage->op, tir::attr::scan_update_scope, this->scan_axis->var, + EvaluateNode::make(0)); + Stmt init = AttrStmtNode::make(stage->op, tir::attr::scan_init_scope, 0, EvaluateNode::make(0)); size_t begin_scan = 0; - for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) { CHECK_EQ(begin_scan, i); begin_scan = i + 1; @@ -302,12 +266,9 @@ Stmt ScanOpNode::BuildProvide( } std::unordered_map vmap; std::unordered_set empty; - auto nest = MakeLoopNest( - stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop); + auto nest = MakeLoopNest(stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop); nest[begin_scan].push_back(init); - nest.push_back( - MakeIfNest( - MakeBoundCheck(stage, dom_map, vmap, false, empty))); + nest.push_back(MakeIfNest(MakeBoundCheck(stage, dom_map, vmap, false, empty))); return MergeNest(nest, provide); } } // namespace te diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index f714691..236aff6 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -21,26 +21,27 @@ * \brief Tensor Compute Op. * \file tensor_compute_op.cc */ +#include #include #include -#include #include #include + #include -#include "./op_util.h" -#include "./compute_op.h" #include "../../arith/compute_expr.h" +#include "./compute_op.h" +#include "./op_util.h" namespace tvm { namespace te { using namespace tir; // TensorComputeOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(TensorComputeOpNode); @@ -52,15 +53,10 @@ DataType TensorComputeOpNode::output_dtype(size_t i) const { return this->intrin->buffers[this->inputs.size() + i]->dtype; } -Operation TensorComputeOpNode::make(std::string name, - std::string tag, - Array axis, - Array reduce_axis, - int schedulable_ndim, - TensorIntrin intrin, - Array tensors, - Array regions, - Array scalar_inputs) { +Operation TensorComputeOpNode::make(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, + TensorIntrin intrin, Array tensors, + Array regions, Array scalar_inputs) { auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); @@ -74,17 +70,12 @@ Operation TensorComputeOpNode::make(std::string name, return Operation(n); } -TVM_REGISTER_GLOBAL("te.TensorComputeOp") -.set_body_typed(TensorComputeOpNode::make); +TVM_REGISTER_GLOBAL("te.TensorComputeOp").set_body_typed(TensorComputeOpNode::make); +Array TensorComputeOpNode::InputTensors() const { return inputs; } -Array TensorComputeOpNode::InputTensors() const { - return inputs; -} - -Operation TensorComputeOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation TensorComputeOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); auto intrin = make_object(*(this->intrin.operator->())); @@ -104,8 +95,7 @@ Operation TensorComputeOpNode::ReplaceInputs( if (intrin->body.same_as(n->intrin->body) && intrin->reduce_init.same_as(n->intrin->reduce_init) && - intrin->reduce_update.same_as(n->intrin->reduce_update) && - inputs.same_as(n->inputs)) { + intrin->reduce_update.same_as(n->intrin->reduce_update) && inputs.same_as(n->inputs)) { return self; } else { n->intrin = TensorIntrin(intrin); @@ -114,8 +104,7 @@ Operation TensorComputeOpNode::ReplaceInputs( } void TensorComputeOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, + const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { for (size_t i = 0; i < this->inputs.size(); ++i) { @@ -131,14 +120,11 @@ void TensorComputeOpNode::PropBoundToInputs( } } -size_t TensorComputeOpNode::num_schedulable_dims() const { - return schedulable_ndim; -} +size_t TensorComputeOpNode::num_schedulable_dims() const { return schedulable_ndim; } -Stmt TensorComputeOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); // Start bind data. @@ -161,9 +147,8 @@ Stmt TensorComputeOpNode::BuildProvide( } input_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + nop)); } // output binding @@ -187,9 +172,8 @@ Stmt TensorComputeOpNode::BuildProvide( output_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + nop)); } // Check variable remap @@ -213,8 +197,7 @@ Stmt TensorComputeOpNode::BuildProvide( ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop); if (this->reduce_axis.size() == 0) { - std::vector > nest( - n.main_nest.begin(), n.main_nest.begin() + tloc + 1); + std::vector > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); nest.emplace_back(MakeIfNest(n.main_predicates)); CHECK_EQ(n.init_predicates.size(), 0U); CHECK(this->intrin->body.defined()) @@ -224,24 +207,23 @@ Stmt TensorComputeOpNode::BuildProvide( body = tir::Substitute(body, vmap); body = MergeNest(binder.asserts(), body); body = te::Substitute(body, n.main_vmap); - Stmt ret = MergeNest(nest, body); + Stmt ret = MergeNest(nest, body); return ret; } else { // Need to split reduction - CHECK(this->intrin->reduce_update.defined()) - << "Reduction update op is not defined"; + CHECK(this->intrin->reduce_update.defined()) << "Reduction update op is not defined"; // Need init and update steps CHECK_NE(this->reduce_axis.size(), 0U); - std::vector > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector > update_nest( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); + std::vector > common(n.main_nest.begin(), + n.main_nest.begin() + n.num_common_loop + 1); + std::vector > update_nest(n.main_nest.begin() + n.num_common_loop + 1, + n.main_nest.begin() + tloc + 1); update_nest.emplace_back(MakeIfNest(n.main_predicates)); if (this->intrin->reduce_init.defined()) { // init nest - std::vector > init_nest( - n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + std::vector > init_nest(n.init_nest.begin(), + n.init_nest.begin() + tloc + 1); init_nest.emplace_back(MakeIfNest(n.init_predicates)); Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init); init = te::Substitute(init, n.init_vmap); @@ -256,11 +238,9 @@ Stmt TensorComputeOpNode::BuildProvide( return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. - CHECK(this->intrin->body.defined()) - << "Normal body op is not defined"; - Stmt update = TransformUpdate(stage, dom_map, n, - this->intrin->body, - this->intrin->reduce_update); + CHECK(this->intrin->body.defined()) << "Normal body op is not defined"; + Stmt update = + TransformUpdate(stage, dom_map, n, this->intrin->body, this->intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); update = tir::Substitute(update, vmap); diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 31d4b36..f322e12 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -21,14 +21,14 @@ * \brief Logics related to tensorize, used by ComputeOpNode. * \file tensorize.cc */ +#include +#include #include #include -#include -#include -#include "op_util.h" -#include "compute_op.h" #include "../schedule/message_passing.h" +#include "compute_op.h" +#include "op_util.h" namespace tvm { namespace te { @@ -39,12 +39,10 @@ using namespace tir; // out_dom: the domain of root iter vars in output op // in_region: region of each input tensor. // return The location of the tensorized scope start. -size_t InferTensorizeRegion( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - std::unordered_map* out_dom, - std::unordered_map >* in_region) { +size_t InferTensorizeRegion(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + std::unordered_map* out_dom, + std::unordered_map >* in_region) { // Get the bound of the tensorized scope. bool found_point = false; size_t loc_scope = 0; @@ -52,8 +50,7 @@ size_t InferTensorizeRegion( // Loop over the leafs for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) { IterVar iv = stage->leaf_iter_vars[i - 1]; - CHECK(iv->iter_type == kDataPar || - iv->iter_type == kCommReduce); + CHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce); auto vit = dom_map.find(iv); CHECK(vit != dom_map.end()); const Range& vrange = vit->second; @@ -69,8 +66,7 @@ size_t InferTensorizeRegion( if (iit != stage->iter_var_attrs.end()) { const IterVarAttr& attr = (*iit).second; if (!found_point) { - CHECK(!attr->bind_thread.defined()) - << "Do not allow thread in tensorize scope"; + CHECK(!attr->bind_thread.defined()) << "Do not allow thread in tensorize scope"; } if (attr->iter_type == kTensorized) { CHECK(!found_point) << "Do not allow two tensorized point"; @@ -113,18 +109,15 @@ size_t InferTensorizeRegion( return loc_scope; } -void VerifyTensorizeLoopNest(const ComputeOpNode* self, - const Stage& stage, - const ComputeLoopNest& n, - size_t tloc) { +void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage, + const ComputeLoopNest& n, size_t tloc) { // Veirfication step. std::unordered_set banned; CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1); - CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || - n.init_nest.size() == 0); + CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || n.init_nest.size() == 0); auto f_push_banned = [&banned](const Stmt& s) { if (const ForNode* op = s.as()) { - banned.insert(op->loop_var.get()); + banned.insert(op->loop_var.get()); } else if (const AttrStmtNode* op = s.as()) { if (const IterVarNode* iv = op->node.as()) { banned.insert(iv->var.get()); @@ -144,20 +137,18 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, } } - auto fbanned = [&](const VarNode* node) { - return banned.count(node); - }; + auto fbanned = [&](const VarNode* node) { return banned.count(node); }; for (const PrimExpr& pred : n.main_predicates) { if (tir::ExprUseVar(pred, fbanned)) { - LOG(FATAL) << "Tensorize failed, split condition " - << pred << " relies on var defined inside tensorize scope"; + LOG(FATAL) << "Tensorize failed, split condition " << pred + << " relies on var defined inside tensorize scope"; } } for (const PrimExpr& pred : n.init_predicates) { if (tir::ExprUseVar(pred, fbanned)) { - LOG(FATAL) << "Tensorize failed, split condition " - << pred << " relies on var defined inside tensorize scope"; + LOG(FATAL) << "Tensorize failed, split condition " << pred + << " relies on var defined inside tensorize scope"; } } } @@ -178,9 +169,8 @@ class TensorIntrinMatcher final : public StmtExprMutator { for (size_t i = e.start; i < e.region.size(); ++i) { args.push_back(op->args[i] - e.region[i]->min); } - return CallNode::make( - op->dtype, e.tensor->op->name, args, - op->call_type, e.tensor->op, e.tensor->value_index); + return CallNode::make(op->dtype, e.tensor->op->name, args, op->call_type, e.tensor->op, + e.tensor->value_index); } } return expr; @@ -205,16 +195,13 @@ class TensorIntrinMatcher final : public StmtExprMutator { axis.push_back(it->second); } } - return ReduceNode::make( - op->combiner, op->source, axis, op->condition, op->value_index); + return ReduceNode::make(op->combiner, op->source, axis, op->condition, op->value_index); } - void Init(const ComputeOpNode* self, - const Stage& stage, + void Init(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, const std::unordered_map& out_dom, - const std::unordered_map >& in_region, - const TensorIntrin& intrin, + const std::unordered_map >& in_region, const TensorIntrin& intrin, Map* compute_intrin_iter_space) { CHECK(self == stage->op.get()); @@ -243,8 +230,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { CHECK(is_one(canonical_extent)) << "Tensorize " << intrin->name << ":" << " Input dimension mismatch with tensor intrin " - << " expected shape=" << e.tensor->shape - << ", given region=" << e.region; + << " expected shape=" << e.tensor->shape << ", given region=" << e.region; } in_remap_[inputs[i]] = e; } @@ -257,10 +243,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { size_t axis_start = self->axis.size() - intrin_compute->axis.size(); for (size_t i = 0; i < axis_start; ++i) { Range r = out_dom.at(self->axis[i]); - CHECK(is_one(r->extent)) - << "Tensorize: Output mismatch with tensor intrin " - << " intrin-dim=" << intrin_compute->axis.size() - << ", tensorize-dim=" << self->axis.size(); + CHECK(is_one(r->extent)) << "Tensorize: Output mismatch with tensor intrin " + << " intrin-dim=" << intrin_compute->axis.size() + << ", tensorize-dim=" << self->axis.size(); var_remap_[self->axis[i]->var.get()] = r->min; } // Assume we tensorize at regin axis i [min, min + extent) @@ -280,10 +265,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size(); for (size_t i = 0; i < axis_start; ++i) { Range r = out_dom.at(self->reduce_axis[i]); - CHECK(is_one(r->extent)) - << "Tensorize: Reduction mismatch with tensor intrin " - << " intrin-dim=" << intrin_compute->reduce_axis.size() - << ", tensorize-dim=" << self->reduce_axis.size(); + CHECK(is_one(r->extent)) << "Tensorize: Reduction mismatch with tensor intrin " + << " intrin-dim=" << intrin_compute->reduce_axis.size() + << ", tensorize-dim=" << self->reduce_axis.size(); var_remap_[self->reduce_axis[i]->var.get()] = r->min; } for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) { @@ -314,14 +298,12 @@ class TensorIntrinMatcher final : public StmtExprMutator { }; // Try to match tensor dataflow of the stage with the intrinsic -Array MatchTensorizeBody( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - const std::unordered_map& out_dom, - const std::unordered_map >& in_region, - const TensorIntrin& intrin, - Map* compute_intrin_iter_space) { +Array MatchTensorizeBody(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + const std::unordered_map& out_dom, + const std::unordered_map >& in_region, + const TensorIntrin& intrin, + Map* compute_intrin_iter_space) { TensorIntrinMatcher matcher; matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space); Array ret; @@ -331,21 +313,18 @@ Array MatchTensorizeBody( return ret; } -void VerifyTensorizeBody( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - const std::unordered_map& out_dom, - const std::unordered_map >& in_region, - const TensorIntrin& intrin) { +void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + const std::unordered_map& out_dom, + const std::unordered_map >& in_region, + const TensorIntrin& intrin) { StructuralEqual expr_equal; Map compute_intrin_iter_space; Array body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin, - &compute_intrin_iter_space); + &compute_intrin_iter_space); const ComputeOpNode* intrin_compute = intrin->op.as(); CHECK(intrin_compute) << "Only support compute intrinsic for now"; - CHECK_EQ(body.size(), intrin_compute->body.size()) - << "Tensorize failed: body size mismatch"; + CHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch"; arith::Analyzer ana; ana.Bind(compute_intrin_iter_space); @@ -353,29 +332,23 @@ void VerifyTensorizeBody( PrimExpr lhs = ana.Simplify(body[i]); PrimExpr rhs = ana.Simplify(intrin_compute->body[i]); if (lhs.dtype() != rhs.dtype()) { - LOG(FATAL) - << "Failed to match the data type with TensorIntrin " - << intrin->name << "'s declaration " - << " provided=" << lhs.dtype() - << ", intrin=" << rhs.dtype(); + LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name + << "'s declaration " + << " provided=" << lhs.dtype() << ", intrin=" << rhs.dtype(); } - CHECK(expr_equal(lhs, rhs)) - << "Failed to match the compute with TensorIntrin " - << intrin->name << "'s declaration " - << " provided= " << lhs - << ", intrin= " << rhs; + CHECK(expr_equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " << intrin->name + << "'s declaration " + << " provided= " << lhs << ", intrin= " << rhs; } } -Stmt MakeTensorize(const ComputeOpNode* self, - const Stage& stage, +Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { std::unordered_map out_dom; std::unordered_map > in_region; size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); - TensorIntrin intrin = stage->iter_var_attrs.at( - stage->leaf_iter_vars[tloc])->tensor_intrin; + TensorIntrin intrin = stage->iter_var_attrs.at(stage->leaf_iter_vars[tloc])->tensor_intrin; CHECK(intrin.defined()); ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); VerifyTensorizeLoopNest(self, stage, n, tloc); @@ -384,8 +357,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, Stmt nop = EvaluateNode::make(0); std::vector input_bind_nest, output_bind_nest; Array inputs = self->InputTensors(); - CHECK_EQ(inputs.size(), intrin->inputs.size()) - << "Tensorize failed: input size mismatch "; + CHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch "; // input binding for (size_t i = 0; i < intrin->inputs.size(); ++i) { Tensor tensor = inputs[i]; @@ -401,9 +373,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, } input_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + nop)); } // output binding const ComputeOpNode* intrin_compute = intrin->op.as(); @@ -423,9 +394,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + nop)); } // Check variable remap std::unordered_map vmap; @@ -437,8 +407,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, IterVar iv = self->reduce_axis[i]; auto it = out_dom.find(iv); CHECK(it != out_dom.end()); - CHECK(is_one(it->second->extent)) - << "Tensorization fail: reduction axis size do not match"; + CHECK(is_one(it->second->extent)) << "Tensorization fail: reduction axis size do not match"; } for (size_t i = start; i < self->reduce_axis.size(); ++i) { IterVar iv = self->reduce_axis[i]; @@ -447,17 +416,14 @@ Stmt MakeTensorize(const ComputeOpNode* self, CHECK(it != out_dom.end()); binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0), "tensir_intrin.reduction.min"); - binder.Bind(target->dom->extent, it->second->extent, - "tensir_intrin.reduction.extent"); + binder.Bind(target->dom->extent, it->second->extent, "tensir_intrin.reduction.extent"); } if (tloc <= n.num_common_loop) { // Do no need to split reduction - std::vector > nest( - n.main_nest.begin(), n.main_nest.begin() + tloc + 1); + std::vector > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); nest.emplace_back(MakeIfNest(n.main_predicates)); CHECK_EQ(n.init_predicates.size(), 0U); - CHECK(intrin->body.defined()) - << "Normal store op for intrin " << intrin << " is not defined"; + CHECK(intrin->body.defined()) << "Normal store op for intrin " << intrin << " is not defined"; Stmt body = MergeNest(output_bind_nest, intrin->body); body = MergeNest(input_bind_nest, body); body = tir::Substitute(body, vmap); @@ -470,16 +436,16 @@ Stmt MakeTensorize(const ComputeOpNode* self, << "Reduction update op for intrin " << intrin << " is not defined"; // Need init and update steps CHECK_NE(self->reduce_axis.size(), 0U); - std::vector > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector > update_nest( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); + std::vector > common(n.main_nest.begin(), + n.main_nest.begin() + n.num_common_loop + 1); + std::vector > update_nest(n.main_nest.begin() + n.num_common_loop + 1, + n.main_nest.begin() + tloc + 1); update_nest.emplace_back(MakeIfNest(n.main_predicates)); if (intrin->reduce_init.defined()) { // init nest - std::vector > init_nest( - n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + std::vector > init_nest(n.init_nest.begin(), + n.init_nest.begin() + tloc + 1); init_nest.emplace_back(MakeIfNest(n.init_predicates)); Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); init = te::Substitute(init, n.init_vmap); @@ -494,11 +460,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. - CHECK(intrin->body.defined()) - << "Normal body op for intrin " << intrin << " is not defined"; - Stmt update = TransformUpdate(stage, dom_map, n, - intrin->body, - intrin->reduce_update); + CHECK(intrin->body.defined()) << "Normal body op for intrin " << intrin << " is not defined"; + Stmt update = TransformUpdate(stage, dom_map, n, intrin->body, intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); update = tir::Substitute(update, vmap); @@ -511,36 +474,26 @@ Stmt MakeTensorize(const ComputeOpNode* self, } // Register functions for unittests -TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Stage stage = args[0]; - Map dmap = args[1]; - std::unordered_map out_dom; - std::unordered_map > in_region; - CHECK(stage->op.as()); - InferTensorizeRegion(stage->op.as(), - stage, - as_unordered_map(dmap), - &out_dom, &in_region); - *ret = Array{Map(out_dom), - Map >(in_region)}; - }); +TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion").set_body([](TVMArgs args, TVMRetValue* ret) { + Stage stage = args[0]; + Map dmap = args[1]; + std::unordered_map out_dom; + std::unordered_map > in_region; + CHECK(stage->op.as()); + InferTensorizeRegion(stage->op.as(), stage, as_unordered_map(dmap), &out_dom, + &in_region); + *ret = Array{Map(out_dom), Map >(in_region)}; +}); -TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Stage stage = args[0]; - Map out_dom = args[1]; - Map > in_region = args[2]; - TensorIntrin intrin = args[3]; - Map vrange; - CHECK(stage->op.as()); - *ret = MatchTensorizeBody(stage->op.as(), - stage, - {{}}, - as_unordered_map(out_dom), - as_unordered_map(in_region), - intrin, - &vrange); - }); +TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody").set_body([](TVMArgs args, TVMRetValue* ret) { + Stage stage = args[0]; + Map out_dom = args[1]; + Map > in_region = args[2]; + TensorIntrin intrin = args[3]; + Map vrange; + CHECK(stage->op.as()); + *ret = MatchTensorizeBody(stage->op.as(), stage, {{}}, as_unordered_map(out_dom), + as_unordered_map(in_region), intrin, &vrange); +}); } // namespace te } // namespace tvm diff --git a/src/te/schedule/auto_inline_elem_wise.cc b/src/te/schedule/auto_inline_elem_wise.cc index 6d79f4a..e2b7215 100644 --- a/src/te/schedule/auto_inline_elem_wise.cc +++ b/src/te/schedule/auto_inline_elem_wise.cc @@ -21,8 +21,8 @@ * \file auto_inline_elem_wise.cc */ #include -#include #include +#include #include namespace tvm { @@ -61,7 +61,6 @@ class ElemWiseDetector : public tir::ExprVisitor { Array axis_; }; - bool IsElemWise(const Operation& op) { if (const ComputeOpNode* compute = op.as()) { ElemWiseDetector v = ElemWiseDetector(compute->axis); @@ -112,12 +111,9 @@ void AutoInlineInjective(Schedule sch) { } } -TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise") -.set_body_typed(AutoInlineElemWise); - +TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise").set_body_typed(AutoInlineElemWise); -TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective") -.set_body_typed(AutoInlineInjective); +TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective").set_body_typed(AutoInlineInjective); } // namespace te } // namespace tvm diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 552d7b7..01d4f93 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -22,13 +22,15 @@ * \brief The bound inference logic. */ #include -#include #include +#include + #include #include + +#include "../../runtime/thread_storage_scope.h" #include "graph.h" #include "message_passing.h" -#include "../../runtime/thread_storage_scope.h" namespace tvm { namespace te { @@ -49,13 +51,11 @@ struct GraphContext { std::unordered_map op2stage_; }; -bool NeedRelax(const IterVar& iv, - bool found_attach, +bool NeedRelax(const IterVar& iv, bool found_attach, const std::unordered_map& bind_map, const runtime::StorageScope& scope) { auto it = bind_map.find(iv); - const std::string& tag = ( - it != bind_map.end() ? it->second->thread_tag : iv->thread_tag); + const std::string& tag = (it != bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag.length() == 0 || tag == "pipeline") { return !found_attach; } @@ -63,25 +63,21 @@ bool NeedRelax(const IterVar& iv, // When there is warp memory // threadIdx.x must be set to be warp index. - if (scope.rank == StorageRank::kWarp && - ts.rank == 1 && - ts.dim_index == 0) { + if (scope.rank == StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) { return true; } return static_cast(scope.rank) <= ts.rank; } // infer storage scope, if not given -StorageScope InferStorageScope( - const Stage& stage, const GraphContext& ctx) { +StorageScope InferStorageScope(const Stage& stage, const GraphContext& ctx) { if (stage->scope.length() != 0) { return StorageScope::make(stage->scope); } int max_rank = -1; for (IterVar iv : ctx.attach_path.at(stage->op)) { auto it = ctx.bind_map.find(iv); - const std::string& tag = ( - it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); + const std::string& tag = (it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag != "pipeline" && tag.length() != 0) { max_rank = std::max(max_rank, ThreadScope::make(tag).rank); } @@ -91,20 +87,16 @@ StorageScope InferStorageScope( return s; } - -void InferRootBound(const Stage& stage, - const GraphContext& ctx, +void InferRootBound(const Stage& stage, const GraphContext& ctx, std::unordered_map* rmap) { - CHECK_NE(stage->attach_type, kInline) - << "call schedule.normalize before scheduleops"; + CHECK_NE(stage->attach_type, kInline) << "call schedule.normalize before scheduleops"; if (stage->attach_type == kInlinedAlready) return; if (stage->is_output) { // verify correctness. - CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) - << "Output must be attached at root"; + CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) << "Output must be attached at root"; } if (stage->is_output || stage->op.as()) { - for (auto iv : stage->op->root_iter_vars()) { + for (auto iv : stage->op->root_iter_vars()) { CHECK(iv->dom.defined()); CHECK(!rmap->count(iv)); (*rmap)[iv] = iv->dom; @@ -154,9 +146,8 @@ void InferRootBound(const Stage& stage, if (is_one(vrange->extent)) { up_state[iv] = IntSet::single_point(vrange->min); } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) { - CHECK(is_zero(vrange->min)) - << "InferBound requires every leaf iter var's min equals 0, " - << " call schedule.normalize to achieve this. "; + CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " + << " call schedule.normalize to achieve this. "; if (ctx.bind_map.count(iv)) { up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var); } else { @@ -172,9 +163,8 @@ void InferRootBound(const Stage& stage, found_attach = true; } Range vrange = rmap->at(iv); - CHECK(is_zero(vrange->min)) - << "InferBound requires every leaf iter var's min equals 0, " - << "call schedule.normalize to achieve this."; + CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " + << "call schedule.normalize to achieve this."; if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) { relax_set.Set(iv->var, IntSet::range(vrange)); if (ctx.bind_map.count(iv)) { @@ -201,9 +191,9 @@ void InferRootBound(const Stage& stage, r = iv->dom; } if (relax_set.size() != 0) { - dom_map[iv->var.get()] = IntSet::interval( - analyzer.int_set(r->min, relax_set).min(), - analyzer.int_set(r->min + r->extent - 1, relax_set).max()); + dom_map[iv->var.get()] = + IntSet::interval(analyzer.int_set(r->min, relax_set).min(), + analyzer.int_set(r->min + r->extent - 1, relax_set).max()); } else { dom_map[iv->var.get()] = IntSet::range(r); } @@ -257,15 +247,13 @@ Map InferBound(const Schedule& sch) { } } for (auto& p : ret) { - ret[p.first] = Range::make_by_min_extent( - analyzer.Simplify(p.second->min), - analyzer.Simplify(p.second->extent)); + ret[p.first] = Range::make_by_min_extent(analyzer.Simplify(p.second->min), + analyzer.Simplify(p.second->extent)); } return Map(ret.begin(), ret.end()); } -TVM_REGISTER_GLOBAL("schedule.InferBound") -.set_body_typed(InferBound); +TVM_REGISTER_GLOBAL("schedule.InferBound").set_body_typed(InferBound); } // namespace te } // namespace tvm diff --git a/src/te/schedule/graph.cc b/src/te/schedule/graph.cc index 9dce36f..6414822 100644 --- a/src/te/schedule/graph.cc +++ b/src/te/schedule/graph.cc @@ -21,14 +21,16 @@ * \file graph.cc * \brief Utilities to get information about schedule graph. */ +#include "graph.h" + #include +#include #include #include -#include -#include -#include + #include -#include "graph.h" +#include +#include namespace tvm { namespace te { @@ -39,22 +41,14 @@ struct TensorDimKey { int dim; TensorDimKey() {} TensorDimKey(const tir::CallNode* op, int dim) - : f(op->func), value_index(op->value_index), dim(dim) { - } - TensorDimKey(const Tensor& t, int dim) - : f(t->op), value_index(t->value_index), dim(dim) { - } + : f(op->func), value_index(op->value_index), dim(dim) {} + TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {} TensorDimKey(const Tensor& t, size_t dim) - : f(t->op), value_index(t->value_index), dim(static_cast(dim)) { - } + : f(t->op), value_index(t->value_index), dim(static_cast(dim)) {} inline bool operator==(const TensorDimKey& other) const { - return f == other.f && - value_index == other.value_index && - dim == other.dim; - } - inline bool operator!=(const TensorDimKey& other) const { - return !operator==(other); + return f == other.f && value_index == other.value_index && dim == other.dim; } + inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); } }; } // namespace te } // namespace tvm @@ -64,15 +58,13 @@ template <> struct hash<::tvm::te::TensorDimKey> { std::size_t operator()(const ::tvm::te::TensorDimKey& k) const { size_t lhs = ::tvm::ObjectHash()(k.f); - size_t rhs = static_cast(k.value_index) << 16UL | - static_cast(k.dim); + size_t rhs = static_cast(k.value_index) << 16UL | static_cast(k.dim); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; } }; } // namespace std - namespace tvm { namespace te { @@ -105,12 +97,9 @@ ReadGraph CreateReadGraph(const Array& roots) { // Do DFS visit to get the subgraph. // Return if op is inside the subgraph. -bool GetSubGraphByPostDFS_( - const Operation& op, - const std::unordered_set& boundary, - bool include_bounary, - std::unordered_map* visited, - Array* result) { +bool GetSubGraphByPostDFS_(const Operation& op, const std::unordered_set& boundary, + bool include_bounary, std::unordered_map* visited, + Array* result) { if (visited->count(op.get())) { return visited->at(op.get()); } @@ -127,9 +116,7 @@ bool GetSubGraphByPostDFS_( // check if we can reach boundary. bool reach_boundary = false; for (Tensor t : op->InputTensors()) { - if (GetSubGraphByPostDFS_(t->op, boundary, - include_bounary, - visited, result)) { + if (GetSubGraphByPostDFS_(t->op, boundary, include_bounary, visited, result)) { reach_boundary = true; } } @@ -140,8 +127,7 @@ bool GetSubGraphByPostDFS_( return reach_boundary; } -Array GetSubGraph(const Array& outputs, - const Array& inputs, +Array GetSubGraph(const Array& outputs, const Array& inputs, bool include_inputs) { Array result; std::unordered_set boundary; @@ -150,16 +136,12 @@ Array GetSubGraph(const Array& outputs, } std::unordered_map visited; for (Tensor t : outputs) { - GetSubGraphByPostDFS_(t->op, boundary, include_inputs, - &visited, &result); + GetSubGraphByPostDFS_(t->op, boundary, include_inputs, &visited, &result); } return result; } - -void PostDFSOrder(const Operation& op, - const ReadGraph& g, - std::unordered_set* visited, +void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set* visited, Array* post_order) { if (visited->count(op)) return; visited->insert(op); @@ -169,9 +151,7 @@ void PostDFSOrder(const Operation& op, post_order->push_back(op); } -Array PostDFSOrder( - const Array& roots, - const ReadGraph& g) { +Array PostDFSOrder(const Array& roots, const ReadGraph& g) { std::unordered_set visited; Array post_order; for (Operation op : roots) { @@ -196,8 +176,7 @@ AttachPath CreateAttachPath(Schedule sch) { std::unordered_set visited; Array path; for (Stage s = stage; s.defined();) { - CHECK(!visited.count(s.get())) - << "Find loop in compute_at attach group"; + CHECK(!visited.count(s.get())) << "Find loop in compute_at attach group"; visited.insert(s.get()); Stage spec = s.GetAttachSpec(); bool start_attach; @@ -221,9 +200,8 @@ AttachPath CreateAttachPath(Schedule sch) { } if (start_attach) path.push_back(iv); } - CHECK(start_attach) - << "Invalid Schedule: cannot find attach point " << attach_ivar - << " in the schedule of " << s->op; + CHECK(start_attach) << "Invalid Schedule: cannot find attach point " << attach_ivar + << " in the schedule of " << s->op; } if (!ret.count(stage->op)) { ret.Set(stage->op, path); @@ -233,7 +211,7 @@ AttachPath CreateAttachPath(Schedule sch) { } // graph of push reach relation of tensor dimensions -using ReachGraph = std::unordered_map >; +using ReachGraph = std::unordered_map>; ReachGraph GetReachGraph(const Array& ops) { ReachGraph reach; @@ -249,10 +227,8 @@ ReachGraph GetReachGraph(const Array& ops) { for (size_t i = 0; i < update.size(); ++i) { Tensor t = op.output(i); for (int k = 1; k < static_cast(update[i]->shape.size()); ++k) { - reach[TensorDimKey(t, k)].emplace_back( - TensorDimKey(update[i], k)); - reach[TensorDimKey(t, k)].emplace_back( - TensorDimKey(init[i], k)); + reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(update[i], k)); + reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(init[i], k)); } } } else if (const auto* compute_op = op.as()) { @@ -264,13 +240,13 @@ ReachGraph GetReachGraph(const Array& ops) { reach[TensorDimKey(t, i)] = {}; } auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) { - const tir::CallNode *call = n.as(); + const tir::CallNode* call = n.as(); if (call != nullptr && call->func.defined()) { if (!bset.count(call->func.get())) return; for (size_t i = 0; i < call->args.size(); ++i) { TensorDimKey dkey(call, static_cast(i)); auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) { - const VarNode *v = node.as(); + const VarNode* v = node.as(); auto it = vmap.find(v); if (it != vmap.end()) { reach[it->second].push_back(dkey); @@ -315,8 +291,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } // merge exact reach - auto f_merge_key = [&exact_reach, &fail_set]( - const TensorDimKey& dst, const TensorDimKey& src) { + auto f_merge_key = [&exact_reach, &fail_set](const TensorDimKey& dst, const TensorDimKey& src) { auto sit = exact_reach.find(src); if (sit == exact_reach.end()) return; auto dit = exact_reach.find(dst); @@ -343,7 +318,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } } else if (const auto* compute_op = op.as()) { - std::unordered_map > vmap; + std::unordered_map> vmap; const auto& axis = compute_op->axis; for (size_t i = 0; i < axis.size(); ++i) { std::vector keys; @@ -352,9 +327,8 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } vmap[axis[i]->var.get()] = std::move(keys); } - auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( - const ObjectRef& n) { - const tir::CallNode *call = n.as(); + auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](const ObjectRef& n) { + const tir::CallNode* call = n.as(); if (call != nullptr && call->func.defined()) { for (size_t i = 0; i < call->args.size(); ++i) { auto it = vmap.find(call->args[i].get()); @@ -391,8 +365,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { TensorDimKey key(scan->update[i], k); TensorDimKey target(scan->state_placeholder[i], k); IterVar sp_iv = scan->spatial_axis_[sp_idx]; - if (fail_set.count(sp_iv.get()) || - !exact_reach.count(key) || + if (fail_set.count(sp_iv.get()) || !exact_reach.count(key) || exact_reach.at(key) != sp_iv.get()) { ret.Set(sp_iv, make_const(DataType::Int(32), 0)); } else { @@ -430,24 +403,18 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { return ret; } - -TVM_REGISTER_GLOBAL("schedule.CreateReadGraph") -.set_body_typed(CreateReadGraph); +TVM_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph); TVM_REGISTER_GLOBAL("schedule.PostDFSOrder") -.set_body_typed([](const Array& roots, - const ReadGraph& g) { - return PostDFSOrder(roots, g); -}); + .set_body_typed([](const Array& roots, const ReadGraph& g) { + return PostDFSOrder(roots, g); + }); -TVM_REGISTER_GLOBAL("schedule.CreateAttachPath") -.set_body_typed(CreateAttachPath); +TVM_REGISTER_GLOBAL("schedule.CreateAttachPath").set_body_typed(CreateAttachPath); -TVM_REGISTER_GLOBAL("schedule.ScanGetBody") -.set_body_typed(ScanGetBody); +TVM_REGISTER_GLOBAL("schedule.ScanGetBody").set_body_typed(ScanGetBody); -TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis") -.set_body_typed(ScanFixPointAnalysis); +TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis").set_body_typed(ScanFixPointAnalysis); } // namespace te } // namespace tvm diff --git a/src/te/schedule/graph.h b/src/te/schedule/graph.h index c3478c7..bb98ff4 100644 --- a/src/te/schedule/graph.h +++ b/src/te/schedule/graph.h @@ -24,9 +24,10 @@ #ifndef TVM_TE_SCHEDULE_GRAPH_H_ #define TVM_TE_SCHEDULE_GRAPH_H_ -#include -#include #include +#include +#include + #include #include #include @@ -72,8 +73,7 @@ ReadGraph CreateReadGraph(const Array& roots); * * \return The subgraph. */ -Array GetSubGraph(const Array& outputs, - const Array& inputs, +Array GetSubGraph(const Array& outputs, const Array& inputs, bool include_inputs); /*! @@ -85,8 +85,7 @@ Array GetSubGraph(const Array& outputs, * \note PostDFSOrder is a special case of Topoligical order, * and can be used when topoligical order is needed. */ -Array PostDFSOrder( - const Array& roots, const ReadGraph& g); +Array PostDFSOrder(const Array& roots, const ReadGraph& g); /*! * \brief Create feedgraph for given Schedule diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 6ae7464..4f0e982 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -21,9 +21,11 @@ * \file message_passing.cc * \brief The message passing domain. */ +#include "message_passing.h" + #include #include -#include "message_passing.h" + #include "../../arith/compute_expr.h" namespace tvm { @@ -31,22 +33,18 @@ namespace te { using namespace tir; -void Update(std::unordered_map* p_state, - const IterVar& iv, - Range r, +void Update(std::unordered_map* p_state, const IterVar& iv, Range r, arith::Analyzer* analyzer) { auto it = p_state->find(iv); if (it == p_state->end()) { (*p_state)[iv] = r; analyzer->Bind(iv->var, r); } else { - bool match = is_zero(it->second->min) && - analyzer->CanProve(r->extent - it->second->extent == 0); - CHECK(match) - << iv - << " domain already inferred," - << " cannot prove their extents are the same " - << it->second->extent << " vs " << r->extent; + bool match = + is_zero(it->second->min) && analyzer->CanProve(r->extent - it->second->extent == 0); + CHECK(match) << iv << " domain already inferred," + << " cannot prove their extents are the same " << it->second->extent << " vs " + << r->extent; } } @@ -89,10 +87,8 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map* } } -void PassDownDomain(const Stage& stage, - std::unordered_map* p_state, - arith::Analyzer* actx, - bool allow_missing) { +void PassDownDomain(const Stage& stage, std::unordered_map* p_state, + arith::Analyzer* actx, bool allow_missing) { auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(indexmod(a, b) == 0)) { return actx->Simplify(indexdiv(a, b)); @@ -100,7 +96,7 @@ void PassDownDomain(const Stage& stage, return actx->Simplify(indexdiv(a + (b - 1), b)); }; - auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { + auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(a < b)) { return actx->Simplify(a); } @@ -138,20 +134,16 @@ void PassDownDomain(const Stage& stage, }; if (r->factor.defined()) { Update(p_state, r->inner, - Range::make_by_min_extent( - 0, resolve_min_extent_for_split(r->inner, r->factor)), + Range::make_by_min_extent(0, resolve_min_extent_for_split(r->inner, r->factor)), actx); Update(p_state, r->outer, - Range::make_by_min_extent( - 0, ceil_div(range_parent->extent, r->factor)), actx); + Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->factor)), actx); } else { Update(p_state, r->outer, - Range::make_by_min_extent( - 0, resolve_min_extent_for_split(r->outer, r->nparts)), + Range::make_by_min_extent(0, resolve_min_extent_for_split(r->outer, r->nparts)), actx); Update(p_state, r->inner, - Range::make_by_min_extent( - 0, ceil_div(range_parent->extent, r->nparts)), actx); + Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->nparts)), actx); } } else if (const FuseNode* r = rel.as()) { if (!state.count(r->outer) || !state.count(r->inner)) { @@ -160,16 +152,13 @@ void PassDownDomain(const Stage& stage, } const Range& range_outer = state.at(r->outer); const Range& range_inner = state.at(r->inner); - state[r->fused] = Range::make_by_min_extent( - 0, range_outer->extent * range_inner->extent); + state[r->fused] = Range::make_by_min_extent(0, range_outer->extent * range_inner->extent); } else if (const RebaseNode* r = rel.as()) { if (!state.count(r->parent)) { CHECK(allow_missing); continue; } - Update(p_state, r->rebased, - Range::make_by_min_extent( - 0, state.at(r->parent)->extent), actx); + Update(p_state, r->rebased, Range::make_by_min_extent(0, state.at(r->parent)->extent), actx); } else if (const SingletonNode* s = rel.as()) { Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx); } else { @@ -185,10 +174,8 @@ void PassDownDomain(const Stage& stage, } } -void PassUpIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing) { +void PassUpIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { IterVarRelation rel = stage->relations[i - 1]; @@ -244,10 +231,8 @@ void PassUpIndex(const Stage& stage, } } -void PassDownIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing) { +void PassDownIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (IterVarRelation rel : stage->relations) { if (const SplitNode* s = rel.as()) { @@ -292,16 +277,10 @@ void PassDownIndex(const Stage& stage, } // Domain message passing. -void PassUpDomain(const SplitNode* s, - const std::unordered_map& dom_map, - const IntSet& outer, - const IntSet& inner, - IntSet* parent) { - if (dom_map.count(s->outer) && - dom_map.count(s->inner) && - dom_map.count(s->parent) && - outer.match_range(dom_map.at(s->outer)) && - inner.match_range(dom_map.at(s->inner))) { +void PassUpDomain(const SplitNode* s, const std::unordered_map& dom_map, + const IntSet& outer, const IntSet& inner, IntSet* parent) { + if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) && + outer.match_range(dom_map.at(s->outer)) && inner.match_range(dom_map.at(s->inner))) { *parent = IntSet::range(dom_map.at(s->parent)); return; } @@ -310,16 +289,12 @@ void PassUpDomain(const SplitNode* s, CHECK(outer.defined()); CHECK(inner.defined()); CHECK(factor.defined()); - *parent = arith::EvalSet( - s->outer->var * factor + s->inner->var + parent_min, - {{s->outer, outer}, {s->inner, inner}}); + *parent = arith::EvalSet(s->outer->var * factor + s->inner->var + parent_min, + {{s->outer, outer}, {s->inner, inner}}); } -void PassUpDomain(const FuseNode* s, - const std::unordered_map& dom_map, - const IntSet& fused, - IntSet* outer, - IntSet* inner) { +void PassUpDomain(const FuseNode* s, const std::unordered_map& dom_map, + const IntSet& fused, IntSet* outer, IntSet* inner) { CHECK(dom_map.count(s->outer)); CHECK(dom_map.count(s->inner)); CHECK(dom_map.count(s->fused)); @@ -336,8 +311,8 @@ void PassUpDomain(const FuseNode* s, if (fused.is_single_point()) { PrimExpr value = fused.point_value(); PrimExpr factor = dom_map.at(s->inner)->extent; - PrimExpr v_outer = indexdiv(value, factor); - PrimExpr v_inner = indexmod(value, factor); + PrimExpr v_outer = indexdiv(value, factor); + PrimExpr v_inner = indexmod(value, factor); if (!is_zero(outer_min)) v_outer = v_outer + outer_min; if (!is_zero(inner_min)) v_inner = v_inner + inner_min; *outer = IntSet::single_point(v_outer); @@ -345,9 +320,8 @@ void PassUpDomain(const FuseNode* s, } else { PrimExpr fused_extent = (fused.max() - fused.min() + 1); PrimExpr inner_extent = dom_map.at(s->inner)->extent; - *outer = IntSet::interval( - outer_min + indexdiv(fused.min(), inner_extent), - outer_min + indexdiv(fused.max(), inner_extent)); + *outer = IntSet::interval(outer_min + indexdiv(fused.min(), inner_extent), + outer_min + indexdiv(fused.max(), inner_extent)); if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) && is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) { // fused never spans multiple rows, make a tight bounding box @@ -357,8 +331,8 @@ void PassUpDomain(const FuseNode* s, } else { // fused may span multiple rows, use full row widths if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) || !is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) { - LOG(WARNING) << - "fused and original axes are not aligned, this may cause redundant computations"; + LOG(WARNING) + << "fused and original axes are not aligned, this may cause redundant computations"; } *inner = IntSet::range(dom_map.at(s->inner)); } @@ -366,44 +340,34 @@ void PassUpDomain(const FuseNode* s, } } -void PassUpDomain(const RebaseNode* s, - const std::unordered_map& dom_map, - const IntSet& rebased, - IntSet* parent) { +void PassUpDomain(const RebaseNode* s, const std::unordered_map& dom_map, + const IntSet& rebased, IntSet* parent) { CHECK(dom_map.count(s->parent)); if (rebased.match_range(dom_map.at(s->rebased))) { *parent = IntSet::range(dom_map.at(s->parent)); return; } PrimExpr parent_min = dom_map.at(s->parent)->min; - *parent = arith::EvalSet(s->rebased->var + parent_min, - {{s->rebased, rebased}}); + *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}}); } -void PassUpDomain(const Stage& stage, - const std::unordered_map& dom_map, +void PassUpDomain(const Stage& stage, const std::unordered_map& dom_map, std::unordered_map* p_state) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { IterVarRelation rel = stage->relations[i - 1]; if (const SplitNode* r = rel.as()) { IntSet parent; - PassUpDomain(r, dom_map, - state.at(r->outer), state.at(r->inner), - &parent); + PassUpDomain(r, dom_map, state.at(r->outer), state.at(r->inner), &parent); state[r->parent] = parent; } else if (const FuseNode* r = rel.as()) { IntSet outer, inner; - PassUpDomain(r, dom_map, - state.at(r->fused), - &outer, &inner); + PassUpDomain(r, dom_map, state.at(r->fused), &outer, &inner); state[r->outer] = outer; state[r->inner] = inner; } else if (const RebaseNode* r = rel.as()) { IntSet parent; - PassUpDomain(r, dom_map, - state.at(r->rebased), - &parent); + PassUpDomain(r, dom_map, state.at(r->rebased), &parent); state[r->parent] = parent; } else if (rel.as()) { } else { @@ -413,8 +377,7 @@ void PassUpDomain(const Stage& stage, } // Pass up bit mask with or relation. -void PassUpBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassUpBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { @@ -461,8 +424,7 @@ void PassUpBitMaskOr(const Stage& stage, } } -void PassDownBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassDownBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (IterVarRelation rel : stage->relations) { @@ -509,17 +471,14 @@ void PassDownBitMaskOr(const Stage& stage, } } - /*! * \brief message passing to find if boundary checking on IterVar is needed. * \param s The stage to be used. * \param p_state The message passing state * IterVar->flag */ -void PassUpBoundCheck(const Stage& s, - const Map& dom_map, - std::unordered_map* p_state, - arith::Analyzer* analyzer) { +void PassUpBoundCheck(const Stage& s, const Map& dom_map, + std::unordered_map* p_state, arith::Analyzer* analyzer) { auto& state = *p_state; for (size_t i = s->relations.size(); i != 0; --i) { IterVarRelation rel = s->relations[i - 1]; @@ -560,16 +519,14 @@ bool IsRangeSame(const Range input_1, const Range input_2) { arith::Analyzer analyzer; if (input_1.same_as(input_2)) return true; - return (analyzer.CanProve(input_1->min == input_2->min) - && analyzer.CanProve(input_1->extent == input_2->extent)); + return (analyzer.CanProve(input_1->min == input_2->min) && + analyzer.CanProve(input_1->extent == input_2->extent)); } -std::vector MakeBoundCheck( - const Stage& stage, - const Map& dom_map, - const std::unordered_map& value_map, - bool skip_ivar_domain, - const std::unordered_set& skip_iter) { +std::vector MakeBoundCheck(const Stage& stage, const Map& dom_map, + const std::unordered_map& value_map, + bool skip_ivar_domain, + const std::unordered_set& skip_iter) { arith::Analyzer analyzer; std::unordered_map bound_state; diff --git a/src/te/schedule/message_passing.h b/src/te/schedule/message_passing.h index 1877235..c382b90 100644 --- a/src/te/schedule/message_passing.h +++ b/src/te/schedule/message_passing.h @@ -25,10 +25,11 @@ #ifndef TVM_TE_SCHEDULE_MESSAGE_PASSING_H_ #define TVM_TE_SCHEDULE_MESSAGE_PASSING_H_ -#include -#include -#include #include +#include +#include +#include + #include #include #include @@ -45,11 +46,8 @@ namespace te { * \param analyzer Analyzer context, storing information about bounds in p_state. * \param allow_missing Whether allow missing value. */ -void PassDownDomain( - const Stage& stage, - std::unordered_map* p_state, - arith::Analyzer* analyzer, - bool allow_missing = false); +void PassDownDomain(const Stage& stage, std::unordered_map* p_state, + arith::Analyzer* analyzer, bool allow_missing = false); /*! * \param Upward inference of index of each IterVar. @@ -60,10 +58,8 @@ void PassDownDomain( * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassUpIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing = false); +void PassUpIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing = false); /*! * \param Downward inference of index of each IterVar. @@ -74,10 +70,8 @@ void PassUpIndex(const Stage& stage, * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassDownIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing = false); +void PassDownIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing = false); /*! * \param Upward inference of domain set of each IterVar. @@ -87,8 +81,7 @@ void PassDownIndex(const Stage& stage, * \param dom_map The domain map of each iteration variable's maximum domain. * \param p_state The index state of each IterVar. */ -void PassUpDomain(const Stage& stage, - const std::unordered_map& dom_map, +void PassUpDomain(const Stage& stage, const std::unordered_map& dom_map, std::unordered_map* p_state); /*! @@ -97,8 +90,7 @@ void PassUpDomain(const Stage& stage, * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassUpBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassUpBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing = false); /*! @@ -107,8 +99,7 @@ void PassUpBitMaskOr(const Stage& stage, * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassDownBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassDownBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing = false); /*! @@ -120,13 +111,10 @@ void PassDownBitMaskOr(const Stage& stage, * \param skip_iter The set of variables to skip bound condition. * \return List of predicates that we need to check. */ -std::vector -MakeBoundCheck( - const Stage& stage, - const Map& dom_map, - const std::unordered_map& value_map, - bool skip_ivar_domain, - const std::unordered_set& skip_iter); +std::vector MakeBoundCheck(const Stage& stage, const Map& dom_map, + const std::unordered_map& value_map, + bool skip_ivar_domain, + const std::unordered_set& skip_iter); } // namespace te } // namespace tvm diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc index c3f333e..8c8f092 100644 --- a/src/te/schedule/operation_inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -20,14 +20,16 @@ /*! * \file operation_inline.cc */ +#include "operation_inline.h" + +#include #include #include -#include #include + #include -#include "operation_inline.h" -#include "../../tir/transforms/ir_util.h" +#include "../../tir/transforms/ir_util.h" namespace tvm { namespace te { @@ -62,8 +64,7 @@ class OperationInliner final : public StmtExprMutator { for (size_t i = 0; i < args_.size(); ++i) { vmap.Set(args_[i], op->args[i]); } - expr = Substitute( - EvaluateNode::make(expr), vmap).as()->value; + expr = Substitute(EvaluateNode::make(expr), vmap).as()->value; } return expr; } else { @@ -77,12 +78,8 @@ class OperationInliner final : public StmtExprMutator { PrimExpr body_; }; -Stmt Inline(Stmt stmt, - Operation f, - Array args, - PrimExpr body) { - CHECK_EQ(f->num_outputs(), 1) - << "can only inline output single value operation"; +Stmt Inline(Stmt stmt, Operation f, Array args, PrimExpr body) { + CHECK_EQ(f->num_outputs(), 1) << "can only inline output single value operation"; Stmt ret = OperationInliner(f, args, body)(std::move(stmt)); if (ret.same_as(stmt)) return ret; return ConvertSSA(ret); diff --git a/src/te/schedule/operation_inline.h b/src/te/schedule/operation_inline.h index d7d55cc..d475fbe 100644 --- a/src/te/schedule/operation_inline.h +++ b/src/te/schedule/operation_inline.h @@ -22,10 +22,10 @@ #ifndef TVM_TE_SCHEDULE_OPERATION_INLINE_H_ #define TVM_TE_SCHEDULE_OPERATION_INLINE_H_ -#include -#include #include #include +#include +#include namespace tvm { namespace te { @@ -41,10 +41,7 @@ namespace te { * * \note All the passes in this file uses SSA form and outputs SSA form. */ -Stmt Inline(Stmt stmt, - Operation op, - Array args, - PrimExpr body); +Stmt Inline(Stmt stmt, Operation op, Array args, PrimExpr body); } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index f3e76a4..ed28806 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -20,20 +20,21 @@ /*! * \file schedule_dataflow_rewrite.cc */ -#include #include +#include #include + #include -#include "message_passing.h" -#include "operation_inline.h" -#include "../../tir/transforms/ir_util.h" #include "../../arith/compute_expr.h" +#include "../../tir/transforms/ir_util.h" +#include "message_passing.h" +#include "operation_inline.h" namespace tvm { namespace te { // find first occurance location in leaf -template +template size_t FindNodeRef(ArrayNode* array_node, const T& v) { const Object* n = v.get(); for (size_t i = 0; i < array_node->data.size(); ++i) { @@ -45,9 +46,7 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) { // The replacer of cache. class VarReplacer : public tir::StmtExprMutator { public: - explicit VarReplacer( - const std::unordered_map& vsub) - : vsub_(vsub) {} + explicit VarReplacer(const std::unordered_map& vsub) : vsub_(vsub) {} PrimExpr VisitExpr_(const VarNode* op) final { auto it = vsub_.find(op); if (it != vsub_.end()) return it->second; @@ -56,19 +55,16 @@ class VarReplacer : public tir::StmtExprMutator { tir::CommReducer MutateCommReducer(tir::CommReducer combiner) { // Replace free variables in combiner - auto new_identity = tir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) { - return this->VisitExpr(e); - }); - auto new_result = tir::UpdateArray(combiner->result, [this] (const PrimExpr& e) { - return this->VisitExpr(e); - }); + auto new_identity = tir::UpdateArray(combiner->identity_element, + [this](const PrimExpr& e) { return this->VisitExpr(e); }); + auto new_result = tir::UpdateArray(combiner->result, + [this](const PrimExpr& e) { return this->VisitExpr(e); }); if (combiner->identity_element.same_as(new_identity) && combiner->identity_element.same_as(new_result)) { return combiner; } else { - return tir::CommReducerNode::make( - combiner->lhs, combiner->rhs, new_result, new_identity); + return tir::CommReducerNode::make(combiner->lhs, combiner->rhs, new_result, new_identity); } } @@ -79,12 +75,8 @@ class VarReplacer : public tir::StmtExprMutator { if (op->combiner.same_as(new_combiner)) { return new_e; } else { - return tir::ReduceNode::make( - new_combiner, - new_reduce->source, - new_reduce->axis, - new_reduce->condition, - new_reduce->value_index); + return tir::ReduceNode::make(new_combiner, new_reduce->source, new_reduce->axis, + new_reduce->condition, new_reduce->value_index); } } @@ -92,8 +84,7 @@ class VarReplacer : public tir::StmtExprMutator { const std::unordered_map& vsub_; }; -PrimExpr InjectPredicate(const Array& predicates, - PrimExpr body) { +PrimExpr InjectPredicate(const Array& predicates, PrimExpr body) { using tir::ReduceNode; using tir::SelectNode; if (predicates.size() == 0) return body; @@ -103,16 +94,14 @@ PrimExpr InjectPredicate(const Array& predicates, n->condition = n->condition && arith::ComputeReduce(predicates, PrimExpr()); return PrimExpr(n); } - return SelectNode::make(arith::ComputeReduce(predicates, PrimExpr()), - body, - make_zero(body.dtype())); + return SelectNode::make(arith::ComputeReduce(predicates, PrimExpr()), body, + make_zero(body.dtype())); } // Replace data flow appears in all stages given the tensor change. // Also update vmap if subsequent dataflow need to be replaced. // Need to keep an update to the date transitive closure property on the vmap by a reverse map. -void ReplaceDataFlow(const Array& stages, - std::unordered_map* vmap, +void ReplaceDataFlow(const Array& stages, std::unordered_map* vmap, std::unordered_map* rvmap) { for (Stage s : stages) { Operation op = s->op->ReplaceInputs(s->op, *vmap); @@ -132,14 +121,11 @@ void ReplaceDataFlow(const Array& stages, } inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - return (a->combiner.same_as(b->combiner)) && - (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && - (a->condition.same_as(b->condition)); + return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)); } -Tensor Schedule::cache_read(const Tensor& tensor, - const std::string& scope, +Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, const Array& readers) { (*this)->InvalidateCache(); // create identity mapping. @@ -153,9 +139,12 @@ Tensor Schedule::cache_read(const Tensor& tensor, std::unordered_map vsub; Stage s = operator[](tensor->op); Tensor sugar_tensor = s->op.output(tensor->value_index); - Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array& i) { - return sugar_tensor(Array(i.begin(), i.end())); - }, os.str()); + Tensor cache = compute( + sugar_tensor->shape, + [&sugar_tensor](const Array& i) { + return sugar_tensor(Array(i.begin(), i.end())); + }, + os.str()); vsub[sugar_tensor] = cache; std::unordered_map vmap; @@ -163,9 +152,7 @@ Tensor Schedule::cache_read(const Tensor& tensor, for (Operation op : readers) { Stage s = operator[](op); Operation repl_op = s->op->ReplaceInputs(s->op, vsub); - CHECK(!repl_op.same_as(s->op)) - << "Cannot find " << tensor - << " in the inputs of " << s->op; + CHECK(!repl_op.same_as(s->op)) << "Cannot find " << tensor << " in the inputs of " << s->op; vmap[s->op.output(0)] = repl_op.output(0); rvmap[repl_op.output(0)] = s->op.output(0); s->op = repl_op; @@ -177,8 +164,7 @@ Tensor Schedule::cache_read(const Tensor& tensor, Stage cache_stage = Stage(cache->op); cache_stage.set_scope(scope); CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos + 1, - cache_stage); + stages->data.insert(stages->data.begin() + pos + 1, cache_stage); (*this)->stage_map.Set(cache->op, cache_stage); // Update group cache_stage->group = op_stage->group; @@ -188,12 +174,9 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } -template -void PrepareAxisMapping(Stage orig_stage, - OpType* op, - std::unordered_set* p_red_axis, - Array* p_new_axis, - std::unordered_map* p_dom_map, +template +void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_set* p_red_axis, + Array* p_new_axis, std::unordered_map* p_dom_map, std::unordered_map* p_vsub, std::unordered_map* p_vsub2newvar, std::vector* p_predicates) { @@ -218,11 +201,9 @@ void PrepareAxisMapping(Stage orig_stage, std::unordered_map value_map; for (IterVar iv : orig_stage->leaf_iter_vars) { if (red_axis.count(iv)) continue; - CHECK_EQ(iv->iter_type, kDataPar) - << "Can only relayout with in data parallel dimensions"; + CHECK_EQ(iv->iter_type, kDataPar) << "Can only relayout with in data parallel dimensions"; Range dom = dom_map.at(iv); - IterVar new_iv = IterVarNode::make( - dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + IterVar new_iv = IterVarNode::make(dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); if (is_one(dom->min)) { value_map[iv] = dom->min; @@ -237,8 +218,7 @@ void PrepareAxisMapping(Stage orig_stage, skip_bound_check.insert(iv); } PassUpIndex(orig_stage, dom_map, &value_map, true); - predicates = MakeBoundCheck( - orig_stage, dom_map, value_map, true, skip_bound_check); + predicates = MakeBoundCheck(orig_stage, dom_map, value_map, true, skip_bound_check); // The root axis for (IterVar iv : op->axis) { if (value_map.count(iv)) { @@ -248,12 +228,8 @@ void PrepareAxisMapping(Stage orig_stage, } } -Array ReplaceOriginalOp(Schedule sch, - Stage orig_stage, - const std::string& scope, - Operation cache_op, - Operation orig_new_op, - size_t tensor_size) { +Array ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::string& scope, + Operation cache_op, Operation orig_new_op, size_t tensor_size) { Array cache_tensor_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); @@ -280,8 +256,7 @@ Array ReplaceOriginalOp(Schedule sch, Stage cache_stage = Stage(cache_op); cache_stage.set_scope(scope); CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos, - cache_stage); + stages->data.insert(stages->data.begin() + pos, cache_stage); sch->stage_map.Set(cache_op, cache_stage); // Update group cache_stage->group = orig_stage->group; @@ -291,10 +266,8 @@ Array ReplaceOriginalOp(Schedule sch, return cache_tensor_list; } - // Cache write and relayout the data according to loop pattern -Array CacheWriteWithReLayout(Schedule sch, - const Array& tensor_array, +Array CacheWriteWithReLayout(Schedule sch, const Array& tensor_array, const std::string& scope) { size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); @@ -310,8 +283,8 @@ Array CacheWriteWithReLayout(Schedule sch, std::unordered_map vsub2newvar; std::vector predicates; - PrepareAxisMapping(orig_stage, compute, - &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); + PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, + &predicates); PrimExpr body; Array body_list; @@ -326,17 +299,14 @@ Array CacheWriteWithReLayout(Schedule sch, const tir::ReduceNode* reduce_body = body.as(); if (first_reduce != nullptr) { CHECK(ReduceEqual(reduce_body, first_reduce)); - body = tir::ReduceNode::make(first_reduce->combiner, - first_reduce->source, - first_reduce->axis, - first_reduce->condition, - reduce_body->value_index); + body = + tir::ReduceNode::make(first_reduce->combiner, first_reduce->source, first_reduce->axis, + first_reduce->condition, reduce_body->value_index); } else { first_reduce = reduce_body; } } else { - CHECK(first_reduce == nullptr) - << "cannot mix reduce and other node in ONE compute bodys"; + CHECK(first_reduce == nullptr) << "cannot mix reduce and other node in ONE compute bodys"; } body_list.push_back(body); } @@ -354,26 +324,21 @@ Array CacheWriteWithReLayout(Schedule sch, args.push_back(value_map.at(iv)); } } - Operation cache_op = ComputeOpNode::make( - compute->name + "." + scope, compute->tag, compute->attrs, - new_axis, body_list); + Operation cache_op = ComputeOpNode::make(compute->name + "." + scope, compute->tag, + compute->attrs, new_axis, body_list); Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); } - Operation orig_new_op = ComputeOpNode::make( - compute->name, compute->tag, compute->attrs, - compute->axis, cache_expr_list); - return ReplaceOriginalOp(sch, orig_stage, scope, - cache_op, orig_new_op, tensor_size); + Operation orig_new_op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, + compute->axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } - // for tensor compute op -Array CacheWriteWithReLayoutTensor(Schedule sch, - const Array& tensor_array, +Array CacheWriteWithReLayoutTensor(Schedule sch, const Array& tensor_array, const std::string& scope) { size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); @@ -391,14 +356,12 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, std::unordered_map vsub2newvar; std::vector predicates; - PrepareAxisMapping(orig_stage, tensor_op, - &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); - + PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, + &predicates); for (int i = tensor_op->schedulable_ndim; i < static_cast(tensor_op->axis.size()); ++i) { IterVar iv = tensor_op->axis[i]; - IterVar new_iv = IterVarNode::make( - iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + IterVar new_iv = IterVarNode::make(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); } Array new_regions; @@ -417,10 +380,10 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input)); } - Operation cache_op = TensorComputeOpNode::make( - tensor_op->name + "." + scope, tensor_op->tag, new_axis, - tensor_op->reduce_axis, tensor_op->schedulable_ndim, - tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs); + Operation cache_op = TensorComputeOpNode::make(tensor_op->name + "." + scope, tensor_op->tag, + new_axis, tensor_op->reduce_axis, + tensor_op->schedulable_ndim, tensor_op->intrin, + tensor_op->inputs, new_regions, new_scalar_inputs); // axis will be used in generating compute op Array compute_axis = tensor_op->axis; @@ -455,19 +418,14 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); } - Operation orig_new_op = ComputeOpNode::make( - tensor_op->name, tensor_op->tag, {}, - compute_axis, cache_expr_list); - return ReplaceOriginalOp(sch, orig_stage, scope, - cache_op, orig_new_op, tensor_size); + Operation orig_new_op = + ComputeOpNode::make(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } - -Array Schedule::cache_write(const Array& tensor_array, - const std::string& scope) { +Array Schedule::cache_write(const Array& tensor_array, const std::string& scope) { (*this)->InvalidateCache(); - CHECK(tensor_array.size() > 0) - << "size of tensor_array must be greater than 0"; + CHECK(tensor_array.size() > 0) << "size of tensor_array must be greater than 0"; Tensor tensor = tensor_array[0]; Stage orig_stage = operator[](tensor->op); const ComputeOpNode* compute = tensor->op.as(); @@ -475,15 +433,12 @@ Array Schedule::cache_write(const Array& tensor_array, << "size of input tensor list must be same as number of stage outputs"; for (size_t i = 1; i < tensor_array.size(); i++) { Stage tmp_stage = operator[](tensor_array[i]->op); - CHECK(orig_stage.same_as(tmp_stage)) - << "Input tensor list must be generated by ONE computeOp"; + CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp"; } return CacheWriteWithReLayout(*this, tensor_array, scope); } - -Tensor Schedule::cache_write(const Tensor& tensor, - const std::string& scope) { +Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { // support original compute and tensor compute both (*this)->InvalidateCache(); if (tensor->op.as()) { @@ -496,7 +451,6 @@ Tensor Schedule::cache_write(const Tensor& tensor, } } - void RebaseNonZeroMinLoop(const Schedule& sch) { std::unordered_map rebase_map; for (Stage s : sch->stages) { @@ -506,16 +460,14 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); for (IterVar iv : root_iter_vars) { size_t idx = FindNodeRef(leaf_vars, iv); - auto it = s->iter_var_attrs.find(iv); + auto it = s->iter_var_attrs.find(iv); // don;t need to rebase path that are binded. - if (it != s->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { + if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { continue; } if (idx < leaf_vars->data.size()) { // insert rebase - IterVar rebased = IterVarNode::make( - Range(), iv->var.copy_with_suffix(""), iv->iter_type); + IterVar rebased = IterVarNode::make(Range(), iv->var.copy_with_suffix(""), iv->iter_type); s->relations.push_back(RebaseNode::make(iv, rebased)); if (s->iter_var_attrs.count(iv)) { s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv)); @@ -557,13 +509,11 @@ void InjectInline(ScheduleNode* sch) { { // setup args const ComputeOpNode* compute = stage->op.as(); - CHECK(compute) - << "can only inline compute op"; + CHECK(compute) << "can only inline compute op"; for (auto iv : compute->axis) { args.push_back(iv->var); } - CHECK_EQ(compute->body.size(), 1U) - << "can only inline compute op with 1 output"; + CHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output"; body = compute->body[0]; } for (size_t j = i; j < sch->stages.size(); ++j) { @@ -580,12 +530,13 @@ void InjectInline(ScheduleNode* sch) { for (size_t k = 1; k < new_body[j].size(); ++k) { const tir::ReduceNode* reduce_ = new_body[j][k].as(); CHECK(reduce_); - CHECK(ReduceEqual(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; + CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; } - PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][0]), - stage->op, args, body).as()->value; + PrimExpr new_value = + Inline(tir::EvaluateNode::make(new_body[j][0]), stage->op, args, body) + .as() + ->value; if (!new_value.same_as(new_body[j][0])) { changed[j] = true; const tir::ReduceNode* r = new_value.as(); @@ -600,8 +551,10 @@ void InjectInline(ScheduleNode* sch) { } } else { for (size_t k = 0; k < new_body[j].size(); ++k) { - PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][k]), - stage->op, args, body).as()->value; + PrimExpr new_value = + Inline(tir::EvaluateNode::make(new_body[j][k]), stage->op, args, body) + .as() + ->value; if (!new_value.same_as(new_body[j][k])) { new_body[j].Set(k, new_value); changed[j] = true; @@ -632,9 +585,8 @@ void InjectInline(ScheduleNode* sch) { CHECK(compute); Operation op = s->op; if (changed[i]) { - op = ComputeOpNode::make( - compute->name, compute->tag, compute->attrs, - compute->axis, new_body[i]); + op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, compute->axis, + new_body[i]); } op = op->ReplaceInputs(op, repl); if (!op.same_as(s->op)) { @@ -646,9 +598,8 @@ void InjectInline(ScheduleNode* sch) { } else if (hybrid_changed[i]) { const HybridOpNode* hybrid = sch->stages[i]->op.as(); CHECK(hybrid); - Operation op = HybridOpNode::make( - hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, - hybrid->outputs, new_hybrid_body[i]); + Operation op = HybridOpNode::make(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, + hybrid->outputs, new_hybrid_body[i]); op = op->ReplaceInputs(op, repl); for (int idx = 0; idx < s->op->num_outputs(); ++idx) { repl[s->op.output(idx)] = op.output(idx); @@ -674,13 +625,10 @@ Schedule Schedule::normalize() { } // Handle reduction factor. -Array Schedule::rfactor(const Tensor& tensor, - const IterVar& axis, - int factor_axis) { +Array Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) { (*this)->InvalidateCache(); using tir::ReduceNode; - CHECK_EQ(axis->iter_type, kCommReduce) - << "Can only factor reduction axis"; + CHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis"; Stage reduce_stage = operator[](tensor->op); const ComputeOpNode* compute_op = reduce_stage->op.as(); CHECK(compute_op) << "Can only factor ComputeOp"; @@ -699,8 +647,7 @@ Array Schedule::rfactor(const Tensor& tensor, std::unordered_set skip_bound_check; // Verify normal axis are not touched. for (IterVar iv : compute_op->axis) { - CHECK(!touch_map.count(iv)) - << "Factor axis touches normal axis."; + CHECK(!touch_map.count(iv)) << "Factor axis touches normal axis."; skip_bound_check.insert(iv); } // get analyzer. @@ -728,11 +675,11 @@ Array Schedule::rfactor(const Tensor& tensor, } } te::PassUpIndex(reduce_stage, dom_map, &value_map, true); - std::vector predicates = MakeBoundCheck( - reduce_stage, dom_map, value_map, true, skip_bound_check); + std::vector predicates = + MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check); // Get the factored op node. - const int factor_axis_pos = \ + const int factor_axis_pos = factor_axis >= 0 ? factor_axis : static_cast(compute_op->axis.size() + 1) + factor_axis; CHECK_LE(factor_axis_pos, compute_op->axis.size()); auto n = make_object(); @@ -741,8 +688,7 @@ Array Schedule::rfactor(const Tensor& tensor, // axis relacement. auto iv_node = make_object(); iv_node->dom = dom_map.at(axis); - CHECK(is_zero(iv_node->dom->min)) - << "Can only factor reduction domain starting from 0"; + CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0"; iv_node->var = axis->var; iv_node->iter_type = kDataPar; @@ -786,18 +732,15 @@ Array Schedule::rfactor(const Tensor& tensor, } } VarReplacer replacer(vsub); - Array new_source = tir::UpdateArray(reduce->source, - [&replacer] (const PrimExpr& e) { return replacer(e); }); + Array new_source = + tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); }); PrimExpr new_pred = replacer(predicate); std::vector body; for (size_t idx = 0; idx < reduce->source.size(); ++idx) { - body.emplace_back(ReduceNode::make(reduce->combiner, - new_source, - n->reduce_axis, - new_pred, - idx)); + body.emplace_back( + ReduceNode::make(reduce->combiner, new_source, n->reduce_axis, new_pred, idx)); } n->body = Array(body); // refresh relations, keep the un-touched relations. @@ -824,16 +767,14 @@ Array Schedule::rfactor(const Tensor& tensor, Stage factor_stage = Stage(factor_op); factor_stage->relations = rels; CHECK_LT(stage_pos, stages->data.size()); - stages->data.insert(stages->data.begin() + stage_pos, - factor_stage); + stages->data.insert(stages->data.begin() + stage_pos, factor_stage); (*this)->stage_map.Set(factor_op, factor_stage); factor_stage->group = reduce_stage->group; if (factor_stage->group.defined()) { ++factor_stage->group->num_child_stages; } // Replace the old reduction. - IterVar repl_red_axis = reduce_axis( - dom_map.at(axis), axis->var->name_hint + ".v"); + IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v"); Array factor_tensors; Array old_tensors; int size = factor_op->num_outputs(); @@ -841,32 +782,33 @@ Array Schedule::rfactor(const Tensor& tensor, factor_tensors.push_back(factor_op.output(idx)); old_tensors.push_back(reduce_stage->op.output(idx)); } - Array repl_tensors = compute(old_tensors[0]->shape, - [&](const Array& i) { - Array indices; - const int idx_size = static_cast(i.size()); - for (int idx = 0; idx < idx_size; ++idx) { - if (factor_axis_pos == idx) { - indices.push_back(repl_red_axis->var); + Array repl_tensors = compute( + old_tensors[0]->shape, + [&](const Array& i) { + Array indices; + const int idx_size = static_cast(i.size()); + for (int idx = 0; idx < idx_size; ++idx) { + if (factor_axis_pos == idx) { + indices.push_back(repl_red_axis->var); + } + indices.push_back(i[idx]); } - indices.push_back(i[idx]); - } - if (factor_axis_pos == idx_size) { + if (factor_axis_pos == idx_size) { indices.push_back(repl_red_axis->var); - } - Array factor_exprs; - for (int idx = 0; idx < size; ++idx) { - factor_exprs.push_back(factor_tensors[idx](indices)); - } - Array reductions; - Array axis = {repl_red_axis}; - PrimExpr cond = const_true(); - for (int idx = 0; idx < size; ++idx) { - reductions.push_back(ReduceNode::make(reduce->combiner, - factor_exprs, axis, cond, idx)); - } - return reductions; - }, reduce_stage->op->name + ".repl"); + } + Array factor_exprs; + for (int idx = 0; idx < size; ++idx) { + factor_exprs.push_back(factor_tensors[idx](indices)); + } + Array reductions; + Array axis = {repl_red_axis}; + PrimExpr cond = const_true(); + for (int idx = 0; idx < size; ++idx) { + reductions.push_back(ReduceNode::make(reduce->combiner, factor_exprs, axis, cond, idx)); + } + return reductions; + }, + reduce_stage->op->name + ".repl"); std::unordered_map vmap; std::unordered_map rvmap; diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index bfee0d5..74ddca5 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -22,17 +22,19 @@ */ #include #include -#include #include +#include + #include #include + #include "graph.h" namespace tvm { namespace te { // find first occurance location in leaf -template +template size_t FindNodeRef(ArrayNode* array_node, const T& v) { const Object* n = v.get(); for (size_t i = 0; i < array_node->data.size(); ++i) { @@ -46,30 +48,23 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) if (pos < leaf_vars->data.size()) return pos; if (FindNodeRef(all_vars, v) < all_vars->data.size()) { - LOG(FATAL) << "Operate on iter var " << v - << "that has already been split"; + LOG(FATAL) << "Operate on iter var " << v << "that has already been split"; } else { - LOG(FATAL) << "Operate on iter var " << v - << "that is not part of the schedule"; + LOG(FATAL) << "Operate on iter var " << v << "that is not part of the schedule"; } return 0; } -void Split(StageNode* self, - IterVar parent, - PrimExpr factor, - PrimExpr nparts, - IterVar* p_outer, +void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // Check if split is valid. - CHECK(parent->iter_type == kDataPar || - parent->iter_type == kCommReduce || + CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || parent->iter_type == kOrdered) << "Cannot split on " << IterVarType2String(parent->iter_type); - IterVar outer = IterVarNode::make( - Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); - IterVar inner = IterVarNode::make( - Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); + IterVar outer = + IterVarNode::make(Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); + IterVar inner = + IterVarNode::make(Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); *p_outer = outer; *p_inner = inner; // The splits @@ -112,8 +107,7 @@ bool Stage::is_scheduled() const { Stage Stage::GetAttachSpec() const { Stage attach_spec = *this; - while (attach_spec->attach_type == kGroupRoot && - attach_spec->group.defined()) { + while (attach_spec->attach_type == kGroupRoot && attach_spec->group.defined()) { attach_spec = attach_spec->group; } return attach_spec; @@ -124,9 +118,8 @@ Stage& Stage::set_scope(std::string scope) { // NOLINT(*) return *this; } -Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) - CHECK_NE((*this)->attach_type, kScanUpdate) - << "Cannot specify compute_at for scan updates"; +Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; // Group constraint checking. Stage group = (*this)->group; if (group.defined()) { @@ -134,8 +127,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) while (pg.defined() && !pg.same_as(group)) { pg = pg->group; } - CHECK(pg.same_as(group)) - << "Can only assign compute_at to stages within the same group"; + CHECK(pg.same_as(group)) << "Can only assign compute_at to stages within the same group"; } (*this)->attach_type = kScope; @@ -144,34 +136,30 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) bool found = false; for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) { if (scope == parent->leaf_iter_vars[i]) { - found = true; break; + found = true; + break; } } - CHECK(found) - << "Cannot find the axis " << scope - << " in parent's leaf_iter_vars" - << " parent=" << parent; + CHECK(found) << "Cannot find the axis " << scope << " in parent's leaf_iter_vars" + << " parent=" << parent; return *this; } -Stage& Stage::compute_inline() { // NOLINT(*) - CHECK_NE((*this)->attach_type, kScanUpdate) - << "Cannot specify compute_at for scan updates"; +Stage& Stage::compute_inline() { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kInline; return *this; } -Stage& Stage::compute_root() { // NOLINT(*) - CHECK_NE((*this)->attach_type, kScanUpdate) - << "Cannot specify compute_at for scan updates"; +Stage& Stage::compute_root() { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kGroupRoot; return *this; } -Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) +Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) StageNode* self = operator->(); - CHECK(ivar->iter_type == kDataPar || - ivar->iter_type == kCommReduce) + CHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce) << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread"; CHECK(thread_ivar->iter_type == kThreadIndex) << "Cannot rebase by " << IterVarType2String(ivar->iter_type) @@ -184,10 +172,8 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) ObjectPtr n; if (it != self->iter_var_attrs.end()) { n = make_object(*(*it).second.operator->()); - if (n->bind_thread.defined() && - !n->bind_thread.same_as(thread_ivar)) { - LOG(WARNING) << "Axis " << ivar - << " is already bind to another thread " << n->bind_thread; + if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) { + LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread; } } else { n = make_object(); @@ -201,18 +187,15 @@ Stage& Stage::env_threads(Array threads) { StageNode* self = operator->(); CHECK(self->op.defined() && self->op.as()) << "env_threads is only valid for composite ops such as ScanOp"; - CHECK_EQ(self->env_threads.size(), 0U) - << "Already set env_threads"; + CHECK_EQ(self->env_threads.size(), 0U) << "Already set env_threads"; ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); std::vector temp; for (IterVar iv : threads) { temp.push_back(iv); } - leaf_vars->data.insert( - leaf_vars->data.begin(), temp.begin(), temp.end()); - all_vars->data.insert( - all_vars->data.end(), temp.begin(), temp.end()); + leaf_vars->data.insert(leaf_vars->data.begin(), temp.begin(), temp.end()); + all_vars->data.insert(all_vars->data.end(), temp.begin(), temp.end()); self->env_threads = threads; return *this; } @@ -223,36 +206,32 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) { return *this; } -Stage& Stage::split( - IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) +Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, + IterVar* p_inner) { // NOLINT(*) Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); return *this; } -Stage& Stage::split_by_nparts( - IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) +Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, + IterVar* p_inner) { // NOLINT(*) Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); return *this; } Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*) StageNode* self = operator->(); - CHECK(outer->iter_type == kDataPar || - outer->iter_type == kCommReduce || + CHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce || outer->iter_type == kOrdered) << "Cannot fuse " << IterVarType2String(outer->iter_type); - CHECK(inner->iter_type == kDataPar || - inner->iter_type == kCommReduce || + CHECK(inner->iter_type == kDataPar || inner->iter_type == kCommReduce || inner->iter_type == kOrdered) << "Cannot fuse " << IterVarType2String(inner->iter_type); IterVarType iter_type = outer->iter_type; if (inner->iter_type > iter_type) iter_type = inner->iter_type; - std::string fused_name = - outer->var->name_hint + "." + inner->var->name_hint + ".fused"; + std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused"; - IterVar fused = IterVarNode::make( - Range(), Var(fused_name, outer->var.dtype()), iter_type); + IterVar fused = IterVarNode::make(Range(), Var(fused_name, outer->var.dtype()), iter_type); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -269,8 +248,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT all_vars->data.push_back(fused); leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer, leaf_vars->data.begin() + pos_inner + 1); - leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, - fused); + leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, fused); *p_target = fused; return *this; } @@ -286,9 +264,8 @@ Stage& Stage::fuse(const Array& axes, IterVar* p_target) { // NOLINT(* StageNode* self = operator->(); // special handle fuse empty array. // insert at the outer most loop - IterVar singleton = IterVarNode::make( - Range::make_by_min_extent(0, 1), - Var("singleton", DataType::Int(32)), kDataPar); + IterVar singleton = IterVarNode::make(Range::make_by_min_extent(0, 1), + Var("singleton", DataType::Int(32)), kDataPar); self->relations.push_back(SingletonNode::make(singleton)); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -303,14 +280,11 @@ Stage& Stage::reorder(const Array& order) { // NOLINT(*) std::unordered_set seen_var; StageNode* self = operator->(); for (IterVar iv : order) { - CHECK(iv->iter_type == kDataPar || - iv->iter_type == kCommReduce || + CHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce || iv->iter_type == kThreadIndex) - << "Cannot reorder IterVar(" - << IterVarType2String(iv->iter_type) << ")"; + << "Cannot reorder IterVar(" << IterVarType2String(iv->iter_type) << ")"; - CHECK_EQ(seen_var.count(iv), 0) - << "Same axis can not appear more than once " << iv; + CHECK_EQ(seen_var.count(iv), 0) << "Same axis can not appear more than once " << iv; seen_var.insert(iv); } ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); @@ -331,20 +305,16 @@ Stage& Stage::reorder(const Array& order) { // NOLINT(*) return *this; } -Stage& Stage::tile(IterVar x_parent, IterVar y_parent, - PrimExpr x_factor, PrimExpr y_factor, - IterVar* p_x_outer, IterVar* p_y_outer, - IterVar* p_x_inner, IterVar* p_y_inner) { +Stage& Stage::tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor, + IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner) { split(x_parent, x_factor, p_x_outer, p_x_inner); split(y_parent, y_factor, p_y_outer, p_y_inner); reorder(Array({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner})); return *this; } -template -inline void UpdateIterVarAttr(StageNode* self, - IterVar var, - FUpdate fupdate, +template +inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate, bool need_leaf = true) { if (need_leaf) { ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); @@ -363,60 +333,53 @@ inline void UpdateIterVarAttr(StageNode* self, } inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) { - UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { - n->iter_type = iter_type; - }); + UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { n->iter_type = iter_type; }); } -Stage& Stage::vectorize(IterVar var) { // NOLINT(*) - CHECK(var->iter_type == kDataPar || - var->iter_type == kOpaque || - var->iter_type == kUnrolled || - var->iter_type == kVectorized || - var->iter_type == kTensorized || +Stage& Stage::vectorize(IterVar var) { // NOLINT(*) + CHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || var->iter_type == kUnrolled || + var->iter_type == kVectorized || var->iter_type == kTensorized || var->iter_type == kParallelized) << "Cannot vectorize on " << IterVarType2String(var->iter_type); SetAttrIterType(operator->(), var, kVectorized); return *this; } -Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) +Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) { - n->iter_type = kTensorized; - n->tensor_intrin = f; - }); + n->iter_type = kTensorized; + n->tensor_intrin = f; + }); return *this; } -Stage& Stage::unroll(IterVar var) { // NOLINT(*) +Stage& Stage::unroll(IterVar var) { // NOLINT(*) SetAttrIterType(operator->(), var, kUnrolled); return *this; } -Stage& Stage::parallel(IterVar var) { // NOLINT(*) +Stage& Stage::parallel(IterVar var) { // NOLINT(*) SetAttrIterType(operator->(), var, kParallelized); return *this; } -Stage& Stage::pragma(IterVar var, - const std::string& pragma_type, - const PrimExpr& pragma_value) { // NOLINT(*) +Stage& Stage::pragma(IterVar var, const std::string& pragma_type, + const PrimExpr& pragma_value) { // NOLINT(*) if (pragma_type == "unroll") { this->unroll(var); } else if (pragma_type == "vectorize") { this->vectorize(var); } else { - UpdateIterVarAttr( - operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { - n->pragma_keys.push_back(tir::StringImmNode::make(pragma_type)); - n->pragma_values.push_back(pragma_value); - }); + UpdateIterVarAttr(operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { + n->pragma_keys.push_back(tir::StringImmNode::make(pragma_type)); + n->pragma_values.push_back(pragma_value); + }); } return *this; } -Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) { - StageNode *self = operator->(); +Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) { + StageNode* self = operator->(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); FindLeafVar(all_vars, leaf_vars, var); @@ -434,16 +397,19 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) { } Stage& Stage::storage_align(IterVar axis, int factor, int offset) { - StageNode *self = operator->(); - UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) { - n->dim_align_factor = factor; - n->dim_align_offset = offset; - }, false); + StageNode* self = operator->(); + UpdateIterVarAttr( + self, axis, + [factor, offset](IterVarAttrNode* n) { + n->dim_align_factor = factor; + n->dim_align_offset = offset; + }, + false); return *this; } Stage& Stage::double_buffer() { - StageNode *self = operator->(); + StageNode* self = operator->(); CHECK(!self->is_output) << "Cannot apply double buffer on output"; self->double_buffer = true; return *this; @@ -451,7 +417,7 @@ Stage& Stage::double_buffer() { Stage& Stage::opengl() { CHECK(!is_scheduled()) << "Must be a fresh schedule"; - StageNode *self = operator->(); + StageNode* self = operator->(); auto all_iter_vars = self->all_iter_vars; // curr version of all_iter_vars CHECK(!all_iter_vars.empty()) << "At least one iter var"; @@ -475,8 +441,7 @@ Stage& Stage::opengl() { break; } default: { - LOG(ERROR) << "Invalid iter var type " - << IterVarType2String(iter_var->iter_type); + LOG(ERROR) << "Invalid iter var type " << IterVarType2String(iter_var->iter_type); break; } } @@ -492,8 +457,7 @@ Stage& Stage::opengl() { } Stage CopyStage(const Stage& s) { - ObjectPtr n = - make_object(*s.operator->()); + ObjectPtr n = make_object(*s.operator->()); return Stage(n); } @@ -521,24 +485,22 @@ Schedule Schedule::copy() const { for (Stage s : n->stages) { if (s->attach_stage.defined()) { CHECK(smap.find(s->attach_stage) != smap.end()) - << s->attach_stage << " not found in " << (*this); + << s->attach_stage << " not found in " << (*this); s->attach_stage = smap.at(s->attach_stage); } if (s->group.defined()) { - CHECK(smap.find(s->group) != smap.end()) - << s->group << " not found in " << (*this); + CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); s->group = smap.at(s->group); } } for (Stage s : n->groups) { if (s->attach_stage.defined()) { CHECK(smap.find(s->attach_stage) != smap.end()) - << s->attach_stage << " not found in " << (*this); + << s->attach_stage << " not found in " << (*this); s->attach_stage = smap.at(s->attach_stage); } if (s->group.defined()) { - CHECK(smap.find(s->group) != smap.end()) - << s->group << " not found in " << (*this); + CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); s->group = smap.at(s->group); } } @@ -548,8 +510,7 @@ Schedule Schedule::copy() const { Stage Schedule::operator[](const Operation& op) { auto it = (*this)->stage_map.find(op); CHECK(it != (*this)->stage_map.end()) - << "Cannot find Stage for operator " << op - << " in the schedule"; + << "Cannot find Stage for operator " << op << " in the schedule"; return (*it).second; } @@ -570,15 +531,13 @@ Stage LeastCommonAncestor(Stage g1, Stage g2) { return g; } -Array RemapTensor(ScheduleNode* self, - const Array& arr) { +Array RemapTensor(ScheduleNode* self, const Array& arr) { self->InitCache(); const auto& op2stage_cache = self->op2stage_cache_; Array ret; for (Tensor t : arr) { if (!op2stage_cache.count(t->op.get())) { - CHECK(self->stage_map.count(t->op)) - << "Given tensor is not in the schedule plan"; + CHECK(self->stage_map.count(t->op)) << "Given tensor is not in the schedule plan"; t = self->stage_map[t->op]->op.output(t->value_index); } ret.push_back(t); @@ -587,17 +546,14 @@ Array RemapTensor(ScheduleNode* self, } // Group the schedule stages. -Stage Schedule::create_group(const Array& outputs, - const Array& inputs, +Stage Schedule::create_group(const Array& outputs, const Array& inputs, bool include_inputs) { ScheduleNode* self = operator->(); self->InitCache(); const auto& op2stage_cache = self->op2stage_cache_; // Get the ops. - Array ops = te::GetSubGraph( - RemapTensor(self, outputs), - RemapTensor(self, inputs), - include_inputs); + Array ops = + te::GetSubGraph(RemapTensor(self, outputs), RemapTensor(self, inputs), include_inputs); // local counter entry // Automatically initialize to 0 during creation. struct Entry { @@ -631,7 +587,7 @@ Stage Schedule::create_group(const Array& outputs, // Propagate the counter statistics from by checking if subgroup // Is full and propagate. std::vector stack; - for (auto &kv : counter) { + for (auto& kv : counter) { if (!kv.first.same_as(parent_group)) { if (kv.first->num_child_stages == kv.second.count) { stack.push_back(kv.first); @@ -650,7 +606,7 @@ Stage Schedule::create_group(const Array& outputs, } } // Verification and remappig the subgroups. - for (auto &kv : counter) { + for (auto& kv : counter) { if (kv.first.same_as(parent_group)) continue; CHECK_EQ(kv.first->num_child_stages, kv.second.count) << "Trying to group region that intersect with an already existed group"; @@ -695,9 +651,7 @@ Stage Schedule::create_group(const Array& outputs, return gstage; } -void ScheduleNode::InvalidateCache() { - op2stage_cache_.clear(); -} +void ScheduleNode::InvalidateCache() { op2stage_cache_.clear(); } void ScheduleNode::InitCache() { if (op2stage_cache_.size() == stages.size()) return; @@ -753,10 +707,7 @@ Schedule ScheduleNode::make(Array ops) { return sch; } -IterVarRelation SplitNode::make(IterVar parent, - IterVar outer, - IterVar inner, - PrimExpr factor, +IterVarRelation SplitNode::make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { auto n = make_object(); n->parent = parent; @@ -767,8 +718,7 @@ IterVarRelation SplitNode::make(IterVar parent, return IterVarRelation(n); } -IterVarRelation FuseNode::make( - IterVar outer, IterVar inner, IterVar fused) { +IterVarRelation FuseNode::make(IterVar outer, IterVar inner, IterVar fused) { auto n = make_object(); n->outer = outer; n->inner = inner; @@ -805,19 +755,19 @@ struct TVMSpecializationThreadLocalEntry { typedef dmlc::ThreadLocalStore TVMSpecializationThreadLocalStore; void SpecializedCondition::EnterWithScope() { - TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get(); + TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); entry->condition_stack.push(*this); } void SpecializedCondition::ExitWithScope() { - TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get(); + TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); CHECK(!entry->condition_stack.empty()); CHECK(entry->condition_stack.top().same_as(*this)); entry->condition_stack.pop(); } SpecializedCondition SpecializedCondition::Current() { - TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get(); + TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); SpecializedCondition cond; if (entry->condition_stack.size() > 0) { cond = entry->condition_stack.top(); @@ -827,13 +777,9 @@ SpecializedCondition SpecializedCondition::Current() { class SpecializedCondition::Internal { public: - static void EnterScope(SpecializedCondition cond) { - cond.EnterWithScope(); - } + static void EnterScope(SpecializedCondition cond) { cond.EnterWithScope(); } - static void ExitScope(SpecializedCondition cond) { - cond.ExitWithScope(); - } + static void ExitScope(SpecializedCondition cond) { cond.ExitWithScope(); } }; TVM_REGISTER_NODE_TYPE(StageNode); @@ -847,193 +793,158 @@ TVM_REGISTER_NODE_TYPE(SpecializedConditionNode); // Printer TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - if (op->op.defined()) { - p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; - } else { - p->stream << "group-stage(" << op << ")"; - } -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << IterVarType2String(op->iter_type); -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "split(parent="; - p->Print(op->parent); - p->stream << ", outer="; - p->Print(op->outer); - p->stream << ", inner="; - p->Print(op->inner); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "split("; - p->stream << "outer="; - p->Print(op->outer); - p->stream << ", inner="; - p->Print(op->inner); - p->stream << ", fused="; - p->Print(op->fused); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "rebase("; - p->stream << "parent="; - p->Print(op->parent); - p->stream << ", rebased="; - p->Print(op->rebased); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "singleton("; - p->Print(op->iter); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "schedule(" << op << ")"; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "specialized_condition("; - p->Print(op->clauses); - p->stream << ')'; -}); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + if (op->op.defined()) { + p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; + } else { + p->stream << "group-stage(" << op << ")"; + } + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << IterVarType2String(op->iter_type); + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "split(parent="; + p->Print(op->parent); + p->stream << ", outer="; + p->Print(op->outer); + p->stream << ", inner="; + p->Print(op->inner); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "split("; + p->stream << "outer="; + p->Print(op->outer); + p->stream << ", inner="; + p->Print(op->inner); + p->stream << ", fused="; + p->Print(op->fused); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "rebase("; + p->stream << "parent="; + p->Print(op->parent); + p->stream << ", rebased="; + p->Print(op->rebased); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "singleton("; + p->Print(op->iter); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "schedule(" << op << ")"; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "specialized_condition("; + p->Print(op->clauses); + p->stream << ')'; + }); -TVM_REGISTER_GLOBAL("te.CreateSchedule") -.set_body_typed(create_schedule); +TVM_REGISTER_GLOBAL("te.CreateSchedule").set_body_typed(create_schedule); -TVM_REGISTER_GLOBAL("te.StageSetScope") -.set_body_method(&Stage::set_scope); +TVM_REGISTER_GLOBAL("te.StageSetScope").set_body_method(&Stage::set_scope); -TVM_REGISTER_GLOBAL("te.StageBind") -.set_body_method(&Stage::bind); +TVM_REGISTER_GLOBAL("te.StageBind").set_body_method(&Stage::bind); TVM_REGISTER_GLOBAL("te.StageSplitByFactor") -.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { - IterVar outer, inner; - stage.split(parent, factor, &outer, &inner); - return Array({outer, inner}); -}); + .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { + IterVar outer, inner; + stage.split(parent, factor, &outer, &inner); + return Array({outer, inner}); + }); TVM_REGISTER_GLOBAL("te.StageSplitByNParts") -.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { - IterVar outer, inner; - stage.split_by_nparts(parent, nparts, &outer, &inner); - return Array({outer, inner}); -}); + .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { + IterVar outer, inner; + stage.split_by_nparts(parent, nparts, &outer, &inner); + return Array({outer, inner}); + }); -TVM_REGISTER_GLOBAL("te.StageFuse") -.set_body_typed([](Stage stage, Array axes) { - IterVar fused; - stage.fuse(axes, &fused); - return fused; - }); +TVM_REGISTER_GLOBAL("te.StageFuse").set_body_typed([](Stage stage, Array axes) { + IterVar fused; + stage.fuse(axes, &fused); + return fused; +}); -TVM_REGISTER_GLOBAL("te.StageComputeAt") -.set_body_method(&Stage::compute_at); +TVM_REGISTER_GLOBAL("te.StageComputeAt").set_body_method(&Stage::compute_at); -TVM_REGISTER_GLOBAL("te.StageComputeInline") -.set_body_method(&Stage::compute_inline); +TVM_REGISTER_GLOBAL("te.StageComputeInline").set_body_method(&Stage::compute_inline); -TVM_REGISTER_GLOBAL("te.StageComputeRoot") -.set_body_method(&Stage::compute_root); +TVM_REGISTER_GLOBAL("te.StageComputeRoot").set_body_method(&Stage::compute_root); -TVM_REGISTER_GLOBAL("te.StageReorder") -.set_body_method(&Stage::reorder); +TVM_REGISTER_GLOBAL("te.StageReorder").set_body_method(&Stage::reorder); TVM_REGISTER_GLOBAL("te.StageTile") -.set_body_typed([]( - Stage stage, - IterVar x_parent, IterVar y_parent, - PrimExpr x_factor, PrimExpr y_factor -) { - IterVar x_outer, y_outer, x_inner, y_inner; - stage.tile(x_parent, y_parent, - x_factor, y_factor, - &x_outer, &y_outer, - &x_inner, &y_inner); - return Array({x_outer, y_outer, x_inner, y_inner}); - }); + .set_body_typed([](Stage stage, IterVar x_parent, IterVar y_parent, PrimExpr x_factor, + PrimExpr y_factor) { + IterVar x_outer, y_outer, x_inner, y_inner; + stage.tile(x_parent, y_parent, x_factor, y_factor, &x_outer, &y_outer, &x_inner, &y_inner); + return Array({x_outer, y_outer, x_inner, y_inner}); + }); -TVM_REGISTER_GLOBAL("te.StageEnvThreads") -.set_body_method(&Stage::env_threads); +TVM_REGISTER_GLOBAL("te.StageEnvThreads").set_body_method(&Stage::env_threads); -TVM_REGISTER_GLOBAL("te.StageSetStorePredicate") -.set_body_method(&Stage::set_store_predicate); +TVM_REGISTER_GLOBAL("te.StageSetStorePredicate").set_body_method(&Stage::set_store_predicate); -TVM_REGISTER_GLOBAL("te.StageUnroll") -.set_body_method(&Stage::unroll); +TVM_REGISTER_GLOBAL("te.StageUnroll").set_body_method(&Stage::unroll); -TVM_REGISTER_GLOBAL("te.StageVectorize") -.set_body_method(&Stage::vectorize); +TVM_REGISTER_GLOBAL("te.StageVectorize").set_body_method(&Stage::vectorize); -TVM_REGISTER_GLOBAL("te.StageTensorize") -.set_body_method(&Stage::tensorize); +TVM_REGISTER_GLOBAL("te.StageTensorize").set_body_method(&Stage::tensorize); -TVM_REGISTER_GLOBAL("te.StageParallel") -.set_body_method(&Stage::parallel); +TVM_REGISTER_GLOBAL("te.StageParallel").set_body_method(&Stage::parallel); -TVM_REGISTER_GLOBAL("te.StagePragma") -.set_body_method(&Stage::pragma); +TVM_REGISTER_GLOBAL("te.StagePragma").set_body_method(&Stage::pragma); -TVM_REGISTER_GLOBAL("te.StagePrefetch") -.set_body_method(&Stage::prefetch); +TVM_REGISTER_GLOBAL("te.StagePrefetch").set_body_method(&Stage::prefetch); -TVM_REGISTER_GLOBAL("te.StageStorageAlign") -.set_body_method(&Stage::storage_align); +TVM_REGISTER_GLOBAL("te.StageStorageAlign").set_body_method(&Stage::storage_align); -TVM_REGISTER_GLOBAL("te.StageDoubleBuffer") -.set_body_method(&Stage::double_buffer); +TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffer); -TVM_REGISTER_GLOBAL("te.StageOpenGL") -.set_body_method(&Stage::opengl); +TVM_REGISTER_GLOBAL("te.StageOpenGL").set_body_method(&Stage::opengl); -TVM_REGISTER_GLOBAL("te.ScheduleNormalize") -.set_body_method(&Schedule::normalize); +TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize); -TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup") -.set_body_method(&Schedule::create_group); +TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group); -TVM_REGISTER_GLOBAL("te.ScheduleCacheRead") -.set_body_method(&Schedule::cache_read); +TVM_REGISTER_GLOBAL("te.ScheduleCacheRead").set_body_method(&Schedule::cache_read); -TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[1].IsObjectRef()) { - *ret = args[0].operator Schedule() - .cache_write(args[1].operator Tensor(), args[2]); - } else { - *ret = args[0].operator Schedule() - .cache_write(args[1].operator Array(), args[2]); - } - }); +TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args[1].IsObjectRef()) { + *ret = args[0].operator Schedule().cache_write(args[1].operator Tensor(), args[2]); + } else { + *ret = args[0].operator Schedule().cache_write(args[1].operator Array(), args[2]); + } +}); -TVM_REGISTER_GLOBAL("te.ScheduleRFactor") -.set_body_method(&Schedule::rfactor); +TVM_REGISTER_GLOBAL("te.ScheduleRFactor").set_body_method(&Schedule::rfactor); -TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition") -.set_body_typed([](Array condition) { - return SpecializedCondition(condition); +TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition").set_body_typed([](Array condition) { + return SpecializedCondition(condition); }); -TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = SpecializedCondition::Current(); +TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SpecializedCondition::Current(); }); TVM_REGISTER_GLOBAL("te.EnterSpecializationScope") -.set_body_typed(SpecializedCondition::Internal::EnterScope); + .set_body_typed(SpecializedCondition::Internal::EnterScope); TVM_REGISTER_GLOBAL("te.ExitSpecializationScope") -.set_body_typed(SpecializedCondition::Internal::ExitScope); + .set_body_typed(SpecializedCondition::Internal::ExitScope); } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index bdb77b6..3a26e98 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -21,31 +21,30 @@ * \file schedule_ops.cc */ #include -#include -#include -#include #include #include -#include +#include +#include +#include + #include #include -#include "graph.h" -#include "../operation/op_util.h" +#include + #include "../../tir/transforms/ir_util.h" +#include "../operation/op_util.h" +#include "graph.h" namespace tvm { namespace te { using namespace tir; -Stmt MakePipeline(const Stage& s, - const std::unordered_map& dom_map, - Stmt consumer, +Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_map, Stmt consumer, bool debug_keep_trivial_loop) { Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); if (s->double_buffer) { - producer = AttrStmtNode::make( - s->op, tir::attr::double_buffer_scope, 1, producer); + producer = AttrStmtNode::make(s->op, tir::attr::double_buffer_scope, 1, producer); } Stmt pipeline = producer; @@ -54,14 +53,12 @@ Stmt MakePipeline(const Stage& s, } pipeline = s->op->BuildRealize(s, dom_map, pipeline); // use attribute to mark scope of the operation. - pipeline = AttrStmtNode::make( - s->op, tir::attr::realize_scope, - StringImmNode::make(s->scope), - pipeline); + pipeline = + AttrStmtNode::make(s->op, tir::attr::realize_scope, StringImmNode::make(s->scope), pipeline); if (s->is_opengl) { - pipeline = AttrStmtNode::make( - s->op, tir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline); + pipeline = + AttrStmtNode::make(s->op, tir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline); } return pipeline; } @@ -69,28 +66,25 @@ Stmt MakePipeline(const Stage& s, // inject the operator's realization on the stmt. class InjectAttach : public StmtMutator { public: - InjectAttach(const Stage& stage, - const Stage& attach_spec, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) - : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map), + InjectAttach(const Stage& stage, const Stage& attach_spec, + const std::unordered_map& dom_map, bool debug_keep_trivial_loop) + : stage_(stage), + attach_spec_(attach_spec), + dom_map_(dom_map), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} Stmt VisitStmt(const Stmt& input_stmt) final { CHECK(input_stmt.defined()); auto stmt = StmtMutator::VisitStmt(input_stmt); const AttrStmtNode* op = stmt.as(); - if (op != nullptr && - op->attr_key == tir::attr::loop_scope) { - if (attach_spec_->attach_type == kScope && - op->node == attach_spec_->attach_ivar) { - CHECK(!found_attach) - << "Find IterVar" << attach_spec_->attach_ivar - << " in multiple places in the IR"; + if (op != nullptr && op->attr_key == tir::attr::loop_scope) { + if (attach_spec_->attach_type == kScope && op->node == attach_spec_->attach_ivar) { + CHECK(!found_attach) << "Find IterVar" << attach_spec_->attach_ivar + << " in multiple places in the IR"; found_attach = true; - stmt = AttrStmtNode::make( - op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); + stmt = + AttrStmtNode::make(op->node, op->attr_key, op->value, + MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); } } return stmt; @@ -113,27 +107,27 @@ class InjectAttach : public StmtMutator { // inject the operator's realization on the stmt. class InjectScanStep : public StmtMutator { public: - InjectScanStep(const Stage& stage, - const Operation& scan_op, - const std::unordered_map& dom_map, - bool is_init, + InjectScanStep(const Stage& stage, const Operation& scan_op, + const std::unordered_map& dom_map, bool is_init, bool debug_keep_trivial_loop) - : stage_(stage), scan_op_(scan_op), - dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} + : stage_(stage), + scan_op_(scan_op), + dom_map_(dom_map), + is_init_(is_init), + debug_keep_trivial_loop_(debug_keep_trivial_loop) {} Stmt VisitStmt(const Stmt& input_stmt) final { CHECK(input_stmt.defined()); auto stmt = StmtMutator::VisitStmt(input_stmt); // update const AttrStmtNode* op = stmt.as(); - if (op != nullptr && - ((op->attr_key == tir::attr::scan_update_scope && !is_init_) || - (op->attr_key == tir::attr::scan_init_scope && is_init_))) { + if (op != nullptr && ((op->attr_key == tir::attr::scan_update_scope && !is_init_) || + (op->attr_key == tir::attr::scan_init_scope && is_init_))) { if (op->node.same_as(scan_op_)) { found_attach = true; - stmt = AttrStmtNode::make( - op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); + stmt = + AttrStmtNode::make(op->node, op->attr_key, op->value, + MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); } } return stmt; @@ -169,8 +163,7 @@ class SchedulePostProc : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::loop_scope || - op->attr_key == tir::attr::scan_init_scope) { + if (op->attr_key == tir::attr::loop_scope || op->attr_key == tir::attr::scan_init_scope) { return this->VisitStmt(op->body); } else if (op->attr_key == tir::attr::scan_update_scope) { const ScanOpNode* scan = op->node.as(); @@ -194,8 +187,7 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - Stmt ret = AttrStmtNode::make( - it->second, op->attr_key, op->value, op->body); + Stmt ret = AttrStmtNode::make(it->second, op->attr_key, op->value, op->body); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); @@ -208,8 +200,8 @@ class SchedulePostProc : public StmtExprMutator { if (it != replace_op_.end()) { if (it->second.defined()) { return AttrStmtNode::make( - Array{tuple[0], it->second.output(tensor->value_index)}, - op->attr_key, op->value, this->VisitStmt(op->body)); + Array{tuple[0], it->second.output(tensor->value_index)}, op->attr_key, + op->value, this->VisitStmt(op->body)); } else { return this->VisitStmt(op->body); } @@ -219,9 +211,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(tensor->op.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - return AttrStmtNode::make( - it->second.output(tensor->value_index), - op->attr_key, op->value, this->VisitStmt(op->body)); + return AttrStmtNode::make(it->second.output(tensor->value_index), op->attr_key, op->value, + this->VisitStmt(op->body)); } else { return this->VisitStmt(op->body); } @@ -235,9 +226,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = RealizeNode::make( - it->second->op, it->second->value_index, - op->dtype, op->bounds, op->condition, op->body); + Stmt ret = RealizeNode::make(it->second->op, it->second->value_index, op->dtype, op->bounds, + op->condition, op->body); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); @@ -252,8 +242,7 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; - Stmt ret = ProvideNode::make( - dst->op, dst->value_index, op->value, op->args); + Stmt ret = ProvideNode::make(dst->op, dst->value_index, op->value, op->args); return this->VisitStmt(ret); } else { return StmtExprMutator::VisitStmt_(op); @@ -266,9 +255,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; - PrimExpr ret = CallNode::make( - op->dtype, dst->op->name, op->args, - op->call_type, dst->op, dst->value_index); + PrimExpr ret = CallNode::make(op->dtype, dst->op->name, op->args, op->call_type, dst->op, + dst->value_index); return this->VisitExpr(ret); } } @@ -299,8 +287,7 @@ class SchedulePostProc : public StmtExprMutator { if (!s->op.same_as(s->origin_op)) { for (int i = 0; i < s->op->num_outputs(); ++i) { Tensor target = s->origin_op.output(i); - AddReplace(s->op.output(i), target, - target, s->origin_op); + AddReplace(s->op.output(i), target, target, s->origin_op); } } // Specially add replacements for scan op. @@ -316,9 +303,7 @@ class SchedulePostProc : public StmtExprMutator { } private: - void AddReplace(Tensor src, - Tensor dst, - Tensor repl_realize = Tensor(), + void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(), Operation repl_op = Operation()) { TensorKey key{src->op, src->value_index}; replace_buffer_[key] = dst; @@ -339,8 +324,7 @@ class SchedulePostProc : public StmtExprMutator { arith::Analyzer analyzer_; }; -Stmt ScheduleOps( - Schedule sch, Map dom_map_, bool debug_keep_trivial_loop) { +Stmt ScheduleOps(Schedule sch, Map dom_map_, bool debug_keep_trivial_loop) { Stmt body = Stmt(); std::unordered_map dom_map = as_unordered_map(dom_map_); // scan init and scan updates @@ -350,8 +334,7 @@ Stmt ScheduleOps( if (!scan) continue; for (Tensor t : scan->init) { if (scan_init.count(t->op)) { - CHECK(scan_init.at(t->op).same_as(s->op)) - << "Scan init tensor can only belong to one scan"; + CHECK(scan_init.at(t->op).same_as(s->op)) << "Scan init tensor can only belong to one scan"; } else { scan_init[t->op] = s->op; } @@ -365,8 +348,7 @@ Stmt ScheduleOps( // reverse the post DFS order. for (size_t i = sch->stages.size(); i != 0; --i) { Stage s = sch->stages[i - 1]; - CHECK_NE(s->attach_type, kInline) - << "call schedule.normalize before scheduleops"; + CHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops"; CHECK(s->op.defined()); // no need to specify place holder op. if (s->op.as()) continue; @@ -377,15 +359,13 @@ Stmt ScheduleOps( CHECK(body.defined()); InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop); body = mu(std::move(body)); - CHECK(mu.found_attach) - << "did not find attachment point for scan.init"; + CHECK(mu.found_attach) << "did not find attachment point for scan.init"; } else if (attach_spec->attach_type == kScanUpdate) { // Handle scan update CHECK(body.defined()); InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop); body = mu(std::move(body)); - CHECK(mu.found_attach) - << "did not find attachment point for scan.update"; + CHECK(mu.found_attach) << "did not find attachment point for scan.update"; } else if (attach_spec->attach_type == kInlinedAlready) { // do nothing } else if (attach_spec->attach_type == kGroupRoot) { @@ -396,11 +376,10 @@ Stmt ScheduleOps( CHECK(body.defined()); InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop); body = mutator(std::move(body)); - CHECK(mutator.found_attach) - << "did not find attachment point for " << s << " in " - << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar - << ", body:\n" - << body; + CHECK(mutator.found_attach) << "did not find attachment point for " << s << " in " + << attach_spec->attach_stage->op << " x " + << attach_spec->attach_ivar << ", body:\n" + << body; } } SchedulePostProc post_proc; @@ -408,8 +387,7 @@ Stmt ScheduleOps( return post_proc(std::move(body)); } -TVM_REGISTER_GLOBAL("schedule.ScheduleOps") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("schedule.ScheduleOps").set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 2) *ret = ScheduleOps(args[0], args[1], false); else diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 2198827..84166d1 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -23,17 +23,19 @@ * \brief Rewrite the Stmt generated by ScheduleOps * to accomondate tensorcore. */ +#include #include +#include +#include +#include +#include #include +#include #include -#include #include -#include -#include -#include -#include -#include + #include + #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" @@ -41,10 +43,10 @@ namespace tvm { namespace te { using namespace te; +using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; -using intrinsic::tvm_address_of; struct Tile { int m{-1}; @@ -61,7 +63,7 @@ std::string simplify_name(std::string input) { } } -PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) { +PrimExpr unpack_type_cast(const PrimExpr& input, const DataType& target_type) { auto cast = input.as(); if (cast == nullptr) { return input; @@ -74,7 +76,7 @@ PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) { // MMAMatcher matches C = Cast(A)*Cast(B)+C, // where A & B are fp16/int8 local buffers, // and C is fp32/int32 local buffer. -class MMAMatcher: public StmtVisitor { +class MMAMatcher : public StmtVisitor { public: explicit MMAMatcher(Map extern_buffer) { for (auto kv : extern_buffer) { @@ -130,7 +132,7 @@ class MMAMatcher: public StmtVisitor { } } - inline bool Matched() const {return matched_;} + inline bool Matched() const { return matched_; } friend class ScheduleAnalyser; friend class BufferAnalyser; @@ -141,7 +143,7 @@ class MMAMatcher: public StmtVisitor { DataType dtype; bool external{false}; bool released{false}; - bool same_as(const BufferInfo &bi) { + bool same_as(const BufferInfo& bi) { if (this->dtype != bi.dtype) return false; if (this->name != bi.name) return false; if (this->external != bi.external) return false; @@ -183,10 +185,8 @@ class MMAMatcher: public StmtVisitor { auto* load_c = add->a.as(); BufferInfo buffer_c; - if (!check_local_buffer_(load_c, &buffer_c) - || !buffer_c.same_as(store_buffer) - || !(buffer_c.dtype == DataType::Float(32) || - buffer_c.dtype == DataType::Int(32))) { + if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) || + !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) { return false; } @@ -198,26 +198,20 @@ class MMAMatcher: public StmtVisitor { auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype); auto load_a = load_a_expr.as(); BufferInfo buffer_a; - if (!check_local_buffer_(load_a, &buffer_a) - || !(buffer_a.dtype == DataType::Float(16) || - buffer_a.dtype == DataType::Int(8) || - buffer_a.dtype == DataType::UInt(8) || - buffer_a.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || - buffer_a.dtype == DataType::Int(1))) { + if (!check_local_buffer_(load_a, &buffer_a) || + !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) || + buffer_a.dtype == DataType::UInt(8) || buffer_a.dtype == DataType::Int(4) || + buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { return false; } auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype); auto load_b = load_b_expr.as(); BufferInfo buffer_b; - if (!check_local_buffer_(load_b, &buffer_b) - || !(buffer_b.dtype == DataType::Float(16) || - buffer_b.dtype == DataType::Int(8) || - buffer_b.dtype == DataType::UInt(8) || - buffer_b.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || - buffer_a.dtype == DataType::Int(1))) { + if (!check_local_buffer_(load_b, &buffer_b) || + !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) || + buffer_b.dtype == DataType::UInt(8) || buffer_b.dtype == DataType::Int(4) || + buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { return false; } @@ -226,8 +220,7 @@ class MMAMatcher: public StmtVisitor { frag_reg_.insert(buffer_b.name); buf_name_.insert(std::make_pair(load_a, buffer_a.name)); buf_name_.insert(std::make_pair(load_b, buffer_b.name)); - mma_sync_.insert(std::make_pair(op, - Array{load_a_expr, load_b_expr, add->a})); + mma_sync_.insert(std::make_pair(op, Array{load_a_expr, load_b_expr, add->a})); return true; } @@ -280,9 +273,8 @@ class BodyVisitor : public StmtExprVisitor { // ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major class ScheduleAnalyser { public: - explicit ScheduleAnalyser(const MMAMatcher &mma_matcher) - : mma_sync_(mma_matcher.mma_sync_), - buf_name_(mma_matcher.buf_name_) {} + explicit ScheduleAnalyser(const MMAMatcher& mma_matcher) + : mma_sync_(mma_matcher.mma_sync_), buf_name_(mma_matcher.buf_name_) {} bool MatrixIdentify(Schedule schedule) { // TODO(minmin): handle the case where MatMul is not the output stage @@ -299,8 +291,8 @@ class ScheduleAnalyser { } const VarNode* axis_var[2]; const VarNode* reduce_axis_var; - axis_var[0] = axis[axis.size()-2]->var.as(); - axis_var[1] = axis[axis.size()-1]->var.as(); + axis_var[0] = axis[axis.size() - 2]->var.as(); + axis_var[1] = axis[axis.size() - 1]->var.as(); reduce_axis_var = reduce_axis[0]->var.as(); BodyVisitor body_visitor; @@ -342,8 +334,8 @@ class ScheduleAnalyser { matrix_major_.insert(std::make_pair(compute->name, "col_major")); } - for (auto &mma_sync : mma_sync_) { - auto &operands = mma_sync.second; + for (auto& mma_sync : mma_sync_) { + auto& operands = mma_sync.second; auto* load_a = operands[0].as(); auto* load_b = operands[1].as(); auto input0 = simplify_name(buf_name_.find(load_a)->second); @@ -398,8 +390,7 @@ class IndexVisitor : public StmtExprVisitor { class BufferAnalyser : public StmtExprVisitor { public: explicit BufferAnalyser(Map extern_buffer, - const ScheduleAnalyser &schedule_analyser, - const MMAMatcher &mma_matcher) + const ScheduleAnalyser& schedule_analyser, const MMAMatcher& mma_matcher) : matrix_abc_(schedule_analyser.matrix_abc_), matrix_major_(schedule_analyser.matrix_major_), frag_reg_(mma_matcher.frag_reg_) { @@ -418,9 +409,7 @@ class BufferAnalyser : public StmtExprVisitor { if (op->attr_key == tir::attr::thread_extent) { if (const IntImmNode* value = op->value.as()) { thread_extent_.insert( - std::make_pair( - op->node.as()->var->name_hint, - value->value)); + std::make_pair(op->node.as()->var->name_hint, value->value)); } StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == tir::attr::realize_scope) { @@ -447,11 +436,9 @@ class BufferAnalyser : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f; const BufferInfo& bi = it->second; - CHECK(!bi.released) - << "Read a buffer that is already out of scope"; + CHECK(!bi.released) << "Read a buffer that is already out of scope"; if (matrix_abc_.count(key.GetName())) { if (bi.shape.size() < 2) { @@ -483,12 +470,7 @@ class BufferAnalyser : public StmtExprVisitor { strides_.insert(std::make_pair(key.GetName(), strides)); if (frag_reg_.count(bi.name)) { - PrimExpr dst = CallNode::make(bi.dtype, - bi.name, - op->args, - CallNode::Halide, - op->func, - 0); + PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0); frag_load_.insert(std::make_pair(op, dst)); auto rel_index = bi.RelIndex(op->args); @@ -545,12 +527,7 @@ class BufferAnalyser : public StmtExprVisitor { const CallNode* value = op->value.as(); if (value != nullptr && frag_reg_.count(value->name)) { - PrimExpr dst = CallNode::make(bi.dtype, - bi.name, - op->args, - CallNode::Halide, - op->func, - 0); + PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0); frag_store_.insert(std::make_pair(op, dst)); } } @@ -560,11 +537,9 @@ class BufferAnalyser : public StmtExprVisitor { if (op->call_type == CallNode::Halide) { TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f; const BufferInfo& bi = it->second; - CHECK(!bi.released) - << "Read a buffer that is already out of scope"; + CHECK(!bi.released) << "Read a buffer that is already out of scope"; if (matrix_abc_.count(op->name)) { if (bi.shape.size() < 2) { @@ -642,8 +617,7 @@ class BufferAnalyser : public StmtExprVisitor { if (dim < avec.size() && avec[dim].align_factor != 0) { PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + \ - indexmod(factor + offset - indexmod(stride, factor), factor); + stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); stride = analyzer_.Simplify(stride); } rstrides.push_back(stride); @@ -730,29 +704,19 @@ class BufferAnalyser : public StmtExprVisitor { } bool supported_warp_tile_() { - if (warp_tile_.m == 16 && - warp_tile_.n == 16 && - warp_tile_.k == 16) { + if (warp_tile_.m == 16 && warp_tile_.n == 16 && warp_tile_.k == 16) { return true; } - if (warp_tile_.m == 8 && - warp_tile_.n == 32 && - warp_tile_.k == 16) { + if (warp_tile_.m == 8 && warp_tile_.n == 32 && warp_tile_.k == 16) { return true; } - if (warp_tile_.m == 32 && - warp_tile_.n == 8 && - warp_tile_.k == 16) { + if (warp_tile_.m == 32 && warp_tile_.n == 8 && warp_tile_.k == 16) { return true; } - if (warp_tile_.m == 8 && - warp_tile_.n == 8 && - warp_tile_.k == 32) { + if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 32) { return true; } - if (warp_tile_.m == 8 && - warp_tile_.n == 8 && - warp_tile_.k == 128) { + if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 128) { return true; } @@ -760,7 +724,7 @@ class BufferAnalyser : public StmtExprVisitor { } std::unordered_map buf_map_; - std::unordered_map > dim_align_; + std::unordered_map> dim_align_; std::unordered_map storage_scope_; std::unordered_map matrix_abc_; std::unordered_map matrix_major_; @@ -780,7 +744,7 @@ class BufferAnalyser : public StmtExprVisitor { // ThreadIdxMutator does the thread index unification inside a warp class ThreadIdxMutator : public StmtExprMutator { public: - explicit ThreadIdxMutator(PrimExpr warp_y): warp_y_(warp_y) {} + explicit ThreadIdxMutator(PrimExpr warp_y) : warp_y_(warp_y) {} PrimExpr VisitExpr_(const VarNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); @@ -807,18 +771,18 @@ class ThreadIdxMutator : public StmtExprMutator { // based on tensor core intrinsics class TensorCoreIRMutator : public StmtExprMutator { public: - explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser, - const BufferAnalyser &buffer_analyser) + explicit TensorCoreIRMutator(const ScheduleAnalyser& schedule_analyser, + const BufferAnalyser& buffer_analyser) : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - mma_sync_(schedule_analyser.mma_sync_), - strides_(buffer_analyser.strides_), - frag_reg_(buffer_analyser.frag_reg_), - loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), - frag_load_(buffer_analyser.frag_load_), - frag_store_(buffer_analyser.frag_store_), - warp_tile_(buffer_analyser.warp_tile_), - warp_threads_y_(buffer_analyser.warp_threads_y_) {} + matrix_major_(schedule_analyser.matrix_major_), + mma_sync_(schedule_analyser.mma_sync_), + strides_(buffer_analyser.strides_), + frag_reg_(buffer_analyser.frag_reg_), + loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), + frag_load_(buffer_analyser.frag_load_), + frag_store_(buffer_analyser.frag_store_), + warp_tile_(buffer_analyser.warp_tile_), + warp_threads_y_(buffer_analyser.warp_threads_y_) {} Stmt VisitStmt_(const RealizeNode* op) final { TensorKey key{op->func, op->value_index}; @@ -836,16 +800,14 @@ class TensorCoreIRMutator : public StmtExprMutator { for (size_t i = 0; i < op->bounds.size() - 2; ++i) { new_bounds.push_back(op->bounds[i]); } - CHECK_GE(op->bounds.size(), 2) - << "Less than 2 dimensions for matrix " << key.GetName(); - new_bounds.push_back(Range::make_by_min_extent( - op->bounds[op->bounds.size() - 2]->min, new_extents[0])); - new_bounds.push_back(Range::make_by_min_extent( - op->bounds[op->bounds.size() - 1]->min, new_extents[1])); - - return RealizeNode::make(op->func, op->value_index, - op->dtype, new_bounds, - op->condition, op->body); + CHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key.GetName(); + new_bounds.push_back( + Range::make_by_min_extent(op->bounds[op->bounds.size() - 2]->min, new_extents[0])); + new_bounds.push_back( + Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, new_extents[1])); + + return RealizeNode::make(op->func, op->value_index, op->dtype, new_bounds, op->condition, + op->body); } return stmt; } @@ -860,14 +822,10 @@ class TensorCoreIRMutator : public StmtExprMutator { } auto it = matrix_abc_.find(simplify_name(node->name)); - CHECK(it != matrix_abc_.end()) - << "Cannot find matrix info for " << node->name; + CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second); Stmt body = this->VisitStmt(op->body); - return AttrStmtNode::make(op->node, - op->attr_key, - matrix_abc, - body); + return AttrStmtNode::make(op->node, op->attr_key, matrix_abc, body); } } return stmt; @@ -877,7 +835,7 @@ class TensorCoreIRMutator : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); auto it = mma_sync_.find(op); if (it != mma_sync_.end()) { - const auto &operands = it->second; + const auto& operands = it->second; PrimExpr a = operands[0]; auto ca = a.as(); PrimExpr b = operands[1]; @@ -889,97 +847,75 @@ class TensorCoreIRMutator : public StmtExprMutator { ObjectPtr buffer_node_b = make_object(); ObjectPtr buffer_node_c = make_object(); - auto mma_sync_call = - [&buffer_node_a, &buffer_node_b, &ca, &cb] - (const Buffer &buffer) { - Buffer buffer_a(buffer_node_a); - Buffer buffer_b(buffer_node_b); - if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_bmma_sync, - {buffer->data, buffer->elem_offset, - buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, - buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); - } else { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_mma_sync, - {buffer->data, buffer->elem_offset, - buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, - buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); - } - }; + auto mma_sync_call = [&buffer_node_a, &buffer_node_b, &ca, &cb](const Buffer& buffer) { + Buffer buffer_a(buffer_node_a); + Buffer buffer_b(buffer_node_b); + if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { + return EvaluateNode::make(CallNode::make( + DataType::Handle(), intrinsic::tvm_bmma_sync, + {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, + CallNode::Intrinsic)); + } else { + return EvaluateNode::make(CallNode::make( + DataType::Handle(), intrinsic::tvm_mma_sync, + {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, + CallNode::Intrinsic)); + } + }; - auto call_add_c = - [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer &buffer) { - return add_buffer_bind_scope_(cc, buffer_node_c, - TensorKey{cc->func, cc->value_index}, mma_sync_call, cc->dtype); - }; + auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) { + return add_buffer_bind_scope_(cc, buffer_node_c, TensorKey{cc->func, cc->value_index}, + mma_sync_call, cc->dtype); + }; - auto call_add_b = - [this, &cb, &buffer_node_b, &call_add_c](const Buffer &buffer) { - return add_buffer_bind_scope_(cb, buffer_node_b, - TensorKey{cb->func, cb->value_index}, call_add_c, cb->dtype); - }; + auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) { + return add_buffer_bind_scope_(cb, buffer_node_b, TensorKey{cb->func, cb->value_index}, + call_add_c, cb->dtype); + }; - return add_buffer_bind_scope_(ca, buffer_node_a, - TensorKey{ca->func, ca->value_index}, call_add_b, ca->dtype); + return add_buffer_bind_scope_(ca, buffer_node_a, TensorKey{ca->func, ca->value_index}, + call_add_b, ca->dtype); } auto it2 = frag_load_.find(op); if (it2 != frag_load_.end()) { PrimExpr dst = it2->second; - if (op->value.as() != nullptr || - op->value.as() != nullptr) { + if (op->value.as() != nullptr || op->value.as() != nullptr) { auto call = dst.as(); - auto fill_fragment_call = - [this, &op](const Buffer &buffer) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_fill_fragment, - {buffer->data, - warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, op->value}, - CallNode::Intrinsic)); - }; + auto fill_fragment_call = [this, &op](const Buffer& buffer) { + return EvaluateNode::make(CallNode::make(DataType::Handle(), intrinsic::tvm_fill_fragment, + {buffer->data, warp_tile_.m, warp_tile_.n, + warp_tile_.k, buffer->elem_offset, op->value}, + CallNode::Intrinsic)); + }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, - TensorKey{call->func, call->value_index}, + return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, fill_fragment_call, call->dtype); } const CallNode* value = op->value.as(); - CHECK(value != nullptr) - << "Can only load fragment from a buffer"; + CHECK(value != nullptr) << "Can only load fragment from a buffer"; auto it = strides_.find(value->name); - CHECK(it != strides_.end()) - << "Cannot find stride for " << value->name; + CHECK(it != strides_.end()) << "Cannot find stride for " << value->name; auto strides = it->second; CHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size()-2]; + PrimExpr stride = strides[strides.size() - 2]; // thread index unification inside a warp PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); - PrimExpr src = CallNode::make(value->dtype, - "&", - {mutated_value}, - CallNode::Extern); + PrimExpr src = CallNode::make(value->dtype, "&", {mutated_value}, CallNode::Extern); auto call = dst.as(); PrimExpr matrix_major; auto iter2 = matrix_major_.find(simplify_name(call->name)); - CHECK(iter2 != matrix_major_.end()) - << "Can not determine matrix major for " << call->name; + CHECK(iter2 != matrix_major_.end()) << "Can not determine matrix major for " << call->name; if (iter2->second == "col_major") { matrix_major = StringImmNode::make("col_major"); } else if (iter2->second == "row_major") { @@ -988,20 +924,16 @@ class TensorCoreIRMutator : public StmtExprMutator { LOG(FATAL) << "invalid matrix major for " << call->name; } - auto load_matrix_call = - [this, &src, &stride, &matrix_major](const Buffer &buffer) { + auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_load_matrix_sync, - {buffer->data, - warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major}, - CallNode::Intrinsic)); + CallNode::make(DataType::Handle(), intrinsic::tvm_load_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, src, stride, matrix_major}, + CallNode::Intrinsic)); }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, - TensorKey{op->func, op->value_index}, + return add_buffer_bind_scope_(call, buffer_node, TensorKey{op->func, op->value_index}, load_matrix_call, call->dtype); } @@ -1009,39 +941,30 @@ class TensorCoreIRMutator : public StmtExprMutator { if (it3 != frag_store_.end()) { TensorKey key{op->func, op->value_index}; auto it = strides_.find(key.GetName()); - CHECK(it != strides_.end()) - << "Cannot find stride for " << key.GetName(); + CHECK(it != strides_.end()) << "Cannot find stride for " << key.GetName(); auto strides = it->second; CHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size()-2]; + PrimExpr stride = strides[strides.size() - 2]; PrimExpr dst = it3->second; // thread index unification inside a warp PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); - dst = CallNode::make(DataType::Handle(), - "&", - {dst}, - CallNode::Extern); + dst = CallNode::make(DataType::Handle(), "&", {dst}, CallNode::Extern); auto call = op->value.as(); - auto store_matrix_call = - [this, &dst, &stride](const Buffer &buffer) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_store_matrix_sync, - {buffer->data, - warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, - StringImmNode::make("col_major")}, - CallNode::Intrinsic)); - }; + auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { + return EvaluateNode::make( + CallNode::make(DataType::Handle(), intrinsic::tvm_store_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, dst, stride, StringImmNode::make("col_major")}, + CallNode::Intrinsic)); + }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, - TensorKey{call->func, call->value_index}, + return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, store_matrix_call, call->dtype); } @@ -1056,54 +979,54 @@ class TensorCoreIRMutator : public StmtExprMutator { if (it != loop_scaling_.end()) { int scale_factor = it->second; int scaled_extent_value = 1; - if (const IntImmNode *ori_extent = op->extent.as()) { + if (const IntImmNode* ori_extent = op->extent.as()) { int ori_extent_value = ori_extent->value; scaled_extent_value = ori_extent_value / scale_factor; } PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, - op->device_api, op->body); + stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, + op->body); } } return stmt; } private: - Array get_tile_size_(const std::string &name) { - auto it = matrix_abc_.find(name); - auto it2 = matrix_major_.find(name); - CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) - << "Cannot find matrix info for " << name; - PrimExpr size0 = make_const(DataType::Int(32), 16); - PrimExpr size1 = make_const(DataType::Int(32), 16); - if (it->second == "matrix_a" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.m); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.n); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_c") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - Array tile_size = {size0, size1}; - return tile_size; + Array get_tile_size_(const std::string& name) { + auto it = matrix_abc_.find(name); + auto it2 = matrix_major_.find(name); + CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) + << "Cannot find matrix info for " << name; + PrimExpr size0 = make_const(DataType::Int(32), 16); + PrimExpr size1 = make_const(DataType::Int(32), 16); + if (it->second == "matrix_a" && it2->second == "col_major") { + size0 = make_const(DataType::Int(32), warp_tile_.k); + size1 = make_const(DataType::Int(32), warp_tile_.m); + } + if (it->second == "matrix_a" && it2->second == "row_major") { + size0 = make_const(DataType::Int(32), warp_tile_.m); + size1 = make_const(DataType::Int(32), warp_tile_.k); + } + if (it->second == "matrix_b" && it2->second == "row_major") { + size0 = make_const(DataType::Int(32), warp_tile_.k); + size1 = make_const(DataType::Int(32), warp_tile_.n); + } + if (it->second == "matrix_b" && it2->second == "col_major") { + size0 = make_const(DataType::Int(32), warp_tile_.n); + size1 = make_const(DataType::Int(32), warp_tile_.k); + } + if (it->second == "matrix_c") { + size0 = make_const(DataType::Int(32), warp_tile_.n); + size1 = make_const(DataType::Int(32), warp_tile_.m); + } + Array tile_size = {size0, size1}; + return tile_size; } - Stmt add_buffer_bind_scope_(const CallNode* call, - const ObjectPtr &buffer_node, const TensorKey &key, - const std::function &call_back, - DataType datatype) { + Stmt add_buffer_bind_scope_(const CallNode* call, const ObjectPtr& buffer_node, + const TensorKey& key, + const std::function& call_back, + DataType datatype) { auto it = bounds_.find(key); CHECK(it != bounds_.end()); Array min_bound; @@ -1134,13 +1057,11 @@ class TensorCoreIRMutator : public StmtExprMutator { CHECK_EQ(call->args.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { elem_offset = AddNode::make( - elem_offset, MulNode::make( - strides[i], SubNode::make(call->args[i], min_bound[i]))); + elem_offset, MulNode::make(strides[i], SubNode::make(call->args[i], min_bound[i]))); } auto it2 = matrix_abc_.find(simplify_name(call->name)); - CHECK(it2 != matrix_abc_.end()) - << "Cannot find matrix info for " << call->name; + CHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << call->name; buffer_node->data = Var(call->name, DataType::Handle()); buffer_node->name = call->name; buffer_node->scope = "wmma." + it2->second; @@ -1164,15 +1085,10 @@ class TensorCoreIRMutator : public StmtExprMutator { args.push_back(call->args[i]); args.push_back(shape[i]); } - auto tuple = CallNode::make(DataType::Handle(), - intrinsic::tvm_tuple, - args, - CallNode::Intrinsic); + auto tuple = + CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); Array node = {buffer, tensor}; - return AttrStmtNode::make(node, - "buffer_bind_scope", - tuple, - call_back(buffer)); + return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer)); } std::unordered_map matrix_abc_; @@ -1189,10 +1105,8 @@ class TensorCoreIRMutator : public StmtExprMutator { int warp_threads_y_{-1}; }; -Stmt SchedulePostProcRewriteForTensorCore( - Stmt stmt, - Schedule schedule, - Map extern_buffer) { +Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, + Map extern_buffer) { // Check if current lower target is CUDA auto target = tvm::Target::Current(true); if (target.defined() && target->target_name != "cuda") { @@ -1217,8 +1131,7 @@ Stmt SchedulePostProcRewriteForTensorCore( return stmt; } - BufferAnalyser buffer_analyser(extern_buffer, - schedule_analyser, mma_matcher); + BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher); buffer_analyser(stmt); if (!buffer_analyser.QualifiedForTensorCore()) { return stmt; @@ -1228,12 +1141,9 @@ Stmt SchedulePostProcRewriteForTensorCore( } TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore") -.set_body_typed([](Stmt stmt, - Schedule schedule, - Map extern_buffer) { - return SchedulePostProcRewriteForTensorCore( - stmt, schedule, extern_buffer); -}); + .set_body_typed([](Stmt stmt, Schedule schedule, Map extern_buffer) { + return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer); + }); } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index bb52be4..57e5528 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -36,14 +36,15 @@ * - Add annotation of extern buffers using the buffer_map field * in the PrimFunc type. */ -#include #include +#include +#include #include #include #include -#include -#include + #include +#include namespace tvm { namespace te { @@ -62,8 +63,7 @@ Buffer CreateBufferFor(const Tensor& tensor) { class TensorToBufferMapper : public StmtExprMutator { public: explicit TensorToBufferMapper(std::unordered_map buffer_map) - : buffer_map_(buffer_map) { - } + : buffer_map_(buffer_map) {} Stmt VisitStmt_(const AttrStmtNode* op) final { auto ret = StmtExprMutator::VisitStmt_(op); @@ -76,22 +76,19 @@ class TensorToBufferMapper : public StmtExprMutator { Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { Buffer buffer = GetOrAllocBuffer(operation.output(i - 1)); - body = AttrStmtNode::make( - buffer, op->attr_key, op->value, body); + body = AttrStmtNode::make(buffer, op->attr_key, op->value, body); } return body; } else if (op->attr_key == tir::attr::buffer_bind_scope) { - Array tuple = Downcast >(op->node); + Array tuple = Downcast>(op->node); Tensor tensor = Downcast(tuple[1]); - return AttrStmtNode::make( - Array{tuple[0], GetOrAllocBuffer(tensor)}, - op->attr_key, op->value, op->body); - } else if (op->attr_key == tir::attr::buffer_dim_align|| + return AttrStmtNode::make(Array{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, + op->value, op->body); + } else if (op->attr_key == tir::attr::buffer_dim_align || op->attr_key == tir::attr::prefetch_scope) { Tensor tensor = Downcast(op->node); Buffer buffer = GetOrAllocBuffer(tensor); - return AttrStmtNode::make( - buffer, op->attr_key, op->value, op->body); + return AttrStmtNode::make(buffer, op->attr_key, op->value, op->body); } else { return ret; } @@ -131,9 +128,7 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor) { - return GetBuffer(tensor, true); - } + Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); } Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { auto it = buffer_map_.find(tensor); @@ -149,9 +144,7 @@ class TensorToBufferMapper : public StmtExprMutator { std::unordered_map buffer_map_; }; - -PrimFunc SchedulePostProcToPrimFunc(Array arg_list, - Stmt body, +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, Optional> extern_buffer_opt) { std::unordered_map extern_buffer; @@ -188,7 +181,7 @@ PrimFunc SchedulePostProcToPrimFunc(Array arg_list, } TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") -.set_body_typed(SchedulePostProcToPrimFunc); + .set_body_typed(SchedulePostProcToPrimFunc); } // namespace te } // namespace tvm diff --git a/src/te/schedule/verify_compact_buffer.cc b/src/te/schedule/verify_compact_buffer.cc index 759adb9..0089c36 100644 --- a/src/te/schedule/verify_compact_buffer.cc +++ b/src/te/schedule/verify_compact_buffer.cc @@ -22,12 +22,12 @@ * \brief Verify if there was any compact buffer bound to a statement. */ #include +#include +#include #include #include #include #include -#include -#include #include @@ -57,8 +57,7 @@ bool VerifyCompactBuffer(const Stmt& stmt) { return verifier.Verify(stmt); } -TVM_REGISTER_GLOBAL("schedule.VerifyCompactBuffer") -.set_body_typed(VerifyCompactBuffer); +TVM_REGISTER_GLOBAL("schedule.VerifyCompactBuffer").set_body_typed(VerifyCompactBuffer); } // namespace te } // namespace tvm diff --git a/src/te/tensor.cc b/src/te/tensor.cc index cb14f6a..606797d 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -21,27 +21,24 @@ * \file tensor.cc */ #include -#include #include +#include #include + #include namespace tvm { namespace te { IterVar thread_axis(Range dom, std::string tag) { - return IterVarNode::make( - dom, Var(tag), kThreadIndex, tag); + return IterVarNode::make(dom, Var(tag), kThreadIndex, tag); } IterVar reduce_axis(Range dom, std::string name) { - return IterVarNode::make( - dom, Var(name), kCommReduce); + return IterVarNode::make(dom, Var(name), kCommReduce); } -Var var(std::string name_hint, DataType t) { - return Var(name_hint, t); -} +Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor PrimExpr Tensor::operator()(Array indices) const { @@ -52,13 +49,11 @@ PrimExpr Tensor::operator()(Array indices) const { PrimExpr Tensor::operator()(Array indices) const { using tir::CallNode; if (ndim() != 0) { - CHECK_EQ(ndim(), indices.size()) - << "Tensor dimension mismatch in read" - << "ndim = " << ndim() << ", indices.size=" << indices.size(); + CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read" + << "ndim = " << ndim() << ", indices.size=" << indices.size(); } - auto n = CallNode::make( - (*this)->dtype, (*this)->op->name, indices, CallNode::Halide, - (*this)->op, (*this)->value_index); + auto n = CallNode::make((*this)->dtype, (*this)->op->name, indices, CallNode::Halide, (*this)->op, + (*this)->value_index); return n; } @@ -71,10 +66,7 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -Tensor TensorNode::make(Array shape, - DataType dtype, - Operation op, - int value_index) { +Tensor TensorNode::make(Array shape, DataType dtype, Operation op, int value_index) { auto n = make_object(); n->shape = std::move(shape); n->dtype = dtype; @@ -84,25 +76,18 @@ Tensor TensorNode::make(Array shape, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* t = static_cast(node.get()); - p->stream << "Tensor(shape=" << t->shape - << ", op.name=" << t->op->name << ')'; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* t = static_cast(node.get()); + p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; + }); TVM_REGISTER_NODE_TYPE(TensorNode); - // TensorIntrin -TensorIntrin TensorIntrinNode::make(std::string name, - Operation op, - Array inputs, - Array buffers, - Array scalar_params, - Stmt body, - Stmt reduce_init, - Stmt reduce_update) { +TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array inputs, + Array buffers, Array scalar_params, Stmt body, + Stmt reduce_init, Stmt reduce_update) { auto n = make_object(); n->name = std::move(name); n->op = std::move(op); @@ -116,20 +101,17 @@ TensorIntrin TensorIntrinNode::make(std::string name, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(TensorIntrinNode); - // TensorIntrinCall -TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, - Array tensors, - Array regions, - Array reduce_axis, +TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array tensors, + Array regions, Array reduce_axis, Array scalar_inputs) { auto n = make_object(); n->intrin = std::move(intrin); @@ -141,40 +123,32 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* n = static_cast(node.get()); - p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* n = static_cast(node.get()); + p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; + }); TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); -TVM_REGISTER_GLOBAL("te.Tensor") -.set_body_typed(TensorNode::make); +TVM_REGISTER_GLOBAL("te.Tensor").set_body_typed(TensorNode::make); -TVM_REGISTER_GLOBAL("te.TensorIntrin") -.set_body_typed(TensorIntrinNode::make); +TVM_REGISTER_GLOBAL("te.TensorIntrin").set_body_typed(TensorIntrinNode::make); -TVM_REGISTER_GLOBAL("te.TensorIntrinCall") -.set_body_typed(TensorIntrinCallNode::make); +TVM_REGISTER_GLOBAL("te.TensorIntrinCall").set_body_typed(TensorIntrinCallNode::make); -TVM_REGISTER_GLOBAL("te.TensorEqual") -.set_body_method(&Tensor::operator==); +TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); -TVM_REGISTER_GLOBAL("te.TensorHash") -.set_body_typed([](Tensor tensor) -> int64_t { - return static_cast(std::hash()(tensor)); - }); +TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { + return static_cast(std::hash()(tensor)); +}); -TVM_REGISTER_GLOBAL("te.OpGetOutput") -.set_body_typed([](Operation op, int64_t output) { +TVM_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { return op.output(static_cast(output)); }); -TVM_REGISTER_GLOBAL("te.OpNumOutputs") -.set_body_method(&OperationNode::num_outputs); +TVM_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method(&OperationNode::num_outputs); -TVM_REGISTER_GLOBAL("te.OpInputTensors") -.set_body_method(&OperationNode::InputTensors); +TVM_REGISTER_GLOBAL("te.OpInputTensors").set_body_method(&OperationNode::InputTensors); } // namespace te } // namespace tvm diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 763e3eb..7eb8013 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -21,16 +21,15 @@ * \file tir/analysis/deep_equal.cc * \brief Deep equality checking. */ -#include #include +#include #include #include namespace tvm { namespace tir { -class DeepCmpSEqualHandler : - public SEqualReducer::Handler { +class DeepCmpSEqualHandler : public SEqualReducer::Handler { public: // use direct recursion. bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { @@ -41,12 +40,9 @@ class DeepCmpSEqualHandler : return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, false)); } - ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { - return ObjectRef(nullptr); - } + ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); } - void MarkGraphNode() final { - } + void MarkGraphNode() final {} private: // reflection vtable @@ -67,9 +63,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { } TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") -.set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { - return ExprDeepEqual()(lhs, rhs); -}); + .set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { + return ExprDeepEqual()(lhs, rhs); + }); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/side_effect.cc b/src/tir/analysis/side_effect.cc index 10039d9..b5fb328 100644 --- a/src/tir/analysis/side_effect.cc +++ b/src/tir/analysis/side_effect.cc @@ -21,9 +21,9 @@ * \file side_effect.cc * \brief side effect analysis */ +#include #include #include -#include namespace tvm { namespace tir { @@ -37,7 +37,8 @@ class ExprSideEffect : public ExprVisitor { void VisitExpr_(const CallNode* op) final { if (!op->is_pure()) { - has_side_effect_ = true; return; + has_side_effect_ = true; + return; } else { ExprVisitor::VisitExpr_(op); } diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc index ffc7792..2a23329 100644 --- a/src/tir/analysis/var_touch.cc +++ b/src/tir/analysis/var_touch.cc @@ -21,27 +21,23 @@ * \file simple_analysis.cc * \brief Implementation of simple passes */ +#include #include #include -#include namespace tvm { namespace tir { class VarTouchVisitor : public ExprVisitor { public: - explicit VarTouchVisitor( - std::function var_set) - : var_set_(var_set) {} + explicit VarTouchVisitor(std::function var_set) : var_set_(var_set) {} void VisitExpr(const PrimExpr& e) final { if (use_var_) return; ExprVisitor::VisitExpr(e); } - void VisitExpr_(const VarNode* op) final { - Handle(op); - } + void VisitExpr_(const VarNode* op) final { Handle(op); } void VisitExpr_(const LoadNode* op) final { Handle(op->buffer_var.get()); @@ -58,9 +54,7 @@ class VarTouchVisitor : public ExprVisitor { std::function var_set_; }; - -bool ExprUseVar(const PrimExpr& e, - std::function var_set) { +bool ExprUseVar(const PrimExpr& e, std::function var_set) { VarTouchVisitor visitor(var_set); visitor(e); return visitor.use_var_; diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 3dd1500..2ad20ff 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -35,12 +35,8 @@ namespace tir { class GPUCodeVerifier : public StmtVisitor { public: - bool Verify(Stmt stmt, - int64_t max_local_memory_per_block, - int64_t max_shared_memory_per_block, - int64_t max_threads_per_block, - int64_t max_thread_x, - int64_t max_thread_y, + bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, + int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); @@ -84,7 +80,7 @@ class GPUCodeVerifier : public StmtVisitor { } Var var = op->node.as()->var; - const auto *extent = op->value.as(); + const auto* extent = op->value.as(); CHECK(extent); // record the number of threads in a block @@ -136,8 +132,8 @@ class GPUCodeVerifier : public StmtVisitor { private: int nest_level_{0}; - std::unordered_set visited_local_buffers_; - std::unordered_set visited_shared_buffers_; + std::unordered_set visited_local_buffers_; + std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; @@ -164,8 +160,7 @@ class GPUCodeVerifier : public StmtVisitor { } }; -bool VerifyGPUCode(const PrimFunc& func, - Map constraints) { +bool VerifyGPUCode(const PrimFunc& func, Map constraints) { GPUCodeVerifier verifier; int64_t max_local_memory_per_block = INT64_MAX; @@ -193,18 +188,11 @@ bool VerifyGPUCode(const PrimFunc& func, LOG(FATAL) << "Invalid check item: " << iter.first; } - return verifier.Verify(func->body, - max_local_memory_per_block, - max_shared_memory_per_block, - max_threads_per_block, - max_thread_x, - max_thread_y, - max_thread_z); + return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, + max_threads_per_block, max_thread_x, max_thread_y, max_thread_z); } - -TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code") -.set_body_typed(VerifyGPUCode); +TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); namespace transform { @@ -213,9 +201,7 @@ Pass VerifyGPUCode(Map constraints) { for (auto kv : mod->functions) { if (auto* n = kv.second.as()) { auto func = GetRef(n); - CHECK(VerifyGPUCode(func, constraints)) - << "RuntimeError: GPU constraint violated" - << func; + CHECK(VerifyGPUCode(func, constraints)) << "RuntimeError: GPU constraint violated" << func; } } return mod; @@ -223,8 +209,7 @@ Pass VerifyGPUCode(Map constraints) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode") -.set_body_typed(VerifyGPUCode); +TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 03a3606..8eb846b 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -21,13 +21,12 @@ * \file verify_memory.cc * \brief Pass to check if memory accesses are legal. */ -#include #include +#include +#include #include +#include #include -#include -#include - namespace tvm { namespace tir { @@ -47,13 +46,12 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { public: /// Special member functions //@{ - explicit MemoryAccessVerifier(PrimFunc f, int device_type) - : func_(f), dev_type_(device_type) {} + explicit MemoryAccessVerifier(PrimFunc f, int device_type) : func_(f), dev_type_(device_type) {} virtual ~MemoryAccessVerifier() = default; - MemoryAccessVerifier(const MemoryAccessVerifier &) = delete; - MemoryAccessVerifier(MemoryAccessVerifier &&) = delete; - MemoryAccessVerifier &operator=(const MemoryAccessVerifier &) = delete; - MemoryAccessVerifier &operator=(MemoryAccessVerifier &&) = delete; + MemoryAccessVerifier(const MemoryAccessVerifier&) = delete; + MemoryAccessVerifier(MemoryAccessVerifier&&) = delete; + MemoryAccessVerifier& operator=(const MemoryAccessVerifier&) = delete; + MemoryAccessVerifier& operator=(MemoryAccessVerifier&&) = delete; //@} /// Interface to perform memory access verification @@ -68,12 +66,12 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { protected: /// Visitor implementation //@{ - void VisitExpr(const PrimExpr &n) final { + void VisitExpr(const PrimExpr& n) final { if (Failed()) return; StmtExprVisitor::VisitExpr(n); } - void VisitStmt(const Stmt &n) final { + void VisitStmt(const Stmt& n) final { if (Failed()) return; StmtExprVisitor::VisitStmt(n); } @@ -85,8 +83,8 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (!InThreadEnv() && (op->attr_key == attr::thread_extent || - op->attr_key == attr::pipeline_exec_scope)) { + if (!InThreadEnv() && + (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) { EnterThreadEnv(); StmtExprVisitor::VisitStmt_(op); ExitThreadEnv(); @@ -107,8 +105,8 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { //@} /// Check if the value of a Variable comes from function argument. - bool IsFromFunctionArgs(const VarNode *var) const { - const VarNode *V = var; + bool IsFromFunctionArgs(const VarNode* var) const { + const VarNode* V = var; for (auto kv : func_->buffer_map) { if (V == kv.second->data.get()) return true; } @@ -119,9 +117,9 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { // The value is expected to come from a tvm_struct_get Call. // Get the first argument of tvm_struct_get, and continue. - const auto &iter = defs_.find(V); + const auto& iter = defs_.find(V); if (iter == defs_.end()) return false; - const CallNode *C = iter->second.as(); + const CallNode* C = iter->second.as(); if (!C || C->name != intrinsic::tvm_struct_get) return false; V = C->args[0].as(); } @@ -129,7 +127,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } /// Handle memory access to a Variable - void HandleLoadStoreToVariable(const Var &var) { + void HandleLoadStoreToVariable(const Var& var) { // We skip the access within thread env. if (InThreadEnv()) return; @@ -153,14 +151,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device. static bool IsGPUDevice(int dev_type) { - return kDLGPU == dev_type || kDLOpenCL == dev_type || - kDLVulkan == dev_type || kDLMetal == dev_type || - kDLROCM == dev_type || kOpenGL == dev_type; + return kDLGPU == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type || + kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type; } /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device. - static bool IsFPGADevice(int dev_type) { - return kDLSDAccel == dev_type || kDLAOCL == dev_type; - } + static bool IsFPGADevice(int dev_type) { return kDLSDAccel == dev_type || kDLAOCL == dev_type; } private: /// Status of visitor @@ -168,21 +163,19 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { bool in_thread_env_{false}; bool failure_{false}; ///< If the verification fails (i.e. has illegal access) //@} - tir::PrimFunc func_{nullptr}; ///< Function to be verified. - int dev_type_{kDLCPU}; ///< Device type - std::unordered_map defs_; ///< Variable definitions + tir::PrimFunc func_{nullptr}; ///< Function to be verified. + int dev_type_{kDLCPU}; ///< Device type + std::unordered_map defs_; ///< Variable definitions }; } // namespace /// Interface of VerifyMemory pass bool VerifyMemory(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerWarpMemory: Require the target attribute"; + CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - if (func->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDefault) { + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->device_type); v.Run(); return !v.Failed(); @@ -191,29 +184,28 @@ bool VerifyMemory(const PrimFunc& func) { } } -TVM_REGISTER_GLOBAL("tir.analysis.verify_memory") -.set_body_typed(VerifyMemory); +TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); namespace transform { Pass VerifyMemory() { - auto pass_func = [=](IRModule mod, PassContext ctx) { - for (auto kv : mod->functions) { - if (auto* n = kv.second.as()) { - auto func = GetRef(n); - CHECK(VerifyMemory(func)) - << "RuntimeError: Direct host side access to device memory is detected." - << " Did you forget to bind?\n" - << func; - } - } - return mod; - }; + auto pass_func = + [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + CHECK(VerifyMemory(func)) + << "RuntimeError: Direct host side access to device memory is detected." + << " Did you forget to bind?\n" + << func; + } + } + return mod; + }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory") -.set_body_typed(VerifyMemory); +TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index 97eaf24..c57cbf7 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -24,11 +24,12 @@ * \file verify_ssa.cc */ #include +#include #include #include -#include -#include + #include +#include #include namespace tvm { @@ -101,7 +102,8 @@ class IRVerifySSA final : public StmtExprVisitor { void MarkDef(const VarNode* v, bool allow_dup = false) { if (defined_.count(v) != 0) { if (!allow_dup) { - is_ssa = false; return; + is_ssa = false; + return; } } else { defined_[v] = 1; @@ -112,16 +114,13 @@ class IRVerifySSA final : public StmtExprVisitor { std::unordered_map defined_; }; - bool VerifySSA(const PrimFunc& func) { IRVerifySSA visitor; visitor.Run(func); return visitor.is_ssa; } -TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa") -.set_body_typed(VerifySSA); - +TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA); namespace transform { @@ -130,9 +129,7 @@ Pass VerifySSA() { for (auto kv : mod->functions) { if (auto* n = kv.second.as()) { auto func = GetRef(n); - CHECK(VerifySSA(func)) - << "RuntimeError: IR is not in SSA form" - << func; + CHECK(VerifySSA(func)) << "RuntimeError: IR is not in SSA form" << func; } } return mod; @@ -140,8 +137,7 @@ Pass VerifySSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifySSA") -.set_body_typed(VerifySSA); +TVM_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA); } // namespace transform diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 0f1c572..45b9680 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -20,15 +20,16 @@ /*! * \file buffer.cc */ +#include +#include #include +#include #include -#include #include -#include -#include #include #include + #include "../../arith/compute_expr.h" namespace tvm { @@ -44,23 +45,13 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, - DataType dtype, - std::string name) { - return BufferNode::make( - Var(name, PointerType(PrimType(dtype))), - dtype, - shape, - Array(), - PrimExpr(), - name, - "", - 0, 0, - kDefault); +Buffer decl_buffer(Array shape, DataType dtype, std::string name) { + return BufferNode::make(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), + PrimExpr(), name, "", 0, 0, kDefault); } // Split the given expression w.r.t the add operator -inline std::vector ExprSplitAddition(const PrimExpr &expr) { +inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; std::vector ret; std::stack split_buffer; @@ -79,7 +70,6 @@ inline std::vector ExprSplitAddition(const PrimExpr &expr) { return ret; } - // Searches for the following types of expr: // mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki // mod_l_expr = c @@ -87,9 +77,9 @@ inline std::vector ExprSplitAddition(const PrimExpr &expr) { // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c) // Currently the we will not search the add/mult combinations exhaustively // as it will take too much computation. -inline std::pair MergeMulModInner(const PrimExpr &mult_expr, - const PrimExpr &mod_l_expr, - const PrimExpr &mod_r_expr) { +inline std::pair MergeMulModInner(const PrimExpr& mult_expr, + const PrimExpr& mod_l_expr, + const PrimExpr& mod_r_expr) { using namespace tir; const MulNode* mult_ptr = mult_expr.as(); if (!mult_ptr) return std::make_pair(false, PrimExpr()); @@ -124,9 +114,8 @@ inline std::pair MergeMulModInner(const PrimExpr &mult_expr, return std::make_pair(false, PrimExpr()); } else if (inner_div_ptr) { PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; - if (expr_equal(overall_mult, inner_div_ptr->b) - && expr_equal(overall_mult, mod_r_expr) - && expr_equal(inner_div_ptr->a, mod_l_expr)) { + if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) && + expr_equal(inner_div_ptr->a, mod_l_expr)) { // Found! PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; return std::make_pair(true, ret); @@ -157,9 +146,7 @@ inline std::pair MergeMulModInner(const PrimExpr &mult_expr, inline void MergeMulModInsertElements(const std::vector& eles, std::list* mult_exprs, std::list >* mod_exprs, - PrimExpr* no_opt_sum, - bool* has_mult, - bool* has_mod) { + PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) { using namespace tir; *has_mult = false; *has_mod = false; @@ -185,7 +172,7 @@ inline void MergeMulModInsertElements(const std::vector& eles, // The search will be performed repeatively until no pattern is found. // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized -inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { +inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { using namespace tir; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and @@ -199,8 +186,7 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { PrimExpr no_opt_sum; bool has_mult; bool has_mod; - MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, - &no_opt_sum, &has_mult, &has_mod); + MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod); bool find_opt = false; std::list >::iterator search_mod_it = mod_exprs.begin(); // 2. Exhaustive Search @@ -208,9 +194,8 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { std::list::iterator mult_it = mult_exprs.begin(); bool inner_find_opt = false; while (mult_it != mult_exprs.end()) { - std::pair ret = MergeMulModInner(*mult_it, - search_mod_it->first, - search_mod_it->second); + std::pair ret = + MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second); if (ret.first) { inner_find_opt = true; auto temp_mod_it = search_mod_it; @@ -218,8 +203,8 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { mod_exprs.erase(temp_mod_it); mult_exprs.erase(mult_it); std::vector ret_eles = ExprSplitAddition(ret.second); - MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, - &no_opt_sum, &has_mult, &has_mod); + MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, + &has_mod); if (has_mult) { search_mod_it = mod_exprs.begin(); } else if (has_mod && search_mod_it == mod_exprs.end()) { @@ -242,9 +227,9 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it; } for (std::list >::iterator it = mod_exprs.begin(); - it != mod_exprs.end(); ++it) { - no_opt_sum = no_opt_sum.get() ? - no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second); + it != mod_exprs.end(); ++it) { + no_opt_sum = no_opt_sum.get() ? no_opt_sum + indexmod(it->first, it->second) + : indexmod(it->first, it->second); } return no_opt_sum; } @@ -300,20 +285,16 @@ inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataTyp PrimExpr Buffer::vload(Array begin, DataType dtype) const { // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); - CHECK(dtype.element_of() == n->dtype.element_of() && - dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot load " << dtype - << " from buffer of " << n->dtype; + CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot load " << dtype << " from buffer of " << n->dtype; if (dtype == DataType::Bool()) { return tir::CastNode::make( DataType::Bool(), - tir::LoadNode::make( - DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), - const_true())); + tir::LoadNode::make(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), + const_true())); } else { - return tir::LoadNode::make( - dtype, n->data, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + return tir::LoadNode::make(dtype, n->data, BufferOffset(n, begin, dtype), + const_true(dtype.lanes())); } } @@ -321,18 +302,14 @@ Stmt Buffer::vstore(Array begin, PrimExpr value) const { // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); DataType dtype = value.dtype(); - CHECK(dtype.element_of() == n->dtype.element_of() && - dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot load " << dtype - << " from buffer of " << n->dtype; + CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot load " << dtype << " from buffer of " << n->dtype; if (value.dtype() == DataType::Bool()) { - return tir::StoreNode::make(n->data, - tir::CastNode::make(DataType::Int(8), value), - BufferOffset(n, begin, DataType::Int(8)), - const_true()); + return tir::StoreNode::make(n->data, tir::CastNode::make(DataType::Int(8), value), + BufferOffset(n, begin, DataType::Int(8)), const_true()); } else { return tir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + const_true(dtype.lanes())); } } @@ -342,7 +319,7 @@ Buffer Buffer::MakeStrideView() const { std::vector temp; auto n = make_object(*operator->()); PrimExpr acc = make_const(n->DefaultIndexType(), 1); - for (size_t i = n->shape.size(); i != 0 ; --i) { + for (size_t i = n->shape.size(); i != 0; --i) { temp.push_back(acc); acc = acc * n->shape[i - 1]; } @@ -364,8 +341,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const // check if stride is needed. for (size_t i = 0; i < extents.size(); ++i) { if (!can_relax) { - if (!is_zero(begins[i]) || - !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { + if (!is_zero(begins[i]) || !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { need_stride = true; } } @@ -376,21 +352,11 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return BufferNode::make(n->data, - n->dtype, - extents, - strides, - elem_offset, - n->name + "_slice", - n->scope, - n->data_alignment, - 0, - n->buffer_type); + return BufferNode::make(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", + n->scope, n->data_alignment, 0, n->buffer_type); } -PrimExpr Buffer::access_ptr(int access_mask, - DataType ptr_type, - int content_lanes, +PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset) const { const BufferNode* self = operator->(); PrimExpr e_dtype; @@ -407,28 +373,19 @@ PrimExpr Buffer::access_ptr(int access_mask, if (content_lanes > 1) { e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); extent = extent / make_const(self->elem_offset.dtype(), content_lanes); - elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), - content_lanes); + elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes); } else { e_dtype = tir::TypeAnnotation(self->dtype); } - Array acc_args{ - e_dtype, self->data, elem_offset, - extent, make_const(DataType::Int(32), access_mask)}; - return tir::CallNode::make( - ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); + Array acc_args{e_dtype, self->data, elem_offset, extent, + make_const(DataType::Int(32), access_mask)}; + return tir::CallNode::make(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, + tir::CallNode::Intrinsic); } -Buffer BufferNode::make(Var data, - DataType dtype, - Array shape, - Array strides, - PrimExpr elem_offset, - std::string name, - std::string scope, - int data_alignment, - int offset_factor, - BufferType buffer_type) { +Buffer BufferNode::make(Var data, DataType dtype, Array shape, Array strides, + PrimExpr elem_offset, std::string name, std::string scope, + int data_alignment, int offset_factor, BufferType buffer_type) { auto n = make_object(); n->data = std::move(data); n->dtype = dtype; @@ -461,31 +418,26 @@ Buffer BufferNode::make(Var data, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "buffer(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "buffer(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(BufferNode); +TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size(), 10); + auto buffer_type = args[9].operator std::string(); + BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; + *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], + args[8], type); +}); -TVM_REGISTER_GLOBAL("tir.Buffer") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size(), 10); - auto buffer_type = args[9].operator std::string(); - BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], - args[5], args[6], args[7], args[8], type); - }); - -TVM_REGISTER_GLOBAL("tir.BufferAccessPtr") -.set_body_method(&Buffer::access_ptr); +TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); -TVM_REGISTER_GLOBAL("tir.BufferVLoad") -.set_body_method(&Buffer::vload); +TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); -TVM_REGISTER_GLOBAL("tir.BufferVStore") -.set_body_method(&Buffer::vstore); +TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 77de9f4..23e13ed 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -21,46 +21,43 @@ * \file src/lang/data_layout.cc * \brief Data Layout expression. */ +#include #include #include #include -#include #include namespace tvm { namespace tir { -using tir::Var; using tir::IterVar; using tir::IterVarNode; +using tir::Var; TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode); const LayoutAxis LayoutAxis::UPPER_CASE[] = { - LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'), - LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'), - LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'), - LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'), - LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'), - LayoutAxis('Z') -}; + LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'), + LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'), + LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'), + LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'), + LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'), + LayoutAxis('Z')}; const LayoutAxis LayoutAxis::LOWER_CASE[] = { - LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'), - LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'), - LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'), - LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'), - LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'), - LayoutAxis('z') -}; + LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'), + LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'), + LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'), + LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'), + LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'), + LayoutAxis('z')}; const LayoutAxis& LayoutAxis::Get(const char name) { CHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z')) - << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z."; - return (name >= 'A' && name <= 'Z') ? - LayoutAxis::UPPER_CASE[name-'A'] : - LayoutAxis::LOWER_CASE[name-'a']; + << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z."; + return (name >= 'A' && name <= 'Z') ? LayoutAxis::UPPER_CASE[name - 'A'] + : LayoutAxis::LOWER_CASE[name - 'a']; } const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) { @@ -83,8 +80,8 @@ Layout::Layout(const Array& axes) { CHECK_GT(factor->value, 0); repr << factor->value; } - CHECK_EQ(axis->var.get()->name_hint.size(), 1) << "Invalid layout axis " - << axis->var.get()->name_hint; + CHECK_EQ(axis->var.get()->name_hint.size(), 1) + << "Invalid layout axis " << axis->var.get()->name_hint; char c = axis->var.get()->name_hint[0]; CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c; repr << axis->var.get()->name_hint; @@ -93,7 +90,7 @@ Layout::Layout(const Array& axes) { data_ = std::move(node); } -Layout::Layout(const std::string& name) { // NOLINT(*) +Layout::Layout(const std::string& name) { // NOLINT(*) if (name == "__undef__") return; auto node = make_object(); @@ -105,19 +102,18 @@ Layout::Layout(const std::string& name) { // NOLINT(*) int32_t factor = 0; for (char c : name) { if (c >= 'A' && c <= 'Z') { - CHECK_EQ(factor, 0) << "Invalid layout " << name - << ": invalid factor size " << factor + CHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " before dimension " << c; std::string shape_name("_shape"); shape_name.insert(0, 1, c); - IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), - Var(std::string(1, c)), tir::kDataPar); + IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), + tir::kDataPar); node->axes.push_back(axis); } else if (c >= 'a' && c <= 'z') { - CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " - << factor << " for dimension " << c; - IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), - Var(std::string(1, c)), tir::kDataPar); + CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor + << " for dimension " << c; + IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), + tir::kDataPar); node->axes.push_back(axis); factor = 0; } else if (c >= '0' && c <= '9') { @@ -141,16 +137,14 @@ Layout::Layout(const std::string& name) { // NOLINT(*) for (const IterVar& v : node->axes) { char axis = v->var.get()->name_hint[0]; if (axis >= 'a' && axis <= 'z') { - CHECK(exist_axis[axis-'a'+'A']) << "Invalid layout " << name << ": missing axis " - << std::toupper(axis); + CHECK(exist_axis[axis - 'a' + 'A']) + << "Invalid layout " << name << ": missing axis " << std::toupper(axis); } } data_ = std::move(node); } -Layout LayoutNode::make(const std::string& layout) { - return Layout(layout); -} +Layout LayoutNode::make(const std::string& layout) { return Layout(layout); } Layout Layout::SubLayout(size_t pos, size_t len) const { if (!defined() || pos > ndim()) return Layout::Undef(); @@ -164,16 +158,16 @@ Layout Layout::SubLayout(size_t pos, size_t len) const { return Layout(new_layout); } -Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const { +Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const { if (!defined()) return Layout::Undef(); const std::string& name = operator->()->name; const auto axes = operator->()->axes; - CHECK(target_pos <= this->ndim()) << "Invalid split position " - << target_pos << " for layout " << name; + CHECK(target_pos <= this->ndim()) + << "Invalid split position " << target_pos << " for layout " << name; CHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis; CHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name; - CHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis - << " has already been split in " << name; + CHECK(!this->Contains(axis.ToSubordinate())) + << "Axis " << axis << " has already been split in " << name; CHECK(factor > 0) << "Invalid split size " << factor; Array new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { @@ -202,16 +196,15 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* l = static_cast(node.get()); - p->stream << "Layout(" << l->name << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* l = static_cast(node.get()); + p->stream << "Layout(" << l->name << ")"; + }); -inline bool GetStoreRule(Array* rule, - const Layout& src_layout, +inline bool GetStoreRule(Array* rule, const Layout& src_layout, const Layout& dst_layout) { - if (!src_layout.defined() || src_layout.name().empty() || - !dst_layout.defined() || dst_layout.name().empty()) { + if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() || + dst_layout.name().empty()) { return false; } for (size_t i = 0; i < dst_layout.ndim(); ++i) { @@ -273,16 +266,15 @@ Array BijectiveLayout::ForwardIndex(const Array& src_index) CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); CHECK_EQ(src_index.size(), self->src_layout->axes.size()) - << "Input mismatch with layout " << self->src_layout; + << "Input mismatch with layout " << self->src_layout; return TransformIndex(src_index, self->src_layout->axes, self->forward_rule); } - Array BijectiveLayout::BackwardIndex(const Array& dst_index) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); CHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) - << "Output mismatch with layout " << self->dst_layout; + << "Output mismatch with layout " << self->dst_layout; return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule); } @@ -310,8 +302,8 @@ inline Array TransformShape(const Array& src_shape, const auto* orig_axis_extent = orig_axis->dom->extent.as(); if (orig_shape_const) { CHECK_EQ(orig_shape_const->value, orig_axis_extent->value) - << "Input shape mismatch at index " << i << ". Expected " - << orig_axis->dom->extent << ", get " << orig_shape; + << "Input shape mismatch at index " << i << ". Expected " << orig_axis->dom->extent + << ", get " << orig_shape; } } bind_map[orig_axis->var.get()] = PrimExpr(0); @@ -343,15 +335,13 @@ inline Array TransformShape(const Array& src_shape, Array BijectiveLayout::ForwardShape(const Array& shape) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->src_layout->axes, - self->dst_layout->axes, self->forward_rule); + return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->forward_rule); } Array BijectiveLayout::BackwardShape(const Array& shape) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->dst_layout->axes, - self->src_layout->axes, self->backward_rule); + return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->backward_rule); } BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { @@ -369,51 +359,47 @@ BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* b = static_cast(node.get()); - p->stream << "BijectiveLayout(" << b->src_layout.name() - << "->" << b->dst_layout.name() << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* b = static_cast(node.get()); + p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name() + << ")"; + }); -TVM_REGISTER_GLOBAL("tir.Layout") -.set_body_typed(LayoutNode::make); +TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed(LayoutNode::make); -TVM_REGISTER_GLOBAL("tir.LayoutIndexOf") -.set_body_typed([](Layout layout, std::string axis) -> int { +TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int { return layout.IndexOf(LayoutAxis::make(axis)); }); TVM_REGISTER_GLOBAL("tir.LayoutFactorOf") -.set_body_typed([](Layout layout, std::string axis) -> int { - return layout.FactorOf(LayoutAxis::make(axis)); -}); + .set_body_typed([](Layout layout, std::string axis) -> int { + return layout.FactorOf(LayoutAxis::make(axis)); + }); -TVM_REGISTER_GLOBAL("tir.LayoutNdim") -.set_body_typed([](Layout layout) -> int { +TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { return layout.ndim(); }); -TVM_REGISTER_GLOBAL("tir.LayoutGetItem") -.set_body_typed([](Layout layout, int idx) -> std::string { +TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string { const LayoutAxis& axis = layout[idx]; return axis.name(); }); TVM_REGISTER_GLOBAL("tir.BijectiveLayout") -.set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { - return BijectiveLayout(src_layout, dst_layout); -}); + .set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { + return BijectiveLayout(src_layout, dst_layout); + }); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") -.set_body_method(&BijectiveLayout::ForwardIndex); + .set_body_method(&BijectiveLayout::ForwardIndex); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex") -.set_body_method(&BijectiveLayout::BackwardIndex); + .set_body_method(&BijectiveLayout::BackwardIndex); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape") -.set_body_method(&BijectiveLayout::ForwardShape); + .set_body_method(&BijectiveLayout::ForwardShape); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape") -.set_body_method(&BijectiveLayout::BackwardShape); + .set_body_method(&BijectiveLayout::BackwardShape); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index a36d81f..5694155 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -22,11 +22,12 @@ */ #include #include -#include #include +#include #include -#include + #include +#include #include "../../support/str_escape.h" @@ -68,9 +69,7 @@ SizeVar::SizeVar(std::string name_hint, DataType dtype) { data_ = std::move(n); } - -TVM_REGISTER_GLOBAL("tir.Var") -.set_body_typed([](std::string name_hint, runtime::TVMArgValue type) { +TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](std::string name_hint, runtime::TVMArgValue type) { if (type.IsObjectRef()) { return Var(name_hint, type.operator Type()); } else { @@ -78,16 +77,11 @@ TVM_REGISTER_GLOBAL("tir.Var") } }); -TVM_REGISTER_GLOBAL("tir.SizeVar") -.set_body_typed([](std::string s, DataType t) { - return SizeVar(s, t); +TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](std::string s, DataType t) { + return SizeVar(s, t); }); - -IterVar IterVarNode::make(Range dom, - Var var, - IterVarType t, - std::string thread_tag) { +IterVar IterVarNode::make(Range dom, Var var, IterVarType t, std::string thread_tag) { ObjectPtr n = make_object(); n->dom = dom; n->var = var; @@ -97,29 +91,25 @@ IterVar IterVarNode::make(Range dom, } TVM_REGISTER_GLOBAL("tir.IterVar") -.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { - return IterVarNode::make( - dom, var, - static_cast(iter_type), - thread_tag); -}); + .set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { + return IterVarNode::make(dom, var, static_cast(iter_type), thread_tag); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "iter_var("; - if (op->var->name_hint.length() != 0) { - p->stream << op->var->name_hint << ", "; - } - if (op->dom.defined()) { - p->stream << op->dom; - } - if (op->thread_tag.length() != 0) { - p->stream << ", " << op->thread_tag; - } - p->stream << ")"; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "iter_var("; + if (op->var->name_hint.length() != 0) { + p->stream << op->var->name_hint << ", "; + } + if (op->dom.defined()) { + p->stream << op->dom; + } + if (op->thread_tag.length() != 0) { + p->stream << ", " << op->thread_tag; + } + p->stream << ")"; + }); TVM_REGISTER_NODE_TYPE(IterVarNode); @@ -130,9 +120,7 @@ PrimExpr StringImmNode::make(std::string value) { return PrimExpr(node); } -TVM_REGISTER_GLOBAL("tir.StringImm") -.set_body_typed(StringImmNode::make); - +TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed(StringImmNode::make); PrimExpr CastNode::make(DataType t, PrimExpr value) { CHECK(value.defined()); @@ -143,7 +131,6 @@ PrimExpr CastNode::make(DataType t, PrimExpr value) { return PrimExpr(node); } - PrimExpr AndNode::make(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; @@ -172,7 +159,6 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) { return PrimExpr(node); } - PrimExpr NotNode::make(PrimExpr a) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.dtype().is_bool()); @@ -183,15 +169,12 @@ PrimExpr NotNode::make(PrimExpr a) { return PrimExpr(node); } - - PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { CHECK(condition.defined()) << "ValueError: condition is undefined"; CHECK(true_value.defined()) << "ValueError: true_value is undefined"; CHECK(false_value.defined()) << "ValueError: true_value is undefined"; CHECK(condition.dtype().is_bool()); - CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || - condition.dtype().lanes() == 1); + CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; ObjectPtr node = make_object(); @@ -259,11 +242,24 @@ PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) { return PrimExpr(node); } -const char* CallNode::vectorizable_intrinsics[] = { - "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt", - "log", "sin", "cos", "pow", "tan", tir::CallNode::shift_left, tir::CallNode::shift_right, - tir::CallNode::likely, tir::CallNode::popcount -}; +const char* CallNode::vectorizable_intrinsics[] = {"floor", + "ceil", + "sign", + "trunc", + "fabs", + "round", + "exp", + "tanh", + "sqrt", + "log", + "sin", + "cos", + "pow", + "tan", + tir::CallNode::shift_left, + tir::CallNode::shift_right, + tir::CallNode::likely, + tir::CallNode::popcount}; bool CallNode::is_vectorizable() const { size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*); @@ -275,12 +271,8 @@ bool CallNode::is_vectorizable() const { return false; } -PrimExpr CallNode::make(DataType dtype, - std::string name, - Array args, - CallType call_type, - FunctionRef func, - int value_index) { +PrimExpr CallNode::make(DataType dtype, std::string name, Array args, CallType call_type, + FunctionRef func, int value_index) { for (size_t i = 0; i < args.size(); ++i) { CHECK(args[i].defined()); } @@ -301,8 +293,7 @@ PrimExpr CallNode::make(DataType dtype, return PrimExpr(node); } -PrimExpr ShuffleNode::make(Array vectors, - Array indices) { +PrimExpr ShuffleNode::make(Array vectors, Array indices) { CHECK_NE(vectors.size(), 0U); CHECK_NE(indices.size(), 0U); @@ -341,9 +332,7 @@ PrimExpr ShuffleNode::make_extract_element(PrimExpr vector, int index) { return make({vector}, {Integer(index)}); } -CommReducer CommReducerNode::make(Array lhs, - Array rhs, - Array result, +CommReducer CommReducerNode::make(Array lhs, Array rhs, Array result, Array identity_element) { auto node = make_object(); node->lhs = lhs; @@ -363,24 +352,19 @@ Array CommReducerNode::operator()(Array a, Array b value_map.Set(rhs[i], b[i]); } auto ret = this->result; - ret.MutateByApply([&value_map] (const PrimExpr& e) { - return Substitute(e, value_map); - }); + ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); }); return ret; } -TVM_REGISTER_GLOBAL("tir.CommReducer") -.set_body_typed(CommReducerNode::make); +TVM_REGISTER_GLOBAL("tir.CommReducer").set_body_typed(CommReducerNode::make); TVM_REGISTER_GLOBAL("tir.CommReducerCombine") -.set_body_method(&tir::CommReducerNode::operator()); - + .set_body_method(&tir::CommReducerNode::operator()); -PrimExpr ReduceNode::make(CommReducer combiner, Array source, - Array axis, PrimExpr condition, int value_index) { +PrimExpr ReduceNode::make(CommReducer combiner, Array source, Array axis, + PrimExpr condition, int value_index) { for (size_t i = 0; i < axis.size(); ++i) { - CHECK_EQ(axis[i]->iter_type, kCommReduce) - << "Can only take axis created by reduce_axis"; + CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; } if (!condition.defined()) { condition = const_true(); @@ -399,10 +383,7 @@ PrimExpr ReduceNode::make(CommReducer combiner, Array source, return PrimExpr(n); } - -TVM_REGISTER_GLOBAL("tir.Reduce") -.set_body_typed(ReduceNode::make); - +TVM_REGISTER_GLOBAL("tir.Reduce").set_body_typed(ReduceNode::make); PrimExpr AnyNode::make() { auto n = make_object(); @@ -417,285 +398,277 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.BufferLoad") -.set_body_typed([](Buffer buffer, Array indices) { +TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array indices) { return BufferLoad(buffer, indices); }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '\"' << support::StrEscape(op->value) << '\"'; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '\"' << support::StrEscape(op->value) << '\"'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dtype << '('; - p->Print(op->value); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - // omit the type - // stream << op->name << "." << op->type; - p->stream << op->name_hint; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " + "; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " - "; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << "*"; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << "/"; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " % "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "min("; - p->Print(op->a); - p->stream << ", "; - p->Print(op->b); - p->stream << ")"; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "max("; - p->Print(op->a); - p->stream << ", "; - p->Print(op->b); - p->stream << ")"; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " == "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " != "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " < "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " <= "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " > "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " >= "; - p->Print(op->b); - p->stream << ')'; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->dtype << '('; + p->Print(op->value); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + // omit the type + // stream << op->name << "." << op->type; + p->stream << op->name_hint; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " + "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " - "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << "*"; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << "/"; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " % "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "min("; + p->Print(op->a); + p->stream << ", "; + p->Print(op->b); + p->stream << ")"; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "max("; + p->Print(op->a); + p->stream << ", "; + p->Print(op->b); + p->stream << ")"; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " == "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " != "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " < "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " <= "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " > "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " >= "; + p->Print(op->b); + p->stream << ')'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "floordiv(" << op->a << ", " << op->b << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "floordiv(" << op->a << ", " << op->b << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "floormod(" << op->a << ", " << op->b << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "floormod(" << op->a << ", " << op->b << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " && "; - p->Print(op->b); - p->stream << ')'; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " && "; + p->Print(op->b); + p->stream << ')'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " || "; - p->Print(op->b); - p->stream << ')'; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " || "; + p->Print(op->b); + p->stream << ')'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '!'; - p->Print(op->a); -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '!'; + p->Print(op->a); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "select("; - p->Print(op->condition); - p->stream << ", "; - p->Print(op->true_value); - p->stream << ", "; - p->Print(op->false_value); - p->stream << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "select("; + p->Print(op->condition); + p->stream << ", "; + p->Print(op->true_value); + p->stream << ", "; + p->Print(op->false_value); + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->buffer_var << "["; - p->Print(op->index); - p->stream << "]"; - if (!is_one(op->predicate)) { + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer_var << "["; + p->Print(op->index); + p->stream << "]"; + if (!is_one(op->predicate)) { p->stream << " if "; p->Print(op->predicate); - } -}); + } + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "ramp("; - p->Print(op->base); - p->stream << ", "; - p->Print(op->stride); - p->stream << ", " << op->lanes << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "ramp("; + p->Print(op->base); + p->stream << ", "; + p->Print(op->stride); + p->stream << ", " << op->lanes << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "x" << op->lanes << "("; - p->Print(op->value); - p->stream << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "x" << op->lanes << "("; + p->Print(op->value); + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->name << "("; - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->name << "("; + for (size_t i = 0; i < op->args.size(); ++i) { + p->Print(op->args[i]); + if (i < op->args.size() - 1) { + p->stream << ", "; + } } - } - p->stream << ")"; - }); + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } } - } - p->stream << "]"; - }); + p->stream << "]"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "(let " << op->var << " = "; - p->Print(op->value); - p->stream << " in "; - p->Print(op->body); - p->stream << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "(let " << op->var << " = "; + p->Print(op->value); + p->stream << " in "; + p->Print(op->body); + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << "?"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "reduce(combiner=" - << op->combiner; - p->stream << ", source=" << op->source; - p->stream << ", axis=" << op->axis; - p->stream << ", where=" << op->condition; - p->stream << ", value_index=" << op->value_index; - p->stream << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "reduce(combiner=" << op->combiner; + p->stream << ", source=" << op->source; + p->stream << ", axis=" << op->axis; + p->stream << ", where=" << op->condition; + p->stream << ", value_index=" << op->value_index; + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "comm_reducer(result=" << op->result - << ", lhs=" << op->lhs - << ", rhs=" << op->rhs - << ", identity_element=" << op->identity_element - << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs + << ", rhs=" << op->rhs << ", identity_element=" << op->identity_element << ")"; + }); TVM_REGISTER_NODE_TYPE(StringImmNode); TVM_REGISTER_NODE_TYPE(CastNode); @@ -728,112 +701,78 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode); TVM_REGISTER_NODE_TYPE(ReduceNode); TVM_REGISTER_NODE_TYPE(AnyNode); +TVM_REGISTER_GLOBAL("tir.Add").set_body_typed(AddNode::make); -TVM_REGISTER_GLOBAL("tir.Add") -.set_body_typed(AddNode::make); +TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed(SubNode::make); -TVM_REGISTER_GLOBAL("tir.Sub") -.set_body_typed(SubNode::make); +TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed(MulNode::make); -TVM_REGISTER_GLOBAL("tir.Mul") -.set_body_typed(MulNode::make); +TVM_REGISTER_GLOBAL("tir.Div").set_body_typed(DivNode::make); -TVM_REGISTER_GLOBAL("tir.Div") -.set_body_typed(DivNode::make); +TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed(ModNode::make); -TVM_REGISTER_GLOBAL("tir.Mod") -.set_body_typed(ModNode::make); +TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed(FloorDivNode::make); -TVM_REGISTER_GLOBAL("tir.FloorDiv") -.set_body_typed(FloorDivNode::make); +TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed(FloorModNode::make); -TVM_REGISTER_GLOBAL("tir.FloorMod") -.set_body_typed(FloorModNode::make); +TVM_REGISTER_GLOBAL("tir.Min").set_body_typed(MinNode::make); -TVM_REGISTER_GLOBAL("tir.Min") -.set_body_typed(MinNode::make); +TVM_REGISTER_GLOBAL("tir.Max").set_body_typed(MaxNode::make); -TVM_REGISTER_GLOBAL("tir.Max") -.set_body_typed(MaxNode::make); +TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed(EQNode::make); -TVM_REGISTER_GLOBAL("tir.EQ") -.set_body_typed(EQNode::make); +TVM_REGISTER_GLOBAL("tir.NE").set_body_typed(NENode::make); -TVM_REGISTER_GLOBAL("tir.NE") -.set_body_typed(NENode::make); +TVM_REGISTER_GLOBAL("tir.LT").set_body_typed(LTNode::make); -TVM_REGISTER_GLOBAL("tir.LT") -.set_body_typed(LTNode::make); +TVM_REGISTER_GLOBAL("tir.LE").set_body_typed(LENode::make); -TVM_REGISTER_GLOBAL("tir.LE") -.set_body_typed(LENode::make); +TVM_REGISTER_GLOBAL("tir.GT").set_body_typed(GTNode::make); -TVM_REGISTER_GLOBAL("tir.GT") -.set_body_typed(GTNode::make); +TVM_REGISTER_GLOBAL("tir.GE").set_body_typed(GENode::make); -TVM_REGISTER_GLOBAL("tir.GE") -.set_body_typed(GENode::make); +TVM_REGISTER_GLOBAL("tir.And").set_body_typed(AndNode::make); -TVM_REGISTER_GLOBAL("tir.And") -.set_body_typed(AndNode::make); +TVM_REGISTER_GLOBAL("tir.Or").set_body_typed(OrNode::make); -TVM_REGISTER_GLOBAL("tir.Or") -.set_body_typed(OrNode::make); +TVM_REGISTER_GLOBAL("tir.Not").set_body_typed(NotNode::make); -TVM_REGISTER_GLOBAL("tir.Not") -.set_body_typed(NotNode::make); +TVM_REGISTER_GLOBAL("tir.Select").set_body_typed(SelectNode::make); -TVM_REGISTER_GLOBAL("tir.Select") -.set_body_typed(SelectNode::make); +TVM_REGISTER_GLOBAL("tir.Ramp").set_body_typed(RampNode::make); -TVM_REGISTER_GLOBAL("tir.Ramp") -.set_body_typed(RampNode::make); +TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed(CastNode::make); -TVM_REGISTER_GLOBAL("tir.Cast") -.set_body_typed(CastNode::make); +TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed(BroadcastNode::make); -TVM_REGISTER_GLOBAL("tir.Broadcast") -.set_body_typed(BroadcastNode::make); +TVM_REGISTER_GLOBAL("tir.Shuffle").set_body_typed(ShuffleNode::make); -TVM_REGISTER_GLOBAL("tir.Shuffle") -.set_body_typed(ShuffleNode::make); - -TVM_REGISTER_GLOBAL("tir.Let") -.set_body_typed(LetNode::make); - -TVM_REGISTER_GLOBAL("tir.Load") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DataType t = args[0]; - if (args.size() == 3) { - *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); - } else { - *ret = LoadNode::make(t, args[1], args[2], args[3]); - } - }); +TVM_REGISTER_GLOBAL("tir.Let").set_body_typed(LetNode::make); -TVM_REGISTER_GLOBAL("tir.Call") -.set_body_typed([]( - DataType type, std::string name, - Array args, int call_type, - FunctionRef func, int value_index -) { - Array prim_expr_args; - for (const auto& it : args) { - CHECK(it->IsInstance() || - it->IsInstance()); - if (const auto* str = it.as()) { - prim_expr_args.push_back(StringImmNode::make(str->data)); - } else { - prim_expr_args.push_back(Downcast(it)); - } +TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) { + DataType t = args[0]; + if (args.size() == 3) { + *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); + } else { + *ret = LoadNode::make(t, args[1], args[2], args[3]); } - return CallNode::make(type, - name, - prim_expr_args, - static_cast(call_type), - func, - value_index); }); +TVM_REGISTER_GLOBAL("tir.Call") + .set_body_typed([](DataType type, std::string name, Array args, int call_type, + FunctionRef func, int value_index) { + Array prim_expr_args; + for (const auto& it : args) { + CHECK(it->IsInstance() || it->IsInstance()); + if (const auto* str = it.as()) { + prim_expr_args.push_back(StringImmNode::make(str->data)); + } else { + prim_expr_args.push_back(Downcast(it)); + } + } + return CallNode::make(type, name, prim_expr_args, static_cast(call_type), + func, value_index); + }); + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 57ff627..7f30abe 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -20,6 +20,7 @@ * \file expr_functor.cc */ #include + #include "functor_common.h" namespace tvm { @@ -49,10 +50,10 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); } -#define DEFINE_BINOP_VISIT_(OP) \ - void ExprVisitor::VisitExpr_(const OP* op) { \ - this->VisitExpr(op->a); \ - this->VisitExpr(op->b); \ +#define DEFINE_BINOP_VISIT_(OP) \ + void ExprVisitor::VisitExpr_(const OP* op) { \ + this->VisitExpr(op->a); \ + this->VisitExpr(op->b); \ } DEFINE_BINOP_VISIT_(AddNode); @@ -79,20 +80,16 @@ void ExprVisitor::VisitExpr_(const StringImmNode* op) {} void ExprVisitor::VisitExpr_(const ReduceNode* op) { VisitArray(op->axis, [this](const IterVar& r) { - this->VisitExpr(r->dom->min); - this->VisitExpr(r->dom->extent); - }); + this->VisitExpr(r->dom->min); + this->VisitExpr(r->dom->extent); + }); VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->condition); } -void ExprVisitor::VisitExpr_(const CastNode* op) { - this->VisitExpr(op->value); -} +void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::VisitExpr_(const NotNode* op) { - this->VisitExpr(op->a); -} +void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); } void ExprVisitor::VisitExpr_(const SelectNode* op) { this->VisitExpr(op->condition); @@ -110,13 +107,9 @@ void ExprVisitor::VisitExpr_(const ShuffleNode* op) { VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); }); } -void ExprVisitor::VisitExpr_(const BroadcastNode* op) { - this->VisitExpr(op->value); -} +void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); } -PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { - return GetRef(op); -} +PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { return this->VisitExpr_(static_cast(op)); @@ -145,8 +138,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); @@ -160,34 +152,26 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { if (args.same_as(op->args)) { return GetRef(op); } else { - return CallNode::make(op->dtype, - op->name, - args, - op->call_type, - op->func, - op->value_index); + return CallNode::make(op->dtype, op->name, args, op->call_type, op->func, op->value_index); } } -#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP *op) { \ - return GetRef(op); \ - } +#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ + PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef(op); } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) -#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP* op) { \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return OP::make(a, b); \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ + PrimExpr ExprMutator::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return OP::make(a, b); \ + } \ } DEFINE_BIOP_EXPR_MUTATE_(AddNode); @@ -209,17 +193,15 @@ DEFINE_BIOP_EXPR_MUTATE_(AndNode); DEFINE_BIOP_EXPR_MUTATE_(OrNode); PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { - auto fitervar = [this](const IterVar& v) { + auto fitervar = [this](const IterVar& v) { Range r = v->dom; PrimExpr min = this->VisitExpr(r->min); PrimExpr extent = this->VisitExpr(r->extent); - if (min.same_as(r->min) && - extent.same_as(r->extent)) { + if (min.same_as(r->min) && extent.same_as(r->extent)) { return v; } else { - return IterVarNode::make( - Range::make_by_min_extent(min, extent), - v->var, v->iter_type, v->thread_tag); + return IterVarNode::make(Range::make_by_min_extent(min, extent), v->var, v->iter_type, + v->thread_tag); } }; Array axis = MutateArray(op->axis, fitervar); @@ -229,13 +211,10 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { PrimExpr condition = this->VisitExpr(op->condition); - if (axis.same_as(op->axis) && - source.same_as(op->source) && - condition.same_as(op->condition)) { + if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition)) { return GetRef(op); } else { - return ReduceNode::make( - op->combiner, source, axis, condition, op->value_index); + return ReduceNode::make(op->combiner, source, axis, condition, op->value_index); } } @@ -261,8 +240,7 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr true_value = this->VisitExpr(op->true_value); PrimExpr false_value = this->VisitExpr(op->false_value); - if (condition.same_as(op->condition) && - true_value.same_as(op->true_value) && + if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return GetRef(op); } else { @@ -273,8 +251,7 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); - if (base.same_as(op->base) && - stride.same_as(op->stride)) { + if (base.same_as(op->base) && stride.same_as(op->stride)) { return GetRef(op); } else { return RampNode::make(base, stride, op->lanes); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index ecaad58..1149e03 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -29,11 +29,8 @@ namespace tvm { namespace tir { // Get the function type of a PrimFunc -PrimFunc::PrimFunc(Array params, - Stmt body, - Type ret_type, - Map buffer_map, - DictAttrs attrs) { +PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, + Map buffer_map, DictAttrs attrs) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -60,29 +57,25 @@ FuncType PrimFuncNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(PrimFuncNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - // TODO(tvm-team) redirect to Text printer once we have a good text format. - auto* node = static_cast(ref.get()); - p->stream << "PrimFunc(" << node->params << ") "; - if (node->attrs.defined()) { - p->stream << "attrs=" << node->attrs; - } - p->stream << " {\n"; - p->indent += 2; - p->Print(node->body); - p->indent -= 2; - p->stream << "}\n"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + // TODO(tvm-team) redirect to Text printer once we have a good text format. + auto* node = static_cast(ref.get()); + p->stream << "PrimFunc(" << node->params << ") "; + if (node->attrs.defined()) { + p->stream << "attrs=" << node->attrs; + } + p->stream << " {\n"; + p->indent += 2; + p->Print(node->body); + p->indent -= 2; + p->stream << "}\n"; + }); TVM_REGISTER_GLOBAL("tir.PrimFunc") -.set_body_typed([](Array params, - Stmt body, - Type ret_type, - Map buffer_map, - DictAttrs attrs) { - return PrimFunc(params, body, ret_type, buffer_map, attrs); -}); + .set_body_typed([](Array params, Stmt body, Type ret_type, + Map buffer_map, DictAttrs attrs) { + return PrimFunc(params, body, ret_type, buffer_map, attrs); + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h index 76a91ea..f63dcfe 100644 --- a/src/tir/ir/functor_common.h +++ b/src/tir/ir/functor_common.h @@ -27,7 +27,7 @@ namespace tvm { namespace tir { // Implementation of Visitors -template +template inline void VisitArray(const Array& arr, F fvisit) { for (size_t i = 0; i < arr.size(); i++) { fvisit(arr[i]); @@ -35,10 +35,8 @@ inline void VisitArray(const Array& arr, F fvisit) { } // Implementation of mutators -template -inline Array MutateArray(const Array& arr, - F fmutate, - bool allow_copy_on_write = false) { +template +inline Array MutateArray(const Array& arr, F fmutate, bool allow_copy_on_write = false) { if (allow_copy_on_write) { // if we allow copy on write, we can directly // call the inplace mutate function. diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index 6224321..2757c2f 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -24,6 +24,7 @@ #include #include #include + #include // Centralized header for constant folders. #include "../../arith/const_fold.h" @@ -32,17 +33,15 @@ namespace tvm { using namespace tir; - runtime::DataType GetRuntimeDataType(const Type& type) { - if (auto * n = type.as()) { + if (auto* n = type.as()) { return n->dtype; } else if (type.as()) { return DataType::Handle(); } else if (IsVoidType(type)) { return DataType::Void(); } else { - LOG(FATAL) << "Type " << type - << " does not have a corresponding runtime::DataType"; + LOG(FATAL) << "Type " << type << " does not have a corresponding runtime::DataType"; return DataType::Handle(); } } @@ -74,8 +73,7 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { return tir::CallNode::make( t, tir::intrinsic::tvm_large_uint_imm, - {make_const(DataType::UInt(32), low), - make_const(DataType::UInt(32), high)}, + {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, tir::CallNode::PureIntrinsic); } @@ -89,8 +87,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { rhs = tir::BroadcastNode::make(rhs, ltype.lanes()); } else { - CHECK(ltype.lanes() == rtype.lanes()) - << "Cannot match type " << ltype << " vs " << rtype; + CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } if (lhs.dtype() == rhs.dtype()) return; // Only do very simple type coversion @@ -197,8 +194,8 @@ PrimExpr infinity(const DataType& dtype) { } namespace tir { -template -inline bool ConstPowerHelper(ValueType val, int *shift) { +template +inline bool ConstPowerHelper(ValueType val, int* shift) { if (val <= 0) return false; shift[0] = 0; while (val != 0) { @@ -254,8 +251,7 @@ PrimExpr cast(const DataType& t, PrimExpr value) { PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::CallNode::make( - t, tir::CallNode::reinterpret, { value }, tir::CallNode::PureIntrinsic); + return tir::CallNode::make(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic); } PrimExpr operator+(PrimExpr a, PrimExpr b) { @@ -267,8 +263,8 @@ PrimExpr operator+(PrimExpr a, PrimExpr b) { // negation PrimExpr operator-(PrimExpr a) { - using tir::IntImmNode; using tir::FloatImmNode; + using tir::IntImmNode; const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); if (pa) return IntImm(a.dtype(), -pa->value); @@ -310,22 +306,14 @@ PrimExpr truncmod(PrimExpr a, PrimExpr b) { return tir::ModNode::make(a, b); } -PrimExpr operator/(PrimExpr a, PrimExpr b) { - return div(a, b); -} +PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); } -PrimExpr operator%(PrimExpr a, PrimExpr b) { - return truncmod(a, b); -} +PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); } // TODO(tqchen): switch to floordiv -PrimExpr indexdiv(PrimExpr a, PrimExpr b) { - return floordiv(a, b); -} +PrimExpr indexdiv(PrimExpr a, PrimExpr b) { return floordiv(a, b); } -PrimExpr indexmod(PrimExpr a, PrimExpr b) { - return floormod(a, b); -} +PrimExpr indexmod(PrimExpr a, PrimExpr b) { return floormod(a, b); } PrimExpr floordiv(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; @@ -347,8 +335,8 @@ PrimExpr floormod(PrimExpr a, PrimExpr b) { PrimExpr min(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton - using arith::is_pos_inf; using arith::is_neg_inf; + using arith::is_pos_inf; if (is_pos_inf(a)) return b; if (is_neg_inf(a)) return a; if (is_pos_inf(b)) return a; @@ -361,8 +349,8 @@ PrimExpr min(PrimExpr a, PrimExpr b) { PrimExpr max(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton - using arith::is_pos_inf; using arith::is_neg_inf; + using arith::is_pos_inf; if (is_pos_inf(a)) return a; if (is_neg_inf(a)) return b; if (is_pos_inf(b)) return b; @@ -384,19 +372,14 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) return false_value; } } - return tir::CallNode::make( - true_value.dtype(), - tir::intrinsic::tvm_if_then_else, - {cond, true_value, false_value}, - tir::CallNode::PureIntrinsic); + return tir::CallNode::make(true_value.dtype(), tir::intrinsic::tvm_if_then_else, + {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); } PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; - return tir::CallNode::make(cond.dtype(), - tir::CallNode::likely, - { cond }, - tir::CallNode::PureIntrinsic); + return tir::CallNode::make(cond.dtype(), tir::CallNode::likely, {cond}, + tir::CallNode::PureIntrinsic); } PrimExpr operator>(PrimExpr a, PrimExpr b) { @@ -469,17 +452,18 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) << - "Shift amount must be non-negative and less than " << rtype.bits() - << " for type " << rtype; - if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); - if (pb) { - if (pb->value == 0) return a; - } - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::shift_right, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pb) + CHECK(pb->value >= 0 && pb->value < rtype.bits()) + << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " + << rtype; + if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); + if (pb) { + if (pb->value == 0) return a; + } + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::shift_right, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator<<(PrimExpr a, PrimExpr b) { @@ -487,17 +471,18 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) << - "Shift amount must be non-negative and less than " << rtype.bits() - << " for type " << rtype; - if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); - if (pb) { - if (pb->value == 0) return a; - } - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::shift_left, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pb) + CHECK(pb->value >= 0 && pb->value < rtype.bits()) + << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " + << rtype; + if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); + if (pb) { + if (pb->value == 0) return a; + } + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::shift_left, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator&(PrimExpr a, PrimExpr b) { @@ -505,11 +490,11 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_and, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_and, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator|(PrimExpr a, PrimExpr b) { @@ -517,11 +502,11 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_or, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_or, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator^(PrimExpr a, PrimExpr b) { @@ -529,24 +514,23 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_xor, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_xor, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_not, { a }, tir::CallNode::PureIntrinsic); + return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_not, {a}, + tir::CallNode::PureIntrinsic); } PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; - return tir::CallNode::make( - x.dtype(), "pow", { x, y }, tir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic); } PrimExpr abs(PrimExpr x) { @@ -568,7 +552,7 @@ PrimExpr abs(PrimExpr x) { return x; } else { LOG(FATAL) << "Data type " << x.dtype() - <<" not supported for absolute op. Skipping absolute op..."; + << " not supported for absolute op. Skipping absolute op..."; return x; } } @@ -585,14 +569,13 @@ PrimExpr isnan(PrimExpr x) { } if (x.dtype().bits() == 16) { return tir::CallNode::make(t, tir::CallNode::isnan, - {cast(DataType::Float(32, t.lanes()), std::move(x))}, - tir::CallNode::PureIntrinsic); + {cast(DataType::Float(32, t.lanes()), std::move(x))}, + tir::CallNode::PureIntrinsic); } else { return tir::CallNode::make(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); } } else { - LOG(FATAL) << "Data type " << x.dtype() - <<" not supported for isnan op. Skipping isnan op..."; + LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; return x; } } @@ -616,8 +599,7 @@ PrimExpr sum(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::AddNode::make(x, y); PrimExpr identity_element = make_zero(source.dtype()); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -626,8 +608,7 @@ PrimExpr all(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::AndNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), true); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -636,8 +617,7 @@ PrimExpr any(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::OrNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), false); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -645,8 +625,7 @@ PrimExpr max(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::MaxNode::make(x, y); PrimExpr identity_element = min_value(source.dtype()); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -654,8 +633,7 @@ PrimExpr min(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::MinNode::make(x, y); PrimExpr identity_element = max_value(source.dtype()); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -663,15 +641,14 @@ PrimExpr prod(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::MulNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), 1); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; - return tir::CallNode::make(x.dtype(), "fmod", { x, y }, tir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic); } PrimExpr floor(PrimExpr x) { @@ -721,91 +698,69 @@ PrimExpr trunc(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { - return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : - std::floor(fx->value))); + return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } return tir::CallNode::make(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); } - // expose basic functions to node namespace -TVM_REGISTER_GLOBAL("node._const") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[0].type_code() == kDLInt) { - *ret = tir::make_const(args[1], args[0].operator int64_t()); - } else if (args[0].type_code() == kDLFloat) { - *ret = tir::make_const(args[1], args[0].operator double()); - } else { - LOG(FATAL) << "only accept int or float"; - } - }); - -TVM_REGISTER_GLOBAL("node.LargeUIntImm") -.set_body_typed(LargeUIntImm); - -TVM_REGISTER_GLOBAL("node.String") -.set_body_typed(tir::StringImmNode::make); +TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args[0].type_code() == kDLInt) { + *ret = tir::make_const(args[1], args[0].operator int64_t()); + } else if (args[0].type_code() == kDLFloat) { + *ret = tir::make_const(args[1], args[0].operator double()); + } else { + LOG(FATAL) << "only accept int or float"; + } +}); -TVM_REGISTER_GLOBAL("tir.min_value") -.set_body_typed(min_value); +TVM_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm); -TVM_REGISTER_GLOBAL("tir.max_value") -.set_body_typed(max_value); +TVM_REGISTER_GLOBAL("node.String").set_body_typed(tir::StringImmNode::make); -TVM_REGISTER_GLOBAL("tir.abs") -.set_body_typed(tvm::abs); +TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value); -TVM_REGISTER_GLOBAL("tir.isnan") -.set_body_typed(tvm::isnan); +TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value); -TVM_REGISTER_GLOBAL("tir.isfinite") -.set_body_typed(tvm::isfinite); +TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs); -TVM_REGISTER_GLOBAL("tir.isinf") -.set_body_typed(tvm::isinf); +TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan); -TVM_REGISTER_GLOBAL("tir.floor") -.set_body_typed(tvm::floor); +TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite); -TVM_REGISTER_GLOBAL("tir.ceil") -.set_body_typed(tvm::ceil); +TVM_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf); -TVM_REGISTER_GLOBAL("tir.round") -.set_body_typed(tvm::round); +TVM_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor); -TVM_REGISTER_GLOBAL("tir.nearbyint") -.set_body_typed(tvm::nearbyint); +TVM_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil); -TVM_REGISTER_GLOBAL("tir.trunc") -.set_body_typed(tvm::trunc); +TVM_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round); -TVM_REGISTER_GLOBAL("tir._cast") -.set_body_typed(tvm::cast); +TVM_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint); +TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); +TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); // operator overloading, smarter than make -#define REGISTER_MAKE_BINARY_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir."#Node) \ - .set_body_typed([](PrimExpr a, PrimExpr b) { \ - return (Func(a, b)); \ +#define REGISTER_MAKE_BINARY_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b) { \ + return (Func(a, b)); \ }) -#define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - bool lhs_is_int = args[0].type_code() == kDLInt; \ - bool rhs_is_int = args[1].type_code() == kDLInt; \ - if (lhs_is_int) { \ - *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ - } else if (rhs_is_int) { \ - *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ - } else { \ - *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ - } \ +#define REGISTER_MAKE_BIT_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir." #Node).set_body([](TVMArgs args, TVMRetValue* ret) { \ + bool lhs_is_int = args[0].type_code() == kDLInt; \ + bool rhs_is_int = args[1].type_code() == kDLInt; \ + if (lhs_is_int) { \ + *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ + } else if (rhs_is_int) { \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ + } else { \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ + } \ }) - REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpMul, operator*); @@ -822,20 +777,20 @@ REGISTER_MAKE_BINARY_OP(_OpMin, min); REGISTER_MAKE_BINARY_OP(_OpMax, max); REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); REGISTER_MAKE_BINARY_OP(_OpNE, operator!=); -REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpGE, operator>=); REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&); REGISTER_MAKE_BINARY_OP(_OpOr, operator||); REGISTER_MAKE_BIT_OP(bitwise_and, operator&); REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); -REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) +REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); TVM_REGISTER_GLOBAL("tir._OpIfThenElse") -.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { - return if_then_else(cond, true_value, false_value); -}); + .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { + return if_then_else(cond, true_value, false_value); + }); } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index cc61e7e..4c58fd6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -21,9 +21,8 @@ * \file tvm/tir/stmt.cc */ #include -#include #include - +#include namespace tvm { namespace tir { @@ -40,13 +39,9 @@ Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.LetStmt") -.set_body_typed(LetStmtNode::make); +TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed(LetStmtNode::make); -Stmt AttrStmtNode::make(ObjectRef node, - std::string attr_key, - PrimExpr value, - Stmt body) { +Stmt AttrStmtNode::make(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) { auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); @@ -55,15 +50,12 @@ Stmt AttrStmtNode::make(ObjectRef node, return Stmt(n); } -TVM_REGISTER_GLOBAL("tir.AttrStmt") -.set_body_typed(AttrStmtNode::make); +TVM_REGISTER_GLOBAL("tir.AttrStmt").set_body_typed(AttrStmtNode::make); Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { CHECK(condition.defined()); - CHECK(message.dtype() == DataType::Int(32) || - message.as()) - << "TypeError: AssertStmt message must be an int or string:" - << message << "\n"; + CHECK(message.dtype() == DataType::Int(32) || message.as()) + << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; ObjectPtr node = make_object(); node->condition = std::move(condition); @@ -73,21 +65,17 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { } TVM_REGISTER_GLOBAL("tir.AssertStmt") -.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { - if (const auto* str = message.as()) { - auto msg = StringImmNode::make(str->data); - return AssertStmtNode::make(condition, msg, body); - } else { - return AssertStmtNode::make(condition, Downcast(message), body); - } -}); + .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { + if (const auto* str = message.as()) { + auto msg = StringImmNode::make(str->data); + return AssertStmtNode::make(condition, msg, body); + } else { + return AssertStmtNode::make(condition, Downcast(message), body); + } + }); -Stmt ForNode::make(Var loop_var, - PrimExpr min, - PrimExpr extent, - ForType for_type, - DeviceAPI device_api, - Stmt body) { +Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, + DeviceAPI device_api, Stmt body) { CHECK(min.defined()); CHECK(extent.defined()); CHECK(min.dtype().is_scalar()); @@ -105,19 +93,12 @@ Stmt ForNode::make(Var loop_var, return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.For") -.set_body_typed([]( - Var loop_var, PrimExpr min, PrimExpr extent, - int for_type, int device_api, Stmt body) { - return ForNode::make(loop_var, - min, - extent, - static_cast(for_type), - static_cast(device_api), - body); +TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent, + int for_type, int device_api, Stmt body) { + return ForNode::make(loop_var, min, extent, static_cast(for_type), + static_cast(device_api), body); }); - Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { CHECK(value.defined()); CHECK(index.defined()); @@ -133,20 +114,17 @@ Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr pr return Stmt(node); } - -TVM_REGISTER_GLOBAL("tir.Store") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PrimExpr value = args[1]; - if (args.size() == 3) { - *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); - } else { - *ret = StoreNode::make(args[0], value, args[2], args[3]); - } - }); - +TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) { + PrimExpr value = args[1]; + if (args.size() == 3) { + *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); + } else { + *ret = StoreNode::make(args[0], value, args[2], args[3]); + } +}); Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array args) { - CHECK(value_index >=0 && value_index < func->num_outputs()) + CHECK(value_index >= 0 && value_index < func->num_outputs()) << "value index output function return value bound"; CHECK(value.defined()) << "Provide of undefined value\n"; @@ -162,45 +140,39 @@ Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array< return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.Provide") -.set_body_typed(ProvideNode::make); +TVM_REGISTER_GLOBAL("tir.Provide").set_body_typed(ProvideNode::make); - -Stmt AllocateNode::make(Var buffer_var, - DataType dtype, - Array extents, - PrimExpr condition, +Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body) { - for (size_t i = 0; i < extents.size(); ++i) { - CHECK(extents[i].defined()); - CHECK(extents[i].dtype().is_scalar()); - } - CHECK(body.defined()); - CHECK(condition.defined()); - CHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->buffer_var = std::move(buffer_var); - node->dtype = dtype; - node->extents = std::move(extents); - node->condition = std::move(condition); - node->body = std::move(body); - return Stmt(node); + for (size_t i = 0; i < extents.size(); ++i) { + CHECK(extents[i].defined()); + CHECK(extents[i].dtype().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.dtype().is_bool()); + + ObjectPtr node = make_object(); + node->buffer_var = std::move(buffer_var); + node->dtype = dtype; + node->extents = std::move(extents); + node->condition = std::move(condition); + node->body = std::move(body); + return Stmt(node); } // overloaded, needs special handling // has default args TVM_REGISTER_GLOBAL("tir.Allocate") -.set_body_typed([]( - Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body - ){ - return AllocateNode::make(buffer_var, type, extents, condition, body); -}); + .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, + Stmt body) { + return AllocateNode::make(buffer_var, type, extents, condition, body); + }); int32_t AllocateNode::constant_allocation_size(const Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { - if (const IntImmNode *int_size = extents[i].as()) { + if (const IntImmNode* int_size = extents[i].as()) { result *= int_size->value; if (result > std::numeric_limits::max()) { return 0; @@ -218,16 +190,10 @@ Stmt FreeNode::make(Var buffer_var) { return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.Free") -.set_body_typed(FreeNode::make); - +TVM_REGISTER_GLOBAL("tir.Free").set_body_typed(FreeNode::make); -Stmt RealizeNode::make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds, - PrimExpr condition, - Stmt body) { +Stmt RealizeNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds, + PrimExpr condition, Stmt body) { for (size_t i = 0; i < bounds.size(); ++i) { CHECK(bounds[i]->min.defined()); CHECK(bounds[i]->extent.defined()); @@ -248,29 +214,23 @@ Stmt RealizeNode::make(FunctionRef func, return Stmt(node); } - -TVM_REGISTER_GLOBAL("tir.Realize") -.set_body_typed(RealizeNode::make); - +TVM_REGISTER_GLOBAL("tir.Realize").set_body_typed(RealizeNode::make); Prefetch::Prefetch(Buffer buffer, Array bounds) { data_ = make_object(buffer, bounds); } -TVM_REGISTER_GLOBAL("tir.Prefetch") -.set_body_typed([](Buffer buffer, Array bounds) { +TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array bounds) { return Prefetch(buffer, bounds); }); - SeqStmt::SeqStmt(Array seq) { auto node = make_object(); node->seq = std::move(seq); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.SeqStmt") -.set_body_typed([](Array seq) { +TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq) { return SeqStmt(std::move(seq)); }); @@ -286,9 +246,7 @@ Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) { return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.IfThenElse") -.set_body_typed(IfThenElseNode::make); - +TVM_REGISTER_GLOBAL("tir.IfThenElse").set_body_typed(IfThenElseNode::make); Stmt EvaluateNode::make(PrimExpr value) { CHECK(value.defined()); @@ -298,8 +256,7 @@ Stmt EvaluateNode::make(PrimExpr value) { return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.Evaluate") -.set_body_typed(EvaluateNode::make); +TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed(EvaluateNode::make); BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) { ObjectPtr node = make_object(); @@ -310,69 +267,60 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) } TVM_REGISTER_GLOBAL("tir.BufferStore") -.set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { - return BufferStore(buffer, value, indices); -}); + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { + return BufferStore(buffer, value, indices); + }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); - -BufferRealize::BufferRealize(Buffer buffer, - Array bounds, - PrimExpr condition, - Stmt body) { - data_ = make_object( - buffer, bounds, condition, body); +BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { + data_ = make_object(buffer, bounds, condition, body); } TVM_REGISTER_GLOBAL("tir.BufferRealize") -.set_body_typed([](Buffer buffer, - Array bounds, - PrimExpr condition, - Stmt body) { - return BufferRealize(buffer, bounds, condition, body); -}); + .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { + return BufferRealize(buffer, bounds, condition, body); + }); TVM_REGISTER_NODE_TYPE(BufferRealizeNode); // Printers TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "let " << op->var << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "let " << op->var << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "// attr ["; - p->Print(op->node); - p->stream << "] " - << op->attr_key << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "// attr ["; + p->Print(op->node); + p->stream << "] " << op->attr_key << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "assert("; - p->Print(op->condition); - p->stream << ", "; - p->Print(op->message); - p->stream << ")\n"; - p->Print(op->body); - }); - -std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "assert("; + p->Print(op->condition); + p->stream << ", "; + p->Print(op->message); + p->stream << ")\n"; + p->Print(op->body); + }); + +std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*) switch (type) { case ForType::Serial: out << "for"; @@ -391,221 +339,221 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->for_type << " (" << op->loop_var << ", "; - p->Print(op->min); - p->stream << ", "; - p->Print(op->extent); - p->stream << ") {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - p->stream << "}\n"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->for_type << " (" << op->loop_var << ", "; + p->Print(op->min); + p->stream << ", "; + p->Print(op->extent); + p->stream << ") {\n"; -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->buffer_var << "["; - p->Print(op->index); - p->stream << "] = "; - p->Print(op->value); - if (!is_one(op->predicate)) { - p->stream << " if "; - p->Print(op->predicate); - } - p->stream << '\n'; - }); + p->indent += 2; + p->Print(op->body); + p->indent -= 2; -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->func->func_name() << "("; - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - p->stream << " ="; - p->Print(op->value); - p->stream << '\n'; - }); + p->PrintIndent(); + p->stream << "}\n"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) p->stream << ", "; - } - p->stream << "]"; - p->stream << " = "; - p->Print(op->value); - p->stream << '\n'; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer_var << "["; + p->Print(op->index); + p->stream << "] = "; + p->Print(op->value); + if (!is_one(op->predicate)) { + p->stream << " if "; + p->Print(op->predicate); + } + p->stream << '\n'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "allocate " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - p->stream << " * "; - p->Print(op->extents[i]); - } - p->stream << "]"; - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << "\n"; - p->Print(op->body); - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->func->func_name() << "("; + for (size_t i = 0; i < op->args.size(); ++i) { + p->Print(op->args[i]); + if (i < op->args.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (op->func->num_outputs() != 1) { + p->stream << ".value[" << op->value_index << "]"; + } + p->stream << " ="; + p->Print(op->value); + p->stream << '\n'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "free " << op->buffer_var; - p->stream << '\n'; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; + } + p->stream << "]"; + p->stream << " = "; + p->Print(op->value); + p->stream << '\n'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "buffer_realize " << op->buffer->name << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "allocate " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + p->stream << " * "; + p->Print(op->extents[i]); + } p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << " {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - p->stream << "}\n"; - }); + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << "\n"; + p->Print(op->body); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "realize " << op->func->func_name() << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); - p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << " {\n"; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "free " << op->buffer_var; + p->stream << '\n'; + }); - p->indent += 2; - p->Print(op->body); - p->indent -= 2; +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "buffer_realize " << op->buffer->name << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << " {\n"; - p->PrintIndent(); - p->stream << "}\n"; - }); + p->indent += 2; + p->Print(op->body); + p->indent -= 2; -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "prefetch " << op->buffer << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); - p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - }); + p->PrintIndent(); + p->stream << "}\n"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - for (Stmt stmt : op->seq) { - p->Print(stmt); - } - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "realize " << op->func->func_name() << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (op->func->num_outputs() != 1) { + p->stream << ".value[" << op->value_index << "]"; + } + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << " {\n"; -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - while (true) { - p->stream << "if (" << op->condition << ") {\n"; p->indent += 2; - p->Print(op->then_case); + p->Print(op->body); p->indent -= 2; - if (!op->else_case.defined()) { - break; + p->PrintIndent(); + p->stream << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "prefetch " << op->buffer << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; } + p->stream << ")"; + }); - if (const IfThenElseNode *nested_if = op->else_case.as()) { - p->PrintIndent(); - p->stream << "} else "; - op = nested_if; - } else { - p->PrintIndent(); - p->stream << "} else {\n"; +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + for (Stmt stmt : op->seq) { + p->Print(stmt); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + while (true) { + p->stream << "if (" << op->condition << ") {\n"; p->indent += 2; - p->Print(op->else_case); + p->Print(op->then_case); p->indent -= 2; - break; + + if (!op->else_case.defined()) { + break; + } + + if (const IfThenElseNode* nested_if = op->else_case.as()) { + p->PrintIndent(); + p->stream << "} else "; + op = nested_if; + } else { + p->PrintIndent(); + p->stream << "} else {\n"; + p->indent += 2; + p->Print(op->else_case); + p->indent -= 2; + break; + } } - } - p->PrintIndent(); - p->stream << "}\n"; -}); + p->PrintIndent(); + p->stream << "}\n"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->Print(op->value); - p->stream << "\n"; - }); - -template -void PrintList(const Array &exprs, ReprPrinter* p) { + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->Print(op->value); + p->stream << "\n"; + }); + +template +void PrintList(const Array& exprs, ReprPrinter* p) { for (size_t i = 0; i < exprs.size(); ++i) { p->Print(exprs[i]); if (i < exprs.size() - 1) { @@ -615,14 +563,14 @@ void PrintList(const Array &exprs, ReprPrinter* p) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "shuffle("; - PrintList(op->vectors, p); - p->stream << ", "; - PrintList(op->indices, p); - p->stream << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "shuffle("; + PrintList(op->vectors, p); + p->stream << ", "; + PrintList(op->indices, p); + p->stream << ")"; + }); TVM_REGISTER_NODE_TYPE(AttrStmtNode); TVM_REGISTER_NODE_TYPE(PrefetchNode); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ec97b03..13d0b09 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -21,7 +21,9 @@ */ #include #include + #include + #include "functor_common.h" namespace tvm { @@ -62,9 +64,9 @@ void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); this->VisitExpr(op->condition); this->VisitStmt(op->body); } @@ -92,30 +94,25 @@ void StmtVisitor::VisitStmt_(const ProvideNode* op) { void StmtVisitor::VisitStmt_(const RealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); this->VisitStmt(op->body); this->VisitExpr(op->condition); } void StmtVisitor::VisitStmt_(const PrefetchNode* op) { VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); } void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { - VisitArray(op->seq, [this](const Stmt& s) { - this->VisitStmt(s); - }); -} - -void StmtVisitor::VisitStmt_(const EvaluateNode* op) { - this->VisitExpr(op->value); + VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); } +void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } class StmtMutator::Internal { public: @@ -146,8 +143,7 @@ class StmtMutator::Internal { Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -160,8 +156,7 @@ Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -175,9 +170,7 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); - if (min.same_as(op->min) && - extent.same_as(op->extent) && - body.same_as(op->body)) { + if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -193,9 +186,7 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); - if (extents.same_as(op->extents) && - body.same_as(op->body) && - condition.same_as(op->condition)) { + if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -213,8 +204,7 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { if (op->else_case.defined()) { else_case = this->VisitStmt(op->else_case); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -230,9 +220,7 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr index = this->VisitExpr(op->index); PrimExpr predicate = this->VisitExpr(op->predicate); - if (value.same_as(op->value) && - index.same_as(op->index) && - predicate.same_as(op->predicate)) { + if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -247,8 +235,7 @@ Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { PrimExpr value = this->VisitExpr(op->value); Array indices = Internal::Mutate(this, op->indices); - if (value.same_as(op->value) && - indices.same_as(op->indices)) { + if (value.same_as(op->value) && indices.same_as(op->indices)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -263,9 +250,7 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt body = this->VisitStmt(op->body); - if (bounds.same_as(op->bounds) && - condition.same_as(op->condition) && - body.same_as(op->body)) { + if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -279,8 +264,7 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { Array args = Internal::Mutate(this, op->args); PrimExpr value = this->VisitExpr(op->value); - if (args.same_as(op->args) && - value.same_as(op->value)) { + if (args.same_as(op->args) && value.same_as(op->value)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -294,9 +278,7 @@ Stmt StmtMutator::VisitStmt_(const RealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); - if (bounds.same_as(op->bounds) && - body.same_as(op->body) && - condition.same_as(op->condition)) { + if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -330,8 +312,7 @@ Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { } // advanced visit function for seqstmt. -Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, - bool flatten_before_visit, +Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, std::function fmutate) { if (flatten_before_visit) { // Pass 1, check if we need to flatten. @@ -344,10 +325,8 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, } // function to run the visit. auto frunvisit = [&](const SeqStmtNode* op) { - Array seq = - fmutate != nullptr ? - MutateArray(op->seq, fmutate, allow_copy_on_write_) : - Internal::Mutate(this, op->seq); + Array seq = fmutate != nullptr ? MutateArray(op->seq, fmutate, allow_copy_on_write_) + : Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { return GetRef(op); } else { @@ -380,9 +359,7 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { PrimExpr message = this->VisitExpr(op->message); Stmt body = this->VisitStmt(op->body); - if (condition.same_as(op->condition) && - message.same_as(op->message) && - body.same_as(op->body)) { + if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -404,14 +381,10 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { } } -Stmt StmtMutator::VisitStmt_(const FreeNode* op) { - return GetRef(op); -} - +Stmt StmtMutator::VisitStmt_(const FreeNode* op) { return GetRef(op); } // Implementations of IRTransform, PostOrderVisit and Substitute -class IRApplyVisit : - public StmtExprVisitor { +class IRApplyVisit : public StmtExprVisitor { public: explicit IRApplyVisit(std::function f) : f_(f) {} @@ -434,8 +407,7 @@ class IRApplyVisit : std::unordered_set visited_; }; -void PostOrderVisit(const ObjectRef& node, - std::function fvisit) { +void PostOrderVisit(const ObjectRef& node, std::function fvisit) { if (node.as()) { IRApplyVisit visitor(fvisit); visitor(Downcast(node)); @@ -445,42 +417,29 @@ void PostOrderVisit(const ObjectRef& node, } } -class IRTransformer final : - public StmtExprMutator { +class IRTransformer final : public StmtExprMutator { public: - IRTransformer(const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, + IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, const std::unordered_set& only_enable) - : f_preorder_(f_preorder), - f_postorder_(f_postorder), - only_enable_(only_enable) { - } + : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {} Stmt VisitStmt(const Stmt& stmt) final { - return MutateInternal(stmt, [this](const Stmt& s) { - return this->BaseVisitStmt(s); - }); + return MutateInternal(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); }); } PrimExpr VisitExpr(const PrimExpr& expr) final { - return MutateInternal(expr, [this](const PrimExpr& e) { - return this->BaseVisitExpr(e); - }); + return MutateInternal(expr, + [this](const PrimExpr& e) { return this->BaseVisitExpr(e); }); } private: // NOTE: redirect to parent's call // This is used to get around limitation of gcc-4.8 - Stmt BaseVisitStmt(const Stmt& s) { - return StmtMutator::VisitStmt(s); - } - PrimExpr BaseVisitExpr(const PrimExpr& e) { - return ExprMutator::VisitExpr(e); - } + Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); } + PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); } template T MutateInternal(const T& node, F fmutate) { - if (only_enable_.size() && - !only_enable_.count(node->type_index())) { + if (only_enable_.size() && !only_enable_.count(node->type_index())) { return fmutate(node); } if (f_preorder_ != nullptr) { @@ -501,10 +460,8 @@ class IRTransformer final : const std::unordered_set& only_enable_; }; -Stmt IRTransform(Stmt ir_node, - const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - Optional> only_enable) { +Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, Optional> only_enable) { std::unordered_set only_type_index; if (only_enable.defined()) { for (auto s : only_enable.value()) { @@ -517,9 +474,7 @@ Stmt IRTransform(Stmt ir_node, class IRSubstitue : public StmtExprMutator { public: - explicit IRSubstitue(std::function(const Var&)> vmap) - : vmap_(vmap) { - } + explicit IRSubstitue(std::function(const Var&)> vmap) : vmap_(vmap) {} PrimExpr VisitExpr_(const VarNode* op) final { Var var = GetRef(op); @@ -533,8 +488,7 @@ class IRSubstitue : public StmtExprMutator { PrimExpr ret = StmtExprMutator::VisitExpr_(op); op = ret.as(); if (auto mapped_var = vmap_(op->buffer_var)) { - return LoadNode::make( - op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); + return LoadNode::make(op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); } else { return ret; } @@ -545,8 +499,8 @@ class IRSubstitue : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); if (auto mapped_var = vmap_(op->buffer_var)) { - return StoreNode::make( - Downcast(mapped_var.value()), op->value, op->index, op->predicate); + return StoreNode::make(Downcast(mapped_var.value()), op->value, op->index, + op->predicate); } else { return ret; } @@ -556,36 +510,28 @@ class IRSubstitue : public StmtExprMutator { std::function(const Var&)> vmap_; }; -Stmt Substitute(Stmt stmt, - std::function(const Var&)> vmap) { +Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { return IRSubstitue(vmap)(std::move(stmt)); } -PrimExpr Substitute(PrimExpr expr, - std::function(const Var&)> vmap) { +PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> vmap) { return IRSubstitue(vmap)(std::move(expr)); } +TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); -TVM_REGISTER_GLOBAL("tir.IRTransform") -.set_body_typed(IRTransform); - - -TVM_REGISTER_GLOBAL("tir.PostOrderVisit") -.set_body_typed([](ObjectRef node, PackedFunc f) { - tir::PostOrderVisit(node, [f](const ObjectRef& n) { - f(n); - }); +TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { + tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); TVM_REGISTER_GLOBAL("tir.Substitute") -.set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef{ - if (node->IsInstance()) { - return Substitute(Downcast(node), vmap); - } else { - return Substitute(Downcast(node), vmap); - } -}); + .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { + if (node->IsInstance()) { + return Substitute(Downcast(node), vmap); + } else { + return Substitute(Downcast(node), vmap); + } + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index dda9ff4..30d5f0f 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -21,16 +21,14 @@ * \file tir/ir/transform.cc * \brief TIR specific transformation passes. */ -#include #include +#include #include - namespace tvm { namespace tir { namespace transform { - /*! * \brief Function level pass that applies transformations to all * TIR functions within the module. @@ -43,9 +41,7 @@ class PrimFuncPassNode : public PassNode { /*! \brief The pass function called on each. */ runtime::TypedPackedFunc pass_func; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a function pass on given pass context. @@ -90,8 +86,7 @@ PrimFuncPass::PrimFuncPass( } // Perform Module -> Module optimizations at the PrimFunc level. -IRModule PrimFuncPassNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); pass_ctx.Trace(mod, pass_info, true); @@ -123,9 +118,7 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { + int opt_level, const std::string& name, const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return PrimFuncPass(pass_func, pass_info); } @@ -133,18 +126,16 @@ Pass CreatePrimFuncPass( TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") -.set_body_typed([](runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - return PrimFuncPass(pass_func, pass_info); -}); + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return PrimFuncPass(pass_func, pass_info); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "PrimFuncPass(" << info->name - << ", opt_level=" << info->opt_level << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")"; + }); } // namespace transform } // namespace tir diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc index 1e6f6c6..67a88f5 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -20,15 +20,15 @@ /*! * \file hoist_if_then_else.cc */ +#include #include #include #include -#include -#include +#include #include #include -#include + #include "../../arith/interval_set.h" #include "../../runtime/thread_storage_scope.h" @@ -152,13 +152,12 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { } }); - PackedFunc replace_target_for = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& current_for = args[0]; - if (current_for.get() == top_for_node) { - *ret = new_if_stmt; - } - }); + PackedFunc replace_target_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& current_for = args[0]; + if (current_for.get() == top_for_node) { + *ret = new_if_stmt; + } + }); return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array{"For"}); } @@ -170,21 +169,19 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { Stmt else_for; CHECK(if_stmt.as()); - PackedFunc replace_then_case = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& node = args[0]; - if (node == if_stmt) { - *ret = node.as()->then_case; - } - }); + PackedFunc replace_then_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->then_case; + } + }); - PackedFunc replace_else_case = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& node = args[0]; - if (node == if_stmt) { - *ret = node.as()->else_case; - } - }); + PackedFunc replace_else_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->else_case; + } + }); then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array{"IfThenElse"}); if (if_stmt.as()->else_case.defined()) { @@ -196,7 +193,7 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { // Locate all For nodes and capture child IfThenElse nodes. void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { - PostOrderVisit(stmt, [&](const ObjectRef& node){ + PostOrderVisit(stmt, [&](const ObjectRef& node) { const ForNode* for_node = node.as(); if (!for_node) return; @@ -269,10 +266,8 @@ void IfThenElseHoist::LocateTopFor() { CHECK(for_node); std::vector new_for_list{for_stmt}; for_tracking_map_.insert({for_stmt.get(), new_for_list}); - if (cond_var_map_[if_stmt] - .count(for_node->loop_var.get())) { - std::vector updated_for_list(for_list.begin(), - for_list.begin() + i); + if (cond_var_map_[if_stmt].count(for_node->loop_var.get())) { + std::vector updated_for_list(for_list.begin(), for_list.begin() + i); if2for_map_[if_stmt] = updated_for_list; break; } else { @@ -315,13 +310,11 @@ void IfThenElseHoist::LocateTopFor() { // We keep all For nodes tracing in for_tracking_map_. When we get a // hoisted IfThenElse, we match it with tracing For nodes to pick // the updated one. -size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, - const Stmt& if_stmt) { +size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt) { std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; size_t updated_for_idx = 0; for (size_t i = 0; i < tracked_for_list.size(); ++i) { - const Stmt& current_for = - tracked_for_list.at(tracked_for_list.size() - 1 - i); + const Stmt& current_for = tracked_for_list.at(tracked_for_list.size() - 1 - i); if (is_first_if(current_for, if_stmt)) { updated_for_idx = tracked_for_list.size() - 1 - i; break; @@ -340,11 +333,11 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i); size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if); - const Stmt& updated_for_node = - for_tracking_map_[for_stmt.get()].at(updated_for_idx); + const Stmt& updated_for_node = for_tracking_map_[for_stmt.get()].at(updated_for_idx); auto generated_for_pair = RemoveIf(updated_for_node, new_if); const Stmt& then_for = generated_for_pair.first; - const Stmt& else_for = generated_for_pair.second;; + const Stmt& else_for = generated_for_pair.second; + for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for; if (else_for.get()) { @@ -356,12 +349,10 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { new_if = IfThenElseNode::make(new_if_node->condition, then_for, else_for); if (i < if2for_map_[if_stmt.get()].size() - 1) { const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); - const Stmt& actual_next_for = - for_tracking_map_[original_next_for.get()].at(updated_for_idx); + const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(updated_for_idx); Stmt update_for_stmt = update_for(actual_next_for, new_if); - for_tracking_map_[original_next_for.get()]. - at(updated_for_idx) = update_for_stmt; + for_tracking_map_[original_next_for.get()].at(updated_for_idx) = update_for_stmt; } } return new_if; @@ -369,56 +360,46 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { // Mutate For nodes in post order DFS manner. Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { - PackedFunc replace_top_for = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& current_for = args[0]; - const ForNode* for_node = current_for.as(); - if (!for_node) return; - - if (top_for_var_map_.count(for_node->loop_var.get())) { - std::vector new_if_list; - for (const Stmt& if_stmt : - top_for_var_map_[for_node->loop_var.get()]) { - new_if_list.emplace_back(HoistIf(if_stmt)); - } + PackedFunc replace_top_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& current_for = args[0]; + const ForNode* for_node = current_for.as(); + if (!for_node) return; - const IfThenElseNode* next_if_node; - const IfThenElseNode* current_if_node = - new_if_list.back().as(); - Stmt new_for = Stmt(); - for (size_t i = new_if_list.size() - 1; i > 0; --i) { - CHECK(current_if_node); - const Stmt current_if_stmt = - IfThenElseNode::make(current_if_node->condition, - current_if_node->then_case, - current_if_node->else_case); - next_if_node = new_if_list[i - 1].as(); - CHECK(next_if_node); - new_for = IfThenElseNode::make(next_if_node->condition, current_if_stmt, - next_if_node->else_case); - current_if_node = new_for.as(); - } + if (top_for_var_map_.count(for_node->loop_var.get())) { + std::vector new_if_list; + for (const Stmt& if_stmt : top_for_var_map_[for_node->loop_var.get()]) { + new_if_list.emplace_back(HoistIf(if_stmt)); + } - if (!new_for.get()) { - const IfThenElseNode* first_if_node = new_if_list[0].as(); - CHECK(first_if_node); - new_for = IfThenElseNode::make(first_if_node->condition, - first_if_node->then_case, - first_if_node->else_case); - } - *ret = new_for; + const IfThenElseNode* next_if_node; + const IfThenElseNode* current_if_node = new_if_list.back().as(); + Stmt new_for = Stmt(); + for (size_t i = new_if_list.size() - 1; i > 0; --i) { + CHECK(current_if_node); + const Stmt current_if_stmt = IfThenElseNode::make( + current_if_node->condition, current_if_node->then_case, current_if_node->else_case); + next_if_node = new_if_list[i - 1].as(); + CHECK(next_if_node); + new_for = + IfThenElseNode::make(next_if_node->condition, current_if_stmt, next_if_node->else_case); + current_if_node = new_for.as(); } - }); - return IRTransform(stmt, nullptr, replace_top_for, Array{"For"}); -} -Stmt HoistIfThenElse(Stmt stmt) { - return IfThenElseHoist().VisitAndMutate(stmt); + if (!new_for.get()) { + const IfThenElseNode* first_if_node = new_if_list[0].as(); + CHECK(first_if_node); + new_for = IfThenElseNode::make(first_if_node->condition, first_if_node->then_case, + first_if_node->else_case); + } + *ret = new_for; + } + }); + return IRTransform(stmt, nullptr, replace_top_for, Array{"For"}); } +Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); } -TVM_REGISTER_GLOBAL("testing.HoistIfThenElse") -.set_body_typed(HoistIfThenElse); +TVM_REGISTER_GLOBAL("testing.HoistIfThenElse").set_body_typed(HoistIfThenElse); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index a68e4ee..01a6996 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -21,23 +21,23 @@ * \file arg_binder.cc * \brief Helper utility to match and bind arguments. */ -#include -#include -#include "ir_util.h" #include "arg_binder.h" + +#include +#include + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { -void BinderAddAssert(arith::Analyzer* ana, - PrimExpr cond, - const std::string& arg_name, +void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg_name, std::vector* asserts) { PrimExpr scond = ana->Simplify(cond); if (is_zero(scond)) { - LOG(FATAL) << "Bind have an unmet assertion: " - << cond << ", " << " on argument " << arg_name; + LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " + << " on argument " << arg_name; } if (!is_one(scond)) { std::ostringstream os; @@ -47,9 +47,7 @@ void BinderAddAssert(arith::Analyzer* ana, } } -bool ArgBinder::Bind_(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, +bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_lets) { CHECK_EQ(arg.dtype(), value.dtype()); if (const VarNode* v = arg.as()) { @@ -73,18 +71,14 @@ bool ArgBinder::Bind_(const PrimExpr& arg, return false; } -void ArgBinder::Bind(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, +void ArgBinder::Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_let) { Bind_(arg, value, arg_name, with_let); } -void ArgBinder::BindArray(const Array& arg, - const Array& value, +void ArgBinder::BindArray(const Array& arg, const Array& value, const std::string& arg_name) { - CHECK_EQ(arg.size(), value.size()) - << "Argument " << arg_name << " array size mismatch"; + CHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch"; for (size_t i = 0; i < arg.size(); ++i) { std::ostringstream os; os << arg_name << "[" << i << "]"; @@ -92,16 +86,11 @@ void ArgBinder::BindArray(const Array& arg, } } -void ArgBinder::BindBuffer(const Buffer& arg, - const Buffer& value, - const std::string& arg_name, +void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match) { - CHECK_EQ(arg->scope, value->scope) - << "Argument " << arg_name - << " Buffer bind scope mismatch"; + CHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch"; CHECK_EQ(arg->dtype, value->dtype) - << "Argument " << arg_name - << " Buffer bind data type mismatch"; + << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " << " required_alignment=" << arg->data_alignment @@ -121,9 +110,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, PrimExpr offset = value->elem_offset; PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(&analyzer_, - truncmod(offset, factor) == zero, - arg_name + ".elem_offset", &asserts_); + BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset", + &asserts_); } } @@ -132,8 +120,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, size_t diff = value->shape.size() - arg->shape.size(); for (size_t i = 0; i < diff; ++i) { CHECK(is_one(analyzer_.Simplify(value->shape[i]))) - << "Argument " << arg_name << " shape mismatch" - << arg->shape << " vs " << value->shape; + << "Argument " << arg_name << " shape mismatch" << arg->shape << " vs " << value->shape; } for (size_t i = 0; i < arg->shape.size(); ++i) { std::ostringstream os; @@ -159,22 +146,17 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind k return TVMStructGet(t, arr, 0, kind); } -void ArgBinder::BindDLTensor(const Buffer& buffer, - const PrimExpr& device_type, - const PrimExpr& device_id, - const Var& handle, +void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, + const PrimExpr& device_id, const Var& handle, const std::string& arg_name) { const DataType tvm_shape_type = DataType::ShapeIndex(); const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = EvaluateNode::make(0); // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); - PrimExpr a_ndim = make_const(tvm_ndim_type, - static_cast(buffer->shape.size())); + PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; - ndim_err_msg << arg_name - << ".ndim is expected to equal " - << buffer->shape.size(); + ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str()); asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); // type checks @@ -182,14 +164,12 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == - IntImm(DataType::UInt(8), dtype.code()) && - TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == - IntImm(DataType::UInt(8), dtype.bits()) && - TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == - IntImm(DataType::UInt(16), dtype.lanes())); - if (!(dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || - dtype == DataType::Int(1))) { + IntImm(DataType::UInt(8), dtype.code()) && + TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == + IntImm(DataType::UInt(8), dtype.bits()) && + TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == + IntImm(DataType::UInt(16), dtype.lanes())); + if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str()); asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop)); @@ -200,9 +180,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Var vptr(buffer->data); def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); // mark alignment of external bufs - init_nest_.emplace_back(AttrStmtNode::make( - vptr, tir::attr::storage_alignment, - IntImm(DataType::Int(32), buffer->data_alignment), nop)); + init_nest_.emplace_back(AttrStmtNode::make(vptr, tir::attr::storage_alignment, + IntImm(DataType::Int(32), buffer->data_alignment), + nop)); } Var v_shape(arg_name + ".shape", DataType::Handle()); @@ -210,28 +190,24 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, init_nest_.emplace_back(LetStmtNode::make( v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { - if (dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || - dtype == DataType::Int(1)) { + if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { break; } std::ostringstream field_name; field_name << v_shape->name_hint << '[' << k << ']'; - Bind_(buffer->shape[k], - cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_shape, - IntImm(DataType::Int(32), k), const_true(1))), - field_name.str(), true); + Bind_( + buffer->shape[k], + cast(buffer->shape[k].dtype(), + LoadNode::make(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))), + field_name.str(), true); } // strides field Var v_strides(arg_name + ".strides", DataType::Handle()); def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back(LetStmtNode::make( - v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), - nop)); - PrimExpr is_null = CallNode::make( - DataType::Bool(1), intrinsic::tvm_handle_is_null, - {v_strides}, CallNode::PureIntrinsic); + v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); + PrimExpr is_null = CallNode::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, + CallNode::PureIntrinsic); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -239,10 +215,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - PrimExpr svalue = cast( - stype, - LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr svalue = cast(stype, LoadNode::make(tvm_shape_type, v_strides, + IntImm(DataType::Int(32), k), const_true(1))); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -251,9 +225,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, << " expected to be compact array"; if (conds.size() != 0) { auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str()); - Stmt check = - AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), - stride_msg, EvaluateNode::make(0)); + Stmt check = AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), + stride_msg, EvaluateNode::make(0)); check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt()); asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); } @@ -264,9 +237,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, size_t k = i - 1; std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; - PrimExpr value = cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr value = cast( + buffer->shape[k].dtype(), + LoadNode::make(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); @@ -283,8 +256,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, field_name << v_strides->name_hint << '[' << k << ']'; Bind_(buffer->strides[k], cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))), + LoadNode::make(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), + const_true(1))), field_name.str(), true); } } @@ -293,7 +266,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, if (const auto* const_offset = buffer->elem_offset.as()) { Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), - TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), + TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), arg_name + ".byte_offset", true); } else { if (Bind_(buffer->elem_offset, @@ -305,18 +278,15 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, PrimExpr offset = buffer->elem_offset; PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(&analyzer_, - truncmod(offset, factor) == zero, - arg_name + ".elem_offset", &asserts_); + BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset", + &asserts_); } } } // device info. - Bind_(device_type, - TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), + Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), arg_name + ".device_type", true); - Bind_(device_id, - TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), + Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), arg_name + ".device_id", true); } diff --git a/src/tir/transforms/arg_binder.h b/src/tir/transforms/arg_binder.h index 1769950..657ebdb 100644 --- a/src/tir/transforms/arg_binder.h +++ b/src/tir/transforms/arg_binder.h @@ -24,13 +24,13 @@ #ifndef TVM_TIR_TRANSFORMS_ARG_BINDER_H_ #define TVM_TIR_TRANSFORMS_ARG_BINDER_H_ -#include -#include #include +#include +#include #include -#include #include +#include namespace tvm { namespace tir { @@ -63,10 +63,7 @@ class ArgBinder { * \param def_map A definition map that contains definition of known variables. * ArgBinder will update this def_map when adding new definitions. */ - explicit ArgBinder( - std::unordered_map* def_map) - : def_map_(def_map) { - } + explicit ArgBinder(std::unordered_map* def_map) : def_map_(def_map) {} /*! * \brief Try to bind arg to value, generate constraint if necessary. * \param arg The argument to be binded. @@ -74,9 +71,7 @@ class ArgBinder { * \param arg_name argument name. * \param with_let Whether add lets during bind */ - void Bind(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, + void Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_let = false); /*! * \brief Bind array to array @@ -84,19 +79,17 @@ class ArgBinder { * \param value The target expression value * \param arg_name argument name. */ - void BindArray(const Array& arg, - const Array& value, + void BindArray(const Array& arg, const Array& value, const std::string& arg_name); /*! * \brief Bind symbolic buffer to another symbolic buffer * \param arg The argument to be binded. * \param value The target expression value * \param arg_name argument name. - * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1. + * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as + * arg's higher dimensions are of 1. */ - void BindBuffer(const Buffer& arg, - const Buffer& value, - const std::string& arg_name, + void BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match); /*! * \brief Bind symbolic buffer to a DLTensor handle. @@ -106,20 +99,13 @@ class ArgBinder { * \param handle The DLTensor handle. * \param arg_name argument name. */ - void BindDLTensor(const Buffer& buffer, - const PrimExpr& device_type, - const PrimExpr& device_id, - const Var& handle, - const std::string& arg_name); + void BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id, + const Var& handle, const std::string& arg_name); /*! \return The defs generated in binding. */ - const std::vector& defs() const { - return defs_; - } + const std::vector& defs() const { return defs_; } /*! \return The asserts generated in binding */ - const std::vector& asserts() const { - return asserts_; - } + const std::vector& asserts() const { return asserts_; } /*! * \brief Initialization nest generated * This is only non-empty when BindDLTensor is called. @@ -131,19 +117,13 @@ class ArgBinder { * Let statement is usually generated when bind to DLTensor and memory load is involved. * \return The initialization nest generated during binding. */ - const std::vector& init_nest() const { - return init_nest_; - } + const std::vector& init_nest() const { return init_nest_; } /*! \return Handle data type of the data */ - const Map& def_handle_dtype() const { - return def_handle_dtype_; - } + const Map& def_handle_dtype() const { return def_handle_dtype_; } private: // Internal bind function - bool Bind_(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, + bool Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_lets); /*! \brief The definition map, can be uses to substitute */ std::unordered_map* def_map_; diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 4b1c009..2e1e5b9 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -22,15 +22,16 @@ */ // Instrument checkers for out of the bounds access. -#include #include +#include #include #include -#include #include -#include +#include + #include #include +#include namespace tvm { namespace tir { @@ -41,20 +42,19 @@ class BoundCollector : public StmtVisitor { void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tir::attr::buffer_bound) { - if (const VarNode *key = op->node.as()) { + if (const VarNode* key = op->node.as()) { mem_to_shape[key] = op->value; } } StmtVisitor::VisitStmt_(op); } // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape; + std::unordered_map mem_to_shape; }; class BoundChecker : public StmtExprMutator { public: - explicit BoundChecker( - const std::unordered_map &mem_to_shape) + explicit BoundChecker(const std::unordered_map& mem_to_shape) : mem_to_shape_(mem_to_shape) {} Stmt VisitStmt_(const AllocateNode* op) final { @@ -86,10 +86,8 @@ class BoundChecker : public StmtExprMutator { PrimExpr condition = MakeCondition(); if (!condition.as()) { Stmt nop = EvaluateNode::make(1); - Stmt then_case = - StoreNode::make(op->buffer_var, op->value, op->index, op->predicate); - Stmt else_case = - AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop); + Stmt then_case = StoreNode::make(op->buffer_var, op->value, op->index, op->predicate); + Stmt else_case = AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop); Stmt body = IfThenElseNode::make(condition, then_case, else_case); return body; } @@ -109,9 +107,7 @@ class BoundChecker : public StmtExprMutator { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const Var& buffer_var, - const Array& new_shape, - const DataType& type) { + void Update(const Var& buffer_var, const Array& new_shape, const DataType& type) { // Sanity check at first. if (!new_shape.size()) { return; @@ -126,11 +122,11 @@ class BoundChecker : public StmtExprMutator { // Scalarize the shape. PrimExpr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()), - CastNode::make(DataType::UInt(64), new_shape[0])); + CastNode::make(DataType::UInt(64), new_shape[0])); for (size_t i = 1; i < new_shape.size(); ++i) { // Cast to unsigned to avoid integer overlow at frist. shape = MulNode::make(shape, MulNode::make(make_const(DataType::UInt(64), type.lanes()), - CastNode::make(DataType::UInt(64), new_shape[i]))); + CastNode::make(DataType::UInt(64), new_shape[i]))); } mem_to_shape_[buffer_var.get()] = shape; } @@ -140,23 +136,21 @@ class BoundChecker : public StmtExprMutator { return false; } - if (const RampNode *ramp_index = index.as()) { - return ramp_index->base.defined() && - ramp_index->base.dtype().is_scalar() && - ramp_index->stride.defined() && - ramp_index->stride.dtype().is_scalar() && (ramp_index->lanes > 0); + if (const RampNode* ramp_index = index.as()) { + return ramp_index->base.defined() && ramp_index->base.dtype().is_scalar() && + ramp_index->stride.defined() && ramp_index->stride.dtype().is_scalar() && + (ramp_index->lanes > 0); } return true; } bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const { - return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && - IndexIsValid(index) && !unsafe_rewritten_; + return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) && + !unsafe_rewritten_; } void Collect(PrimExpr index, Var buffer_var) { - store_scope_bound_collector_.push_back( - std::make_pair(index, mem_to_shape_[buffer_var.get()])); + store_scope_bound_collector_.push_back(std::make_pair(index, mem_to_shape_[buffer_var.get()])); } PrimExpr MakeCondition() { @@ -166,13 +160,12 @@ class BoundChecker : public StmtExprMutator { PrimExpr index = buffer_to_mem.first; PrimExpr upper_bound = buffer_to_mem.second; - if (const RampNode *ramp_index = index.as()) { + if (const RampNode* ramp_index = index.as()) { // In case index is base + stride * i. // Non inclusive range. - index = AddNode::make( - ramp_index->base, - MulNode::make(ramp_index->stride, make_const(ramp_index->stride.dtype(), - ramp_index->lanes - 1))); + index = AddNode::make(ramp_index->base, MulNode::make(ramp_index->stride, + make_const(ramp_index->stride.dtype(), + ramp_index->lanes - 1))); } // Try to simplify index and bound. @@ -188,8 +181,7 @@ class BoundChecker : public StmtExprMutator { PrimExpr current_condition = AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound)); - condition = - !i ? current_condition : AndNode::make(condition, current_condition); + condition = !i ? current_condition : AndNode::make(condition, current_condition); } return condition; } @@ -201,9 +193,9 @@ class BoundChecker : public StmtExprMutator { // Pool which collects the pair of index and shape for specific store/load. std::vector> store_scope_bound_collector_; // Error message. - const char *const error_message_ = "OUT OF THE BOUNDS"; + const char* const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape_; + std::unordered_map mem_to_shape_; // internal analyzer arith::Analyzer analyzer_; }; @@ -230,7 +222,7 @@ Pass InstrumentBoundCheckers() { } TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers") -.set_body_typed(InstrumentBoundCheckers); + .set_body_typed(InstrumentBoundCheckers); } // namespace transform diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index c17d665..9e5e4ae 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -22,14 +22,13 @@ * * \file combine_context_call.cc */ +#include +#include +#include #include #include #include #include -#include -#include -#include - #include @@ -44,7 +43,7 @@ class ContextCallCombiner final : public StmtExprMutator { if (op->is_intrinsic(intrinsic::tvm_thread_context)) { CHECK_EQ(op->args.size(), 1U); PrimExpr ctx = op->args[0]; - auto it = ctx_map_.find(ctx); + auto it = ctx_map_.find(ctx); if (it != ctx_map_.end()) { return it->second; } else { @@ -65,8 +64,7 @@ class ContextCallCombiner final : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::coproc_uop_scope) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::coproc_uop_scope) { // Map of comparison expression to variable std::unordered_map temp; std::swap(temp, ctx_map_); @@ -91,14 +89,11 @@ class ContextCallCombiner final : public StmtExprMutator { } } - Stmt Combine(Stmt stmt) { - return BuildContext(ctx_map_, this->VisitStmt(stmt)); - } + Stmt Combine(Stmt stmt) { return BuildContext(ctx_map_, this->VisitStmt(stmt)); } private: static Stmt BuildContext( - const std::unordered_map& cmap, - Stmt body) { + const std::unordered_map& cmap, Stmt body) { for (const auto& kv : cmap) { body = LetStmtNode::make(kv.second, kv.first, body); } @@ -108,7 +103,6 @@ class ContextCallCombiner final : public StmtExprMutator { std::unordered_map ctx_map_; }; - namespace transform { Pass CombineContextCall() { @@ -120,8 +114,7 @@ Pass CombineContextCall() { return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); } -TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall") -.set_body_typed(CombineContextCall); +TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 174564f..41e1124 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -21,11 +21,13 @@ * \file coproc_sync.cc */ #include -#include #include #include +#include + #include #include + #include "ir_util.h" #include "storage_access.h" @@ -89,11 +91,9 @@ class CoProcTouchedBuffer : public StmtExprVisitor { // Synchronization planning with co-processor. class CoProcSyncPlanner : public StorageAccessVisitor { public: - explicit CoProcSyncPlanner( - const std::unordered_set& touched, - const std::string& coproc_name) - : touched_(touched), coproc_name_(coproc_name) { - } + explicit CoProcSyncPlanner(const std::unordered_set& touched, + const std::string& coproc_name) + : touched_(touched), coproc_name_(coproc_name) {} void Plan(const Stmt& stmt) { this->VisitStmt(stmt); @@ -107,22 +107,19 @@ class CoProcSyncPlanner : public StorageAccessVisitor { std::unordered_map > sync_; protected: - bool Enabled(const VarNode* buf, - const StorageScope& scope) const final { + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return touched_.count(buf); } // Plan the sync - std::vector Summarize( - std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const ForNode* loop) final { return PlanSync(seq, loop, false); } private: // Plan write synchronization if write is not coherent - std::vector PlanSync( - std::vector seq, const ForNode* loop, - bool force_sync_at_end) { + std::vector PlanSync(std::vector seq, const ForNode* loop, + bool force_sync_at_end) { // detect write barriers // access by the co-processor. std::vector co_access; @@ -131,8 +128,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { auto find_conflict = [&](const AccessEntry& acc) { for (const AccessEntry& x : co_access) { if (x.buffer.same_as(acc.buffer) && - ((acc.type == kRead && x.type == kWrite) || - acc.type == kWrite)) { + ((acc.type == kRead && x.type == kWrite) || acc.type == kWrite)) { return true; } } @@ -143,7 +139,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { bool sync_write = false; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_write = true; break; + sync_write = true; + break; } if (acc.type == kSync) { co_access.clear(); @@ -169,7 +166,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { const StmtEntry& s = seq[i]; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_at_end = true; break; + sync_at_end = true; + break; } } if (sync_.count(s.stmt) || sync_at_end) break; @@ -197,10 +195,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } std::vector GetSync(std::string sync_name) { - return {EvaluateNode::make(CallNode::make( - DataType::Int(32), - sync_name, - {}, CallNode::Intrinsic))}; + return { + EvaluateNode::make(CallNode::make(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; } const std::unordered_set& touched_; @@ -210,9 +206,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { // Detect memory barriers when coproc read/write memory class CoProcBarrierDetector : public StorageAccessVisitor { public: - explicit CoProcBarrierDetector( - const std::unordered_set& touched, - const std::string& coproc_name) + explicit CoProcBarrierDetector(const std::unordered_set& touched, + const std::string& coproc_name) : touched_(touched) { read_barrier_name_ = coproc_name + ".coproc_read_barrier"; write_barrier_name_ = coproc_name + ".coproc_write_barrier"; @@ -233,14 +228,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor { std::unordered_map > barrier_after_; protected: - bool Enabled(const VarNode* buf, - const StorageScope& scope) const final { + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return touched_.count(buf); } // Plan the sync - std::vector Summarize( - std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const ForNode* loop) final { if (read_barrier_) { return PlanReadBarrier(seq, loop); } else { @@ -250,17 +243,15 @@ class CoProcBarrierDetector : public StorageAccessVisitor { private: // Plan write barrier at Read after write point. - std::vector PlanWriteBarrier( - std::vector seq, const ForNode* loop) { + std::vector PlanWriteBarrier(std::vector seq, const ForNode* loop) { std::vector read_seq; std::unordered_map > write_set; auto fupdate = [&](size_t i, const AccessEntry& acc) { - auto it = write_set.find(acc.buffer.get()); + auto it = write_set.find(acc.buffer.get()); if (it != write_set.end()) { CHECK_NE(i, 0U); - barrier_after_[seq[i - 1].stmt].push_back( - MakeBarrier(write_barrier_name_, it->second)); + barrier_after_[seq[i - 1].stmt].push_back(MakeBarrier(write_barrier_name_, it->second)); write_set.erase(it); } }; @@ -284,23 +275,21 @@ class CoProcBarrierDetector : public StorageAccessVisitor { fupdate(seq.size(), acc); } } - for (const auto &kv : write_set) { + for (const auto& kv : write_set) { read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end()); } return read_seq; } - std::vector PlanReadBarrier( - std::vector seq, const ForNode* loop) { + std::vector PlanReadBarrier(std::vector seq, const ForNode* loop) { std::vector write_seq; std::unordered_map > read_set; auto fupdate = [&](size_t i, const AccessEntry& acc) { - auto it = read_set.find(acc.buffer.get()); + auto it = read_set.find(acc.buffer.get()); if (it != read_set.end()) { CHECK_NE(i, seq.size()); - barrier_before_[seq[i].stmt].push_back( - MakeBarrier(read_barrier_name_, it->second)); + barrier_before_[seq[i].stmt].push_back(MakeBarrier(read_barrier_name_, it->second)); read_set.erase(it); } }; @@ -325,7 +314,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { fupdate(0, acc); } } - for (const auto &kv : read_set) { + for (const auto& kv : read_set) { write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end()); } return write_seq; @@ -340,13 +329,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor { } Range none; Range r = arith::Union(wset).cover_range(none); - CHECK(r.defined()) - << "Cannot deduce write range of " << wvec[0].buffer; + CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; PrimExpr min = r->min; PrimExpr extent = r->extent; return EvaluateNode::make(CallNode::make( - DataType::Int(32), func, - {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic)); + DataType::Int(32), func, {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, + CallNode::Intrinsic)); } // Write barrier name bool read_barrier_{false}; @@ -355,12 +343,9 @@ class CoProcBarrierDetector : public StorageAccessVisitor { const std::unordered_set& touched_; }; - class CoProcInstDepDetector : public StmtVisitor { public: - explicit CoProcInstDepDetector( - const IterVar& coproc_axis, - const std::string& coproc_name) + explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name) : coproc_axis_(coproc_axis) { sync_push_name_ = coproc_name + ".coproc_dep_push"; sync_pop_name_ = coproc_name + ".coproc_dep_pop"; @@ -375,8 +360,7 @@ class CoProcInstDepDetector : public StmtVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::coproc_scope && - op->node.same_as(coproc_axis_)) { + if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) { const IntImmNode* ctx_id = op->value.as(); CHECK(ctx_id != nullptr); curr_state_.clear(); @@ -399,9 +383,7 @@ class CoProcInstDepDetector : public StmtVisitor { curr_state_.node = op; CHECK(first_state_.node != nullptr); // loop carry dependency - InjectSync(last_state_, first_state_, - &(curr_state_.exit_push), - &(curr_state_.enter_pop)); + InjectSync(last_state_, first_state_, &(curr_state_.exit_push), &(curr_state_.enter_pop)); curr_state_.enter_ctx = first_state_.enter_ctx; curr_state_.exit_ctx = last_state_.exit_ctx; } @@ -423,12 +405,8 @@ class CoProcInstDepDetector : public StmtVisitor { curr_state.node = op; MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); - curr_state.enter_ctx.insert( - first_state_.enter_ctx.begin(), - first_state_.enter_ctx.end()); - curr_state.exit_ctx.insert( - last_state_.exit_ctx.begin(), - last_state_.exit_ctx.end()); + curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); + curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); } first_state_.clear(); last_state_.clear(); @@ -439,12 +417,8 @@ class CoProcInstDepDetector : public StmtVisitor { curr_state.node = op; MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); - curr_state.enter_ctx.insert( - first_state_.enter_ctx.begin(), - first_state_.enter_ctx.end()); - curr_state.exit_ctx.insert( - last_state_.exit_ctx.begin(), - last_state_.exit_ctx.end()); + curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); + curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); } } // update in the trace. @@ -487,15 +461,14 @@ class CoProcInstDepDetector : public StmtVisitor { // record the push/pop sequence that could be possibly un-matched. // return the push/pop message at enter/exit of the Block // after considering the existing unmatcheded events and added events - void InjectSync(const SyncState& prev, - const SyncState& next, + void InjectSync(const SyncState& prev, const SyncState& next, std::vector >* prev_exit_push, std::vector >* next_enter_pop) { prev_exit_push->clear(); next_enter_pop->clear(); // quick path - if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && - prev.exit_ctx.size() == 1 && next.enter_ctx.size() == 1) { + if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && prev.exit_ctx.size() == 1 && + next.enter_ctx.size() == 1) { int from = *prev.exit_ctx.begin(); int to = *next.enter_ctx.begin(); if (from != to) { @@ -520,15 +493,11 @@ class CoProcInstDepDetector : public StmtVisitor { // policy 1 std::vector prev_after, next_before; for (const std::pair& p : pending) { - if (std::find(prev.exit_push.begin(), - prev.exit_push.end(), p) == - prev.exit_push.end()) { + if (std::find(prev.exit_push.begin(), prev.exit_push.end(), p) == prev.exit_push.end()) { vpush.push_back(p); prev_after.emplace_back(MakePush(p.first, p.second)); } - if (std::find(next.enter_pop.begin(), - next.enter_pop.end(), p) == - next.enter_pop.end()) { + if (std::find(next.enter_pop.begin(), next.enter_pop.end(), p) == next.enter_pop.end()) { vpop.push_back(p); next_before.emplace_back(MakePop(p.first, p.second)); } @@ -549,18 +518,18 @@ class CoProcInstDepDetector : public StmtVisitor { } } if (prev_after.size() != 0) { - auto &v1 = insert_after_[prev.node]; + auto& v1 = insert_after_[prev.node]; v1.insert(v1.end(), prev_after.begin(), prev_after.end()); } if (next_before.size() != 0) { - auto &v2 = insert_before_[next.node]; + auto& v2 = insert_before_[next.node]; v2.insert(v2.end(), next_before.begin(), next_before.end()); } } void MatchFixEnterPop(const SyncState& state) { if (state.enter_pop.size() == 0) return; - auto &vec = insert_before_[state.node]; + auto& vec = insert_before_[state.node]; for (const std::pair& p : state.enter_pop) { vec.push_back(MakePush(p.first, p.second)); } @@ -568,7 +537,7 @@ class CoProcInstDepDetector : public StmtVisitor { void MatchFixExitPush(const SyncState& state) { if (state.exit_push.size() == 0) return; - auto &vec = insert_after_[state.node]; + auto& vec = insert_after_[state.node]; for (const std::pair& p : state.exit_push) { vec.push_back(MakePop(p.first, p.second)); } @@ -587,16 +556,16 @@ class CoProcInstDepDetector : public StmtVisitor { } Stmt MakePush(int from, int to) { - return EvaluateNode::make(CallNode::make( - DataType::Int(32), sync_push_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + return EvaluateNode::make( + CallNode::make(DataType::Int(32), sync_push_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } Stmt MakePop(int from, int to) { - return EvaluateNode::make(CallNode::make( - DataType::Int(32), sync_pop_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + return EvaluateNode::make( + CallNode::make(DataType::Int(32), sync_pop_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } // sync states. SyncState first_state_, last_state_, curr_state_; @@ -605,7 +574,6 @@ class CoProcInstDepDetector : public StmtVisitor { std::string sync_push_name_, sync_pop_name_; }; - class CoProcSyncInserter : public StmtMutator { public: Stmt Insert(Stmt stmt) { @@ -614,7 +582,7 @@ class CoProcSyncInserter : public StmtMutator { if (visitor.coproc_.size() == 0) return stmt; std::unordered_set touched; - for (const auto &kv : visitor.touched_) { + for (const auto& kv : visitor.touched_) { if (kv.second.normal && kv.second.coproc) { touched.insert(kv.first); } @@ -641,8 +609,7 @@ class CoProcSyncInserter : public StmtMutator { vec.insert(vec.end(), kv.second.begin(), kv.second.end()); } // Detect barrier - CoProcInstDepDetector sync_detector( - *visitor.coproc_.begin(), coproc_name); + CoProcInstDepDetector sync_detector(*visitor.coproc_.begin(), coproc_name); sync_detector.Plan(stmt); for (const auto& kv : sync_detector.insert_before_) { auto& vec = insert_before_[kv.first]; @@ -661,9 +628,8 @@ class CoProcSyncInserter : public StmtMutator { Stmt new_stmt = StmtMutator::VisitStmt(stmt); return SeqStmt::Flatten( - it_before != insert_before_.end() ? it_before->second : std::vector(), - new_stmt, - it_after != insert_after_.end() ? it_after->second : std::vector()); + it_before != insert_before_.end() ? it_before->second : std::vector(), new_stmt, + it_after != insert_after_.end() ? it_after->second : std::vector()); } private: @@ -673,10 +639,7 @@ class CoProcSyncInserter : public StmtMutator { std::unordered_map > insert_after_; }; - -Stmt CoProcSync(Stmt stmt) { - return CoProcSyncInserter().Insert(std::move(stmt)); -} +Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); } namespace transform { @@ -689,8 +652,7 @@ Pass CoProcSync() { return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {}); } -TVM_REGISTER_GLOBAL("tir.transform.CoProcSync") -.set_body_typed(CoProcSync); +TVM_REGISTER_GLOBAL("tir.transform.CoProcSync").set_body_typed(CoProcSync); } // namespace transform diff --git a/src/tir/transforms/decorate_device_scope.cc b/src/tir/transforms/decorate_device_scope.cc index 7ff2e3f..0decb94 100644 --- a/src/tir/transforms/decorate_device_scope.cc +++ b/src/tir/transforms/decorate_device_scope.cc @@ -21,18 +21,15 @@ * \file decorate_device_scope.cc */ #include -#include #include +#include #include namespace tvm { namespace tir { Stmt DecorateDeviceScope(Stmt&& stmt) { - Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), - tir::attr::device_scope, - 0, - stmt); + Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt); return body; } @@ -47,8 +44,7 @@ Pass DecorateDeviceScope() { return CreatePrimFuncPass(pass_func, 0, "tir.DecorateDeviceScope", {}); } -TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope") -.set_body_typed(DecorateDeviceScope); +TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope").set_body_typed(DecorateDeviceScope); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 86bbefc..5af1a39 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -21,12 +21,13 @@ * \brief Replace certain copy with copy intrinsics. * \file copy_intrin_rewrite.cc */ -#include -#include -#include #include +#include +#include #include #include +#include + #include "../../arith/pattern_match.h" namespace tvm { @@ -36,11 +37,9 @@ using runtime::PackedFunc; class CopyIntrinInjector : public StmtMutator { public: - CopyIntrinInjector(const std::string& pragma_key, - const PackedFunc& flower_copy_fromto) - : pragma_key_(attr::pragma_scope_prefix+ pragma_key), - flower_copy_fromto_(flower_copy_fromto) { - } + CopyIntrinInjector(const std::string& pragma_key, const PackedFunc& flower_copy_fromto) + : pragma_key_(attr::pragma_scope_prefix + pragma_key), + flower_copy_fromto_(flower_copy_fromto) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { @@ -48,15 +47,14 @@ class CopyIntrinInjector : public StmtMutator { storage_scope_[buf] = op->value.as()->value; } else if (op->attr_key == pragma_key_) { Stmt ret; - CHECK(MatchCopyPattern(op->body, &ret)) - << "Cannot match copy pattern of " << op->body; + CHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body; return ret; } return StmtMutator::VisitStmt_(op); } private: - bool MatchCopyPattern(Stmt stmt, Stmt *out) { + bool MatchCopyPattern(Stmt stmt, Stmt* out) { using namespace arith; Stmt body = stmt; @@ -72,9 +70,8 @@ class CopyIntrinInjector : public StmtMutator { // Expr sel_cond, sel_true_value, sel_false_value; // match select or if PVar sel_cond, sel_true_value, sel_false_value; - bool has_cond = - if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || - select(sel_cond, sel_true_value, sel_false_value).Match(store->value); + bool has_cond = if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || + select(sel_cond, sel_true_value, sel_false_value).Match(store->value); const CastNode* cast = store->value.as(); const LoadNode* load = store->value.as(); @@ -95,11 +92,9 @@ class CopyIntrinInjector : public StmtMutator { for (const ForNode* op : loops) { loop_vars.push_back(op->loop_var); } - Array store_strides = - arith::DetectLinearEquation(store->index, loop_vars); - Array load_strides = - arith::DetectLinearEquation(load->index, loop_vars); - if (load_strides.size() == 0 || store_strides.size() == 0) return false; + Array store_strides = arith::DetectLinearEquation(store->index, loop_vars); + Array load_strides = arith::DetectLinearEquation(load->index, loop_vars); + if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array dst_shape; const size_t loop_var_size = loop_vars.size(); if (loop_var_size == 0) { @@ -114,8 +109,7 @@ class CopyIntrinInjector : public StmtMutator { PrimExpr pad_value; PrimExpr src_elem_offset = load_strides[loop_var_size]; if (has_cond) { - Array clip_bound = - arith::DetectClipBound(sel_cond.Eval(), loop_vars); + Array clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars); pad_value = sel_false_value.Eval(); if (clip_bound.size() == 0) return false; CHECK_EQ(src_shape.size(), loop_vars.size()); @@ -150,27 +144,15 @@ class CopyIntrinInjector : public StmtMutator { Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); if (loop_var_size == 0) { - src_strides.push_back(make_const(DataType::Int(32), 1)); - dst_strides.push_back(make_const(DataType::Int(32), 1)); + src_strides.push_back(make_const(DataType::Int(32), 1)); + dst_strides.push_back(make_const(DataType::Int(32), 1)); } - Buffer dst = BufferNode::make( - store->buffer_var, - store->value.dtype(), - dst_shape, - dst_strides, - store_strides[loop_var_size], - store->buffer_var->name_hint, - GetStorageScope(store->buffer_var.get()), - 0, 0, kDefault); - Buffer src = BufferNode::make( - load->buffer_var, - load->dtype, - src_shape, - src_strides, - src_elem_offset, - load->buffer_var->name_hint, - GetStorageScope(load->buffer_var.get()), - 0, 0, kDefault); + Buffer dst = BufferNode::make(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, + store_strides[loop_var_size], store->buffer_var->name_hint, + GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); + Buffer src = BufferNode::make(load->buffer_var, load->dtype, src_shape, src_strides, + src_elem_offset, load->buffer_var->name_hint, + GetStorageScope(load->buffer_var.get()), 0, 0, kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); CHECK(out->defined()) << "flower function did not return correct stmt"; return true; @@ -194,28 +176,23 @@ class CopyIntrinInjector : public StmtMutator { arith::Analyzer analyzer_; }; -Stmt InjectCopyIntrin(Stmt stmt, - const std::string& pragma_key, +Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key, const PackedFunc& flower_copy_fromto) { return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt)); } - namespace transform { -Pass InjectCopyIntrin(std::string pragma_key, - PackedFunc flower_copy_fromto) { +Pass InjectCopyIntrin(std::string pragma_key, PackedFunc flower_copy_fromto) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - n->body = CopyIntrinInjector( - pragma_key, flower_copy_fromto)(std::move(n->body)); + n->body = CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin") -.set_body_typed(InjectCopyIntrin); +TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin").set_body_typed(InjectCopyIntrin); } // namespace transform diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 4e5d08c..0189978 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -22,11 +22,12 @@ * \file inject_double_buffer.cc */ #include -#include -#include #include -#include "ir_util.h" +#include +#include + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -52,7 +53,6 @@ class DoubleBufferDetector : public StmtExprVisitor { std::unordered_set touched_; }; - class StripDoubleBufferWrite : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -66,8 +66,7 @@ class StripDoubleBufferWrite : public StmtMutator { class DoubleBufferInjector : public StmtExprMutator { public: - explicit DoubleBufferInjector(int split_loop) - : split_loop_(split_loop) {} + explicit DoubleBufferInjector(int split_loop) : split_loop_(split_loop) {} Stmt Inject(Stmt stmt) { DoubleBufferDetector detector; @@ -99,8 +98,8 @@ class DoubleBufferInjector : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { - it->second.stride = arith::ComputeReduce( - op->extents, PrimExpr()) * op->dtype.lanes(); + it->second.stride = + arith::ComputeReduce(op->extents, PrimExpr()) * op->dtype.lanes(); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); Array new_extents{make_const(op->extents[0].dtype(), 2)}; @@ -109,13 +108,11 @@ class DoubleBufferInjector : public StmtExprMutator { } CHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back(AttrStmtNode::make( - op->buffer_var, attr::storage_scope, - StringImmNode::make(it->second.scope), - EvaluateNode::make(0))); - alloc_nest.emplace_back(AllocateNode::make( - op->buffer_var, op->dtype, new_extents, op->condition, - EvaluateNode::make(0))); + alloc_nest.emplace_back(AttrStmtNode::make(op->buffer_var, attr::storage_scope, + StringImmNode::make(it->second.scope), + EvaluateNode::make(0))); + alloc_nest.emplace_back(AllocateNode::make(op->buffer_var, op->dtype, new_extents, + op->condition, EvaluateNode::make(0))); return op->body; } else { return StmtExprMutator::VisitStmt_(op); @@ -134,8 +131,7 @@ class DoubleBufferInjector : public StmtExprMutator { << "It is better to split with multiple of 2"; CHECK(is_zero(old_loop->min)); PrimExpr zero = old_loop->min; - PrimExpr new_ext = - old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); + PrimExpr new_ext = old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); PrimExpr factor = make_const(new_ext.dtype(), split_loop_); PrimExpr outer_ext = new_ext / factor; PrimExpr tail_base = outer_ext * factor; @@ -146,9 +142,8 @@ class DoubleBufferInjector : public StmtExprMutator { vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i); loop_seq.emplace_back(Substitute(old_loop->body, vmap)); } - Stmt loop = ForNode::make( - outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, - SeqStmt::Flatten(loop_seq)); + Stmt loop = ForNode::make(outer_var, zero, outer_ext, old_loop->for_type, + old_loop->device_api, SeqStmt::Flatten(loop_seq)); // tail std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); @@ -156,8 +151,7 @@ class DoubleBufferInjector : public StmtExprMutator { PrimExpr idx = tail_base + make_const(tail_base.dtype(), i); vmap[old_loop->loop_var.get()] = idx; tail_seq.emplace_back( - IfThenElseNode::make(idx < old_loop->extent, - Substitute(tail_body, vmap))); + IfThenElseNode::make(idx < old_loop->extent, Substitute(tail_body, vmap))); } stmt = SeqStmt::Flatten(loop, tail_seq); } @@ -179,10 +173,8 @@ class DoubleBufferInjector : public StmtExprMutator { const StorageEntry& e = it->second; CHECK(in_double_buffer_scope_); CHECK(e.stride.defined()); - return StoreNode::make(op->buffer_var, - op->value, - e.switch_write_var * e.stride + op->index, - op->predicate); + return StoreNode::make(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index, + op->predicate); } else { return stmt; } @@ -196,10 +188,8 @@ class DoubleBufferInjector : public StmtExprMutator { const StorageEntry& e = it->second; CHECK(e.stride.defined()); CHECK(e.switch_read_var.defined()); - return LoadNode::make(op->dtype, - op->buffer_var, - e.switch_read_var * e.stride + op->index, - op->predicate); + return LoadNode::make(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index, + op->predicate); } else { return expr; } @@ -213,8 +203,7 @@ class DoubleBufferInjector : public StmtExprMutator { private: Stmt MakeProducer(const AttrStmtNode* op) { const Var buffer = Downcast(op->node); - CHECK_NE(loop_nest_.size(), 0U) - << "Double buffer scope must be inside a loop"; + CHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop"; auto it = dbuffer_info_.find(buffer.get()); if (it == dbuffer_info_.end()) { LOG(WARNING) << "Skip double buffer scope " << op->node; @@ -226,8 +215,7 @@ class DoubleBufferInjector : public StmtExprMutator { PrimExpr one = make_const(e.loop->loop_var.dtype(), 1); PrimExpr two = make_const(e.loop->loop_var.dtype(), 2); PrimExpr loop_shift = e.loop->loop_var + one; - e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", - e.loop->loop_var.dtype()); + e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", e.loop->loop_var.dtype()); e.switch_read_var = indexmod(e.loop->loop_var, two); in_double_buffer_scope_ = true; Stmt body = this->VisitStmt(op->body); @@ -270,12 +258,10 @@ class DoubleBufferInjector : public StmtExprMutator { std::unordered_map dbuffer_info_; }; - Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) { return DoubleBufferInjector(split_loop).Inject(stmt); } - namespace transform { Pass InjectDoubleBuffer(int split_loop) { @@ -287,8 +273,7 @@ Pass InjectDoubleBuffer(int split_loop) { return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer") -.set_body_typed(InjectDoubleBuffer); +TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer); } // namespace transform diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc index e9dae0a..3b626f0 100644 --- a/src/tir/transforms/inject_prefetch.cc +++ b/src/tir/transforms/inject_prefetch.cc @@ -21,20 +21,21 @@ * \file inject_prefetch.cc */ // Inject prefetch op in HalideIR +#include +#include #include #include #include #include #include -#include -#include + #include namespace tvm { namespace tir { -using arith::IntSet; using arith::DomainTouched; +using arith::IntSet; class PrefetchInjector : public StmtMutator { public: @@ -68,7 +69,7 @@ class PrefetchInjector : public StmtMutator { } Stmt VisitStmt_(const ForNode* op) final { - auto &var = op->loop_var; + auto& var = op->loop_var; loop_nest_.push_back(var); if (op->for_type == ForType::Vectorized) { vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1); @@ -83,16 +84,13 @@ class PrefetchInjector : public StmtMutator { private: std::vector loop_nest_; - std::unordered_map vectorized_; + std::unordered_map vectorized_; static const Range none; }; const Range PrefetchInjector::none; -Stmt InjectPrefetch(Stmt stmt) { - return PrefetchInjector()(std::move(stmt)); -} - +Stmt InjectPrefetch(Stmt stmt) { return PrefetchInjector()(std::move(stmt)); } namespace transform { @@ -105,8 +103,7 @@ Pass InjectPrefetch() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch") -.set_body_typed(InjectPrefetch); +TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch").set_body_typed(InjectPrefetch); } // namespace transform diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 01fb6fe..834a7e9 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -24,9 +24,11 @@ #include #include #include + #include -#include "ir_util.h" + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -34,8 +36,7 @@ namespace tir { // If expression is touched by var. class ExprTouched final : public StmtExprVisitor { public: - explicit ExprTouched(const std::unordered_set &touched, - bool check_write) + explicit ExprTouched(const std::unordered_set& touched, bool check_write) : touched_var_(touched), check_write_(check_write) {} void VisitExpr(const PrimExpr& n) final { @@ -43,19 +44,17 @@ class ExprTouched final : public StmtExprVisitor { if (expr_touched_ && !check_write_) return; StmtExprVisitor::VisitExpr(n); } - void VisitStmt(const Stmt& n) final { + void VisitStmt(const Stmt& n) final { // early stopping if (expr_touched_ && !check_write_) return; StmtExprVisitor::VisitStmt(n); } - void VisitExpr_(const LoadNode *op) final { + void VisitExpr_(const LoadNode* op) final { HandleUseVar(op->buffer_var.get()); StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const VarNode *op) final { - HandleUseVar(op); - } - void VisitExpr_(const CallNode *op) final { + void VisitExpr_(const VarNode* op) final { HandleUseVar(op); } + void VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { const auto* rw_mask = op->args[4].as(); const VarNode* buffer_var = op->args[1].as(); @@ -84,9 +83,7 @@ class ExprTouched final : public StmtExprVisitor { used_vars_.push_back(var); } } - void HandleWriteVar(const VarNode* var) { - write_vars_.push_back(var); - } + void HandleWriteVar(const VarNode* var) { write_vars_.push_back(var); } // the fields. bool expr_touched_{false}; std::vector used_vars_; @@ -134,8 +131,7 @@ class VarTouchedAnalysis : public StmtVisitor { Record(op->buffer_var.get(), tc); this->VisitStmt(op->body); } - void Record(const VarNode* var, - const ExprTouched& tc) { + void Record(const VarNode* var, const ExprTouched& tc) { if (touched_var_.count(var)) return; if (tc.expr_touched_) { touched_var_.insert(var); @@ -148,14 +144,11 @@ class VarTouchedAnalysis : public StmtVisitor { } } - std::unordered_set - TouchedVar(const Stmt& stmt, - const VarNode* var) { + std::unordered_set TouchedVar(const Stmt& stmt, const VarNode* var) { touched_var_.insert(var); this->VisitStmt(stmt); // do a DFS to push affect around dependency. - std::vector pending( - touched_var_.begin(), touched_var_.end()); + std::vector pending(touched_var_.begin(), touched_var_.end()); while (!pending.empty()) { const VarNode* v = pending.back(); pending.pop_back(); @@ -173,29 +166,26 @@ class VarTouchedAnalysis : public StmtVisitor { // Whether variable is touched by the thread variable. std::unordered_set touched_var_; // x -> all the buffers x read from - std::unordered_map > affect_; + std::unordered_map > affect_; }; - // Inject virtual thread loop // rewrite the buffer access pattern when necessary. class VTInjector : public StmtExprMutator { public: // constructor - VTInjector(Var var, - int num_threads, - const std::unordered_set& touched_var, + VTInjector(Var var, int num_threads, const std::unordered_set& touched_var, bool allow_share) - : var_(var), num_threads_(num_threads), - touched_var_(touched_var), allow_share_(allow_share) { - } + : var_(var), + num_threads_(num_threads), + touched_var_(touched_var), + allow_share_(allow_share) {} // Inject VTLoop when needed. Stmt VisitStmt(const Stmt& s) final { CHECK(!visit_touched_var_); auto stmt = StmtExprMutator::VisitStmt(s); if (visit_touched_var_ || trigger_base_inject_) { - if (!vt_loop_injected_) { + if (!vt_loop_injected_) { return InjectVTLoop(stmt, false); } visit_touched_var_ = false; @@ -205,8 +195,7 @@ class VTInjector : public StmtExprMutator { } // Variable PrimExpr VisitExpr_(const VarNode* op) final { - CHECK(!alloc_remap_.count(op)) - << "Buffer address may get rewritten in virtual thread"; + CHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread"; if (touched_var_.count(op)) { visit_touched_var_ = true; } @@ -224,9 +213,8 @@ class VTInjector : public StmtExprMutator { } auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - return LoadNode::make(op->dtype, op->buffer_var, - RewriteIndex(op->index, it->second), - op->predicate); + return LoadNode::make(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), + op->predicate); } else { return expr; } @@ -242,13 +230,10 @@ class VTInjector : public StmtExprMutator { visit_touched_var_ = true; PrimExpr offset = this->VisitExpr(op->args[2]); PrimExpr extent = this->VisitExpr(op->args[3]); - PrimExpr stride = - it->second / make_const(offset.dtype(), dtype.lanes()); + PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; - return CallNode::make( - op->dtype, op->name, - {op->args[0], op->args[1], offset, extent, op->args[4]}, - op->call_type); + return CallNode::make(op->dtype, op->name, + {op->args[0], op->args[1], offset, extent, op->args[4]}, op->call_type); } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { return allow_share_ ? GetRef(op) : var_; } else { @@ -269,10 +254,8 @@ class VTInjector : public StmtExprMutator { trigger_base_inject_ = !allow_share_; auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - return StoreNode::make(op->buffer_var, - op->value, - RewriteIndex(op->index, it->second), - op->predicate); + return StoreNode::make(op->buffer_var, op->value, RewriteIndex(op->index, it->second), + op->predicate); } else { return stmt; } @@ -283,13 +266,11 @@ class VTInjector : public StmtExprMutator { if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } else if (!allow_share_ && !vt_loop_injected_ && - (op->attr_key == attr::coproc_uop_scope || - op->attr_key == attr::coproc_scope)) { + (op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) { return InjectVTLoop(GetRef(op), true); } else { Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return AttrStmtNode::make(op->node, op->attr_key, value, body); @@ -304,8 +285,7 @@ class VTInjector : public StmtExprMutator { } visit_touched_var_ = false; Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetStmtNode::make(op->var, value, body); @@ -323,12 +303,10 @@ class VTInjector : public StmtExprMutator { visit_touched_var_ = false; Stmt body = this->VisitStmt(op->body); ++max_loop_depth_; - if (extent.same_as(op->extent) && - body.same_as(op->body)) { + if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return ForNode::make( - op->loop_var, op->min, extent, op->for_type, op->device_api, body); + return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body); } } // IfThenElse @@ -347,8 +325,7 @@ class VTInjector : public StmtExprMutator { else_case = this->VisitStmt(op->else_case); max_loop_depth_ = std::max(temp, max_loop_depth_); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -391,8 +368,7 @@ class VTInjector : public StmtExprMutator { // always rewrite if not allow sharing. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - PrimExpr stride = arith::ComputeReduce( - op->extents, PrimExpr()) * op->dtype.lanes(); + PrimExpr stride = arith::ComputeReduce(op->extents, PrimExpr()) * op->dtype.lanes(); Array other; other.push_back(make_const(op->extents[0].dtype(), num_threads_)); for (PrimExpr e : extents) { @@ -408,14 +384,10 @@ class VTInjector : public StmtExprMutator { // Mutate the body. body = this->VisitStmt(op->body); } - if (!changed && - body.same_as(op->body) && - condition.same_as(op->condition)) { + if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { - return AllocateNode::make( - op->buffer_var, op->dtype, - extents, condition, body); + return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body); } } @@ -445,9 +417,8 @@ class VTInjector : public StmtExprMutator { Var idx(var_->name_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return ForNode::make(idx, make_zero(idx.dtype()), - make_const(idx.dtype(), num_threads_), - ForType::Serial, DeviceAPI::None, stmt); + return ForNode::make(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_), + ForType::Serial, DeviceAPI::None, stmt); } } @@ -472,7 +443,6 @@ class VTInjector : public StmtExprMutator { std::unordered_map alloc_remap_; }; - class VirtualThreadInjector : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -513,8 +483,7 @@ Pass InjectVirtualThread() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread") -.set_body_typed(InjectVirtualThread); +TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread); } // namespace transform diff --git a/src/tir/transforms/ir_util.cc b/src/tir/transforms/ir_util.cc index 9ff3fca..ff3e941 100644 --- a/src/tir/transforms/ir_util.cc +++ b/src/tir/transforms/ir_util.cc @@ -21,11 +21,13 @@ * \file ir_util.cc * \brief Helper functions to construct and compose IR nodes. */ +#include "ir_util.h" + #include -#include -#include + #include -#include "ir_util.h" +#include +#include namespace tvm { namespace tir { @@ -84,7 +86,6 @@ Stmt MergeNest(const std::vector>& nest, Stmt body) { return body; } - class IRConvertSSA final : public StmtExprMutator { public: PrimExpr VisitExpr_(const VarNode* op) final { @@ -112,9 +113,8 @@ class IRConvertSSA final : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (scope_.count(op->buffer_var.get())) { - return LoadNode::make( - op->dtype, scope_[op->buffer_var.get()].back(), - op->index, op->predicate); + return LoadNode::make(op->dtype, scope_[op->buffer_var.get()].back(), op->index, + op->predicate); } else { return expr; } @@ -123,9 +123,8 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(op->buffer_var.get())) { - return StoreNode::make( - scope_[op->buffer_var.get()].back(), op->value, - op->index, op->predicate); + return StoreNode::make(scope_[op->buffer_var.get()].back(), op->value, op->index, + op->predicate); } else { return stmt; } @@ -152,8 +151,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return ForNode::make( - new_var, op->min, op->extent, op->for_type, op->device_api, op->body); + return ForNode::make(new_var, op->min, op->extent, op->for_type, op->device_api, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -167,9 +165,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return AllocateNode::make( - new_var, op->dtype, op->extents, op->condition, - op->body); + return AllocateNode::make(new_var, op->dtype, op->extents, op->condition, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -184,15 +180,13 @@ class IRConvertSSA final : public StmtExprMutator { if (new_alloc.same_as(op->body)) return GetRef(op); alloc = new_alloc.as(); CHECK(alloc); - return AttrStmtNode::make( - alloc->buffer_var, op->attr_key, op->value, new_alloc); + return AttrStmtNode::make(alloc->buffer_var, op->attr_key, op->value, new_alloc); } } Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { - return AttrStmtNode::make( - scope_[v].back(), op->attr_key, op->value, op->body); + return AttrStmtNode::make(scope_[v].back(), op->attr_key, op->value, op->body); } else { return stmt; } @@ -202,13 +196,11 @@ class IRConvertSSA final : public StmtExprMutator { } private: - std::unordered_map > scope_; + std::unordered_map> scope_; std::unordered_set defined_; }; -Stmt ConvertSSA(Stmt stmt) { - return IRConvertSSA()(std::move(stmt)); -} +Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_util.h b/src/tir/transforms/ir_util.h index 18f7977..69b5a39 100644 --- a/src/tir/transforms/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -24,9 +24,10 @@ #ifndef TVM_TIR_TRANSFORMS_IR_UTIL_H_ #define TVM_TIR_TRANSFORMS_IR_UTIL_H_ +#include #include #include -#include + #include namespace tvm { @@ -56,7 +57,7 @@ Stmt MergeNest(const std::vector >& nest, Stmt body); * \return if update happens, return the new array, else return the * original array */ -template +template inline Array UpdateArray(Array arr, F fupdate) { std::vector new_arr(arr.size()); bool changed = false; @@ -81,13 +82,10 @@ inline Array UpdateArray(Array arr, F fupdate) { * \param kind The data kind. * \return the get expression. */ -inline PrimExpr TVMStructGet( - DataType dtype, Var handle, int index, - intrinsic::TVMStructFieldKind kind) { - Array args ={ - handle, - make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind))}; +inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, + intrinsic::TVMStructFieldKind kind) { + Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind))}; return CallNode::make(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); } @@ -101,7 +99,7 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { return CallNode::make( DataType::Handle(), intrinsic::tvm_address_of, {LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), - const_true(dtype.lanes()))}, + const_true(dtype.lanes()))}, CallNode::PureIntrinsic); } @@ -116,11 +114,9 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return CallNode::make( - DataType::Handle(), intrinsic::tvm_address_of, - {LoadNode::make(dtype, handle, offset, - const_true(dtype.lanes()))}, - CallNode::PureIntrinsic); + return CallNode::make(DataType::Handle(), intrinsic::tvm_address_of, + {LoadNode::make(dtype, handle, offset, const_true(dtype.lanes()))}, + CallNode::PureIntrinsic); } /*! @@ -131,14 +127,10 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { * \param value The value to be set. * \return the set stmt. */ -inline Stmt TVMStructSet( - Var handle, int index, - intrinsic::TVMStructFieldKind kind, PrimExpr value) { - Array args ={ - handle, - make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind)), - value}; +inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind kind, + PrimExpr value) { + Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind)), value}; return EvaluateNode::make( CallNode::make(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); } @@ -150,8 +142,7 @@ inline Stmt TVMStructSet( */ inline DataType APIType(DataType t) { if (t.is_handle()) return t; - CHECK_EQ(t.lanes(), 1) - << "Cannot pass vector type through packed API."; + CHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; if (t.is_uint() || t.is_int()) return DataType::Int(64); CHECK(t.is_float()); return DataType::Float(64); @@ -174,7 +165,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { return align; } - /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index 86b8cde..bb4e5f7 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -24,8 +24,9 @@ * \file lift_attr_scope.cc */ #include -#include #include +#include + #include "ir_util.h" namespace tvm { @@ -35,14 +36,12 @@ namespace tir { // to a few specified attr keys class AttrScopeLifter : public StmtMutator { public: - explicit AttrScopeLifter(std::string attr_key) - : attr_key_(attr_key) {} + explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {} Stmt Lift(Stmt stmt) { stmt = operator()(std::move(stmt)); if (attr_node_.defined()) { - stmt = AttrStmtNode::make( - attr_node_, attr_key_, attr_value_, stmt); + stmt = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, stmt); } return stmt; } @@ -52,14 +51,11 @@ class AttrScopeLifter : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (attr_node_.defined()) { - Stmt body = AttrStmtNode::make( - attr_node_, attr_key_, attr_value_, op->body); + Stmt body = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, op->body); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); - return AllocateNode::make( - op->buffer_var, op->dtype, - op->extents, op->condition, body); + return AllocateNode::make(op->buffer_var, op->dtype, op->extents, op->condition, body); } else { return stmt; } @@ -97,8 +93,7 @@ class AttrScopeLifter : public StmtMutator { // check if all decorations are common. for (size_t begin = 0; begin < attr_node.size();) { size_t end = begin + 1; - while (end < attr_node.size() && - attr_node[end].same_as(attr_node[begin]) && + while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) && ValueSame(attr_value[end], attr_value[begin])) { ++end; } @@ -116,8 +111,7 @@ class AttrScopeLifter : public StmtMutator { } Stmt stmt = SeqStmt::Flatten(seq); if (attr_node[begin].defined()) { - stmt = AttrStmtNode::make( - attr_node[begin], attr_key_, attr_value[begin], stmt); + stmt = AttrStmtNode::make(attr_node[begin], attr_key_, attr_value[begin], stmt); } reorg.push_back(stmt); begin = end; @@ -137,32 +131,25 @@ class AttrScopeLifter : public StmtMutator { std::swap(first_node, attr_node_); std::swap(first_value, attr_value_); Stmt else_case = this->VisitStmt(op->else_case); - if (attr_node_.defined() && - attr_value_.defined() && - first_node.defined() && - first_value.defined() && - attr_node_.same_as(first_node) && + if (attr_node_.defined() && attr_value_.defined() && first_node.defined() && + first_value.defined() && attr_node_.same_as(first_node) && ValueSame(attr_value_, first_value)) { - if (then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { + if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { return IfThenElseNode::make(op->condition, then_case, else_case); } } else { if (first_node.defined()) { - then_case = AttrStmtNode::make( - first_node, attr_key_, first_value, then_case); + then_case = AttrStmtNode::make(first_node, attr_key_, first_value, then_case); } if (attr_node_.defined()) { - else_case = AttrStmtNode::make( - attr_node_, attr_key_, attr_value_, else_case); + else_case = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, else_case); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); } - if (then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { + if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { return IfThenElseNode::make(op->condition, then_case, else_case); @@ -192,7 +179,6 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { return AttrScopeLifter(attr_key).Lift(std::move(stmt)); } - namespace transform { Pass LiftAttrScope(std::string attr_key) { @@ -204,8 +190,7 @@ Pass LiftAttrScope(std::string attr_key) { return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope") -.set_body_typed(LiftAttrScope); +TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope").set_body_typed(LiftAttrScope); } // namespace transform diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index dbceb37..6392e70 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -20,24 +20,26 @@ /*! * \file loop_partition.cc */ +#include +#include #include #include -#include #include -#include -#include +#include + #include #include -#include "ir_util.h" + #include "../../arith/interval_set.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { -using arith::IntSet; using arith::DeduceBound; using arith::Intersect; +using arith::IntSet; using PartitionKey = std::pair; struct PartitionKeyHash { @@ -72,8 +74,7 @@ bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) class CandidateSelector final : public StmtExprVisitor { public: using VarIsUsed = bool; - explicit CandidateSelector(bool split_const_loop) - : split_const_loop_(split_const_loop) {} + explicit CandidateSelector(bool split_const_loop) : split_const_loop_(split_const_loop) {} void VisitStmt_(const ForNode* op) final { // partition const loop when sets split_const_loop_ @@ -92,7 +93,7 @@ class CandidateSelector final : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { - const IterVarNode *iv = op->node.as(); + const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); @@ -156,16 +157,16 @@ class CandidateSelector final : public StmtExprVisitor { class PartitionFinder : public StmtExprVisitor { public: explicit PartitionFinder(Var current_var, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) - : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { - for (const auto& kv : hint_map) { - out_vars_.insert(kv.first); - } - for (const auto& kv : relax_map) { - out_vars_.insert(kv.first); - } - } + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) + : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { + for (const auto& kv : hint_map) { + out_vars_.insert(kv.first); + } + for (const auto& kv : relax_map) { + out_vars_.insert(kv.first); + } + } void VisitStmt_(const ForNode* op) final { if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; @@ -198,21 +199,18 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(CallNode::likely)) { PrimExpr cond = op->args[0]; - if (ExprUseVars(cond, - std::unordered_set({current_var_.get()}))) { + if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { // For cond, find out the interval, if exists, in which we can prove that cond is // true. Also find the interval, if exists, in which we can prove that cond is // false. - IntSet interval = - DeduceBound(current_var_, cond, hint_map_, relax_map_); + IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); if (!interval.is_nothing()) { // cond is true within interval partitions[{cond.get(), true}] = interval; } PrimExpr inverse_cond = InverseCond(cond); if (inverse_cond.defined()) { - IntSet interval = - DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); + IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); if (!interval.is_nothing()) { // cond is false within interval partitions[{cond.get(), false}] = interval; @@ -261,7 +259,7 @@ class PartitionFinder : public StmtExprVisitor { class ConditionEliminator : public StmtExprMutator { public: explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) - : ps_(ps), cond_value_(cond_value) {} + : ps_(ps), cond_value_(cond_value) {} PrimExpr VisitExpr(const PrimExpr& e) final { if (ps_.find(e.get()) != ps_.end()) { @@ -275,12 +273,11 @@ class ConditionEliminator : public StmtExprMutator { bool cond_value_; }; - // Insert the partition branch at the innermost thread scope class ThreadPartitionInserter : public StmtMutator { public: - explicit ThreadPartitionInserter(const std::unordered_set& ps, - PrimExpr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} + explicit ThreadPartitionInserter(const std::unordered_set& ps, PrimExpr cond) + : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -310,8 +307,7 @@ class ThreadPartitionInserter : public StmtMutator { // likely conditions class LoopPartitioner : public StmtMutator { public: - explicit LoopPartitioner(bool split_const_loop) - : selector(CandidateSelector(split_const_loop)) {} + explicit LoopPartitioner(bool split_const_loop) : selector(CandidateSelector(split_const_loop)) {} Stmt VisitAndMutate(Stmt stmt) { selector(stmt); @@ -320,15 +316,14 @@ class LoopPartitioner : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (selector.candidates.count(op)) { - Stmt s = TryPartition(op, GetRef(op), op->loop_var, - op->min, op->min + op->extent - 1, op->body, false); + Stmt s = TryPartition(op, GetRef(op), op->loop_var, op->min, op->min + op->extent - 1, + op->body, false); if (s.defined()) return s; } // normal path when loop partition fails // normal loop variable can be put into hint map. - hint_map_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); + hint_map_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = StmtMutator::VisitStmt_(op); hint_map_.erase(op->loop_var.get()); return res; @@ -339,7 +334,7 @@ class LoopPartitioner : public StmtMutator { return StmtMutator::VisitStmt_(op); } - const IterVarNode *iv = op->node.as(); + const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; if (selector.candidates.count(op)) { @@ -352,13 +347,11 @@ class LoopPartitioner : public StmtMutator { Stmt res; if (scope.rank == 1) { // threadIdx should be put into relax map, in case of divergence. - relax_map_.insert({var.get(), - IntSet::interval(make_zero(var.dtype()), op->value - 1)}); + relax_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)}); res = StmtMutator::VisitStmt_(op); relax_map_.erase(var.get()); } else { - hint_map_.insert({var.get(), - IntSet::interval(make_zero(var.dtype()), op->value - 1)}); + hint_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)}); res = StmtMutator::VisitStmt_(op); hint_map_.erase(var.get()); } @@ -366,13 +359,11 @@ class LoopPartitioner : public StmtMutator { } private: - Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, - PrimExpr min, PrimExpr max, Stmt body, bool partition_thread_scope); + Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, + Stmt body, bool partition_thread_scope); - std::pair> - GetIntervalAndCondset(const Partition &partitions, - const arith::IntervalSet &for_interval, - bool cond_value); + std::pair> GetIntervalAndCondset( + const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value); inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body); @@ -385,18 +376,15 @@ class LoopPartitioner : public StmtMutator { // Returns an interval (in the first component) in which all the conditions // given in the second component provably have value given by cond_value -std::pair> -LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, - const arith::IntervalSet &for_interval, - bool cond_value) { +std::pair> LoopPartitioner::GetIntervalAndCondset( + const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) { Array sets; std::unordered_set cond_set; - for (const auto &kv : partitions) { + for (const auto& kv : partitions) { if (kv.first.second == cond_value) { arith::IntervalSet interval = Downcast(kv.second); - arith::IntervalSet intersection = arith::Intersect( - &analyzer_, interval, for_interval); + arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval); if (!intersection->IsEmpty()) { sets.push_back(kv.second); cond_set.insert(kv.first.first); @@ -453,13 +441,8 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, * which will eventually be simplified to empty code. And because only one loop was generated * from loop 2 we stop recursing. */ -Stmt LoopPartitioner::TryPartition(const Object* node, - const Stmt& stmt, - Var var, - PrimExpr min, - PrimExpr max, - Stmt body, - bool partition_thread_scope) { +Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min, + PrimExpr max, Stmt body, bool partition_thread_scope) { using namespace arith; // include hint of var. hint_map_.insert({var.get(), IntSet::interval(min, max)}); @@ -476,7 +459,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, std::unordered_set cond_set; // find an interval in which all conditions on var are true std::tie(middle_interval, cond_set) = - GetIntervalAndCondset(finder.partitions, for_interval, true); + GetIntervalAndCondset(finder.partitions, for_interval, true); if (middle_interval.is_nothing()) { // if such interval doesn't exist, find an interval in which all // conditions on var are false @@ -507,8 +490,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, if (!analyzer_.CanProve(body_begin == min)) { PrimExpr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the pre doubt loop"; + LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; body_begin = MaxNode::make(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false; @@ -533,15 +515,13 @@ Stmt LoopPartitioner::TryPartition(const Object* node, // require the extent to be non-negative PrimExpr cond = (max - post_doubt_begin + 1 >= 0); if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the post doubt loop"; - post_doubt_begin = MinNode::make(post_doubt_begin, max+1); + LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; + post_doubt_begin = MinNode::make(post_doubt_begin, max + 1); // stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false; } if (!partition_thread_scope) { - Stmt post_body = - Substitute(body, {{Var{var}, var + post_doubt_begin}}); + Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); } } @@ -583,8 +563,8 @@ Stmt LoopPartitioner::TryPartition(const Object* node, return s; } -inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt body) { - const ForNode *for_node = static_cast(node); +inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) { + const ForNode* for_node = static_cast(node); CHECK(for_node); if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore @@ -597,7 +577,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b class RemoveLikelyTags : public StmtExprMutator { public: - PrimExpr VisitExpr_(const CallNode *op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(CallNode::likely)) { CHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); @@ -613,7 +593,6 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop) { return stmt; } - namespace transform { Pass LoopPartition(bool split_const_loop) { @@ -625,8 +604,7 @@ Pass LoopPartition(bool split_const_loop) { return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LoopPartition") -.set_body_typed(LoopPartition); +TVM_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition); } // namespace transform diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index ce81528..92b463c 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -21,10 +21,11 @@ * \brief Pass for lowering custom datatypes */ +#include +#include #include #include -#include -#include + #include "../../target/datatype/registry.h" namespace tvm { @@ -79,9 +80,8 @@ class CustomDatatypesLowerer : public StmtExprMutator { if (toBeLowered) { auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); - return AllocateNode::make( - allocate->buffer_var, new_allocate_type, allocate->extents, - allocate->condition, allocate->body); + return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents, + allocate->condition, allocate->body); } return stmt; } @@ -97,19 +97,19 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; } -#define DEFINE_MUTATE__(OP, NodeName) \ - inline PrimExpr VisitExpr_(const NodeName* op) final { \ - auto type_code = op->dtype.code(); \ +#define DEFINE_MUTATE__(OP, NodeName) \ + inline PrimExpr VisitExpr_(const NodeName* op) final { \ + auto type_code = op->dtype.code(); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ - PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ - op = expr.as(); \ - if (toBeLowered) { \ - auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ - CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ - << static_cast(type_code) << " not found"; \ - return (*lower)(expr); \ - } \ - return expr; \ + PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ + op = expr.as(); \ + if (toBeLowered) { \ + auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ + CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ + << static_cast(type_code) << " not found"; \ + return (*lower)(expr); \ + } \ + return expr; \ } DEFINE_MUTATE__(Add, AddNode); @@ -131,15 +131,13 @@ class CustomDatatypesLowerer : public StmtExprMutator { std::string target_; }; - namespace transform { Pass LowerCustomDatatypes() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerCustomDatatypes: Require the target attribute"; + CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body)); return f; @@ -147,8 +145,7 @@ Pass LowerCustomDatatypes() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes") -.set_body_typed(LowerCustomDatatypes); +TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes); } // namespace transform diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index dac426d..a842462 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -21,20 +21,21 @@ * \file lower_device_storage_access.cc * \brief Lower the special device storage access. */ -#include -#include -#include #include -#include #include -#include "ir_util.h" +#include +#include +#include +#include + #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { -using runtime::StorageScope; using runtime::StorageRank; +using runtime::StorageScope; class StorageAccessInfoLower : public StmtExprMutator { public: @@ -51,8 +52,7 @@ class StorageAccessInfoLower : public StmtExprMutator { << "Double allocation of " << it->second.scope.to_string(); if (info->head_address.defined()) { - return LetStmtNode::make( - op->buffer_var, info->head_address, op->body); + return LetStmtNode::make(op->buffer_var, info->head_address, op->body); } else { return op->body; } @@ -99,30 +99,23 @@ class StorageAccessInfoLower : public StmtExprMutator { PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); if (it != storage_info_.end() && it->second.info.defined()) { - return MakeTaggedAccessPtr( - op->dtype, buffer_var, dtype, offset, - it->second.info); + return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second.info); } CHECK(op->dtype.is_handle()); // Change to address_of return AddressOffset(buffer_var, dtype, offset); } - PrimExpr MakeTaggedAccessPtr(DataType ptr_type, - Var buffer_var, - DataType dtype, - PrimExpr offset, + PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, DataType dtype, PrimExpr offset, const MemoryInfo& info) { if (ptr_type.is_handle()) { - CHECK(info->head_address.defined()) - << buffer_var << " is not adddressable."; + CHECK(info->head_address.defined()) << buffer_var << " is not adddressable."; return AddressOffset(buffer_var, dtype, offset); } int dtype_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(info->unit_bits % dtype_bits, 0); - return cast(ptr_type, - analyzer_.Simplify(offset / make_const( - offset.dtype(), info->unit_bits / dtype_bits))); + return cast(ptr_type, analyzer_.Simplify( + offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); } // The storage entry. struct StorageEntry { @@ -139,9 +132,7 @@ class StorageAccessInfoLower : public StmtExprMutator { arith::Analyzer analyzer_; }; -Stmt LowerStorageAccessInfo(Stmt stmt) { - return StorageAccessInfoLower()(std::move(stmt)); -} +Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); } namespace transform { @@ -151,12 +142,11 @@ Pass LowerDeviceStorageAccessInfo() { n->body = StorageAccessInfoLower()(std::move(n->body)); return f; }; - return CreatePrimFuncPass( - pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); } TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") -.set_body_typed(LowerDeviceStorageAccessInfo); + .set_body_typed(LowerDeviceStorageAccessInfo); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index a909d4c..7df8fd2 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -21,23 +21,24 @@ * Lower intrinsic calls and ops to device specific ir when possible. * \file lower_intrin.cc */ +#include +#include #include +#include #include -#include -#include -#include #include -#include "../../arith/pattern_match.h" + #include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/pattern_match.h" namespace tvm { namespace tir { class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { public: - using IRMutatorWithAnalyzer::VisitStmt_; using IRMutatorWithAnalyzer::VisitExpr_; + using IRMutatorWithAnalyzer::VisitStmt_; IntrinInjecter(arith::Analyzer* analyzer, std::string target_name) : IRMutatorWithAnalyzer(analyzer) { @@ -50,8 +51,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { PrimExpr r = ApplyPattern(op->name, GetRef(op)); if (r.defined()) return r; } @@ -78,16 +78,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { const DataType& dtype = op->dtype; CHECK(dtype.is_int() || dtype.is_uint()); - if (support_bitwise_op_ && - is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { // lower to right shift if possible. return op->a >> make_const(dtype, shift); } if (analyzer_->CanProveGreaterEqual(op->b, 0)) { // Common path, positive divisor - if (analyzer_->CanProveGreaterEqual(op->a, 0) || - analyzer_->CanProveGreaterEqual(e, 0)) { + if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) { return truncdiv(op->a, op->b); } else { DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; @@ -100,7 +98,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { - return tir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); + return tir::SelectNode::make(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); } } } else { @@ -110,9 +108,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) PrimExpr rdiv = truncdiv(op->a, op->b); PrimExpr rmod = truncmod(op->a, op->b); - return tir::SelectNode::make( - (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), - rdiv, rdiv - make_const(dtype, 1)); + return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, + rdiv - make_const(dtype, 1)); } } @@ -125,11 +122,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { const DataType& dtype = op->dtype; CHECK(dtype.is_int() || dtype.is_uint()); - if (support_bitwise_op_ && - is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { // lower to masking if possible. - int64_t mask = ( - static_cast(1) << static_cast(shift)) - 1; + int64_t mask = (static_cast(1) << static_cast(shift)) - 1; return op->a & make_const(dtype, mask); } @@ -160,9 +155,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b > 0 && rmod < 0 -> rmod + b // b < 0 && rmod < 0 -> rmod // b < 0 && rmod > 0 -> rmod + b - return tir::SelectNode::make( - (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), - rmod, rmod + op->b); + return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, + rmod + op->b); } } @@ -171,8 +165,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PVar x, y; PVar c; auto e = GetRef(op); - if (max(floordiv(x, y), c).Match(e) && - c.Eval()->value >= 0 && + if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval()); } @@ -232,15 +225,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return e; } - PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, - const AddNode* op) { + PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) { // emit fma instruction: a * b + c PrimExpr lhs = SwapBroadcastCast(a); PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = (*fma_)(CallNode::make( - op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); + PrimExpr r = + (*fma_)(CallNode::make(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { @@ -288,18 +280,15 @@ Pass LowerIntrin() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerIntrin: Require the target attribute"; + CHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - n->body = - IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); + n->body = IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin") -.set_body_typed(LowerIntrin); +TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin); } // namespace transform diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 11e420b..127b012 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -21,18 +21,18 @@ * Lower allreduce to device implementable ir. * \file lower_thread_allreduce.cc */ +#include +#include +#include #include #include #include -#include -#include -#include #include -#include "ir_util.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -40,9 +40,9 @@ namespace tir { class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) - : target_(target), warp_size_(target->thread_warp_size) {} + : target_(target), warp_size_(target->thread_warp_size) {} - Stmt VisitStmt_(const AttrStmtNode *op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { thread_extents_.push_back(op); Stmt ret = StmtExprMutator::VisitStmt_(op); @@ -58,7 +58,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return ret; } } else if (op->attr_key == attr::reduce_scope) { - const CommReducerNode *combiner = op->node.as(); + const CommReducerNode* combiner = op->node.as(); CHECK(combiner); reduce_combiner_.push_back(combiner); Stmt ret = StmtExprMutator::VisitStmt_(op); @@ -85,20 +85,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = AllocateNode::make(repl->buffer_var, repl->dtype, - repl->extents, repl->condition, op->body); + stmt = AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, + op->body); stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, - StringImmNode::make("local"), stmt); + StringImmNode::make("local"), stmt); } else { // use volatile access to shared buffer. - stmt = AttrStmtNode::make( - repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = AllocateNode::make( - repl->buffer_var, repl->dtype, - repl->extents, repl->condition, stmt); - stmt = AttrStmtNode::make( - repl->buffer_var, attr::storage_scope, - StringImmNode::make("shared"), stmt); + stmt = AttrStmtNode::make(repl->buffer_var, attr::volatile_scope, 1, op->body); + stmt = + AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); + stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, + StringImmNode::make("shared"), stmt); } return stmt; } else { @@ -130,18 +127,18 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // make allreduce. Stmt MakeAllreduce(const CallNode* call) { CHECK(!reduce_combiner_.empty()); - const CommReducerNode *combiner = reduce_combiner_.back(); + const CommReducerNode* combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - const IntImmNode *size_of_args = call->args[0].as(); + const IntImmNode* size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; std::vector values(size); std::vector types(size); - PrimExpr cond = call->args[size+1]; + PrimExpr cond = call->args[size + 1]; for (size_t idx = 0; idx < size; ++idx) { - values[idx] = call->args[1+idx]; + values[idx] = call->args[1 + idx]; if (!is_one(cond)) { values[idx] = SelectNode::make(cond, values[idx], inits[idx]); } @@ -149,7 +146,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } std::vector buffers(size); for (size_t idx = 0; idx < size; ++idx) { - const VarNode* buffer = call->args[2+size+idx].as(); + const VarNode* buffer = call->args[2 + size + idx].as(); CHECK(buffer); buffers[idx] = buffer; } @@ -168,12 +165,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { e.scope = runtime::ThreadScope::make(iv->thread_tag); e.iv = iv; CHECK_LE(e.scope.rank, 1); - CHECK_GE(e.scope.dim_index, 0) - << "vthread do not work with cross thread reduction"; + CHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction"; if (e.scope.rank == 1) { const auto* ptr = attr->value.as(); - CHECK(ptr) - << "Need constant extent for reduce set " << iv; + CHECK(ptr) << "Need constant extent for reduce set " << iv; e.extent = static_cast(ptr->value); if (reduce_set.count(iv->var.get())) { vred.push_back(e); @@ -183,8 +178,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } } - CHECK_EQ(nmatch, reduce_set.size()) - << "Not all reduce index are presented in the context"; + CHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context"; std::sort(vred.begin(), vred.end()); std::sort(vpar.begin(), vpar.end()); // the size of each index. @@ -221,15 +215,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr index(0); for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); + shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); PrimExpr pred = const_true(types[idx].lanes()); seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], index, pred)); // Uses a local variable to store the shuffled data. // Later on, this allocation will be properly attached to this statement. Var var("t" + std::to_string(idx), types[idx]); - Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred, - EvaluateNode::make(0)); + Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred, EvaluateNode::make(0)); local_vars.push_back(s); } @@ -240,13 +233,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var mask_var("mask", DataType::UInt(32)); { PrimExpr pred = const_true(1); - PrimExpr mask = CallNode::make(DataType::UInt(32), - intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); + PrimExpr mask = CallNode::make(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, + CallNode::Intrinsic); seq.emplace_back(StoreNode::make(mask_var, mask, index, pred)); // Push allocation with an empty body. Later this will be fixed // when the entire body is ready. - auto stmt = AllocateNode::make(mask_var, mask_var->dtype, - {PrimExpr(1)}, pred, EvaluateNode::make(0)); + auto stmt = AllocateNode::make(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, + EvaluateNode::make(0)); local_vars.push_back(stmt); } @@ -316,8 +309,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var var = shared_bufs[i]; load_remap_[buffers[i]] = LoadNode::make(types[i], var, index, pred); Array extents{PrimExpr(1)}; - auto node = AllocateNode::make(var, types[i], extents, pred, - EvaluateNode::make(0)); + auto node = AllocateNode::make(var, types[i], extents, pred, EvaluateNode::make(0)); alloc_remap_[buffers[i]] = node; warp_allocs_.insert(node.get()); } @@ -328,7 +320,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::vector stores(size); for (size_t i = 0; i < size; ++i) { PrimExpr pred = const_true(types[i].lanes()); - Var buffer_var = Downcast(call->args[2+size+i]); + Var buffer_var = Downcast(call->args[2 + size + i]); stores[i] = StoreNode::make(buffer_var, values[i], 0, pred); } return SeqStmt::Flatten(stores); @@ -341,26 +333,23 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // previous iteration on the same buffer. seq.emplace_back(SyncThread("shared")); for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); + shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(StoreNode::make( - shared_bufs[idx], values[idx], - BufIndex(reduce_index, group_index, reduce_extent), pred)); + seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], + BufIndex(reduce_index, group_index, reduce_extent), pred)); } seq.emplace_back(SyncThread("shared")); - seq.emplace_back(MakeBufAllreduce( - combiner, types, shared_bufs, - reduce_index, group_index, reduce_extent, threadx_extent)); + seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, + reduce_extent, threadx_extent)); for (size_t idx = 0; idx < size; ++idx) { CHECK(!load_remap_.count(buffers[idx])); PrimExpr pred = const_true(types[idx].lanes()); load_remap_[buffers[idx]] = LoadNode::make( - types[idx], shared_bufs[idx], - BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); + types[idx], shared_bufs[idx], + BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); alloc_remap_[buffers[idx]] = AllocateNode::make( - shared_bufs[idx], types[idx], - {PrimExpr(group_extent), PrimExpr(reduce_extent)}, - pred, EvaluateNode::make(0)); + shared_bufs[idx], types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, + EvaluateNode::make(0)); } } @@ -369,10 +358,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (auto var : local_vars) { const AllocateNode* repl = var.as(); if (repl) { - body = AllocateNode::make(repl->buffer_var, repl->dtype, - repl->extents, repl->condition, body); + body = + AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, - StringImmNode::make("local"), body); + StringImmNode::make("local"), body); } } @@ -380,13 +369,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // make allreduce. - Stmt MakeBufAllreduce(const CommReducerNode *combiner, - const std::vector& types, - const Array& shared_bufs, - PrimExpr reduce_index, - PrimExpr group_index, - int reduce_extent, - int threadx_extent) { + Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector& types, + const Array& shared_bufs, PrimExpr reduce_index, PrimExpr group_index, + int reduce_extent, int threadx_extent) { // Get next power of two int reduce_align = 1; while (reduce_extent > reduce_align) { @@ -402,8 +387,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Array a, b; for (size_t i = 0; i < size; ++i) { b.push_back(LoadNode::make(types[i], shared_bufs[i], - BufIndex(reduce_index + offset, group_index, reduce_extent), - const_true())); + BufIndex(reduce_index + offset, group_index, reduce_extent), + const_true())); a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true())); } Array ret = (*combiner)(a, b); @@ -423,9 +408,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } CHECK(threadx_extent >= 1 && warp_size_ >= 1); // normal synchronization - while (reduce_align > threadx_extent || - reduce_align > warp_size_) { - reduce_align = reduce_align >> 1; + while (reduce_align > threadx_extent || reduce_align > warp_size_) { + reduce_align = reduce_align >> 1; PrimExpr cond = reduce_index < reduce_align; seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); @@ -447,8 +431,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Flatten the thread index. // Also return a warp number, - PrimExpr FlattenThread(const std::vector& tvec, - int* out_total_extent) { + PrimExpr FlattenThread(const std::vector& tvec, int* out_total_extent) { int& total_extent = *out_total_extent; total_extent = 1; if (tvec.size() == 0) { @@ -477,21 +460,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // sync thread op. static Stmt SyncThread(const std::string& sync) { - return EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync)}, - CallNode::Intrinsic)); + return EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImmNode::make(sync)}, CallNode::Intrinsic)); } // Emit warp shuffle intrinsic calls. - PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, - int delta_or_lane) { + PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, int delta_or_lane) { PrimExpr pred = const_true(1); PrimExpr index(0); PrimExpr mask = LoadNode::make(DataType::UInt(32), mask_var, index, pred); PrimExpr width = IntImm(DataType::Int(32), warp_size_); - Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), - width, width}; + Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; return CallNode::make(val.dtype(), name, args, CallNode::Intrinsic); } @@ -527,9 +506,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { e.extent = static_cast(ptr->value); } - return e.extent == warp_size_ && - e.scope.dim_index == 0 && - e.scope.rank == 1; + return e.extent == warp_size_ && e.scope.dim_index == 0 && e.scope.rank == 1; } // The target. @@ -542,11 +519,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::vector thread_extents_; std::vector reduce_combiner_; // The load remap - std::unordered_map load_remap_; + std::unordered_map load_remap_; // Allocate remap - std::unordered_map alloc_remap_; + std::unordered_map alloc_remap_; // Allocate from warp reductions - std::unordered_set warp_allocs_; + std::unordered_set warp_allocs_; // Internal analyzer arith::Analyzer analyzer_; }; @@ -557,8 +534,7 @@ Pass LowerThreadAllreduce() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerThreadAllreduce: Require the target attribute"; + CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; const TargetNode* target_node = target.as(); n->body = ThreadAllreduceBuilder(target_node)(n->body); return f; @@ -566,8 +542,7 @@ Pass LowerThreadAllreduce() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce") -.set_body_typed(LowerThreadAllreduce); +TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index ee6c44d..88c4363 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -21,10 +21,10 @@ * Lower TVM related builtin intrinsics such as packed call. * \file tir/transforms/lower_tvm_buildin.cc */ +#include #include #include #include -#include #include @@ -40,10 +40,7 @@ inline PrimExpr ConstInt32(size_t index) { inline PrimExpr StackAlloca(std::string type, size_t num) { Array args = {StringImmNode::make(type), ConstInt32(num)}; - return CallNode::make( - DataType::Handle(), - intrinsic::tvm_stack_alloca, - args, CallNode::Intrinsic); + return CallNode::make(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic); } // Calculate the statistics of packed function. @@ -57,18 +54,14 @@ class BuiltinLower : public StmtExprMutator { stack_tcode_ = Var("stack_tcode", DataType::Handle()); stmt = this->VisitStmt(stmt); if (max_shape_stack_ != 0) { - stmt = LetStmtNode::make( - stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); + stmt = LetStmtNode::make(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); } if (max_array_stack_ != 0) { - stmt = LetStmtNode::make( - stack_array_, StackAlloca("array", max_array_stack_), stmt); + stmt = LetStmtNode::make(stack_array_, StackAlloca("array", max_array_stack_), stmt); } if (max_arg_stack_ != 0) { - stmt = LetStmtNode::make( - stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt); - stmt = LetStmtNode::make( - stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt); + stmt = LetStmtNode::make(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt); + stmt = LetStmtNode::make(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt); } return stmt; } @@ -109,44 +102,34 @@ class BuiltinLower : public StmtExprMutator { } CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - Stmt throw_last_error = EvaluateNode::make( - CallNode::make(DataType::Int(32), - intrinsic::tvm_throw_last_error, {}, - CallNode::Intrinsic)); + Stmt throw_last_error = EvaluateNode::make(CallNode::make( + DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); - Stmt body = SeqStmt({ - IfThenElseNode::make( - CallNode::make(DataType::Bool(1), - intrinsic::tvm_handle_is_null, - {op->buffer_var}, CallNode::PureIntrinsic), - throw_last_error), - op->body}); + Stmt body = SeqStmt( + {IfThenElseNode::make(CallNode::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, + {op->buffer_var}, CallNode::PureIntrinsic), + throw_last_error), + op->body}); Stmt alloca = LetStmtNode::make( op->buffer_var, - CallNode::make(op->buffer_var.dtype(), - "TVMBackendAllocWorkspace", - {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), - cast(DataType::UInt(64), total_bytes), - IntImm(DataType::Int(32), op->dtype.code()), - IntImm(DataType::Int(32), op->dtype.bits())}, - CallNode::Extern), + CallNode::make( + op->buffer_var.dtype(), "TVMBackendAllocWorkspace", + {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), + cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}, + CallNode::Extern), body); - PrimExpr free_op = CallNode::make(DataType::Int(32), - "TVMBackendFreeWorkspace", - {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), - op->buffer_var}, - CallNode::Extern); - Stmt free_stmt = IfThenElseNode::make( - free_op != make_zero(DataType::Int(32)), throw_last_error); + PrimExpr free_op = CallNode::make(DataType::Int(32), "TVMBackendFreeWorkspace", + {cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), op->buffer_var}, + CallNode::Extern); + Stmt free_stmt = + IfThenElseNode::make(free_op != make_zero(DataType::Int(32)), throw_last_error); body = SeqStmt({alloca, free_stmt}); - body = AttrStmtNode::make( - op->buffer_var, attr::storage_alignment, - make_const(DataType::Int(32), runtime::kTempAllocaAlignment), - body); + body = AttrStmtNode::make(op->buffer_var, attr::storage_alignment, + make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); return body; } @@ -185,9 +168,8 @@ class BuiltinLower : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); for (size_t i = 0; i < op->args.size(); ++i) { - prep_seq_.emplace_back( - StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]), - ConstInt32(stack_begin +i), const_true(1))); + prep_seq_.emplace_back(StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]), + ConstInt32(stack_begin + i), const_true(1))); } return AddressOffset(stack_shape_, DataType::Int(64), stack_begin); } @@ -197,45 +179,36 @@ class BuiltinLower : public StmtExprMutator { run_array_stack_ += 1; PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); PrimExpr strides = op->args[2]; if (!strides.defined() || is_zero(strides)) { strides = make_zero(DataType::Handle()); } - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); DataType dtype = op->args[4].dtype(); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode, make_const(DataType::UInt(8), static_cast(dtype.code())))); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, - make_const(DataType::UInt(8), dtype.bits()))); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, - make_const(DataType::UInt(16), dtype.lanes()))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, + make_const(DataType::UInt(8), dtype.bits()))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, + make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); PrimExpr byte_offset = op->args[5]; if (!is_zero(byte_offset)) { byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); } - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, - cast(DataType::UInt(64), byte_offset))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, + cast(DataType::UInt(64), byte_offset))); CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, - cast(DataType::Int(32), device_id_))); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, - cast(DataType::Int(32), device_type_))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, + cast(DataType::Int(32), device_id_))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, + cast(DataType::Int(32), device_type_))); return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); } // call packed. @@ -255,18 +228,15 @@ class BuiltinLower : public StmtExprMutator { if (t != api_type) { arg = CastNode::make(api_type, arg); } - prep_seq_.emplace_back(TVMStructSet( - stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), + intrinsic::kTVMValueContent, arg)); int arg_tcode = api_type.code(); if (api_type.is_handle() && arg.as()) { arg_tcode = kTVMStr; } if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; prep_seq_.emplace_back( - StoreNode::make(stack_tcode_, - ConstInt32(arg_tcode), - stack_index, const_true(1))); + StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_); @@ -275,19 +245,14 @@ class BuiltinLower : public StmtExprMutator { run_shape_stack_ = restore_shape_stack; run_array_stack_ = restore_array_stack; run_arg_stack_ = arg_stack_begin; - Array packed_args = { - op->args[0], - stack_value_, - stack_tcode_, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1) - }; - return CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed_lowered, - packed_args, CallNode::Intrinsic); + Array packed_args = {op->args[0], stack_value_, stack_tcode_, + ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + op->args.size() - 1)}; + return CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args, + CallNode::Intrinsic); } - PrimExpr MakeCallTracePacked(const CallNode *op) { + PrimExpr MakeCallTracePacked(const CallNode* op) { size_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; size_t arg_stack_begin = run_arg_stack_; @@ -304,15 +269,12 @@ class BuiltinLower : public StmtExprMutator { if (t != api_type) { arg = CastNode::make(api_type, arg); } - prep_seq_.emplace_back(TVMStructSet( - stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), + intrinsic::kTVMValueContent, arg)); int arg_tcode = api_type.code(); CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; prep_seq_.emplace_back( - StoreNode::make(stack_tcode_, - ConstInt32(arg_tcode), - stack_index, const_true(1))); + StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_); @@ -323,18 +285,13 @@ class BuiltinLower : public StmtExprMutator { // Update the top of the stack, so we can use more than one // packed function's arguments with the one stack. run_arg_stack_ = arg_stack_begin + args_size - 1; - Array packed_args = { - op->args[0], - stack_value_, - stack_tcode_, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1), - // Pass traced value. - op->args[args_size - 1] - }; - return CallNode::make( - op->dtype, intrinsic::tvm_call_trace_packed_lowered, - packed_args, CallNode::Intrinsic); + Array packed_args = {op->args[0], stack_value_, stack_tcode_, + ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + op->args.size() - 1), + // Pass traced value. + op->args[args_size - 1]}; + return CallNode::make(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args, + CallNode::Intrinsic); } private: @@ -379,8 +336,7 @@ Pass LowerTVMBuiltin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin") -.set_body_typed(LowerTVMBuiltin); +TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 0abbe76..4c8dec0 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -25,20 +25,19 @@ */ // Thanks to Andrew Adams and Vinod Grover for // explaining the concept of warp shuffle. -#include #include - +#include +#include +#include +#include #include #include -#include #include -#include -#include #include -#include "../../arith/pattern_match.h" #include "../../arith/compute_expr.h" +#include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { @@ -101,13 +100,8 @@ namespace tir { // store warp_mem[m * warp_index + (width * m) * y + x] class WarpStoreCoeffFinder : private StmtVisitor { public: - WarpStoreCoeffFinder(const VarNode* buffer, - Var warp_index, - arith::Analyzer* analyzer) - : buffer_(buffer), - warp_index_(warp_index), - analyzer_(analyzer) { - } + WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer) + : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {} // find the warp co-efficient in the statement given the warp size int Find(const Stmt& stmt) { this->VisitStmt(stmt); @@ -116,7 +110,7 @@ class WarpStoreCoeffFinder : private StmtVisitor { private: /// Visitor implementation - void VisitStmt_(const StoreNode *op) final { + void VisitStmt_(const StoreNode* op) final { if (op->buffer_var.get() == buffer_) { if (op->value.dtype().lanes() == 1) { UpdatePattern(op->index); @@ -133,16 +127,14 @@ class WarpStoreCoeffFinder : private StmtVisitor { } void UpdatePattern(const PrimExpr& index) { - Array m = - arith::DetectLinearEquation(index, {warp_index_}); - CHECK_EQ(m.size(), 2U) - << "LowerWarpMemory failed due to store index=" << index; + Array m = arith::DetectLinearEquation(index, {warp_index_}); + CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); const auto* mcoeff_as_int = mcoeff.as(); CHECK(mcoeff_as_int && mcoeff_as_int->value > 0) << "LowerWarpMemory failed due to store index=" << index - << ", require positive constant coefficient on warp index " << warp_index_ - << " but get " << mcoeff; + << ", require positive constant coefficient on warp index " << warp_index_ << " but get " + << mcoeff; if (warp_coeff_ != 0) { CHECK_EQ(warp_coeff_, mcoeff_as_int->value) @@ -162,13 +154,10 @@ class WarpStoreCoeffFinder : private StmtVisitor { arith::Analyzer* analyzer_; }; - // Visitor to find the warp index class WarpIndexFinder : private StmtVisitor { public: - explicit WarpIndexFinder(int warp_size) - : warp_size_(warp_size) { - } + explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {} // find the warp co-efficient and the shuffle width in the statement std::pair Find(const Stmt& stmt) { this->VisitStmt(stmt); @@ -179,21 +168,20 @@ class WarpIndexFinder : private StmtVisitor { private: /// Visitor implementation - void VisitStmt_(const AttrStmtNode *op) final { + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { auto* value_as_int = op->value.as(); - CHECK(value_as_int && - value_as_int->value <= warp_size_ && + CHECK(value_as_int && value_as_int->value <= warp_size_ && warp_size_ % value_as_int->value == 0) << "Expect threadIdx.x 's size to be no larger than, and a factor of" - << " warp size(" << warp_size_ << ")" << " to enable warp memory" + << " warp size(" << warp_size_ << ")" + << " to enable warp memory" << " but get " << op->value << " instead"; if (warp_index_.defined()) { CHECK(warp_index_.same_as(iv)) - << "Find two instance of " << warp_index_->thread_tag - << " in the same kernel. " + << "Find two instance of " << warp_index_->thread_tag << " in the same kernel. " << "Please create it using thread_axis once and reuse the axis " << "across multiple binds in the same kernel"; } else { @@ -221,27 +209,21 @@ class WarpAccessRewriter : protected StmtExprMutator { Stmt Rewrite(const AllocateNode* op) { buffer_ = op->buffer_var.get(); int alloc_size = op->constant_allocation_size(); - CHECK_GT(alloc_size, 0) - << "warp memory only support constant alloc size"; + CHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); - warp_coeff_ = WarpStoreCoeffFinder( - buffer_, warp_index_, analyzer_).Find(op->body); + warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0) << "Warp memory must be multiple of the extent of threadIdx.x"; warp_group_ = alloc_size / (width_ * warp_coeff_); - return AllocateNode::make( - op->buffer_var, - op->dtype, - {make_const(DataType::Int(32), alloc_size / width_)}, - op->condition, - this->VisitStmt(op->body)); + return AllocateNode::make(op->buffer_var, op->dtype, + {make_const(DataType::Int(32), alloc_size / width_)}, op->condition, + this->VisitStmt(op->body)); } protected: PrimExpr VisitExpr_(const VarNode* op) override { - CHECK(op != buffer_) - << "Cannot access address of warp memory directly"; + CHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); } @@ -261,16 +243,13 @@ class WarpAccessRewriter : protected StmtExprMutator { std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id CHECK(!ExprUseVar(local_index, warp_index_)) - << "LowerWarpMemory failed to rewrite load to shuffle for index " - << op->index << " local_index=" << local_index; - PrimExpr load_value = LoadNode::make( - op->dtype, op->buffer_var, local_index, op->predicate); - PrimExpr mask = CallNode::make(DataType::UInt(32), - intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); - return CallNode::make(load_value.dtype(), - intrinsic::tvm_warp_shuffle, - {mask, load_value, group, width_, warp_size_}, - CallNode::Intrinsic); + << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index + << " local_index=" << local_index; + PrimExpr load_value = LoadNode::make(op->dtype, op->buffer_var, local_index, op->predicate); + PrimExpr mask = CallNode::make(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, + CallNode::Intrinsic); + return CallNode::make(load_value.dtype(), intrinsic::tvm_warp_shuffle, + {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); } @@ -303,10 +282,8 @@ class WarpAccessRewriter : protected StmtExprMutator { PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_); y = y * m + x; - PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), - m); - return std::make_pair(analyzer_->canonical_simplify(y), - analyzer_->canonical_simplify(z)); + PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), m); + return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z)); } } @@ -327,14 +304,12 @@ class WarpAccessRewriter : protected StmtExprMutator { arith::Analyzer* analyzer_; }; - // Bind bound information of variables to make analyzer more effective // TODO(tqchen): consider a pass to inline the bound info into the expr // so analysis can be context independent. class BindVarBoundInfo : public StmtVisitor { public: - explicit BindVarBoundInfo(arith::Analyzer* analyzer) - : analyzer_(analyzer) {} + explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {} void VisitStmt_(const ForNode* op) final { const Var& loop_var = op->loop_var; @@ -343,8 +318,7 @@ class BindVarBoundInfo : public StmtVisitor { } void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); if (!var_dom_.count(iv->var.get())) { @@ -366,9 +340,7 @@ class BindVarBoundInfo : public StmtVisitor { // Mutator to change the read pattern class WarpMemoryRewriter : private StmtMutator { public: - explicit WarpMemoryRewriter(int warp_size) - : warp_size_(warp_size) { - } + explicit WarpMemoryRewriter(int warp_size) : warp_size_(warp_size) {} Stmt Rewrite(Stmt stmt) { if (warp_size_ == 1) return stmt; @@ -398,8 +370,7 @@ class WarpMemoryRewriter : private StmtMutator { warp_buffer_.insert(buf); Stmt ret = StmtMutator::VisitStmt_(op); op = ret.as(); - return AttrStmtNode::make( - op->node, op->attr_key, StringImmNode::make("local"), op->body); + return AttrStmtNode::make(op->node, op->attr_key, StringImmNode::make("local"), op->body); } } return StmtMutator::VisitStmt_(op); @@ -418,16 +389,14 @@ Pass LowerWarpMemory() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerWarpMemory: Require the target attribute"; + CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory") -.set_body_typed(LowerWarpMemory); +TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory); } // namespace transform diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 4e5ca2d..b6314ad 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,22 +20,22 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ -#include -#include -#include -#include -#include -#include +#include #include #include -#include +#include +#include +#include +#include +#include +#include -#include -#include #include +#include +#include -#include "ir_util.h" #include "arg_binder.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -45,15 +45,12 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { EvaluateNode::make(0)); } -PrimFunc MakePackedAPI(PrimFunc&& func, - int num_unpacked_args) { +PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol) - << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; + CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "MakePackedAPI: Require the target attribute"; + CHECK(target.defined()) << "MakePackedAPI: Require the target attribute"; int target_device_type = target.value()->device_type; std::string name_hint = global_symbol.value(); @@ -85,15 +82,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func, // local function definitions // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { - Array call_args{ - v_packed_args, - IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; + Array call_args{v_packed_args, IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - PrimExpr res = CallNode::make( - api_type, intrinsic::tvm_struct_get, call_args, - CallNode::PureIntrinsic); + PrimExpr res = + CallNode::make(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic); // cast to the target version. if (api_type != t) { res = CastNode::make(t, res); @@ -111,8 +105,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, std::ostringstream os; os << name_hint << ": num_args should be " << num_packed_args; - seq_init.emplace_back( - MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); + seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); } // Need to re-declare vars, in case some arguments also appears in the buffer. @@ -131,24 +124,21 @@ PrimFunc MakePackedAPI(PrimFunc&& func, } if (i < num_packed_args) { // Value loads - seq_init.emplace_back(LetStmtNode::make( - v_arg, f_arg_value(v_arg.dtype(), i), nop)); + seq_init.emplace_back(LetStmtNode::make(v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back(LetStmtNode::make( - tcode, LoadNode::make( - DataType::Int(32), v_packed_arg_type_ids, - IntImm(DataType::Int(32), i), const_true(1)), - nop)); + seq_init.emplace_back( + LetStmtNode::make(tcode, + LoadNode::make(DataType::Int(32), v_packed_arg_type_ids, + IntImm(DataType::Int(32), i), const_true(1)), + nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be pointer"; seq_check.emplace_back( - AssertStmtNode::make(tcode == kTVMOpaqueHandle || - tcode == kTVMNDArrayHandle || - tcode == kTVMDLTensorHandle || - tcode == kTVMNullptr, + AssertStmtNode::make(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || + tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, tvm::tir::StringImmNode::make(msg.str()), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; @@ -188,35 +178,30 @@ PrimFunc MakePackedAPI(PrimFunc&& func, } for (const auto& kv : buffer_def) { - binder.BindDLTensor(kv.second, device_type, device_id, - kv.first, kv.first->name_hint); + binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint); } if (num_unpacked_args == 0) { func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); } - auto body = AttrStmtNode::make( - make_zero(DataType::Int(32)), attr::compute_scope, - StringImmNode::make(name_hint + "_compute_"), func_ptr->body); + auto body = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::compute_scope, + StringImmNode::make(name_hint + "_compute_"), func_ptr->body); // Set device context if (vmap.count(device_id.get())) { PrimExpr node = StringImmNode::make("default"); - seq_check.push_back(AttrStmtNode::make( - node, attr::device_context_id, device_id, nop)); - seq_check.push_back(AttrStmtNode::make( - node, attr::device_context_type, device_type, nop)); + seq_check.push_back(AttrStmtNode::make(node, attr::device_context_id, device_id, nop)); + seq_check.push_back(AttrStmtNode::make(node, attr::device_context_type, device_type, nop)); if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { Stmt set_device = EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed, - {StringImmNode::make(runtime::symbol::tvm_set_device), - device_type, device_id}, CallNode::Intrinsic)); + DataType::Int(32), intrinsic::tvm_call_packed, + {StringImmNode::make(runtime::symbol::tvm_set_device), device_type, device_id}, + CallNode::Intrinsic)); body = SeqStmt({set_device, body}); } } - func_ptr->body = MergeNest( - {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); @@ -229,7 +214,6 @@ PrimFunc MakePackedAPI(PrimFunc&& func, LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str(); } - func_ptr->buffer_map = Map(); func_ptr->checked_type_ = func_ptr->func_type_annotation(); func_ptr->ret_type = PrimType(DataType::Int(32)); @@ -248,9 +232,8 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDefault) { + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); updates.push_back({kv.first, updated_func}); } @@ -263,12 +246,10 @@ Pass MakePackedAPI(int num_unpacked_args) { return m; }; - return tvm::transform::CreateModulePass( - pass_func, 0, "tir.MakePackedAPI", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI") -.set_body_typed(MakePackedAPI); +TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed(MakePackedAPI); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 796e39b..ad86e45 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -22,9 +22,10 @@ * \brief narrow the datatype of indexing vars */ +#include #include #include -#include + #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" @@ -55,8 +56,8 @@ namespace tir { // - Use DataTypeRewritter to rewrite the components of an indexing expression. using arith::Analyzer; -using arith::IRMutatorWithAnalyzer; using arith::ConstIntBound; +using arith::IRMutatorWithAnalyzer; // Determine the result dtype for Var, IntImm and Cast, // which will be stored in `vmap` eventually. @@ -70,8 +71,7 @@ using arith::ConstIntBound; // Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()` class DataTypeVisitor final : public StmtExprVisitor { public: - explicit DataTypeVisitor(int target_bits) - : bits_(target_bits), target_bits_(target_bits) {} + explicit DataTypeVisitor(int target_bits) : bits_(target_bits), target_bits_(target_bits) {} void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { @@ -86,7 +86,7 @@ class DataTypeVisitor final : public StmtExprVisitor { (bound->max_value <= ubound && bound->min_value >= lbound)) { bits = target_bits_; } - int tmp = bits > bits_ ? bits : bits_; + int tmp = bits > bits_ ? bits : bits_; std::swap(bits_, tmp); StmtExprVisitor::VisitExpr(e); std::swap(bits_, tmp); @@ -96,19 +96,16 @@ class DataTypeVisitor final : public StmtExprVisitor { } void VisitStmt_(const ForNode* op) { - analyzer_.Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); + analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); vextent_[op->loop_var.as()] = op->extent.dtype(); return StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, - Range::make_by_min_extent(0, op->value)); + analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); vextent_[iv->var.as()] = op->value.dtype(); StmtExprVisitor::VisitStmt_(op); } else { @@ -191,7 +188,7 @@ class DataTypeVisitor final : public StmtExprVisitor { class DataTypeRewriter : public StmtExprMutator { public: - explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {} + explicit DataTypeRewriter(int target_bits) : visitor_(target_bits) {} Stmt operator()(Stmt s) { visitor_(s); @@ -211,19 +208,15 @@ class DataTypeRewriter : public StmtExprMutator { is_index_ = true; PrimExpr index = this->VisitExpr(op->index); is_index_ = false; - Stmt s = StoreNode::make(op->buffer_var, - op->value, - index, - op->predicate); + Stmt s = StoreNode::make(op->buffer_var, op->value, index, op->predicate); return StmtExprMutator::VisitStmt_(s.as()); } Stmt VisitStmt_(const ForNode* op) final { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); - CHECK(op != nullptr) - << "Expected type to be ForNode" - << ", but get " << s->GetTypeKey(); + CHECK(op != nullptr) << "Expected type to be ForNode" + << ", but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), @@ -231,27 +224,20 @@ class DataTypeRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); - CHECK(op != nullptr) - << "Expected type to be AttrStmtNode" - << ", but get " << s->GetTypeKey(); + CHECK(op != nullptr) << "Expected type to be AttrStmtNode" + << ", but get " << s->GetTypeKey(); const IterVarNode* iv = op->node.as(); - CHECK(iv != nullptr) - << "Expected type to be IterVarNode" - << ", but get " << op->node->GetTypeKey(); + CHECK(iv != nullptr) << "Expected type to be IterVarNode" + << ", but get " << op->node->GetTypeKey(); PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag); } - return AttrStmtNode::make( - ivmap_[iv], - op->attr_key, - cast(var.dtype(), op->value), - op->body); + return AttrStmtNode::make(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); } return StmtExprMutator::VisitStmt_(op); } @@ -297,9 +283,8 @@ class DataTypeRewriter : public StmtExprMutator { if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { PrimExpr e = StmtExprMutator::VisitExpr_(op); const CastNode* new_op = e.as(); - CHECK(new_op != nullptr) - << "Expected type to be CastNode" - << ", but get " << e->GetTypeKey(); + CHECK(new_op != nullptr) << "Expected type to be CastNode" + << ", but get " << e->GetTypeKey(); return CastNode::make(visitor_.vmap[op], new_op->value); } return StmtExprMutator::VisitExpr_(op); @@ -335,40 +320,38 @@ class DataTypeRewriter : public StmtExprMutator { bool is_index_{false}; }; -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return FUNC(a, b); \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return FUNC(a, b); \ + } \ } -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); - CHECK(op != nullptr) - << "Expected type to be CallNode" - << ", but get " << e->GetTypeKey(); + CHECK(op != nullptr) << "Expected type to be CallNode" + << ", but get " << e->GetTypeKey(); if (op->call_type == CallNode::PureIntrinsic) { if (op->name == intrinsic::tvm_if_then_else) { return if_then_else(op->args[0], op->args[1], op->args[2]); @@ -389,9 +372,7 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { return e; } -Stmt NarrowDataType(Stmt stmt, int target_bits) { - return DataTypeRewriter(target_bits)(stmt); -} +Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); } namespace transform { @@ -401,12 +382,10 @@ Pass NarrowDataType(int target_bits) { n->body = DataTypeRewriter(target_bits)(std::move(n->body)); return f; }; - return CreatePrimFuncPass( - pass_func, 0, "tir.NarrowDataType", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType") -.set_body_typed(NarrowDataType); +TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index fdcfc4d..efb9e69 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -20,12 +20,12 @@ /*! * \file remap_thread_axis.cc */ +#include #include #include #include -#include -#include +#include namespace tvm { namespace tir { @@ -33,14 +33,9 @@ namespace tir { // Mutator to change the read pattern class ThreadAxisRewriter : private StmtExprMutator { public: - explicit ThreadAxisRewriter( - const std::unordered_map& tmap) - : tmap_(tmap) { - } + explicit ThreadAxisRewriter(const std::unordered_map& tmap) : tmap_(tmap) {} - Stmt Rewrite(Stmt stmt) { - return operator()(std::move(stmt)); - } + Stmt Rewrite(Stmt stmt) { return operator()(std::move(stmt)); } private: Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -57,8 +52,7 @@ class ThreadAxisRewriter : private StmtExprMutator { CHECK(vmap_[v].same_as(new_iv->var)); } Stmt body = this->VisitStmt(op->body); - return AttrStmtNode::make( - new_iv, op->attr_key, op->value, body); + return AttrStmtNode::make(new_iv, op->attr_key, op->value, body); } } return StmtExprMutator::VisitStmt_(op); @@ -75,7 +69,6 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; - PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { @@ -83,8 +76,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) } auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - CHECK(opt_thread_axis != nullptr) - << "Require attribute " << tir::attr::kDeviceThreadAxis; + CHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis; auto thread_axis = opt_thread_axis.value(); auto* n = f.CopyOnWrite(); @@ -99,7 +91,6 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis); } - namespace transform { Pass RemapThreadAxis(Map thread_map) { @@ -109,8 +100,7 @@ Pass RemapThreadAxis(Map thread_map) { return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis") -.set_body_typed(RemapThreadAxis); +TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index ceaf27b..15a7e86 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -22,11 +22,12 @@ * \brief Remove no op from the stmt */ #include -#include #include #include -#include +#include #include +#include + #include namespace tvm { @@ -105,7 +106,7 @@ class NoOpRemover : public StmtMutator { auto n = CopyOnWrite(op); size_t top = 0; for (size_t i = 0; i < n->seq.size(); ++i) { - if (!is_no_op(n->seq[i])) { + if (!is_no_op(n->seq[i])) { n->seq.Set(top++, n->seq[i]); } } @@ -147,9 +148,7 @@ class NoOpRemover : public StmtMutator { } }; -Stmt RemoveNoOp(Stmt stmt) { - return NoOpRemover()(std::move(stmt)); -} +Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); } namespace transform { @@ -162,8 +161,7 @@ Pass RemoveNoOp() { return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp") -.set_body_typed(RemoveNoOp); +TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp); } // namespace transform diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 6052cbf..149cda9 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -29,16 +29,13 @@ namespace tvm { namespace tir { - // For now, rewrite unsafe select expression to if_then_else // TODO(tqchen) pattern matching to support masked load class UnsafeExprDetector : public ExprFunctor { public: // select itself is always considered safe if condition is safe // Because we will issue guard to make sure it is. - bool VisitExpr_(const SelectNode* op) { - return VisitExpr(op->condition); - } + bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); } bool VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { return VisitExpr(op->args[0]); @@ -75,21 +72,11 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const GENode* op) final { return BinaryOp(op); } bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); } bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); } - bool VisitExpr_(const NotNode* op) final { - return VisitExpr(op->a); - } - bool VisitExpr_(const LetNode* op) final { - return VisitExpr(op->body) || VisitExpr(op->value); - } - bool VisitExpr_(const CastNode* op) final { - return VisitExpr(op->value); - } - bool VisitExpr_(const BroadcastNode* op) final { - return VisitExpr(op->value); - } - bool VisitExpr_(const RampNode* op) final { - return VisitExpr(op->base) && VisitExpr(op->stride); - } + bool VisitExpr_(const NotNode* op) final { return VisitExpr(op->a); } + bool VisitExpr_(const LetNode* op) final { return VisitExpr(op->body) || VisitExpr(op->value); } + bool VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } + bool VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); } + bool VisitExpr_(const RampNode* op) final { return VisitExpr(op->base) && VisitExpr(op->stride); } bool VisitExpr_(const ShuffleNode* op) final { for (PrimExpr e : op->vectors) { if (VisitExpr(e)) return true; @@ -102,7 +89,7 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const StringImmNode* op) final { return false; } private: - template + template bool BinaryOp(const T* op) { return VisitExpr(op->a) || VisitExpr(op->b); } @@ -115,23 +102,17 @@ class UnsafeSelectRewriter : public StmtExprMutator { op = expr.as(); UnsafeExprDetector unsafe; bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); - if ((unsafe.VisitExpr(op->true_value) || - unsafe.VisitExpr(op->false_value)) && + if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) && cond_is_scalar_bool) { - return CallNode::make( - op->dtype, - intrinsic::tvm_if_then_else, - {op->condition, op->true_value, op->false_value}, - CallNode::Intrinsic); + return CallNode::make(op->dtype, intrinsic::tvm_if_then_else, + {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic); } else { return expr; } } }; -Stmt RewriteUnsafeSelect(Stmt stmt) { - return UnsafeSelectRewriter()(std::move(stmt)); -} +Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter()(std::move(stmt)); } namespace transform { @@ -144,8 +125,7 @@ Pass RewriteUnsafeSelect() { return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect") -.set_body_typed(RewriteUnsafeSelect); +TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect); } // namespace transform diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 752939e..759b320 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -21,14 +21,13 @@ * \file simplify.cc * \brief Statement simplifier based on analyzer */ +#include #include +#include #include +#include #include -#include -#include -#include -#include #include "../../arith/ir_mutator_with_analyzer.h" namespace tvm { @@ -38,20 +37,15 @@ using namespace tir; class StmtSimplifier : public IRMutatorWithAnalyzer { public: - explicit StmtSimplifier(Analyzer* analyzer) - : IRMutatorWithAnalyzer(analyzer) {} + explicit StmtSimplifier(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} using Parent = IRMutatorWithAnalyzer; using Parent::VisitStmt; using Parent::VisitStmt_; - PrimExpr VisitExpr(const PrimExpr& expr) final { - return analyzer_->Simplify(expr); - } + PrimExpr VisitExpr(const PrimExpr& expr) final { return analyzer_->Simplify(expr); } - Stmt Simplify(Stmt stmt) { - return operator()(std::move(stmt)); - } + Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } Stmt VisitStmt_(const ForNode* op) final { analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); @@ -69,8 +63,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return this->VisitStmt(op->body); } Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -109,8 +102,7 @@ Pass Simplify() { return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); } -TVM_REGISTER_GLOBAL("tir.transform.Simplify") -.set_body_typed(Simplify); +TVM_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify); } // namespace transform diff --git a/src/tir/transforms/skip_assert.cc b/src/tir/transforms/skip_assert.cc index 4511838..d9cd6d3 100644 --- a/src/tir/transforms/skip_assert.cc +++ b/src/tir/transforms/skip_assert.cc @@ -17,10 +17,10 @@ * under the License. */ +#include #include -#include #include -#include +#include namespace tvm { namespace tir { @@ -34,9 +34,7 @@ class AssertSkipper : public StmtMutator { } }; -Stmt SkipAssert(Stmt stmt) { - return AssertSkipper()(std::move(stmt)); -} +Stmt SkipAssert(Stmt stmt) { return AssertSkipper()(std::move(stmt)); } namespace transform { @@ -49,8 +47,7 @@ Pass SkipAssert() { return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SkipAssert") -.set_body_typed(SkipAssert); +TVM_REGISTER_GLOBAL("tir.transform.SkipAssert").set_body_typed(SkipAssert); } // namespace transform diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 44f032f..9bdb0e2 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -22,14 +22,14 @@ * \brief Split device function from host. */ #include -#include -#include -#include +#include +#include +#include #include +#include +#include #include -#include -#include -#include +#include #include @@ -69,13 +69,11 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); Stmt body = this->VisitStmt(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && - !HasSideEffect(op->value)) { + if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && - value.same_as(op->value)) { + if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); } else { return LetStmtNode::make(op->var, value, body); @@ -102,13 +100,11 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); PrimExpr body = this->VisitExpr(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && - !HasSideEffect(op->value)) { + if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && - value.same_as(op->value)) { + if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); @@ -127,12 +123,10 @@ class VarUseDefAnalysis : public StmtExprMutator { } void HandleDef(const VarNode* v) { - CHECK(!def_count_.count(v)) - << "variable " << v->name_hint - << " has already been defined, the Stmt is not SSA"; - CHECK(!use_count_.count(v)) - << "variable " << v->name_hint - << " has been used before definition!"; + CHECK(!def_count_.count(v)) << "variable " << v->name_hint + << " has already been defined, the Stmt is not SSA"; + CHECK(!use_count_.count(v)) << "variable " << v->name_hint + << " has been used before definition!"; use_count_[v] = 0; def_count_[v] = 1; } @@ -161,7 +155,6 @@ class VarUseDefAnalysis : public StmtExprMutator { std::unordered_map def_count_; }; - Array UndefinedVars(const Stmt& stmt, const Array& args) { VarUseDefAnalysis m; for (Var arg : args) { @@ -171,16 +164,10 @@ Array UndefinedVars(const Stmt& stmt, const Array& args) { return m.undefined_; } - class HostDeviceSplitter : public StmtMutator { public: - explicit HostDeviceSplitter(IRModule* device_mod, - Target device_target, - std::string name_prefix) - : device_mod_(device_mod), - device_target_(device_target), - name_prefix_(name_prefix) { - } + explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) + : device_mod_(device_mod), device_target_(device_target), name_prefix_(name_prefix) {} Stmt VisitStmt_(const AllocateNode* op) final { handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); @@ -188,8 +175,7 @@ class HostDeviceSplitter : public StmtMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::pipeline_exec_scope || + if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { return SplitDeviceFunc(GetRef(op)); } @@ -216,8 +202,7 @@ class HostDeviceSplitter : public StmtMutator { // Create a new version of v. auto it = handle_data_type_.find(var.get()); if (it != handle_data_type_.end()) { - tir::Var new_var(var->name_hint, - PointerType(PrimType((*it).second->dtype))); + tir::Var new_var(var->name_hint, PointerType(PrimType((*it).second->dtype))); params.push_back(new_var); remap_vars.Set(var, new_var); } else { @@ -237,8 +222,8 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_); device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); - device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, - runtime::String(kernel_symbol)); + device_func = + WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func); @@ -252,9 +237,8 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } - return EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed, - call_args, CallNode::Intrinsic)); + return EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, + call_args, CallNode::Intrinsic)); } // target ir module @@ -268,19 +252,15 @@ class HostDeviceSplitter : public StmtMutator { std::unordered_map handle_data_type_; }; - PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "SplitHostDevice: Require the target attribute"; + CHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; - HostDeviceSplitter splitter( - device_mod, - target.value(), - static_cast(global_symbol.value())); + HostDeviceSplitter splitter(device_mod, target.value(), + static_cast(global_symbol.value())); auto* n = func.CopyOnWrite(); n->body = splitter(std::move(n->body)); @@ -289,7 +269,6 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { return std::move(func); } - namespace transform { Pass SplitHostDevice() { @@ -308,12 +287,10 @@ Pass SplitHostDevice() { return mod; }; - return tvm::transform::CreateModulePass( - pass_func, 0, "tir.SplitHostDevice", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice") -.set_body_typed(SplitHostDevice); +TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 1f28e13..35888bd 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -20,12 +20,15 @@ /*! * \file storage_access.cc */ +#include "storage_access.h" + #include + #include #include -#include "storage_access.h" -#include "ir_util.h" + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -89,8 +92,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = - StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::make(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == attr::double_buffer_write) { CHECK(double_buffer_write_ == nullptr); @@ -145,8 +147,8 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) { if (s.access.size() != 0) { // relax the touched set to contain all ranges in the loop. std::unordered_map relax_map; - relax_map[op->loop_var.get()] = arith::IntSet::range( - Range::make_by_min_extent(op->min, op->extent)); + relax_map[op->loop_var.get()] = + arith::IntSet::range(Range::make_by_min_extent(op->min, op->extent)); for (AccessEntry& e : s.access) { if (e.buffer.defined()) { CHECK(e.touched.defined()); @@ -180,7 +182,7 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { void StorageAccessVisitor::VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); StmtExprVisitor::VisitExpr_(l); } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); @@ -197,8 +199,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { e.threads = env_threads(); e.dtype = dtype; e.buffer = Downcast(op->args[1]); - e.touched = arith::IntSet::range( - Range::make_by_min_extent(offset, extent)); + e.touched = arith::IntSet::range(Range::make_by_min_extent(offset, extent)); e.scope = scope; if (flag->value & 1) { e.type = kRead; diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 12e76bd..80bbff4 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -24,19 +24,21 @@ #ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ #define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ +#include #include #include -#include #include -#include + #include +#include + #include "../../runtime/thread_storage_scope.h" namespace tvm { namespace tir { -using runtime::StorageScope; using runtime::StorageRank; +using runtime::StorageScope; /*! * \brief Base class of storage access analysis */ @@ -85,31 +87,20 @@ class StorageAccessVisitor : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final; protected: - StorageAccessVisitor() { - scope_.push_back(std::vector()); - } + StorageAccessVisitor() { scope_.push_back(std::vector()); } /*! \return number of conditions in the current scope. */ - int condition_counter() const { - return condition_counter_; - } + int condition_counter() const { return condition_counter_; } /*! \return whether we are in device environment. */ - bool in_device_env() const { - return in_device_env_; - } + bool in_device_env() const { return in_device_env_; } /*! \return environment threads */ - const Array& env_threads() const { - return env_threads_; - } + const Array& env_threads() const { return env_threads_; } /*! * \brief Whether we need analyze the buffer in current scope. * \param buffer The buffer to be checked * \param scope The scope of the buffer. * \return Whether the analysis of buffer is enabled. */ - virtual bool Enabled(const VarNode* buffer, - const StorageScope& scope) const { - return true; - } + virtual bool Enabled(const VarNode* buffer, const StorageScope& scope) const { return true; } /*! * \brief Summarize the sequence of operations into parent. * @@ -121,8 +112,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \return The summarized sequence that represent access that * the parent should taken care of to synchronize. */ - virtual std::vector Summarize( - std::vector seq, const ForNode* loop) = 0; + virtual std::vector Summarize(std::vector seq, const ForNode* loop) = 0; /*! * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9668695..96d0e30 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -23,40 +23,39 @@ */ // The pass definition originates from Halide pipeline. -#include #include +#include +#include +#include +#include +#include #include +#include #include -#include #include -#include #include -#include -#include -#include + #include -#include "ir_util.h" -#include "arg_binder.h" + #include "../../arith/compute_expr.h" #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" +#include "arg_binder.h" +#include "ir_util.h" namespace tvm { namespace tir { +using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; -using intrinsic::tvm_address_of; class StorageFlattener : public StmtExprMutator { public: - explicit StorageFlattener(const Map& extern_buffer_map, - int cache_line_size, - bool create_bound_attributes, - IRVisitorWithAnalyzer* bound_analyzer) - : bound_analyzer_(bound_analyzer), - create_bound_attributes_(create_bound_attributes) { + explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, + bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { for (auto kv : extern_buffer_map) { BufferEntry e; e.buffer = kv.second; @@ -70,8 +69,7 @@ class StorageFlattener : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && - !it->second.same_as(op->buffer_var)) { + if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); Var buf_var = Downcast(it->second); return StoreNode::make(buf_var, op->value, op->index, op->predicate); @@ -89,10 +87,8 @@ class StorageFlattener : public StmtExprMutator { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); auto it = buf_map_.find(buffer); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << buffer; - body = AttrStmtNode::make( - it->second.buffer->data, op->attr_key, op->value, std::move(body)); + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; + body = AttrStmtNode::make(it->second.buffer->data, op->attr_key, op->value, std::move(body)); return body; } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -129,31 +125,24 @@ class StorageFlattener : public StmtExprMutator { const auto& key = op->buffer; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + CHECK(!e.released) << "Read a buffer that is already out of scope"; if (is_opengl_) { - return EvaluateNode::make(CallNode::make( - DataType(), - CallNode::glsl_texture_store, - {e.buffer->data, op->value}, - CallNode::Intrinsic)); + return EvaluateNode::make(CallNode::make(DataType(), CallNode::glsl_texture_store, + {e.buffer->data, op->value}, CallNode::Intrinsic)); } else { Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - shape_collector_.push_back( - std::make_pair(e.buffer->data, e.buffer->shape)); + shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } // To create bound attribute collector should has at least one item. if (create_bound_attributes_ && shape_collector_.size()) { for (size_t i = 0; i < shape_collector_.size(); ++i) { - body = AttrStmtNode::make( - shape_collector_[i].first, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, shape_collector_[i].second), body); + body = AttrStmtNode::make(shape_collector_[i].first, tir::attr::buffer_bound, + MakeBound(e.buffer->dtype, shape_collector_[i].second), body); } } return body; @@ -176,14 +165,12 @@ class StorageFlattener : public StmtExprMutator { } // deduce current storage scope. auto it = storage_scope_.find(op->buffer.get()); - CHECK(it != storage_scope_.end()) - << "Cannot find storage scope of " << op->buffer; + CHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; StorageScope skey; const std::string& strkey = it->second; if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { - skey.rank = runtime::DefaultStorageRank( - curr_thread_scope_.back().rank); + skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); } } else { skey = StorageScope::make(strkey); @@ -221,11 +208,9 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = BufferNode::make( - Var(op->buffer->data->name_hint, DataType::Handle()), - op->buffer->dtype, shape, strides, PrimExpr(), - op->buffer->name, skey.to_string(), - align, 0, kDefault); + e.buffer = BufferNode::make(Var(op->buffer->data->name_hint, DataType::Handle()), + op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + skey.to_string(), align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); @@ -240,26 +225,23 @@ class StorageFlattener : public StmtExprMutator { } if (strides.size() != 0) { int first_dim = 0; - ret = AllocateNode::make( - e.buffer->data, storage_type, - {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); + ret = AllocateNode::make(e.buffer->data, storage_type, + {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, + make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } else { shape = e.buffer->shape; if (shape.size() == 0) { shape.push_back(make_const(DataType::Int(32), 1)); } - ret = AllocateNode::make( - e.buffer->data, storage_type, shape, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); + ret = AllocateNode::make(e.buffer->data, storage_type, shape, + make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = AttrStmtNode::make( - e.buffer->data, attr::storage_scope, - StringImmNode::make(e.buffer->scope), ret); + ret = AttrStmtNode::make(e.buffer->data, attr::storage_scope, + StringImmNode::make(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, e.buffer->shape), ret); + MakeBound(e.buffer->dtype, e.buffer->shape), ret); } return ret; } @@ -269,8 +251,7 @@ class StorageFlattener : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && - !it->second.same_as(op->buffer_var)) { + if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); Var buf_var = Downcast(it->second); return LoadNode::make(op->dtype, buf_var, op->index, op->predicate); @@ -295,38 +276,31 @@ class StorageFlattener : public StmtExprMutator { const auto& key = op->buffer; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + CHECK(!e.released) << "Read a buffer that is already out of scope"; if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - shape_collector_.push_back( - std::make_pair(e.buffer->data, e.buffer->shape)); + shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype); } - - Stmt VisitStmt_(const PrefetchNode *op) final { + Stmt VisitStmt_(const PrefetchNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); CHECK(op != nullptr); const auto& key = op->buffer; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + CHECK(!e.released) << "Read a buffer that is already out of scope"; CHECK_EQ(e.buffer->shape.size(), op->bounds.size()) - << "Prefetch dim should be the same as buffer dim"; + << "Prefetch dim should be the same as buffer dim"; - int block_size = 1, - elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(); + int block_size = 1, elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(); int starts = op->bounds.size() - 1; @@ -344,25 +318,23 @@ class StorageFlattener : public StmtExprMutator { for (int i = op->bounds.size() - 1; i > starts; --i) { args.push_back(op->bounds[i]->min); } - auto &func_name = op->buffer->name; - vars.push_back(Var( - "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); + auto& func_name = op->buffer->name; + vars.push_back(Var("prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); args.push_back(op->bounds[starts]->min + stride * vars.back()); for (int i = starts - 1; i >= 0; --i) { - vars.push_back(Var( - "prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); + vars.push_back(Var("prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); args.push_back(vars.back() + op->bounds[i]->min); } for (int i = starts; i >= 0; --i) { if (i < starts) { - stmt = ForNode::make( - vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); + stmt = ForNode::make(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, + stmt); } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); - PrimExpr address = CallNode::make( - DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); - PrimExpr prefetch = CallNode::make( - op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); + PrimExpr address = + CallNode::make(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); + PrimExpr prefetch = CallNode::make(op->buffer->dtype, CallNode::prefetch, + {address, 0, 3, 1}, CallNode::Intrinsic); stmt = EvaluateNode::make(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); @@ -372,9 +344,8 @@ class StorageFlattener : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - CHECK(op->call_type != CallNode::Halide) - << "Cannot handle Halide calls " - << " please run SchedulePostProcToPrimFunc first"; + CHECK(op->call_type != CallNode::Halide) << "Cannot handle Halide calls " + << " please run SchedulePostProcToPrimFunc first"; return StmtExprMutator::VisitExpr_(op); } @@ -390,7 +361,6 @@ class StorageFlattener : public StmtExprMutator { return Stmt(); } - private: // The specific tensor data layout is not determined before // StorageFlatten pass. We use buffer_bind_scope @@ -427,7 +397,7 @@ class StorageFlattener : public StmtExprMutator { // We do support a few relaxed case, such as bindingx // region with shape [1, 1, n, m] to buffer with shape [n, m] Stmt HandleBufferBindScope(const AttrStmtNode* op) { - Array arr = Downcast > (op->node); + Array arr = Downcast>(op->node); CHECK_EQ(arr.size(), 2U); const BufferNode* buffer = arr[0].as(); const BufferNode* target = arr[1].as(); @@ -437,8 +407,7 @@ class StorageFlattener : public StmtExprMutator { auto key = GetRef(target); auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find buffer of " << key; + CHECK(it != buf_map_.end()) << "Cannot find buffer of " << key; const BufferEntry& be = it->second; CHECK(!be.released); CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); @@ -452,15 +421,14 @@ class StorageFlattener : public StmtExprMutator { } else { for (size_t i = 0; i < tuple->args.size(); i += 2) { begins.push_back(tuple->args[i]); - auto new_extent = bound_analyzer_->Simplify(tuple->args[i+1]); + auto new_extent = bound_analyzer_->Simplify(tuple->args[i + 1]); extents.push_back(new_extent); } } Buffer slice = be.buffer.MakeSlice(begins, extents); if (buffer->strides.size() == 0) { CHECK_EQ(slice->strides.size(), 0U) - << "Trying to bind compact buffer to strided one strides=" - << slice->strides; + << "Trying to bind compact buffer to strided one strides=" << slice->strides; } else { slice = slice.MakeStrideView(); } @@ -508,26 +476,24 @@ class StorageFlattener : public StmtExprMutator { } }; - bool ShapeIsValid(const Array &shape) { + bool ShapeIsValid(const Array& shape) { // Zero-dimensional tensor does not need boundary check. - if (!shape.size()) - return false; + if (!shape.size()) return false; for (size_t i = 0; i < shape.size(); ++i) { - if (!shape[i].defined() || !shape[i].dtype().is_scalar() || - is_negative_const(shape[i])) { + if (!shape[i].defined() || !shape[i].dtype().is_scalar() || is_negative_const(shape[i])) { return false; } } return true; } - PrimExpr MakeBound(const DataType &type, const Array &shape) { + PrimExpr MakeBound(const DataType& type, const Array& shape) { // We have already checked the shape size to be greater then 0. PrimExpr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]); for (size_t i = 1; i < shape.size(); ++i) { - bound = MulNode::make( - bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i])); + bound = + MulNode::make(bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i])); } return bound; } @@ -538,8 +504,7 @@ class StorageFlattener : public StmtExprMutator { // Buffer map std::unordered_map buf_map_; // Dimension alignment - std::unordered_map, - ObjectHash, ObjectEqual> dim_align_; + std::unordered_map, ObjectHash, ObjectEqual> dim_align_; // Storage scope std::unordered_map storage_scope_; // The current thread scope. @@ -557,35 +522,27 @@ class StorageFlattener : public StmtExprMutator { bool create_bound_attributes_{false}; }; -PrimFunc StorageFlatten(PrimFunc func, - int cache_line_size, - bool create_bound_attributes) { +PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { auto fptr = func.CopyOnWrite(); IRVisitorWithAnalyzer bound_analyzer; bound_analyzer(fptr->body); - fptr->body = StorageFlattener(fptr->buffer_map, - cache_line_size, - create_bound_attributes, + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer)(std::move(fptr->body)); return func; } - namespace transform { // TODO(tvm-team): consolidate configs to the PassContext -Pass StorageFlatten(int cache_line_size, - bool create_bound_attributes) { +Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return StorageFlatten( - std::move(f), cache_line_size, create_bound_attributes); + return StorageFlatten(std::move(f), cache_line_size, create_bound_attributes); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {}); } -TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten") -.set_body_typed(StorageFlatten); +TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten").set_body_typed(StorageFlatten); } // namespace transform diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index ca2b5a9..fc86f2b 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -22,19 +22,21 @@ * \brief Memory access pattern analysis and optimization. * Re-write data access to enable memory sharing when possible. */ -#include #include -#include -#include +#include +#include #include +#include #include -#include +#include + #include -#include #include -#include "ir_util.h" +#include + #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -125,8 +127,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { const VarNode* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - CHECK_LT(it->second.level, scope_.size()) - << "Load memory in places other than store."; + CHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; scope_[it->second.level].touched.push_back(buf); } } @@ -142,24 +143,23 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - CHECK_LT(it->second.level, scope_.size()) - << " buf=" << buf->name_hint; + CHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; scope_[it->second.level].touched.push_back(buf); } } - template + template void VisitNewScope(const T* op) { scope_.push_back(StmtEntry()); StmtEntry e; e.stmt = op; - int64_t begin_index = static_cast(linear_seq_.size()); + int64_t begin_index = static_cast(linear_seq_.size()); // before scope. linear_seq_.push_back(e); StmtExprVisitor::VisitStmt_(op); // after scope. e.touched = std::move(scope_.back().touched); scope_.pop_back(); - int64_t end_index = static_cast(linear_seq_.size()); + int64_t end_index = static_cast(linear_seq_.size()); CHECK_GT(end_index, begin_index); e.scope_pair_offset = begin_index - end_index; linear_seq_.push_back(e); @@ -179,24 +179,17 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = - StorageScope::make(op->value.as()->value); + alloc_info_[buf].storage_scope = StorageScope::make(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } } - void VisitStmt_(const IfThenElseNode* op) final { - VisitNewScope(op); - } + void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } - void VisitStmt_(const ForNode* op) final { - VisitNewScope(op); - } + void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } - void VisitStmt_(const AssertStmtNode* op) final { - VisitNewScope(op); - } + void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } // linearized access sequence. std::vector linear_seq_; @@ -238,9 +231,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { // class InplaceOpVerifier : public StmtExprVisitor { public: - bool Check(const Object* stmt, - const VarNode* dst, - const VarNode* src) { + bool Check(const Object* stmt, const VarNode* dst, const VarNode* src) { dst_ = dst; src_ = src; result_ = true; @@ -272,7 +263,8 @@ class InplaceOpVerifier : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { // assume all opaque access is unsafe if (op == dst_ || op == src_) { - result_ = false; return; + result_ = false; + return; } } @@ -293,9 +285,9 @@ class InplaceOpVerifier : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { // always reject extern code - if (op->attr_key == attr::extern_scope || - op->attr_key == attr::volatile_scope) { - result_ = false; return; + if (op->attr_key == attr::extern_scope || op->attr_key == attr::volatile_scope) { + result_ = false; + return; } StmtExprVisitor::VisitStmt_(op); } @@ -304,17 +296,19 @@ class InplaceOpVerifier : public StmtExprVisitor { const VarNode* buf = op->buffer_var.get(); // cannot read from dst_ (no reduction) if (buf == dst_) { - result_ = false; return; + result_ = false; + return; } // do not allow indirect memory load if (mem_nest_ != 0) { - result_ = false; return; + result_ = false; + return; } if (src_ == buf) { - if (store_ == nullptr || - store_->value.dtype() != op->dtype || + if (store_ == nullptr || store_->value.dtype() != op->dtype || !tir::ExprDeepEqual()(store_->index, op->index)) { - result_ = false; return; + result_ = false; + return; } } ++mem_nest_; @@ -322,7 +316,6 @@ class InplaceOpVerifier : public StmtExprVisitor { --mem_nest_; } - private: // result of the check bool result_{true}; @@ -358,10 +351,9 @@ class StoragePlanRewriter : public StmtExprMutator { for (StorageEntry* e : attach_map_.at(nullptr)) { // CHECK_EQ(e->scope.rank, 0); if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmtNode::make( - e->alloc_var, attr::storage_scope, - StringImmNode::make(e->scope.to_string()), - EvaluateNode::make(0))); + nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope, + StringImmNode::make(e->scope.to_string()), + EvaluateNode::make(0))); nest.push_back(e->new_alloc); } } @@ -374,20 +366,16 @@ class StoragePlanRewriter : public StmtExprMutator { op = stmt.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return stmt; - return StoreNode::make(it->second->alloc_var, - op->value, - RemapIndex(op->value.dtype(), op->index, it->second), - op->predicate); + return StoreNode::make(it->second->alloc_var, op->value, + RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); } PrimExpr VisitExpr_(const LoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return expr; - return LoadNode::make(op->dtype, - it->second->alloc_var, - RemapIndex(op->dtype, op->index, it->second), - op->predicate); + return LoadNode::make(op->dtype, it->second->alloc_var, + RemapIndex(op->dtype, op->index, it->second), op->predicate); } PrimExpr VisitExpr_(const VarNode* op) final { auto it = alloc_map_.find(op); @@ -417,10 +405,9 @@ class StoragePlanRewriter : public StmtExprMutator { if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } - return CallNode::make( - op->dtype, op->name, - {op->args[0], se->alloc_var, offset, extent, op->args[4]}, - op->call_type); + return CallNode::make(op->dtype, op->name, + {op->args[0], se->alloc_var, offset, extent, op->args[4]}, + op->call_type); } else { return StmtExprMutator::VisitExpr_(op); } @@ -429,17 +416,14 @@ class StoragePlanRewriter : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread || + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return AttrStmtNode::make( - op->node, op->attr_key, op->value, - MakeAttach(svec, op->body)); + return AttrStmtNode::make(op->node, op->attr_key, op->value, MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } @@ -448,31 +432,26 @@ class StoragePlanRewriter : public StmtExprMutator { op = stmt.as(); auto it = alloc_map_.find(op->node.as()); if (it == alloc_map_.end()) return stmt; - return AttrStmtNode::make( - it->second->alloc_var, op->attr_key, op->value, op->body); + return AttrStmtNode::make(it->second->alloc_var, op->attr_key, op->value, op->body); } else { return StmtExprMutator::VisitStmt_(op); } } Stmt VisitStmt_(const ForNode* op) final { - CHECK(op->for_type != ForType::Vectorized) - << "VectorizeLoop before LiftStorageAlloc"; + CHECK(op->for_type != ForType::Vectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return ForNode::make( - op->loop_var, op->min, op->extent, op->for_type, op->device_api, - MakeAttach(svec, op->body)); + return ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, + MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } } - Stmt VisitStmt_(const AllocateNode* op) final { - return this->VisitStmt(op->body); - } + Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); } private: struct StorageEntry { @@ -517,15 +496,13 @@ class StoragePlanRewriter : public StmtExprMutator { std::vector kill; }; - Stmt MakeAttach(const std::vector& svec, - Stmt body) { + Stmt MakeAttach(const std::vector& svec, Stmt body) { std::vector nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmtNode::make( - e->alloc_var, attr::storage_scope, - StringImmNode::make(e->scope.to_string()), - EvaluateNode::make(0))); + nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope, + StringImmNode::make(e->scope.to_string()), + EvaluateNode::make(0))); nest.push_back(e->new_alloc); } } @@ -545,15 +522,14 @@ class StoragePlanRewriter : public StmtExprMutator { attach_map_[e->attach_scope_].push_back(e); } // find allocation via attach map. - for (auto &kv : attach_map_) { + for (auto& kv : attach_map_) { // find the element with the most amount of bytes. std::vector& vec = kv.second; // try to find merge, for tagged memory for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; if (e->scope.tag.length() != 0) { - CHECK_NE(e->const_nbits, 0U) - << "Special tagged memory must be const size"; + CHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { vec[j]->merged_children.push_back(e); @@ -568,7 +544,8 @@ class StoragePlanRewriter : public StmtExprMutator { // already merged if (e->bits_offset != 0) continue; if (e->merged_children.size() != 0) { - NewAllocTagMerged(e); continue; + NewAllocTagMerged(e); + continue; } // Get the allocation size; e->alloc_var = e->allocs[0]->buffer_var; @@ -581,10 +558,9 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->allocs.size() == 1) { // simply use the original allocation. PrimExpr sz = arith::ComputeReduce(e->allocs[0]->extents, - make_const(DataType::Int(32), 1)); - e->new_alloc = AllocateNode::make( - e->alloc_var, alloc_type, {sz}, - e->allocs[0]->condition, EvaluateNode::make(0)); + make_const(DataType::Int(32), 1)); + e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, + EvaluateNode::make(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -595,13 +571,12 @@ class StoragePlanRewriter : public StmtExprMutator { // Build a merged allocation PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { - PrimExpr sz = arith::ComputeReduce( - op->extents, make_const(DataType::Int(32), 1)); + PrimExpr sz = + arith::ComputeReduce(op->extents, make_const(DataType::Int(32), 1)); auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { - LOG(WARNING) << "The allocation requires : " << imm->value - << " * " << nbits + LOG(WARNING) << "The allocation requires : " << imm->value << " * " << nbits << " bits, which is greater than the maximum of" " int32. The size is cast to int64." << "\n"; @@ -625,9 +600,8 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = combo_size + make_const(DataType::Int(32), 1); } combo_size = analyzer_.Simplify(combo_size); - e->new_alloc = AllocateNode::make( - e->alloc_var, alloc_type, {combo_size}, const_true(), - EvaluateNode::make(0)); + e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {combo_size}, const_true(), + EvaluateNode::make(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -653,7 +627,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Always align to max_simd_bits // so we can remap types by keeping this property if (total_bits % align != 0) { - total_bits += align - (total_bits % align); + total_bits += align - (total_bits % align); } e->alloc_var = e->allocs[0]->buffer_var; for (StorageEntry* child : e->merged_children) { @@ -663,15 +637,14 @@ class StoragePlanRewriter : public StmtExprMutator { child->alloc_var = e->alloc_var; total_bits += child->const_nbits; if (total_bits % align != 0) { - total_bits += align - (total_bits % align); + total_bits += align - (total_bits % align); } } uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); - PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), - (total_bits + type_bits - 1) / type_bits); - e->new_alloc = AllocateNode::make( - e->alloc_var, e->elem_type, {alloc_size}, const_true(), - EvaluateNode::make(0)); + PrimExpr alloc_size = + make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits); + e->new_alloc = AllocateNode::make(e->alloc_var, e->elem_type, {alloc_size}, const_true(), + EvaluateNode::make(0)); if (info.defined()) { CHECK_LE(total_bits, info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); @@ -764,8 +737,7 @@ class StoragePlanRewriter : public StmtExprMutator { visitor.Check(s.stmt, var, src)) { uint64_t const_nbits = static_cast(ae.alloc->constant_allocation_size()) * - ae.alloc->dtype.bits() * - ae.alloc->dtype.lanes(); + ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; @@ -786,8 +758,7 @@ class StoragePlanRewriter : public StmtExprMutator { // enter/exit new scope if (s.stmt->IsInstance()) { const auto* op = static_cast(s.stmt); - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread || + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { PlanNewScope(op); } else { @@ -816,10 +787,8 @@ class StoragePlanRewriter : public StmtExprMutator { } } // Allocate new storage entry. - StorageEntry* NewAlloc(const AllocateNode* op, - const Object* attach_scope, - const StorageScope& scope, - size_t const_nbits) { + StorageEntry* NewAlloc(const AllocateNode* op, const Object* attach_scope, + const StorageScope& scope, size_t const_nbits) { CHECK(op != nullptr); // Re-use not successful, allocate a new buffer. std::unique_ptr entry(new StorageEntry()); @@ -832,23 +801,21 @@ class StoragePlanRewriter : public StmtExprMutator { return e; } - StorageEntry* FindAlloc(const AllocateNode* op, - const Object* attach_scope, + StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, const StorageScope& scope) { CHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. const uint64_t match_range = 16; uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); - uint64_t const_nbits = static_cast( - op->constant_allocation_size() * op_elem_bits); + uint64_t const_nbits = static_cast(op->constant_allocation_size() * op_elem_bits); // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory if (scope.tag.length() == 0) { if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) { return NewAlloc(op, attach_scope, scope, const_nbits); } - if (const_nbits > 0 && const_nbits <= 32) { + if (const_nbits > 0 && const_nbits <= 32) { return NewAlloc(op, attach_scope, scope, const_nbits); } } @@ -859,7 +826,7 @@ class StoragePlanRewriter : public StmtExprMutator { auto end = const_free_map_.upper_bound(const_nbits * match_range); // start looking at the buffer that is bigger than the required size first for (auto it = mid; it != end; ++it) { - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; // when not divided, no reuse, eg, float4 vs float3 @@ -871,7 +838,7 @@ class StoragePlanRewriter : public StmtExprMutator { // then start looking at smaller buffers. for (auto it = mid; it != begin;) { --it; - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; if (e->elem_type != op->dtype.element_of()) continue; @@ -881,8 +848,7 @@ class StoragePlanRewriter : public StmtExprMutator { } } else { // Simple strategy: round roubin. - for (auto it = sym_free_list_.begin(); - it != sym_free_list_.end(); ++it) { + for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { StorageEntry* e = *it; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; @@ -904,8 +870,7 @@ class StoragePlanRewriter : public StmtExprMutator { // This rules only apply if we are using non special memory if (e->scope.tag.length() == 0) { // Disable sharing of local memory. - if (e->scope.rank >= StorageRank::kWarp || - e->allocs[0]->dtype.is_handle()) return; + if (e->scope.rank >= StorageRank::kWarp || e->allocs[0]->dtype.is_handle()) return; // disable reuse of small arrays if (e->const_nbits > 0 && e->const_nbits <= 32) return; } @@ -936,7 +901,6 @@ class StoragePlanRewriter : public StmtExprMutator { arith::Analyzer analyzer_; }; - // Turn alloc into vector alloc // if all its access is the same vector type. class VectorAllocRewriter : public StmtExprMutator { @@ -964,19 +928,15 @@ class VectorAllocRewriter : public StmtExprMutator { op = stmt.as(); const auto& tvec = acc_map_[op->buffer_var.get()]; - if (tvec.size() == 1 && - tvec[0].element_of() == op->dtype.element_of() && - tvec[0].lanes() % op->dtype.lanes() == 0 && - tvec[0].lanes() != op->dtype.lanes()) { + if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() && + tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != op->dtype.lanes()) { int factor = tvec[0].lanes() / op->dtype.lanes(); Array extents = op->extents; arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]); if (me->base % factor == 0 && me->coeff % factor == 0) { extents.Set(extents.size() - 1, extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - return AllocateNode::make( - op->buffer_var, tvec[0], extents, - op->condition, op->body); + return AllocateNode::make(op->buffer_var, tvec[0], extents, op->condition, op->body); } } return stmt; @@ -1000,7 +960,6 @@ Stmt StorageRewrite(Stmt stmt) { return VectorAllocRewriter()(std::move(stmt)); } - PrimFunc PointerValueTypeRewrite(PrimFunc f) { auto* n = f.CopyOnWrite(); VectorAllocRewriter rewriter; @@ -1014,8 +973,7 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { const auto& tvec = rewriter.acc_map_[var.get()]; if (tvec.size() == 1) { - tir::Var new_var(var->name_hint, - PointerType(PrimType(tvec[0]))); + tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0]))); args.push_back(new_var); remap_vars.Set(var, new_var); @@ -1023,8 +981,7 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { // always set data type to be non vectorized so // load/store can still work via scalarization if (tvec.size() != 0 && !var->type_annotation.defined()) { - tir::Var new_var(var->name_hint, - PointerType(PrimType(tvec[0].with_lanes(1)))); + tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1)))); args.push_back(new_var); remap_vars.Set(var, new_var); } else { @@ -1042,7 +999,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { return f; } - namespace transform { Pass StorageRewrite() { @@ -1055,9 +1011,7 @@ Pass StorageRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } -TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite") -.set_body_typed(StorageRewrite); - +TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite); Pass PointerValueTypeRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -1067,7 +1021,7 @@ Pass PointerValueTypeRewrite() { } TVM_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite") -.set_body_typed(PointerValueTypeRewrite); + .set_body_typed(PointerValueTypeRewrite); } // namespace transform diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 9924dd2..8650d2c 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -21,17 +21,17 @@ * \brief Infer TensorCore metadata from tensor intrinsic. * \file tensorcore_fragment.cc */ +#include #include -#include #include -#include +#include #include #include -#include "storage_access.h" -#include "ir_util.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" +#include "storage_access.h" namespace tvm { namespace tir { @@ -47,7 +47,7 @@ class FragmentGetter : public StmtExprVisitor { std::string layout; FragmentInfo() = default; FragmentInfo(int _m, int _n, int _k, const std::string& _layout) - : m(_m), n(_n), k(_k), layout(_layout) {} + : m(_m), n(_n), k(_k), layout(_layout) {} }; void VisitExpr_(const CallNode* op) final { @@ -136,13 +136,12 @@ class FragmentGetter : public StmtExprVisitor { // Check shape of fragment making sure it is a valid shape for tvm_mma_sync class FragmentChecker : public StmtExprVisitor { public: - explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} + explicit FragmentChecker(const FragmentGetter& getter) : fragment_getter(getter) {} void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); // Check shape when calling tvm_mma_sync - if (op->is_intrinsic(intrinsic::tvm_mma_sync) || - op->is_intrinsic(intrinsic::tvm_bmma_sync)) { + if (op->is_intrinsic(intrinsic::tvm_mma_sync) || op->is_intrinsic(intrinsic::tvm_bmma_sync)) { CHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var_d = op->args[0].as(); const VarNode* buffer_var_a = op->args[2].as(); @@ -170,13 +169,13 @@ class FragmentChecker : public StmtExprVisitor { return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k; } // Fragment infomation - const FragmentGetter &fragment_getter; + const FragmentGetter& fragment_getter; }; // Store the metadata into attributes class InferFragmenter : public StmtMutator { public: - explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} + explicit InferFragmenter(const FragmentGetter& getter) : fragment_getter(getter) {} Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); @@ -186,15 +185,14 @@ class InferFragmenter : public StmtMutator { FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer); // Add shape attribute to all fragments - std::string shape = std::to_string(info.m) + ", " + - std::to_string(info.n) + ", " + - std::to_string(info.k); + std::string shape = + std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k); PrimExpr shape_expr = StringImmNode::make(shape); Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); if (info.layout != "") { // Add shape attribute to matrix_a and matrix_b Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout, - StringImmNode::make(info.layout), shape_attr); + StringImmNode::make(info.layout), shape_attr); return layout_attr; } else { return shape_attr; @@ -205,7 +203,7 @@ class InferFragmenter : public StmtMutator { private: // Fragment infomation - const FragmentGetter &fragment_getter; + const FragmentGetter& fragment_getter; }; Stmt InferFragment(Stmt stmt) { @@ -228,8 +226,7 @@ Pass InferFragment() { return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InferFragment") -.set_body_typed(InferFragment); +TVM_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index a32fd64..0379fd9 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -20,39 +20,35 @@ /*! * \file thread_storage_sync.cc */ -#include +#include #include -#include +#include #include #include -#include #include #include +#include "../../runtime/thread_storage_scope.h" #include "ir_util.h" #include "storage_access.h" -#include "../../runtime/thread_storage_scope.h" namespace tvm { namespace tir { class ThreadSyncPlanner : public StorageAccessVisitor { public: - explicit ThreadSyncPlanner(StorageScope sync_scope) - : sync_scope_(sync_scope) {} + explicit ThreadSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {} - // The syncs inserted before each statement + // The syncs inserted before each statement std::unordered_set syncs_inserted_; protected: - bool Enabled(const VarNode* buf, - const StorageScope& scope) const final { + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return in_device_env() && scope == sync_scope_; } // Plan the sync - std::vector Summarize( - std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const ForNode* loop) final { // Unsynced reads and writes std::vector reads; std::vector writes; @@ -70,19 +66,23 @@ class ThreadSyncPlanner : public StorageAccessVisitor { for (const AccessEntry& acc : s.access) { if (acc.type == kRead) { if (FindConflict(writes, acc, false)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kWrite) { if (FindConflict(reads, acc, false)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kSync) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } } // If sync is inserted. remove the irrelevant things. if (sync_before_stmt) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } // Add the read/write of current statement for (const AccessEntry& acc : s.access) { @@ -91,12 +91,12 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } else if (acc.type == kWrite) { writes.push_back(acc); } else if (acc.type == kSync) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } } if (sync_before_stmt) { - CHECK_EQ(condition_counter(), 0) - << "Cannot insert syncs inside condition"; + CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; syncs_inserted_.insert(s.stmt); } } @@ -109,19 +109,21 @@ class ThreadSyncPlanner : public StorageAccessVisitor { for (const AccessEntry& acc : s.access) { if (acc.type == kRead) { if (FindConflict(writes, acc, true)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kWrite) { if (FindConflict(reads, acc, true)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kSync) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } } if (sync_before_stmt) { - CHECK_EQ(condition_counter(), 0) - << "Cannot insert syncs inside condition"; + CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; syncs_inserted_.insert(s.stmt); break; } @@ -174,22 +176,16 @@ class ThreadSyncPlanner : public StorageAccessVisitor { private: // find conflicting entry in vec. - bool FindConflict(const std::vector& vec, - const AccessEntry& e, - bool loop_carry) { + bool FindConflict(const std::vector& vec, const AccessEntry& e, bool loop_carry) { for (const AccessEntry& x : vec) { if (x.buffer.same_as(e.buffer)) { // Assumes no race between threads // Same index value means no conflicts // TODO(tqchen) more standard set based testing. - if (e.touched.is_single_point() && - x.touched.is_single_point()) { - if (ExprDeepEqual()(e.touched.point_value(), - x.touched.point_value())) continue; + if (e.touched.is_single_point() && x.touched.is_single_point()) { + if (ExprDeepEqual()(e.touched.point_value(), x.touched.point_value())) continue; } - if (x.double_buffer_write && - e.type == kRead && - !loop_carry) continue; + if (x.double_buffer_write && e.type == kRead && !loop_carry) continue; return true; } } @@ -203,8 +199,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { class ThreadSyncInserter : public StmtExprMutator { public: - ThreadSyncInserter(StorageScope sync_scope, - const std::unordered_set& syncs) + ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set& syncs) : sync_scope_(sync_scope), syncs_(syncs) {} Stmt VisitStmt(const Stmt& stmt) final { @@ -214,10 +209,9 @@ class ThreadSyncInserter : public StmtExprMutator { if (sync_scope_.rank == StorageRank::kGlobal) { barrier = MakeGlobalBarrier(); } else { - barrier = EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync_scope_.to_string())}, - CallNode::Intrinsic)); + barrier = EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImmNode::make(sync_scope_.to_string())}, + CallNode::Intrinsic)); } // Mutate after query, to avoid stmt change. auto ret = StmtExprMutator::VisitStmt(stmt); @@ -258,8 +252,7 @@ class ThreadSyncInserter : public StmtExprMutator { return ret; } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = - StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::make(op->value.as()->value); return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -316,13 +309,10 @@ class ThreadSyncInserter : public StmtExprMutator { } } rw_stats_.clear(); - Stmt kinit = EvaluateNode::make( - CallNode::make( - DataType::Int(32), - intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); + Stmt kinit = EvaluateNode::make(CallNode::make( + DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); body = SeqStmt({kinit, body}); - body = AttrStmtNode::make( - op->node, op->attr_key, op->value, body); + body = AttrStmtNode::make(op->node, op->attr_key, op->value, body); return SeqStmt({prep, body}); } Stmt MakeGlobalBarrier() { @@ -334,8 +324,7 @@ class ThreadSyncInserter : public StmtExprMutator { IterVar iv = Downcast(attr->node); runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag); if (s.rank == 0) { - num_blocks_ = (num_blocks_.defined() ? - attr->value * num_blocks_ : attr->value); + num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); } else if (s.rank == 1) { PrimExpr cond = iv->var == make_zero(iv->var.dtype()); is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; @@ -346,9 +335,8 @@ class ThreadSyncInserter : public StmtExprMutator { } return EvaluateNode::make( CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync_scope_.to_string()), - is_lead_, num_blocks_}, - CallNode::Intrinsic)); + {StringImmNode::make(sync_scope_.to_string()), is_lead_, num_blocks_}, + CallNode::Intrinsic)); } // data structure. StorageScope sync_scope_; @@ -384,8 +372,7 @@ Pass ThreadSync(std::string storage_scope) { return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ThreadSync") -.set_body_typed(ThreadSync); +TVM_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 4fc69a3..a69ccc5 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -22,32 +22,31 @@ * \file unroll_loop.cc */ // Unrolls the loop as in Halide pipeline. +#include #include #include #include -#include #include -#include -#include +#include + #include +#include #include -#include "ir_util.h" + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { class LoopUnroller : public StmtExprMutator { public: - explicit LoopUnroller(int auto_max_step, - int auto_max_depth, - int auto_max_extent, + explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll) { - } + explicit_unroll_(explicit_unroll) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -72,24 +71,19 @@ class LoopUnroller : public StmtExprMutator { op = stmt.as(); int value = GetExtent(op); // condition for auto unroll - bool auto_unroll = ( - op->for_type == ForType::Serial && - value >= 0 && - normal_loop_depth_ == 0 && - unroll_depth_ <= auto_max_depth_); + bool auto_unroll = (op->for_type == ForType::Serial && value >= 0 && normal_loop_depth_ == 0 && + unroll_depth_ <= auto_max_depth_); - auto_unroll = auto_unroll && ( - value * step_count_ <= auto_max_step_|| - value <= auto_max_extent_); + auto_unroll = + auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_); if (op->for_type == ForType::Unrolled) { - CHECK_GE(value, 0) - << "Cannot unroll non-constant loop"; + CHECK_GE(value, 0) << "Cannot unroll non-constant loop"; auto_unroll = true; } if (auto_unroll) { - step_count_ *= value; + step_count_ *= value; unroll_depth_ += 1; } else { normal_loop_depth_ += 1; @@ -102,9 +96,8 @@ class LoopUnroller : public StmtExprMutator { } else { if (auto_unroll) { if (op->for_type != ForType::Unrolled) { - return ForNode::make( - op->loop_var, op->min, op->extent, - ForType::Unrolled, op->device_api, op->body); + return ForNode::make(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api, + op->body); } } return stmt; @@ -159,7 +152,7 @@ class LoopUnroller : public StmtExprMutator { int GetExtent(const ForNode* op) { // constant folding. PrimExpr extent = analyzer_.Simplify(op->extent); - const IntImmNode *v1 = extent.as(); + const IntImmNode* v1 = extent.as(); int value = -1; // integers that do not fit in int32_t are treated as symbolic, // as it's impossible to unroll such large loops @@ -186,17 +179,9 @@ class LoopUnroller : public StmtExprMutator { arith::Analyzer analyzer_; }; - -Stmt UnrollLoop(Stmt stmt, - int auto_max_step, - int auto_max_depth, - int auto_max_extent, +Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll) { - Stmt ret = LoopUnroller( - auto_max_step, - auto_max_depth, - auto_max_extent, - explicit_unroll)(stmt); + Stmt ret = LoopUnroller(auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { @@ -206,24 +191,17 @@ Stmt UnrollLoop(Stmt stmt, namespace transform { -Pass UnrollLoop(int auto_max_step, - int auto_max_depth, - int auto_max_extent, - bool explicit_unroll) { +Pass UnrollLoop(int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - n->body = UnrollLoop(std::move(f->body), - auto_max_step, - auto_max_depth, - auto_max_extent, + n->body = UnrollLoop(std::move(f->body), auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); } -TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop") -.set_body_typed(UnrollLoop); +TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop); } // namespace transform diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index e155c70..9e553cb 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -21,14 +21,16 @@ * \file vectorize_loop.cc */ // Loop vectorizer as in Halide pipeline. +#include #include #include -#include #include -#include -#include +#include + #include +#include #include + #include "../../arith/compute_expr.h" namespace tvm { @@ -41,9 +43,8 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { return BroadcastNode::make(op->value, lanes); } } - CHECK_EQ(e.dtype().lanes(), 1) - << "Cannot broadcast lane=" << e.dtype().lanes() - << " to " << lanes; + CHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to " + << lanes; return BroadcastNode::make(e, lanes); } @@ -64,9 +65,8 @@ class VecAllocAccess : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->buffer_var.get() == buf_) { - return LoadNode::make(op->dtype, op->buffer_var, - op->index * var_lanes_ + var_, - op->predicate); + return LoadNode::make(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, + op->predicate); } else { return expr; } @@ -76,10 +76,8 @@ class VecAllocAccess : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op->buffer_var.get() == buf_) { - return StoreNode::make(op->buffer_var, - op->value, - op->index * var_lanes_ + var_, - op->predicate); + return StoreNode::make(op->buffer_var, op->value, op->index * var_lanes_ + var_, + op->predicate); } else { return stmt; } @@ -96,8 +94,7 @@ class VecAllocAccess : public StmtExprMutator { class Vectorizer : public StmtExprMutator { public: - Vectorizer(Var var, int var_lanes) - : var_(var), var_lanes_(var_lanes) { + Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { ramp_ = RampNode::make(0, 1, var_lanes); } @@ -112,17 +109,12 @@ class Vectorizer : public StmtExprMutator { } } - PrimExpr VisitExpr_(const AddNode* op) final { - return AddSubVec(op); - } - PrimExpr VisitExpr_(const SubNode* op) final { - return AddSubVec(op); - } + PrimExpr VisitExpr_(const AddNode* op) final { return AddSubVec(op); } + PrimExpr VisitExpr_(const SubNode* op) final { return AddSubVec(op); } PrimExpr VisitExpr_(const MulNode* op) final { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { + if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); @@ -130,60 +122,30 @@ class Vectorizer : public StmtExprMutator { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) { - return RampNode::make( - a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); + return RampNode::make(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); } if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) { - return RampNode::make( - b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); + return RampNode::make(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); } } return MulNode::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } return BinaryVec(op); } - PrimExpr VisitExpr_(const DivNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const ModNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const FloorDivNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const FloorModNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const MinNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const MaxNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const EQNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const NENode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const LTNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const LENode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const GTNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const GENode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const AndNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const OrNode* op) final { - return BinaryVec(op); - } + PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const RampNode* op) final { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); @@ -198,29 +160,23 @@ class Vectorizer : public StmtExprMutator { stride = BroadcastTo(stride, lanes); Array elems; for (int i = 0; i < lanes; ++i) { - elems.push_back( - RampNode::make(ShuffleNode::make_extract_element(base, i), - ShuffleNode::make_extract_element(stride, i), - op->lanes)); + elems.push_back(RampNode::make(ShuffleNode::make_extract_element(base, i), + ShuffleNode::make_extract_element(stride, i), op->lanes)); } return ShuffleNode::make_concat(elems); } - PrimExpr VisitExpr_(const SelectNode *op) final { + PrimExpr VisitExpr_(const SelectNode* op) final { PrimExpr cond = this->VisitExpr(op->condition); PrimExpr t = this->VisitExpr(op->true_value); PrimExpr f = this->VisitExpr(op->false_value); - if (cond.same_as(op->condition) && - t.same_as(op->true_value) && - f.same_as(op->false_value)) { + if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { return GetRef(op); } else { - int lanes = std::max(std::max( - cond.dtype().lanes(), - t.dtype().lanes()), f.dtype().lanes()); + int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes()); return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); } } - PrimExpr VisitExpr_(const CastNode *op) final { + PrimExpr VisitExpr_(const CastNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); @@ -233,31 +189,28 @@ class Vectorizer : public StmtExprMutator { if (v == var_.get()) { return ramp_; } else if (lets_.count(v)) { - return lets_[v]; + return lets_[v]; } else { return GetRef(v); } } // IfThenElse expr - PrimExpr MutateIfThenElseExpr_(const CallNode *op) { + PrimExpr MutateIfThenElseExpr_(const CallNode* op) { PrimExpr cond = this->VisitExpr(op->args[0]); - if (cond.dtype().is_vector()) { + if (cond.dtype().is_vector()) { need_scalarize_ = true; return GetRef(op); } PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr f = this->VisitExpr(op->args[2]); - if (cond.same_as(op->args[0]) && - t.same_as(op->args[1]) && - f.same_as(op->args[2])) { + if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { return GetRef(op); } else { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); f = BroadcastTo(f, lanes); - return CallNode::make( - op->dtype.with_lanes(lanes), op->name, - {cond, t, f}, op->call_type, op->func, op->value_index); + return CallNode::make(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type, + op->func, op->value_index); } } // Call @@ -279,8 +232,8 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make( - op->dtype, op->name, new_args, op->call_type, op->func, op->value_index); + return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func, + op->value_index); } } else { int lane = 0; @@ -289,9 +242,8 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make( - op->dtype.with_lanes(lane), op->name, new_args, - op->call_type, op->func, op->value_index); + return CallNode::make(op->dtype.with_lanes(lane), op->name, new_args, op->call_type, + op->func, op->value_index); } } } @@ -303,11 +255,8 @@ class Vectorizer : public StmtExprMutator { return GetRef(op); } else { int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); - return LoadNode::make( - op->dtype.with_lanes(lanes), - op->buffer_var, - BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + return LoadNode::make(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes), + BroadcastTo(pred, lanes)); } } // Let @@ -320,8 +269,7 @@ class Vectorizer : public StmtExprMutator { return LetNode::make(v, value, this->VisitExpr(op->body)); } else { PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); @@ -350,10 +298,8 @@ class Vectorizer : public StmtExprMutator { } else { int lanes = std::max(value.dtype().lanes(), index.dtype().lanes()); lanes = std::max(lanes, pred.dtype().lanes()); - return StoreNode::make(op->buffer_var, - BroadcastTo(value, lanes), - BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + return StoreNode::make(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes), + BroadcastTo(pred, lanes)); } } // For @@ -368,13 +314,10 @@ class Vectorizer : public StmtExprMutator { return Scalarize(GetRef(op)); } Stmt body = this->VisitStmt(op->body); - if (extent.same_as(op->extent) && - body.same_as(op->body)) { + if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return ForNode::make( - op->loop_var, op->min, extent, - op->for_type, op->device_api, body); + return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body); } } // IfThenElse @@ -389,8 +332,7 @@ class Vectorizer : public StmtExprMutator { if (op->else_case.defined()) { else_case = this->VisitStmt(op->else_case); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -421,12 +363,9 @@ class Vectorizer : public StmtExprMutator { // place the vector lanes in least significant dimension. extents.push_back(var_lanes_); // rewrite access to buffer internally. - Stmt body = VecAllocAccess( - op->buffer_var.get(), var_, var_lanes_)(op->body); + Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); body = this->VisitStmt(body); - return AllocateNode::make( - op->buffer_var, op->dtype, - extents, condition, body); + return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body); } // scalarize the statment Stmt Scalarize(Stmt stmt) { @@ -473,24 +412,22 @@ class Vectorizer : public StmtExprMutator { if (!changed) return arr; return Array(new_arr); } - template + template PrimExpr BinaryVec(const T* op) { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { + if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } - template + template PrimExpr AddSubVec(const T* op) { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { + if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); @@ -500,12 +437,10 @@ class Vectorizer : public StmtExprMutator { if (a.dtype().lanes() == 1 && b_ramp) { return RampNode::make( arith::Compute(a, b_ramp->base), - arith::Compute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), - b_ramp->lanes); + arith::Compute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); } if (b.dtype().lanes() == 1 && a_ramp) { - return RampNode::make( - arith::Compute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); + return RampNode::make(arith::Compute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); @@ -529,9 +464,7 @@ class LoopVectorizer : public StmtMutator { } }; -Stmt VectorizeLoop(Stmt stmt) { - return LoopVectorizer()(std::move(stmt)); -} +Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); } class VectorizeSkipper : public StmtMutator { public: @@ -539,18 +472,15 @@ class VectorizeSkipper : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (op->for_type == ForType::Vectorized) { - return ForNode::make(op->loop_var, op->min, op->extent, - ForType::Serial, op->device_api, + return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body); } else { - return stmt; + return stmt; } } }; -Stmt SkipVectorize(Stmt stmt) { - return VectorizeSkipper()(std::move(stmt)); -} +Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); } namespace transform { @@ -568,8 +498,7 @@ Pass VectorizeLoop(bool enable_vectorize) { return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop") -.set_body_typed(VectorizeLoop); +TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop); } // namespace transform diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index f4c259f..8e9d7bc 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -25,7 +25,7 @@ TEST(Simplify, MinMax) { tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); - auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ; + auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)); auto e1s = ana.canonical_simplify(e1); CHECK(tvm::tir::is_zero(e1s)); @@ -37,7 +37,7 @@ TEST(Simplify, MinMax) { TEST(Simplify, Mul) { tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); - auto e = (x * x) - (x * x) ; + auto e = (x * x) - (x * x); auto es = ana.canonical_simplify(e); CHECK(tvm::tir::is_zero(es)); } @@ -53,7 +53,7 @@ TEST(Simplify, Mod) { auto es = ana.canonical_simplify(mod - x); CHECK(tvm::tir::is_zero(es)); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/attrs_test.cc b/tests/cpp/attrs_test.cc index ccf1b25..7b301bd 100644 --- a/tests/cpp/attrs_test.cc +++ b/tests/cpp/attrs_test.cc @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include namespace tvm { namespace test { @@ -33,23 +33,17 @@ struct TestAttrs : public AttrsNode { double learning_rate; TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") { - TVM_ATTR_FIELD(axis) - .set_default(10) - .set_lower_bound(1) - .set_upper_bound(10) - .describe("axis field"); - TVM_ATTR_FIELD(name) - .describe("name of the field"); + TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe( + "axis field"); + TVM_ATTR_FIELD(name).describe("name of the field"); TVM_ATTR_FIELD(expr) .describe("expression field") .set_default(tir::make_const(DataType::Int(32), 1)); - TVM_ATTR_FIELD(learning_rate) - .describe("learning_rate") - .set_default(0.1); + TVM_ATTR_FIELD(learning_rate).describe("learning_rate").set_default(0.1); } }; -} -} +} // namespace test +} // namespace tvm TEST(Attrs, Basic) { using namespace tvm; @@ -84,12 +78,11 @@ TEST(Attrs, Basic) { // Check docstring std::ostringstream os; n->PrintDocString(os); - LOG(INFO) << "docstring\n"<< os.str(); + LOG(INFO) << "docstring\n" << os.str(); CHECK(os.str().find("expr : PrimExpr, default=1") != std::string::npos); } - -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 6ea0b21..c9a91fc 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -20,12 +20,12 @@ #include #include #include -#include -#include #include +#include +#include -#include #include +#include TEST(BuildModule, Basic) { using namespace tvm; @@ -37,18 +37,17 @@ TEST(BuildModule, Basic) { auto A = placeholder(shape, DataType::Float(32), "A"); auto B = placeholder(shape, DataType::Float(32), "B"); - auto C = compute(A->shape, [&A, &B](PrimExpr i) { - return A[i] + B[i]; - }, "C"); + auto C = compute( + A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "C"); - auto s = create_schedule({ C->op }); + auto s = create_schedule({C->op}); auto cAxis = C->op.as()->axis; IterVar bx, tx; s[C].split(cAxis[0], 64, &bx, &tx); - auto args = Array({ A, B, C }); + auto args = Array({A, B, C}); std::unordered_map binds; auto config = BuildConfig::Create(); @@ -94,19 +93,16 @@ TEST(BuildModule, Heterogeneous) { auto B = placeholder(shape, DataType::Float(32), "B"); auto C = placeholder(shape, DataType::Float(32), "C"); - auto elemwise_add = compute(A->shape, [&A, &B](PrimExpr i) { - return A[i] + B[i]; - }, "elemwise_add"); + auto elemwise_add = compute( + A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "elemwise_add"); auto copy = placeholder(shape, DataType::Float(32), "__copy"); - auto elemwise_sub = compute(C->shape, [©, &C](PrimExpr i) { - return copy[i] - C[i]; - }, "elemwise_sub"); + auto elemwise_sub = compute( + C->shape, [©, &C](PrimExpr i) { return copy[i] - C[i]; }, "elemwise_sub"); With cuda_scope(target_cuda); auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); - With llvm_scope(target_llvm); auto s2 = create_schedule({elemwise_sub->op}); @@ -117,8 +113,7 @@ TEST(BuildModule, Heterogeneous) { std::unordered_map binds; auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config); auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config); - Map inputs = {{target_cuda, lowered_s1}, - {target_llvm, lowered_s2}}; + Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target(), config); // Assertion for build. @@ -148,12 +143,9 @@ TEST(BuildModule, Heterogeneous) { "\"float32\"]]}}"; // Setup inputs. - auto a_val = - runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto b_val = - runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto c_val = - runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto a_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto b_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto c_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto pa = (float*)(a_val->data); auto pb = (float*)(b_val->data); @@ -174,8 +166,8 @@ TEST(BuildModule, Heterogeneous) { const runtime::PackedFunc* graph_runtime = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); - runtime::Module mod = (*graph_runtime)( - json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id); + runtime::Module mod = + (*graph_runtime)(json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id); // test FFI for module. auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) { @@ -186,7 +178,6 @@ TEST(BuildModule, Heterogeneous) { test_ffi(runtime::Module(mod), static_cast(kTVMModuleHandle)); test_ffi(Optional(mod), static_cast(kTVMModuleHandle)); - PackedFunc set_input = mod.GetFunction("set_input", false); PackedFunc run = mod.GetFunction("run", false); PackedFunc get_output = mod.GetFunction("get_output", false); @@ -204,7 +195,7 @@ TEST(BuildModule, Heterogeneous) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index c89f815..5d1f472 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include #include #include @@ -35,8 +35,7 @@ class TestErrorSwitch { public: // Need this so that destructor of temporary objects don't interrupt our // testing. - TestErrorSwitch(const TestErrorSwitch& other) - : should_fail(other.should_fail) { + TestErrorSwitch(const TestErrorSwitch& other) : should_fail(other.should_fail) { const_cast(other).should_fail = false; } @@ -50,8 +49,7 @@ class TestErrorSwitch { } }; -class TestArrayObj : public Object, - public InplaceArrayBase { +class TestArrayObj : public Object, public InplaceArrayBase { public: static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "test.TestArrayObj"; @@ -112,8 +110,7 @@ TEST(InplaceArrayBase, BadExceptionSafety) { TestErrorSwitch f2{true}; TestErrorSwitch f3{false}; std::vector fields{f1, f2, f3}; - auto ptr = - make_inplace_array_object(fields.size()); + auto ptr = make_inplace_array_object(fields.size()); try { ptr->WrongInit(fields.begin(), fields.end()); } catch (...) { @@ -133,8 +130,7 @@ TEST(InplaceArrayBase, ExceptionSafety) { // since it's not initalized. TestErrorSwitch f2{true}; std::vector fields{f1, f2}; - auto ptr = - make_inplace_array_object(fields.size()); + auto ptr = make_inplace_array_object(fields.size()); try { ptr->Init(fields.begin(), fields.end()); } catch (...) { @@ -223,8 +219,7 @@ TEST(Map, Iterator) { using namespace tvm; PrimExpr a = 1, b = 2; Map map1{{a, b}}; - std::unordered_map map2( - map1.begin(), map1.end()); + std::unordered_map map2(map1.begin(), map1.end()); CHECK(map2[a].as()->value == 2); } @@ -402,7 +397,6 @@ TEST(String, Cast) { String s2 = Downcast(r); } - TEST(Optional, Composition) { Optional opt0(nullptr); Optional opt1 = String("xyz"); diff --git a/tests/cpp/crt_memory_test.cc b/tests/cpp/crt_memory_test.cc index 1c12916..c2582ba 100644 --- a/tests/cpp/crt_memory_test.cc +++ b/tests/cpp/crt_memory_test.cc @@ -27,7 +27,7 @@ TEST(CRTMemory, Alloc) { for (int idx = 0; idx < 65536; idx++) { - void * a = vmalloc(1); + void* a = vmalloc(1); EXPECT_EQ(vleak_size, 1); vfree(a); EXPECT_EQ(vleak_size, 0); @@ -36,9 +36,9 @@ TEST(CRTMemory, Alloc) { TEST(CRTMemory, Realloc) { for (int idx = 0; idx < 65536; idx++) { - void * a = vrealloc(0, 1); + void* a = vrealloc(0, 1); EXPECT_EQ(vleak_size, 1); - void * b = vrealloc(a, 1); + void* b = vrealloc(a, 1); EXPECT_EQ(a, b); EXPECT_EQ(vleak_size, 1); vfree(a); @@ -46,7 +46,7 @@ TEST(CRTMemory, Realloc) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index e17cc73..a5d47dd 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -34,7 +34,6 @@ TEST(Expr, Basic) { CHECK(os.str() == "max(((x + 1) + 2), 100)"); } - TEST(ExprNodeRef, Basic) { using namespace tvm; using namespace tvm::tir; @@ -44,8 +43,7 @@ TEST(ExprNodeRef, Basic) { CHECK(GetRef(op).same_as(z)); } - -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 3941de5..052cba1 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -19,10 +19,10 @@ #include #include -#include -#include #include +#include #include +#include #include TEST(IRF, Basic) { @@ -32,14 +32,10 @@ TEST(IRF, Basic) { auto z = x + 1; NodeFunctor f; - f.set_dispatch([](const ObjectRef& n, int b) { - return b; - }); - f.set_dispatch([](const ObjectRef& n, int b) { - return b + 2; - }); - CHECK_EQ(f(x, 2), 2); - CHECK_EQ(f(z, 2), 4); + f.set_dispatch([](const ObjectRef& n, int b) { return b; }); + f.set_dispatch([](const ObjectRef& n, int b) { return b + 2; }); + CHECK_EQ(f(x, 2), 2); + CHECK_EQ(f(z, 2), 4); } TEST(IRF, CountVar) { @@ -51,37 +47,31 @@ TEST(IRF, CountVar) { auto z = x + 1 + y + y; tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { if (n.as()) ++n_var; - }); + }); CHECK_EQ(n_var, 2); } - TEST(IRF, ExprTransform) { using namespace tvm; using namespace tvm::tir; Var x("x"); auto z = x + 1; - class MyExprFunctor - : public tir::ExprFunctor { + class MyExprFunctor : public tir::ExprFunctor { public: - int VisitExpr_(const VarNode* op, int b) final { - return b; - } - int VisitExpr_(const IntImmNode* op, int b) final { - return op->value; - } + int VisitExpr_(const VarNode* op, int b) final { return b; } + int VisitExpr_(const IntImmNode* op, int b) final { return op->value; } int VisitExpr_(const AddNode* op, int b) final { return VisitExpr(op->a, b) + VisitExpr(op->b, b); } }; MyExprFunctor f; - CHECK_EQ(f(x, 2), 2); - CHECK_EQ(f(z, 2), 3); + CHECK_EQ(f(x, 2), 2); + CHECK_EQ(f(z, 2), 3); try { f(z - 1, 2); LOG(FATAL) << "should fail"; - } catch(dmlc::Error) { + } catch (dmlc::Error) { } } @@ -91,43 +81,33 @@ TEST(IRF, ExprVisit) { Var x("x"); auto z = x + 1; - class MyVisitor - : public tir::ExprFunctor, - public tir::StmtFunctor { + class MyVisitor : public tir::ExprFunctor, + public tir::StmtFunctor { public: int count = 0; // implementation - void VisitExpr_(const VarNode* op) final { - ++count; - } - void VisitExpr_(const IntImmNode* op) final { - } + void VisitExpr_(const VarNode* op) final { ++count; } + void VisitExpr_(const IntImmNode* op) final {} void VisitExpr_(const AddNode* op) final { VisitExpr(op->a); VisitExpr(op->b); } - void VisitStmt_(const EvaluateNode* op) final { - VisitExpr(op->value); - } + void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); } }; MyVisitor v; v.VisitStmt(EvaluateNode::make(z)); CHECK_EQ(v.count, 1); } - TEST(IRF, StmtVisitor) { using namespace tvm; using namespace tvm::tir; Var x("x"); - class MyVisitor - : public StmtExprVisitor { + class MyVisitor : public StmtExprVisitor { public: int count = 0; // implementation - void VisitExpr_(const VarNode* op) final { - ++count; - } + void VisitExpr_(const VarNode* op) final { ++count; } }; MyVisitor v; auto fmaketest = [&]() { @@ -145,24 +125,16 @@ TEST(IRF, StmtMutator) { using namespace tvm::tir; Var x("x"); - class MyVisitor - : public tir::StmtMutator, - public tir::ExprMutator { + class MyVisitor : public tir::StmtMutator, public tir::ExprMutator { public: using StmtMutator::operator(); using ExprMutator::operator(); protected: // implementation - PrimExpr VisitExpr_(const AddNode* op) final { - return op->a; - } - Stmt VisitStmt_(const SeqStmtNode* op) final { - return StmtMutator::VisitSeqStmt_(op, true); - } - PrimExpr VisitExpr(const PrimExpr& expr) final { - return ExprMutator::VisitExpr(expr); - } + PrimExpr VisitExpr_(const AddNode* op) final { return op->a; } + Stmt VisitStmt_(const SeqStmtNode* op) final { return StmtMutator::VisitSeqStmt_(op, true); } + PrimExpr VisitExpr(const PrimExpr& expr) final { return ExprMutator::VisitExpr(expr); } }; auto fmakealloc = [&]() { auto z = x + 1; @@ -220,7 +192,8 @@ TEST(IRF, StmtMutator) { } { - auto body = EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); + auto body = + EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); auto res = v(std::move(body)); CHECK(res.as()->value.as()->args[0].same_as(x)); } @@ -255,7 +228,7 @@ TEST(IRF, StmtMutator) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/object_protocol_test.cc b/tests/cpp/object_protocol_test.cc index 438f688..0df8024 100644 --- a/tests/cpp/object_protocol_test.cc +++ b/tests/cpp/object_protocol_test.cc @@ -19,8 +19,8 @@ #include #include -#include #include +#include namespace tvm { namespace test { @@ -59,7 +59,6 @@ class ObjAA : public ObjA { TVM_DECLARE_FINAL_OBJECT_INFO(ObjAA, ObjA); }; - TVM_REGISTER_OBJECT_TYPE(ObjBase); TVM_REGISTER_OBJECT_TYPE(ObjA); TVM_REGISTER_OBJECT_TYPE(ObjB); @@ -97,7 +96,7 @@ TEST(ObjectHierachy, Basic) { CHECK(refB.as() != nullptr); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 787e0c4..523df98 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -19,11 +19,11 @@ #include #include -#include #include +#include #include -#include #include +#include TEST(PackedFunc, Basic) { using namespace tvm; @@ -34,15 +34,15 @@ TEST(PackedFunc, Basic) { DLTensor a; Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 3); - CHECK(args.values[0].v_float64 == 1.0); - CHECK(args.type_codes[0] == kDLFloat); - CHECK(args.values[1].v_handle == &a); - CHECK(args.type_codes[1] == kTVMDLTensorHandle); - CHECK(args.values[2].v_handle == &x); - CHECK(args.type_codes[2] == kTVMOpaqueHandle); - *rv = Var("a"); - })(1.0, &a, handle); + CHECK(args.num_args == 3); + CHECK(args.values[0].v_float64 == 1.0); + CHECK(args.type_codes[0] == kDLFloat); + CHECK(args.values[1].v_handle == &a); + CHECK(args.type_codes[1] == kTVMDLTensorHandle); + CHECK(args.values[2].v_handle == &x); + CHECK(args.type_codes[2] == kTVMOpaqueHandle); + *rv = Var("a"); + })(1.0, &a, handle); CHECK(v->name_hint == "a"); } @@ -52,36 +52,32 @@ TEST(PackedFunc, Node) { using namespace tvm::runtime; Var x; Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 1); - CHECK(args[0].IsObjectRef()); - Var b = args[0]; - CHECK(x.same_as(b)); - *rv = b; - })(x); + CHECK(args.num_args == 1); + CHECK(args[0].IsObjectRef()); + Var b = args[0]; + CHECK(x.same_as(b)); + *rv = b; + })(x); CHECK(t.same_as(x)); } TEST(PackedFunc, NDArray) { using namespace tvm; using namespace tvm::runtime; - auto x = NDArray::Empty( - {}, String2DLDataType("float32"), - TVMContext{kDLCPU, 0}); + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); reinterpret_cast(x->data)[0] = 10.0f; CHECK(x.use_count() == 1); - PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { - *rv = args[0]; - }); + PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); NDArray ret = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - NDArray y = args[0]; - DLTensor* ptr = args[0]; - CHECK(ptr == x.operator->()); - CHECK(x.same_as(y)); - CHECK(x.use_count() == 2); - *rv = forward(y); - })(x); + NDArray y = args[0]; + DLTensor* ptr = args[0]; + CHECK(ptr == x.operator->()); + CHECK(x.same_as(y)); + CHECK(x.use_count() == 2); + *rv = forward(y); + })(x); CHECK(ret.use_count() == 2); CHECK(ret.same_as(x)); } @@ -90,48 +86,45 @@ TEST(PackedFunc, str) { using namespace tvm; using namespace tvm::runtime; PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 1); - std::string x = args[0]; - CHECK(x == "hello"); - String y = args[0]; - CHECK(y == "hello"); - *rv = x; - })("hello"); + CHECK(args.num_args == 1); + std::string x = args[0]; + CHECK(x == "hello"); + String y = args[0]; + CHECK(y == "hello"); + *rv = x; + })("hello"); PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 1); - runtime::String s = args[0]; - CHECK(s == "hello"); + CHECK(args.num_args == 1); + runtime::String s = args[0]; + CHECK(s == "hello"); })(runtime::String("hello")); } - TEST(PackedFunc, func) { using namespace tvm; using namespace tvm::runtime; - PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { - *rv = args[0].operator int() + 1; - }); + PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0].operator int() + 1; }); // function as arguments int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - PackedFunc f = args[0]; - // TVMArgValue -> Arguments as function - *rv = f(args[1]).operator int(); - })(addone, 1); + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); CHECK_EQ(r0, 2); int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - // TVMArgValue -> TVMRetValue - *rv = args[1]; - })(2, 100); + // TVMArgValue -> TVMRetValue + *rv = args[1]; + })(2, 100); CHECK_EQ(r1, 100); int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - // re-assignment - *rv = args[0]; - // TVMRetValue -> Function argument - *rv = addone(args[0].operator PackedFunc()(args[1], 1)); - })(addone, 100); + // re-assignment + *rv = args[0]; + // TVMRetValue -> Function argument + *rv = addone(args[0].operator PackedFunc()(args[1], 1)); + })(addone, 100); CHECK_EQ(r2, 102); } @@ -140,14 +133,14 @@ TEST(PackedFunc, Expr) { using namespace tvm::runtime; // automatic conversion of int to expr PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { - PrimExpr x = args[0]; - *rv = x.as()->value + 1; + PrimExpr x = args[0]; + *rv = x.as()->value + 1; }); int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - PackedFunc f = args[0]; - // TVMArgValue -> Arguments as function - *rv = f(args[1]).operator int(); - })(addone, 1); + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); CHECK_EQ(r0, 2); } @@ -155,12 +148,10 @@ TEST(PackedFunc, Type) { using namespace tvm; using namespace tvm::runtime; auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - DataType x = args[0]; - *rv = x; - }); - auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - *rv = args[0]; - }); + DataType x = args[0]; + *rv = x; + }); + auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); CHECK(get_type("int32").operator DataType() == DataType::Int(32)); CHECK(get_type("float").operator DataType() == DataType::Float(32)); CHECK(get_type2("float32x2").operator DataType() == DataType::Float(32, 2)); @@ -174,9 +165,7 @@ TEST(TypedPackedFunc, HighOrder) { using BindFunc = TypedPackedFunc; BindFunc ftyped; ftyped = [](Int2Func f1, int value) -> Int1Func { - auto binded = [f1, value](int x) { - return f1(value, x); - }; + auto binded = [f1, value](int x) { return f1(value, x); }; Int1Func x(binded); return x; }; @@ -194,28 +183,23 @@ TEST(TypedPackedFunc, Deduce) { using tvm::runtime::detail::function_signature; TypedPackedFunc x; - auto f = [](int x) -> int { - return x + 1; - }; + auto f = [](int x) -> int { return x + 1; }; std::function y; - static_assert(std::is_same::FType, - int(float)>::value, "invariant1"); - static_assert(std::is_same::FType, - int(int)>::value, "invariant2"); - static_assert(std::is_same::FType, - void(float)>::value, "invariant3"); + static_assert(std::is_same::FType, int(float)>::value, + "invariant1"); + static_assert(std::is_same::FType, int(int)>::value, + "invariant2"); + static_assert(std::is_same::FType, void(float)>::value, + "invariant3"); } - TEST(PackedFunc, ObjectConversion) { using namespace tvm; using namespace tvm::tir; using namespace tvm::runtime; TVMRetValue rv; - auto x = NDArray::Empty( - {}, String2DLDataType("float32"), - TVMContext{kDLCPU, 0}); + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); // assign null rv = ObjectRef(); CHECK_EQ(rv.type_code(), kTVMNullptr); @@ -232,15 +216,15 @@ TEST(PackedFunc, ObjectConversion) { CHECK(!rv.IsObjectRef()); auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); - CHECK(args[0].operator NDArray().same_as(x)); - CHECK(args[0].operator ObjectRef().same_as(x)); - CHECK(args[1].operator ObjectRef().get() == nullptr); - CHECK(args[1].operator NDArray().get() == nullptr); - CHECK(args[1].operator Module().get() == nullptr); - CHECK(args[1].operator Array().get() == nullptr); - CHECK(!args[0].IsObjectRef()); - }); + CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); + CHECK(args[0].operator NDArray().same_as(x)); + CHECK(args[0].operator ObjectRef().same_as(x)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(args[1].operator Array().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); pf1(x, ObjectRef()); pf1(ObjectRef(x), NDArray()); @@ -259,14 +243,14 @@ TEST(PackedFunc, ObjectConversion) { CHECK(!rv.IsObjectRef()); auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args[0].type_code(), kTVMModuleHandle); - CHECK(args[0].operator Module().same_as(m)); - CHECK(args[0].operator ObjectRef().same_as(m)); - CHECK(args[1].operator ObjectRef().get() == nullptr); - CHECK(args[1].operator NDArray().get() == nullptr); - CHECK(args[1].operator Module().get() == nullptr); - CHECK(!args[0].IsObjectRef()); - }); + CHECK_EQ(args[0].type_code(), kTVMModuleHandle); + CHECK(args[0].operator Module().same_as(m)); + CHECK(args[0].operator ObjectRef().same_as(m)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); pf2(m, ObjectRef()); pf2(ObjectRef(m), Module()); } @@ -275,13 +259,12 @@ TEST(TypedPackedFunc, RValue) { using namespace tvm; using namespace tvm::runtime; { - auto inspect = [](TVMArgs args, TVMRetValue* rv) { for (int i = 0; i < args.size(); ++i) { CHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg); } }; - PackedFunc finspect(inspect); + PackedFunc finspect(inspect); finspect(tir::Var("x")); } { @@ -325,7 +308,7 @@ TEST(TypedPackedFunc, RValue) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 5cb7910..59d0a43 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -17,9 +17,10 @@ * under the License. */ +#include "../src/arith/pattern_match.h" + #include #include -#include "../src/arith/pattern_match.h" TEST(Pattern, Basic) { using namespace tvm; @@ -64,8 +65,7 @@ TEST(Pattern, Basic) { CHECK((px >= py && px < pz).Match(x >= y && x < z)); CHECK((!(px > py || px != py)).Match(!(x > y || x != y))); { - CHECK(select(px >= pz, py, py + pz).Match( - tir::SelectNode::make((x + 1) >= 1, y, y + 1))); + CHECK(select(px >= pz, py, py + pz).Match(tir::SelectNode::make((x + 1) >= 1, y, y + 1))); CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } // bit intrinsics @@ -81,52 +81,44 @@ TEST(Pattern, Basic) { CHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2))))); // select { - CHECK(select(px > pz, py, py + pz).Match( - tir::SelectNode::make(x > 1, y, y + 1))); + CHECK(select(px > pz, py, py + pz).Match(tir::SelectNode::make(x > 1, y, y + 1))); CHECK(is_const_int(pz.Eval(), 1)); } - CHECK(!select(px > pz, py, py + pz).Match( - tir::SelectNode::make(x > 2, y, y + 1))); - CHECK(!select(px > pz, py, py).Match( - tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(!select(px > pz, py, py + pz).Match(tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(!select(px > pz, py, py).Match(tir::SelectNode::make(x > 2, y, y + 1))); { - CHECK(select(px, py, pz).Match( - tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(select(px, py, pz).Match(tir::SelectNode::make(x > 2, y, y + 1))); CHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1)); } // if_then_else { - CHECK(if_then_else(px > pz, py, py + pz).Match( - if_then_else(x > 1, y, y + 1))); + CHECK(if_then_else(px > pz, py, py + pz).Match(if_then_else(x > 1, y, y + 1))); CHECK(is_const_int(pz.Eval(), 1)); } // cast pattern { - CHECK(!cast(PConst( - DataType::Int(32)), px).Match(tir::CastNode::make(DataType::Float(64), x))); + CHECK(!cast(PConst(DataType::Int(32)), px) + .Match(tir::CastNode::make(DataType::Float(64), x))); CHECK(cast(pt, px).Match(tir::CastNode::make(DataType::Float(64), x))); CHECK(pt.Eval() == DataType::Float(64)); auto zz = cast(pt, px).Eval(); - CHECK((cast(pt, px) - cast(pt, py)).Match( - tir::CastNode::make(DataType::Float(64), x) - tir::CastNode::make(DataType::Int(64), x))); + CHECK((cast(pt, px) - cast(pt, py)) + .Match(tir::CastNode::make(DataType::Float(64), x) - + tir::CastNode::make(DataType::Int(64), x))); auto expr = tir::CastNode::make(DataType::Int(32), tir::CastNode::make(DataType::Float(64), x)); CHECK(!(cast(pt, cast(pt, px))).Match(expr)); } // ramp pattern { - CHECK(ramp(px, PConst(1), planes).Match( - tir::RampNode::make(x, 1, 10))); + CHECK(ramp(px, PConst(1), planes).Match(tir::RampNode::make(x, 1, 10))); CHECK(planes.Eval() == 10); - CHECK(!ramp(px, PConst(1), planes).Match( - tir::RampNode::make(x, 2, 10))); + CHECK(!ramp(px, PConst(1), planes).Match(tir::RampNode::make(x, 2, 10))); } // broadcast pattern { - CHECK(broadcast(px, planes).Match( - tir::BroadcastNode::make(x, 10))); + CHECK(broadcast(px, planes).Match(tir::BroadcastNode::make(x, 10))); CHECK(planes.Eval() == 10); - CHECK(broadcast(px * py , planes).Match( - tir::BroadcastNode::make(x * 10, 10))); + CHECK(broadcast(px * py, planes).Match(tir::BroadcastNode::make(x * 10, 10))); } } @@ -148,7 +140,7 @@ TEST(Pattern, IntImm) { CHECK(!(v * c).Match((tx + 1) * 3)); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 8048e10..33f6061 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -18,61 +18,59 @@ */ #include -#include -#include -#include -#include -#include -#include -#include -#include #include #include -#include +#include #include +#include +#include +#include +#include +#include +#include #include +#include #include +#include using namespace tvm; using namespace tvm::relay; TVM_REGISTER_GLOBAL("test.strategy") -.set_body_typed([](const Attrs& attrs, const Array& inputs, - const Type& out_type, const Target& target) { - FTVMCompute fcompute = [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) -> Array { + .set_body_typed([](const Attrs& attrs, const Array& inputs, const Type& out_type, + const Target& target) { + FTVMCompute fcompute = [](const Attrs& attrs, const Array& inputs, + const Type& out_type) -> Array { CHECK_EQ(inputs.size(), 2U); return {topi::add(inputs[0], inputs[1])}; - }; - FTVMSchedule fschedule = [](const Attrs& attrs, - const Array& outs, - const Target& target) { + }; + FTVMSchedule fschedule = [](const Attrs& attrs, const Array& outs, + const Target& target) { With target_scope(target); return topi::generic::schedule_injective(target, outs); - }; + }; - auto n = make_object(); - auto strategy = tvm::relay::OpStrategy(std::move(n)); - strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10); - return strategy; -}); + auto n = make_object(); + auto strategy = tvm::relay::OpStrategy(std::move(n)); + strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10); + return strategy; + }); TVM_REGISTER_GLOBAL("relay.backend.lower_call") -.set_body_typed([](const relay::Call& call, const Array& inputs, - const Target& target) { - static auto fstrategy = Op::GetAttr("FTVMStrategy"); - Op op = Downcast(call->op); - auto out_type = call->checked_type(); - OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); - auto impl = strategy->specializations[0]->implementations[0]; - auto outs = impl.Compute(call->attrs, inputs, out_type); - auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput"); - if (!f) { - LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered"; - } - return (*f)(outs, impl); -}); + .set_body_typed([](const relay::Call& call, const Array& inputs, + const Target& target) { + static auto fstrategy = Op::GetAttr("FTVMStrategy"); + Op op = Downcast(call->op); + auto out_type = call->checked_type(); + OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); + auto impl = strategy->specializations[0]->implementations[0]; + auto outs = impl.Compute(call->attrs, inputs, out_type); + auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput"); + if (!f) { + LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered"; + } + return (*f)(outs, impl); + }); TEST(Relay, BuildModule) { auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32)); @@ -178,7 +176,7 @@ TEST(Relay, GetExprRefCount) { CHECK(ref_count[z.get()] == 1); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 3c41691..cb7330d 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -19,30 +19,30 @@ #include #include -#include -#include -#include #include +#include #include +#include +#include TEST(Relay, SelfReference) { using namespace tvm; auto tensor_type = relay::TensorType({}, DataType::Bool()); auto x = relay::Var("x", relay::Type()); - auto f = relay::Function(tvm::Array{ x }, x, relay::Type(), {}); + auto f = relay::Function(tvm::Array{x}, x, relay::Type(), {}); CHECK(f->IsInstance()); auto y = relay::Var("y", tensor_type); - auto call = relay::Call(f, Array{ y }); - auto fx = relay::Function(tvm::Array{ y }, call, relay::Type(), {}); + auto call = relay::Call(f, Array{y}); + auto fx = relay::Function(tvm::Array{y}, call, relay::Type(), {}); auto mod = IRModule::FromExpr(fx); mod = relay::transform::InferType()(mod); auto type_fx = mod->Lookup("main"); - auto expected = relay::FuncType(tvm::Array{ tensor_type }, tensor_type, {}, {}); + auto expected = relay::FuncType(tvm::Array{tensor_type}, tensor_type, {}, {}); CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected)); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index d974f02..e01a6ea 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -19,27 +19,25 @@ #include #include -#include #include -#include #include +#include #include +#include #include #include #include #include #include -TVM_REGISTER_GLOBAL("schedule") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { - *rv = topi::generic::schedule_injective(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("schedule").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + *rv = topi::generic::schedule_injective(args[0], args[1]); +}); TEST(Relay, Sequential) { using namespace tvm; auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32)); - auto c_data = - tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); // Create a function for optimization. auto c = relay::Constant(c_data); @@ -53,8 +51,7 @@ TEST(Relay, Sequential) { auto z2 = relay::Call(add_op, {z, z1}); // Let expression and varaible a should be dead-code eliminated. auto z3 = relay::Let(a, c, z2); - relay::Function func = - relay::Function(relay::FreeVars(z3), z3, relay::Type(), {}); + relay::Function func = relay::Function(relay::FreeVars(z3), z3, relay::Type(), {}); // Get schedule auto reg = tvm::runtime::Registry::Get("relay.op._Register"); @@ -67,11 +64,8 @@ TEST(Relay, Sequential) { // Run sequential passes. tvm::Array pass_seqs{ - relay::transform::InferType(), - relay::transform::DeadCodeElimination(), - relay::transform::EliminateCommonSubexpr(), - relay::transform::AlterOpLayout() - }; + relay::transform::InferType(), relay::transform::DeadCodeElimination(), + relay::transform::EliminateCommonSubexpr(), relay::transform::AlterOpLayout()}; relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); auto mod = IRModule::FromExpr(func); auto pass_ctx = relay::transform::PassContext::Create(); @@ -96,8 +90,7 @@ TEST(Relay, Sequential) { y1 = relay::Call(add_op, {x1, y1}); auto zz = relay::Call(add_op, {y1, c1}); zz = relay::Call(add_op, {zz, zz}); - relay::Function expected_func = - relay::Function(relay::FreeVars(zz), zz, relay::Type(), {}); + relay::Function expected_func = relay::Function(relay::FreeVars(zz), zz, relay::Type(), {}); // Infer type for the expected function. auto mod1 = IRModule::FromExpr(expected_func); diff --git a/tests/cpp/simple_passes_test.cc b/tests/cpp/simple_passes_test.cc index be4c746..36b3645 100644 --- a/tests/cpp/simple_passes_test.cc +++ b/tests/cpp/simple_passes_test.cc @@ -33,8 +33,7 @@ TEST(SimplePasses, HasSideEffect) { CHECK(!tvm::tir::HasSideEffect(A[0])); } - -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc index a9566cb..ea02ca6 100644 --- a/tests/cpp/tensor_test.cc +++ b/tests/cpp/tensor_test.cc @@ -30,9 +30,8 @@ TEST(Tensor, Basic) { Tensor A = placeholder({m, l}, DataType::Float(32), "A"); Tensor B = placeholder({n, l}, DataType::Float(32), "B"); - auto C = compute({m, n}, [&](Var i, Var j) { - return A[i][j]; - }, "C"); + auto C = compute( + {m, n}, [&](Var i, Var j) { return A[i][j]; }, "C"); Tensor::Slice x = A[n]; } @@ -46,13 +45,12 @@ TEST(Tensor, Reduce) { te::Tensor B = te::placeholder({n, l}, DataType::Float(32), "B"); IterVar rv = reduce_axis(Range{0, l}, "k"); - auto C = te::compute({m, n}, [&](Var i, Var j) { - return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); - }, "C"); + auto C = te::compute( + {m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C"); LOG(INFO) << C->op.as()->body; } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/threading_backend_test.cc b/tests/cpp/threading_backend_test.cc index 508705c..cf7434b 100644 --- a/tests/cpp/threading_backend_test.cc +++ b/tests/cpp/threading_backend_test.cc @@ -17,13 +17,13 @@ * under the License. */ +#include +#include + #include #include #include -#include -#include - constexpr size_t N = 128; static FTVMParallelLambda atomic_add_task_id = [](int task_id, TVMParallelGroupEnv* penv, diff --git a/tests/cpp/topi_ewise_test.cc b/tests/cpp/topi_ewise_test.cc index a1ca6d7..10c7b9d 100644 --- a/tests/cpp/topi_ewise_test.cc +++ b/tests/cpp/topi_ewise_test.cc @@ -17,9 +17,9 @@ * under the License. */ -#include -#include #include +#include +#include namespace topi { TEST(Tensor, Basic) { @@ -28,9 +28,9 @@ TEST(Tensor, Basic) { Tensor A = placeholder({m, l}, DataType::Float(32), "A"); auto C = topi::exp(A); } -} +} // namespace topi -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index fa80926..c9c9f88 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -17,11 +17,11 @@ * under the License. */ -#include - #include #include + #include +#include #include #ifdef USE_MICRO_STANDALONE_RUNTIME @@ -30,9 +30,10 @@ #if defined(__APPLE__) && defined(__MACH__) #include +#include +#include #include #include -#include #include #include #include @@ -41,9 +42,7 @@ #include #include #include - -#include -#include +#include TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { *rv = topi::generic::schedule_injective(args[0], args[1]); diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index 98614c3..1b36ace 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -28,8 +28,8 @@ #include #include -#include #include +#include namespace topi { @@ -49,8 +49,8 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, std::string name = "T_broadcast_to", std::string tag = kBroadcast) { CHECK_GE(output_shape.size(), t->shape.size()) - << "Not a broadcast, output dimensionality smaller than input.\noutput: " - << output_shape << "\nvs\ninput: " << t; + << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape + << "\nvs\ninput: " << t; auto bh = detail::BroadcastShape(output_shape, t->shape); CHECK_EQ(output_shape.size(), bh.common_shape.size()); for (size_t i = 0; i < output_shape.size(); ++i) { @@ -59,57 +59,39 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, auto l = [&](tvm::Array ovars) { return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; - return tvm::te::compute( - tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, - name, - tag); + return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), + l, name, tag); } -#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ - inline tvm::PrimExpr Name(const tvm::PrimExpr& a, \ - const tvm::PrimExpr& b) { \ - ComputeRule; \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, \ - std::string tag = kBroadcast) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return detail::WithBroadcast(l, A, B, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::PrimExpr& B, \ - std::string name = "T_" #Name, \ - std::string tag = kElementWise) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute(A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \ - return l(A(i), B); \ - }, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::PrimExpr& A, \ - const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, \ - std::string tag = kElementWise) { \ - auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute(B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \ - return l(A, B(i)); \ - }, name, tag); \ +#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ + inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kBroadcast) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return detail::WithBroadcast(l, A, B, name, tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \ } - -#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::te::Tensor& B) { \ - return topi::OpName(A, B); \ - } \ - inline tvm::te::Tensor Name(const tvm::PrimExpr& A, \ - const tvm::te::Tensor& B) { \ - return topi::OpName(A, B); \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::PrimExpr& B) { \ - return topi::OpName(A, B); \ +#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \ + return topi::OpName(A, B); \ + } \ + inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \ + return topi::OpName(A, B); \ + } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \ + return topi::OpName(A, B); \ } /*! diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h index f2ed029..0c04eaf 100644 --- a/topi/include/topi/contrib/cublas.h +++ b/topi/include/topi/contrib/cublas.h @@ -24,8 +24,8 @@ #ifndef TOPI_CONTRIB_CUBLAS_H_ #define TOPI_CONTRIB_CUBLAS_H_ -#include #include +#include namespace topi { namespace contrib { @@ -33,65 +33,51 @@ using namespace tvm; using namespace tvm::te; using namespace topi::detail; /*! -* \brief Create an op that multiplies lhs and rhs with cuBLAS -* -* \param lhs The left matrix operand -* \param rhs The right matrix operand -* \param transa Whether to transpose lhs -* \param transb Whether to transpose rhs -* -* \return The output tensor -*/ -inline Tensor cublas_matmul(const Tensor& lhs, - const Tensor& rhs, - bool transa, - bool transb) { + * \brief Create an op that multiplies lhs and rhs with cuBLAS + * + * \param lhs The left matrix operand + * \param rhs The right matrix operand + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { auto n = transa ? lhs->shape[1] : lhs->shape[0]; auto m = transb ? rhs->shape[0] : rhs->shape[1]; return make_extern( - { { n, m } }, { lhs->dtype }, { lhs, rhs }, - [&](Array ins, Array outs) { - return call_packed({ - StringImmNode::make("tvm.contrib.cublas.matmul"), - pack_buffer(ins[0]), - pack_buffer(ins[1]), - pack_buffer(outs[0]), - transa, - transb }); - }, "C", "", {})[0]; + {{n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImmNode::make("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), + pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } /*! -* \brief Create an op that multiplies batch matrices -* lhs and rhs with cuBLAS -* -* \param lhs The left matrix operand -* \param rhs The right matrix operand -* \param transa Whether to transpose lhs -* \param transb Whether to transpose rhs -* -* \return The output tensor -*/ -inline Tensor cublas_batch_matmul(const Tensor& lhs, - const Tensor& rhs, - bool transa, - bool transb) { + * \brief Create an op that multiplies batch matrices + * lhs and rhs with cuBLAS + * + * \param lhs The left matrix operand + * \param rhs The right matrix operand + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { auto b = lhs->shape[0]; auto n = transa ? lhs->shape[2] : lhs->shape[1]; auto m = transb ? rhs->shape[1] : rhs->shape[2]; - return make_extern( - { { b, n, m } }, { lhs->dtype }, { lhs, rhs }, - [&](Array ins, Array outs) { - return call_packed({ - StringImmNode::make("tvm.contrib.cublas.batch_matmul"), - pack_buffer(ins[0]), - pack_buffer(ins[1]), - pack_buffer(outs[0]), - transa, - transb }); - }, "C", "", {})[0]; + return make_extern({{b, n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImmNode::make("tvm.contrib.cublas.batch_matmul"), + pack_buffer(ins[0]), pack_buffer(ins[1]), + pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } } // namespace contrib diff --git a/topi/include/topi/contrib/rocblas.h b/topi/include/topi/contrib/rocblas.h index f0bf926..3baf105 100644 --- a/topi/include/topi/contrib/rocblas.h +++ b/topi/include/topi/contrib/rocblas.h @@ -25,6 +25,7 @@ #define TOPI_CONTRIB_ROCBLAS_H_ #include + #include "topi/detail/extern.h" namespace topi { @@ -32,33 +33,26 @@ namespace contrib { using namespace tvm; using namespace tvm::te; /*! -* \brief Create an op that multiplies lhs and rhs with rocBLAS -* -* \param lhs The left matrix operand -* \param rhs The right matrix operand -* \param transa Whether to transpose lhs -* \param transb Whether to transpose rhs -* -* \return The output tensor -*/ -inline Tensor rocblas_matmul(const Tensor& lhs, - const Tensor& rhs, - bool transa, - bool transb) { + * \brief Create an op that multiplies lhs and rhs with rocBLAS + * + * \param lhs The left matrix operand + * \param rhs The right matrix operand + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { auto n = transa ? lhs->shape[1] : lhs->shape[0]; auto m = transb ? rhs->shape[0] : rhs->shape[1]; return make_extern( - { { n, m } }, { lhs->dtype }, { lhs, rhs }, - [&](Array ins, Array outs) { - return call_packed({ - StringImmNode::make("tvm.contrib.rocblas.matmul"), - pack_buffer(ins[0]), - pack_buffer(ins[1]), - pack_buffer(outs[0]), - transa, - transb }); - }, "C", "", {})[0]; + {{n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImmNode::make("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), + pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } } // namespace contrib diff --git a/topi/include/topi/cuda/dense.h b/topi/include/topi/cuda/dense.h index 1f0701e..145d249 100644 --- a/topi/include/topi/cuda/dense.h +++ b/topi/include/topi/cuda/dense.h @@ -24,14 +24,14 @@ #ifndef TOPI_CUDA_DENSE_H_ #define TOPI_CUDA_DENSE_H_ -#include -#include -#include -#include -#include -#include #include +#include #include +#include +#include +#include +#include +#include namespace topi { using namespace tvm; @@ -39,21 +39,19 @@ using namespace tvm::te; namespace cuda { /*! -* \brief Implementation of dense for CUDA backend -* -* \param target The target device -* \param data Tensor with shape [batch, in_dim] -* \param weight Tensor with shape [out_dim, in_dim] -* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() -* \param out_dtype Output data type. Used for mixed precision. -* -* \return Tensor with shape [batch, out_dim] -*/ -inline tvm::te::Tensor dense_cuda(const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { + * \brief Implementation of dense for CUDA backend + * + * \param target The target device + * \param data Tensor with shape [batch, in_dim] + * \param weight Tensor with shape [out_dim, in_dim] + * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() + * \param out_dtype Output data type. Used for mixed precision. + * + * \return Tensor with shape [batch, out_dim] + */ +inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, + const DataType& out_dtype) { CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { @@ -68,10 +66,8 @@ inline tvm::te::Tensor dense_cuda(const Target& target, CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported."; auto mm = topi::contrib::cublas_matmul(data, weight, false, true); if (bias.defined()) { - mm = tvm::te::compute({ batch, out_dim }, - [&](Var i, Var j) { - return mm(i, j) + bias(j); - }, "tensor", kBroadcast); + mm = tvm::te::compute( + {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast); } return mm; @@ -81,16 +77,15 @@ inline tvm::te::Tensor dense_cuda(const Target& target, } /*! -* \brief Create a CUDA schedule for dense -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_dense(const Target &target, const Array& outs) { - if (target->target_name == "cuda" && - target->libs().count("cublas")) { + * \brief Create a CUDA schedule for dense + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_dense(const Target& target, const Array& outs) { + if (target->target_name == "cuda" && target->libs().count("cublas")) { return topi::generic::schedule_extern(target, outs); } diff --git a/topi/include/topi/cuda/injective.h b/topi/include/topi/cuda/injective.h index a7792a5..5a5c5af 100644 --- a/topi/include/topi/cuda/injective.h +++ b/topi/include/topi/cuda/injective.h @@ -24,11 +24,11 @@ #ifndef TOPI_CUDA_INJECTIVE_H_ #define TOPI_CUDA_INJECTIVE_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -63,7 +63,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out * * \return A schedule for the given ops. */ -inline Schedule schedule_injective(const Target &target, const Array& outs) { +inline Schedule schedule_injective(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/cuda/normalization.h b/topi/include/topi/cuda/normalization.h index bfc209d..f8f498e 100644 --- a/topi/include/topi/cuda/normalization.h +++ b/topi/include/topi/cuda/normalization.h @@ -24,20 +24,20 @@ #ifndef TOPI_CUDA_NORMALIZATION_H_ #define TOPI_CUDA_NORMALIZATION_H_ +#include +#include #include #include -#include -#include namespace topi { using namespace tvm; using namespace tvm::te; namespace cuda { /*! -* \brief Create a CUDA schedule for LRN -* \param outs The output tensors. -* \return A schedule for the given ops. -*/ + * \brief Create a CUDA schedule for LRN + * \param outs The output tensors. + * \return A schedule for the given ops. + */ inline Schedule schedule_lrn(const Array& outs) { Array out_ops; for (auto t : outs) { diff --git a/topi/include/topi/cuda/pooling.h b/topi/include/topi/cuda/pooling.h index 75b66b3..87866f2 100644 --- a/topi/include/topi/cuda/pooling.h +++ b/topi/include/topi/cuda/pooling.h @@ -24,12 +24,12 @@ #ifndef TOPI_CUDA_POOLING_H_ #define TOPI_CUDA_POOLING_H_ +#include +#include +#include +#include #include #include -#include -#include -#include -#include namespace topi { using namespace tvm; @@ -38,14 +38,14 @@ using namespace tvm::te; namespace cuda { /*! -* \brief Create a CUDA schedule for pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_pool(const Target &target, const Array& outs) { + * \brief Create a CUDA schedule for pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_pool(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -105,14 +105,14 @@ inline Schedule schedule_pool(const Target &target, const Array& outs) { } /*! -* \brief Create a CUDA schedule for global_pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_global_pool(const Target &target, const Array& outs) { + * \brief Create a CUDA schedule for global_pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_global_pool(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -142,7 +142,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array& s[out].split(i, num_thread, &by, &ty); IterVar bx, tx; s[out].split(c, num_thread, &bx, &tx); - s[out].reorder({ by, bx, ty, tx }); + s[out].reorder({by, bx, ty, tx}); s[out].bind(ty, thread_y); s[out].bind(tx, thread_x); s[out].bind(by, block_y); diff --git a/topi/include/topi/cuda/reduction.h b/topi/include/topi/cuda/reduction.h index add8d99..35ce346 100644 --- a/topi/include/topi/cuda/reduction.h +++ b/topi/include/topi/cuda/reduction.h @@ -24,11 +24,11 @@ #ifndef TOPI_CUDA_REDUCTION_H_ #define TOPI_CUDA_REDUCTION_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -45,10 +45,8 @@ namespace cuda { * an index, such as argmax or argmin. * * \return The schedule given by sch -*/ -Schedule ScheduleReduce(const Target& target, - Operation op, - Schedule sch, + */ +Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch, bool is_idx_reduce = false) { Tensor data_out; Tensor data_in; @@ -61,8 +59,8 @@ Schedule ScheduleReduce(const Target& target, } auto out_stage = sch[data_out]; - CHECK_GT(out_stage->op.as()->reduce_axis.size(), 0) << - "reduce_axis must be greater than zero"; + CHECK_GT(out_stage->op.as()->reduce_axis.size(), 0) + << "reduce_axis must be greater than zero"; bool all_reduce; int num_thread; @@ -120,10 +118,8 @@ Schedule ScheduleReduce(const Target& target, } } else { if (is_idx_reduce) { - sch[temp_idx_input].compute_at(stage_real, - stage_real->op.as()->axis[0]); - sch[temp_val_input].compute_at(stage_real, - stage_real->op.as()->axis[0]); + sch[temp_idx_input].compute_at(stage_real, stage_real->op.as()->axis[0]); + sch[temp_val_input].compute_at(stage_real, stage_real->op.as()->axis[0]); } } @@ -152,13 +148,13 @@ void TraverseBeforeReduce(Schedule s, Operation op) { } /*! -* \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each -* of the op's inputs. -* -* \param target The target to generate a schedule for. -* \param s The schedule we are building -* \param op The reduce op -*/ + * \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each + * of the op's inputs. + * + * \param target The target to generate a schedule for. + * \param s The schedule we are building + * \param op The reduce op + */ void TraverseAfterReduce(const Target& target, Schedule s, Operation op) { if (is_broadcast(op->tag)) { LOG(ERROR) << "Elementwise op after reduce is not yet supported"; @@ -178,13 +174,13 @@ void TraverseAfterReduce(const Target& target, Schedule s, Operation op) { } /*! -* \brief Create a CUDA schedule for a reduce operation. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a CUDA schedule for a reduce operation. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ Schedule schedule_reduce(const Target& target, Array outs) { CHECK_EQ(outs.size(), 1) << "outs must have size 1"; Array out_ops; diff --git a/topi/include/topi/cuda/softmax.h b/topi/include/topi/cuda/softmax.h index 4c88c3e..a3aa857 100644 --- a/topi/include/topi/cuda/softmax.h +++ b/topi/include/topi/cuda/softmax.h @@ -24,11 +24,11 @@ #ifndef TOPI_CUDA_SOFTMAX_H_ #define TOPI_CUDA_SOFTMAX_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -44,7 +44,7 @@ namespace cuda { * * \return A schedule for the given ops. */ -inline Schedule schedule_softmax(const Target &target, const Array& outs) { +inline Schedule schedule_softmax(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/detail/array_utils.h b/topi/include/topi/detail/array_utils.h index 3a3453a..d720472 100644 --- a/topi/include/topi/detail/array_utils.h +++ b/topi/include/topi/detail/array_utils.h @@ -39,7 +39,7 @@ using namespace tvm::te; * * \return True iff the given array contains the given item. */ -template +template inline bool contains(Array array, T item) { for (auto& i : array) { if (i == item) { diff --git a/topi/include/topi/detail/broadcast.h b/topi/include/topi/detail/broadcast.h index 8622920..ca30293 100644 --- a/topi/include/topi/detail/broadcast.h +++ b/topi/include/topi/detail/broadcast.h @@ -24,8 +24,8 @@ #ifndef TOPI_DETAIL_BROADCAST_H_ #define TOPI_DETAIL_BROADCAST_H_ -#include #include +#include #include #include @@ -77,10 +77,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, bh.vars1.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]); } else { - CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] - << " and " << shape2[s2_size - i] << " in: " - << tvm::Array(shape1.begin(), shape1.end()) - << " and " + CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and " + << shape2[s2_size - i] + << " in: " << tvm::Array(shape1.begin(), shape1.end()) << " and " << tvm::Array(shape2.begin(), shape2.end()); } } @@ -97,10 +96,8 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, } inline tvm::Array InputIndexFromBroadcast( - const tvm::Array& ovars, - const tvm::te::Tensor& T, - const std::deque& my_vars, - const std::deque& all_vars) { + const tvm::Array& ovars, const tvm::te::Tensor& T, + const std::deque& my_vars, const std::deque& all_vars) { tvm::Array ivars; CHECK_EQ(ovars.size(), all_vars.size()); // N^2, could use a map but NBD. @@ -125,21 +122,16 @@ inline tvm::Array InputIndexFromBroadcast( } template -inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, - const tvm::te::Tensor& A, - const tvm::te::Tensor& B, - const std::string& name = "tensor", - const std::string& tag = "") { +inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A, + const tvm::te::Tensor& B, const std::string& name = "tensor", + const std::string& tag = "") { auto bh = BroadcastShape(A->shape, B->shape); auto l = [&](tvm::Array ovars) { return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)), B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars))); }; - return tvm::te::compute( - tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, - name, - tag); + return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), + l, name, tag); } } // namespace detail diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index afa8833..9bd1251 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -24,10 +24,10 @@ #ifndef TOPI_DETAIL_CONSTANT_UTILS_H_ #define TOPI_DETAIL_CONSTANT_UTILS_H_ -#include #include -#include #include +#include +#include #include #include @@ -44,10 +44,7 @@ using namespace tvm::te; * * \return true if the given expr is a constant int or uint, false otherwise. */ -inline bool IsConstInt(PrimExpr expr) { - return - expr->IsInstance(); -} +inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance(); } /*! * \brief Get the value of the given constant integer expression. An error @@ -74,13 +71,11 @@ inline int64_t GetConstInt(PrimExpr expr) { * * \return A vector of the integer values */ -inline std::vector GetConstIntValues( - Array exprs, const std::string& var_name) { +inline std::vector GetConstIntValues(Array exprs, const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { - CHECK(IsConstInt(expr)) << "All elements of " - << var_name << " must be constant integers"; + CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers"; result.push_back(GetConstInt(expr)); } return result; @@ -95,8 +90,8 @@ inline std::vector GetConstIntValues( * * \return A vector of the int64_t values */ -inline std::vector GetConstInt64Values( - Array exprs, const std::string& var_name) { +inline std::vector GetConstInt64Values(Array exprs, + const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { @@ -107,8 +102,8 @@ inline std::vector GetConstInt64Values( } /*! - * \brief Check weather the two expressions are equal or not, if not simplify the expressions and check again - * \note This is stronger equality check than tvm::tir::Equal + * \brief Check weather the two expressions are equal or not, if not simplify the expressions and + * check again \note This is stronger equality check than tvm::tir::Equal * * \param lhs First expreesion * \param rhs Second expreesion @@ -120,7 +115,7 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { bool result = expr_equal(lhs, rhs); if (!result) { PrimExpr zero(0); - result = expr_equal(tvm::arith::Analyzer().Simplify(lhs-rhs), zero); + result = expr_equal(tvm::arith::Analyzer().Simplify(lhs - rhs), zero); } return result; } diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index ab83200..e6ede6a 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -25,9 +25,9 @@ #define TOPI_DETAIL_EXTERN_H_ #include -#include -#include +#include +#include namespace topi { namespace detail { @@ -43,13 +43,11 @@ using namespace tvm::te; * * \return The Buffer object */ -inline Buffer DeclExternBuffer(Array shape, - DataType dtype, - std::string name) { +inline Buffer DeclExternBuffer(Array shape, DataType dtype, std::string name) { auto data = var(name, DataType::Handle()); auto elem_offset = PrimExpr(); - return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - -1, 0, kDefault); + return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, + kDefault); } /*! @@ -76,15 +74,12 @@ using FExtern = std::function, Array)>; * be one output Tensor for each element of out_shapes, with dtype equal to the corresponding * element of out_types. */ -inline Array make_extern(const Array< Array >& out_shapes, +inline Array make_extern(const Array >& out_shapes, const std::vector& out_types, - const Array& inputs, - FExtern fextern, - std::string name, - std::string tag, - ::tvm::Map attrs) { + const Array& inputs, FExtern fextern, std::string name, + std::string tag, ::tvm::Map attrs) { CHECK_EQ(out_shapes.size(), out_types.size()) - << "make_extern: out_shapes and out_types must have equal size"; + << "make_extern: out_shapes and out_types must have equal size"; Array input_placeholders; for (auto t : inputs) { @@ -98,9 +93,8 @@ inline Array make_extern(const Array< Array >& out_shapes, auto body = fextern(input_placeholders, output_placeholders); auto body_stmt = tvm::tir::EvaluateNode::make(body); - auto op = ExternOpNode::make( - name, tag, attrs, inputs, - input_placeholders, output_placeholders, body_stmt); + auto op = ExternOpNode::make(name, tag, attrs, inputs, input_placeholders, output_placeholders, + body_stmt); Array outputs; for (size_t i = 0; i < output_placeholders.size(); ++i) { @@ -119,27 +113,25 @@ inline Array make_extern(const Array< Array >& out_shapes, */ inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = tvm::tir::CallNode::make( - DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + auto shape = + tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tir::CallNode::make( - DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + strides = + tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); } else { strides = 0; } - Array pack_args{ - buf->data, - shape, - strides, - make_const(DataType::Int(32), static_cast(buf->shape.size())), - make_const(buf->dtype, 0), - buf->elem_offset - }; + Array pack_args{buf->data, + shape, + strides, + make_const(DataType::Int(32), static_cast(buf->shape.size())), + make_const(buf->dtype, 0), + buf->elem_offset}; return tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, - pack_args, tvm::tir::CallNode::CallType::Intrinsic); + pack_args, tvm::tir::CallNode::CallType::Intrinsic); } /*! @@ -152,8 +144,8 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(Array args) { - return tvm::tir::CallNode::make(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, - args, tvm::tir::CallNode::CallType::Intrinsic); + return tvm::tir::CallNode::make(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args, + tvm::tir::CallNode::CallType::Intrinsic); } } // namespace detail diff --git a/topi/include/topi/detail/pad_utils.h b/topi/include/topi/detail/pad_utils.h index 1f2a7c5..7c416ec 100644 --- a/topi/include/topi/detail/pad_utils.h +++ b/topi/include/topi/detail/pad_utils.h @@ -18,16 +18,17 @@ */ /*! -* \file pad_utils.h -* \brief Padding helpers -*/ + * \file pad_utils.h + * \brief Padding helpers + */ #ifndef TOPI_DETAIL_PAD_UTILS_H_ #define TOPI_DETAIL_PAD_UTILS_H_ -#include +#include +#include +#include -#include "tvm/tir/expr.h" -#include "tvm/tir/op.h" +#include namespace topi { namespace detail { @@ -50,7 +51,7 @@ inline Array GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) { auto pad_top = indexdiv(pad_h + 1, 2); auto pad_left = indexdiv(pad_w + 1, 2); - return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left }; + return {pad_top, pad_left, pad_h - pad_top, pad_w - pad_left}; } } // namespace detail diff --git a/topi/include/topi/detail/ravel_unravel.h b/topi/include/topi/detail/ravel_unravel.h index ca46da0..c87f2c9 100644 --- a/topi/include/topi/detail/ravel_unravel.h +++ b/topi/include/topi/detail/ravel_unravel.h @@ -18,9 +18,9 @@ */ /*! -* \file ravel_unravel.h -* \brief Index ravel and unraval operations -*/ + * \file ravel_unravel.h + * \brief Index ravel and unraval operations + */ #ifndef TOPI_DETAIL_RAVEL_UNRAVEL_H_ #define TOPI_DETAIL_RAVEL_UNRAVEL_H_ @@ -34,13 +34,13 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Flatten the indices to 1D -* -* \param indices The input coordinates -* \param shape Shape of the tensor -* -* \return The index after flattening -*/ + * \brief Flatten the indices to 1D + * + * \param indices The input coordinates + * \param shape Shape of the tensor + * + * \return The index after flattening + */ inline PrimExpr RavelIndex(Array indices, Array shape) { CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; CHECK_GT(indices.size(), 0) << "indices must not be empty"; @@ -56,13 +56,13 @@ inline PrimExpr RavelIndex(Array indices, Array shape) { } /*! -* \brief Convert flattened index to coordinate array -* -* \param idx The 1D index -* \param shape Shape of the tensor -* -* \return The coordinate corresponding to the 1D index -*/ + * \brief Convert flattened index to coordinate array + * + * \param idx The 1D index + * \param shape Shape of the tensor + * + * \return The coordinate corresponding to the 1D index + */ inline Array UnravelIndex(PrimExpr idx, Array shape) { std::vector indices; diff --git a/topi/include/topi/detail/tensor_utils.h b/topi/include/topi/detail/tensor_utils.h index 6ac3982..d144c75 100644 --- a/topi/include/topi/detail/tensor_utils.h +++ b/topi/include/topi/detail/tensor_utils.h @@ -24,7 +24,6 @@ #ifndef TOPI_DETAIL_TENSOR_UTILS_H_ #define TOPI_DETAIL_TENSOR_UTILS_H_ - #include namespace topi { @@ -63,7 +62,7 @@ inline bool is_empty_shape(const Array& x) { * \return The interpolated value in the given index. */ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& indices, - const PrimExpr max_y, const PrimExpr max_x) { + const PrimExpr max_y, const PrimExpr max_x) { auto in_y = indices[2]; auto yf = tvm::floor(in_y); auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y)); @@ -85,9 +84,7 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& auto C = input(indices[0], indices[1], y1, x0); auto D = input(indices[0], indices[1], y1, x1); - return A * ( 1 - x_lerp) * ( 1 - y_lerp) + - B * x_lerp * (1 - y_lerp) + - C * (1 - x_lerp) * y_lerp + + return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp + D * x_lerp * y_lerp; } diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index dfcf83f..70daac2 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -24,10 +24,12 @@ #ifndef TOPI_ELEMWISE_H_ #define TOPI_ELEMWISE_H_ -#include #include +#include + #include #include + #include "broadcast.h" namespace topi { @@ -35,13 +37,11 @@ using namespace tvm; using namespace tvm::te; // Unary intrinsic operators -#define TOPI_DECLARE_UNARY_OP(OpName) \ - inline Tensor OpName(const Tensor& x, \ - std::string name = "T_" #OpName, \ - std::string tag = kElementWise) { \ - return compute(x->shape, [&](const Array& i) { \ - return ::tvm::OpName(x(i)); \ - }, name, tag); \ +#define TOPI_DECLARE_UNARY_OP(OpName) \ + inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \ + std::string tag = kElementWise) { \ + return compute( \ + x->shape, [&](const Array& i) { return ::tvm::OpName(x(i)); }, name, tag); \ } TOPI_DECLARE_UNARY_OP(exp); @@ -76,9 +76,7 @@ TOPI_DECLARE_UNARY_OP(isinf); * \brief Fast_tanh_float implementation from Eigen * https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26 */ -inline Tensor fast_tanh_float(const Tensor& in, - std::string name, - std::string tag) { +inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string tag) { // Clamp the inputs to the range [-9, 9] since anything outside // this range is +/-1.0f in single-precision. auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0)); @@ -98,178 +96,171 @@ inline Tensor fast_tanh_float(const Tensor& in, auto beta_4 = make_const(in->dtype, 1.18534705686654e-04); auto beta_6 = make_const(in->dtype, 1.19825839466702e-06); - return compute(x->shape, - [&](const Array& i) { - auto x2 = x(i) * x(i); - auto p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x(i) * p; - - auto q = x2 * beta_6 + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - return p / q; - }, - name, tag); + return compute( + x->shape, + [&](const Array& i) { + auto x2 = x(i) * x(i); + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x(i) * p; + + auto q = x2 * beta_6 + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + return p / q; + }, + name, tag); } /*! -* \brief Creates an operation that returns hyperbolic tanh of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is tanh -*/ -inline Tensor fast_tanh(const Tensor& x, - std::string name = "T_fast_tanh", + * \brief Creates an operation that returns hyperbolic tanh of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is tanh + */ +inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh", std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { // invoke fast_tanh_float implementation return fast_tanh_float(x, name, tag); } else { // fallback to default implementation - return compute(x->shape, [&](const Array& i) { - return ::tvm::tanh(x(i)); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ::tvm::tanh(x(i)); }, name, tag); } } /*! -* \brief Creates an operation that returns identity of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the identity operation -*/ -inline Tensor identity(const Tensor& x, - std::string name = "T_identity", + * \brief Creates an operation that returns identity of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the identity operation + */ +inline Tensor identity(const Tensor& x, std::string name = "T_identity", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return x(i); }, name, tag); } /*! -* \brief Creates an operation that returns the negation of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the negation operation -*/ -inline Tensor negative(const Tensor& x, - std::string name = "T_negative", + * \brief Creates an operation that returns the negation of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the negation operation + */ +inline Tensor negative(const Tensor& x, std::string name = "T_negative", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return -x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return -x(i); }, name, tag); } /*! -* \brief Creates an operation that returns the logical NOT of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the logical NOT operation -*/ -inline Tensor logical_not(const Tensor& x, - std::string name = "T_logical_not", + * \brief Creates an operation that returns the logical NOT of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the logical NOT operation + */ +inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return !x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return !x(i); }, name, tag); } /*! -* \brief Creates an operation that returns the bitwise NOT of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the bitwise NOT operation -*/ -inline Tensor bitwise_not(const Tensor& x, - std::string name = "T_bitwise_not", + * \brief Creates an operation that returns the bitwise NOT of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the bitwise NOT operation + */ +inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return ~x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ~x(i); }, name, tag); } /*! -* \brief Returns the sign of the tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the sign -*/ -inline Tensor sign(const Tensor& x, - std::string name = "T_sign", - std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - PrimExpr zero = make_zero(x->dtype); - PrimExpr one = make_const(x->dtype, 1); - PrimExpr minus_one = make_const(x->dtype, -1); - auto s1 = tvm::tir::SelectNode::make((x(i) < zero), minus_one, zero); - auto s2 = tvm::tir::SelectNode::make((x(i) > zero), one, s1); - return s2; - }, name, tag); + * \brief Returns the sign of the tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the sign + */ +inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag = kElementWise) { + return compute( + x->shape, + [&](const Array& i) { + PrimExpr zero = make_zero(x->dtype); + PrimExpr one = make_const(x->dtype, 1); + PrimExpr minus_one = make_const(x->dtype, -1); + auto s1 = tvm::tir::SelectNode::make((x(i) < zero), minus_one, zero); + auto s2 = tvm::tir::SelectNode::make((x(i) > zero), one, s1); + return s2; + }, + name, tag); } /*! -* \brief Creates an operation that returns rsqrt of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the rsqrt operation -*/ -inline Tensor rsqrt(const Tensor& x, - std::string name = "tensor", - std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - PrimExpr one = make_const(x->dtype, 1); - return one/tvm::sqrt(x(i)); - }, name, tag); + * \brief Creates an operation that returns rsqrt of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the rsqrt operation + */ +inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string tag = kElementWise) { + return compute( + x->shape, + [&](const Array& i) { + PrimExpr one = make_const(x->dtype, 1); + return one / tvm::sqrt(x(i)); + }, + name, tag); } /*! -* \brief Creates an operation that clips each element of a tensor to -* the interval [a_min, a_max] -* -* \param x The input tensor -* \param a_min The inclusive lower bound of the interval -* \param a_max The inclusive upper bound of the interval -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the clip operation -*/ -inline Tensor clip(const Tensor& x, - const PrimExpr& a_min, - const PrimExpr& a_max, - std::string name = "T_clip", - std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - auto min_val = tvm::cast(x->dtype, a_min); - auto max_val = tvm::cast(x->dtype, a_max); - return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) - }, name, tag); + * \brief Creates an operation that clips each element of a tensor to + * the interval [a_min, a_max] + * + * \param x The input tensor + * \param a_min The inclusive lower bound of the interval + * \param a_max The inclusive upper bound of the interval + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the clip operation + */ +inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max, + std::string name = "T_clip", std::string tag = kElementWise) { + return compute( + x->shape, + [&](const Array& i) { + auto min_val = tvm::cast(x->dtype, a_min); + auto max_val = tvm::cast(x->dtype, a_max); + return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) + }, + name, tag); } /*! @@ -284,22 +275,23 @@ inline Tensor clip(const Tensor& x, * * \return A Tensor whose op member is the cast operation */ -inline Tensor cast(const Tensor& x, - DataType type, - std::string name = "T_cast", +inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - auto expr = x(i); - if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { - if (expr.dtype().lanes() == type.lanes()) { - return expr; - } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { - return tvm::tir::BroadcastNode::make(expr, type.lanes()); - } - } - - return tvm::cast(type, x(i)); - }, name, tag); + return compute( + x->shape, + [&](const Array& i) { + auto expr = x(i); + if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { + if (expr.dtype().lanes() == type.lanes()) { + return expr; + } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { + return tvm::tir::BroadcastNode::make(expr, type.lanes()); + } + } + + return tvm::cast(type, x(i)); + }, + name, tag); } /*! @@ -314,12 +306,13 @@ inline Tensor cast(const Tensor& x, */ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor", std::string tag = kElementWise) { - return compute(x->shape, - [&](const Array& i) { - return tvm::tir::CallNode::make(type, "reinterpret", {x(i)}, - tvm::tir::CallNode::PureIntrinsic); - }, - name, tag); + return compute( + x->shape, + [&](const Array& i) { + return tvm::tir::CallNode::make(type, "reinterpret", {x(i)}, + tvm::tir::CallNode::PureIntrinsic); + }, + name, tag); } /*! @@ -331,63 +324,58 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te * * \return A Tensor whose op member is the sum operation */ -inline Tensor elemwise_sum(const Array& xs, - std::string name = "T_elemwise_sum", +inline Tensor elemwise_sum(const Array& xs, std::string name = "T_elemwise_sum", std::string tag = kElementWise) { CHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor."; - return compute(xs[0]->shape, [&](const Array& i) { - auto sum_expr = xs[0](i); - for (size_t j = 1; j < xs.size(); j++) { - sum_expr = sum_expr + xs[j](i); - } - return sum_expr; - }, name, tag); + return compute( + xs[0]->shape, + [&](const Array& i) { + auto sum_expr = xs[0](i); + for (size_t j = 1; j < xs.size(); j++) { + sum_expr = sum_expr + xs[j](i); + } + return sum_expr; + }, + name, tag); } /*! -* \brief Creates an operation that fill a tensor with fill_value -* -* \param shape The shape of a tensor -* \param dtype The Type of fill_value -* \param fill_value The value to be filled -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the full operation -*/ -inline Tensor full(const Array& shape, - DataType dtype, - const PrimExpr fill_value, - std::string name = "T_full", - std::string tag = kElementWise) { + * \brief Creates an operation that fill a tensor with fill_value + * + * \param shape The shape of a tensor + * \param dtype The Type of fill_value + * \param fill_value The value to be filled + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the full operation + */ +inline Tensor full(const Array& shape, DataType dtype, const PrimExpr fill_value, + std::string name = "T_full", std::string tag = kElementWise) { PrimExpr ev = cast(dtype, fill_value); if (!ev.defined()) { LOG(ERROR) << "Can't cast fill_value to " << dtype; } - return compute(shape, [&](const Array& i) { - return ev; - }, name, tag); + return compute( + shape, [&](const Array& i) { return ev; }, name, tag); } /*! -* \brief Creates an operation that construct a tensor with same shape as input tensor, -* then fill a tensor with fill_value -* -* \param x The input tensor -* \param fill_value The value to be filled -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op memeber is the full_like operation -*/ -inline Tensor full_like(const Tensor& x, - const PrimExpr fill_value, - std::string name = "T_full_like", - std::string tag = kElementWise) { + * \brief Creates an operation that construct a tensor with same shape as input tensor, + * then fill a tensor with fill_value + * + * \param x The input tensor + * \param fill_value The value to be filled + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op memeber is the full_like operation + */ +inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, + std::string name = "T_full_like", std::string tag = kElementWise) { PrimExpr ev = cast(x->dtype, fill_value); - return compute(x->shape, [&](const Array& i) { - return ev; - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ev; }, name, tag); } /*! @@ -411,9 +399,7 @@ inline Tensor full_like(const Tensor& x, * Approximation for fractional part: * y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2)) */ -inline Tensor fast_exp_float32(const Tensor& _x, - std::string name, - std::string tag) { +inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) { auto x_hi = make_const(DataType::Float(32), 88.3762626647950f); auto x_lo = make_const(DataType::Float(32), -88.3762626647949f); auto log2e = make_const(DataType::Float(32), 1.44269504088896341f); @@ -428,25 +414,25 @@ inline Tensor fast_exp_float32(const Tensor& _x, auto one_half = make_const(DataType::Float(32), 0.5f); auto b = make_const(DataType::Float(32), 127.0f); - return compute(_x->shape, - [&](const Array& i) { - // clamp x - auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); - // integer part - auto n = ::tvm::floor(x * log2e + one_half); - // fractional part - auto f = x - n * ln2; - auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f - + p[5]) * f * f + f + one; - // Return 2^m * exp(r). - auto ef = tvm::reinterpret(DataType::Float(32), - ::tvm::cast(DataType::Int(32), n + b) << 23); - return ::tvm::max(ef * y, _x(i)); // NOLINT(*) - }, - name, tag); + return compute( + _x->shape, + [&](const Array& i) { + // clamp x + auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); + // integer part + auto n = ::tvm::floor(x * log2e + one_half); + // fractional part + auto f = x - n * ln2; + auto y = + (((((p[0] * f + p[1]) * f + p[2]) * f + p[3]) * f + p[4]) * f + p[5]) * f * f + f + one; + // Return 2^m * exp(r). + auto ef = + tvm::reinterpret(DataType::Float(32), ::tvm::cast(DataType::Int(32), n + b) << 23); + return ::tvm::max(ef * y, _x(i)); // NOLINT(*) + }, + name, tag); } - /*! * \brief Fast exponential function implementation * @@ -457,16 +443,14 @@ inline Tensor fast_exp_float32(const Tensor& _x, * \return A Tensor whose op member is exponent operation * */ -inline Tensor fast_exp(const Tensor& x, - std::string name = "T_fast_exp", - std::string tag = kElementWise) { +inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", + std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { auto ret = fast_exp_float32(x, name, tag); return ret; } else { - return compute(x->shape, [&](const Array& i) { - return ::tvm::exp(x(i)); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ::tvm::exp(x(i)); }, name, tag); } } @@ -474,9 +458,7 @@ inline Tensor fast_exp(const Tensor& x, * \brief Fast_tanh_float implementation from Eigen * https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290 */ -inline Tensor fast_erf_float32(const Tensor& data, - std::string name, - std::string tag) { +inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) { auto plus_4 = make_const(DataType::Float(32), 4.f); auto minus_4 = make_const(DataType::Float(32), -4.f); @@ -496,28 +478,31 @@ inline Tensor fast_erf_float32(const Tensor& data, auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f); auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f); - return compute(data->shape, [&](const Array &i) { - // clamp x - auto x = tvm::max(tvm::min(data(i), plus_4), minus_4); - auto x2 = x * x; - - // Evaluate the numerator polynomial p. - auto p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x * p; - - // Evaluate the denominator polynomial p. - auto q = x2 * beta_8 + beta_6; - q = x2 * q + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - - return p / q; - }, name, tag); + return compute( + data->shape, + [&](const Array& i) { + // clamp x + auto x = tvm::max(tvm::min(data(i), plus_4), minus_4); + auto x2 = x * x; + + // Evaluate the numerator polynomial p. + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x * p; + + // Evaluate the denominator polynomial p. + auto q = x2 * beta_8 + beta_6; + q = x2 * q + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + + return p / q; + }, + name, tag); } /*! @@ -529,8 +514,7 @@ inline Tensor fast_erf_float32(const Tensor& data, * * \return A Tensor whose op member is erf operation */ -inline Tensor fast_erf(const Tensor& x, - std::string name = "T_fast_erf", +inline Tensor fast_erf(const Tensor& x, std::string name = "T_fast_erf", std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { auto ret = fast_erf_float32(x, name, tag); diff --git a/topi/include/topi/generic/default.h b/topi/include/topi/generic/default.h index 640ab95..c44bc69 100644 --- a/topi/include/topi/generic/default.h +++ b/topi/include/topi/generic/default.h @@ -24,11 +24,11 @@ #ifndef TOPI_GENERIC_DEFAULT_H_ #define TOPI_GENERIC_DEFAULT_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -36,13 +36,13 @@ using namespace tvm::te; namespace generic { /*! -* \brief Create a generic default schedule for the given output tensors. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a generic default schedule for the given output tensors. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ inline Schedule default_schedule(const Target& target, Array outs) { Array out_ops; for (auto t : outs) { @@ -53,14 +53,14 @@ inline Schedule default_schedule(const Target& target, Array outs) { } /*! -* \brief Create a generic default schedule for the given output tensors, and apply -* auto inline -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a generic default schedule for the given output tensors, and apply + * auto inline + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ inline Schedule default_schedule_auto_inline(const Target& target, Array outs) { Array out_ops; for (auto t : outs) { diff --git a/topi/include/topi/generic/extern.h b/topi/include/topi/generic/extern.h index e08158f..6e5507b 100644 --- a/topi/include/topi/generic/extern.h +++ b/topi/include/topi/generic/extern.h @@ -24,12 +24,12 @@ #ifndef TOPI_GENERIC_EXTERN_H_ #define TOPI_GENERIC_EXTERN_H_ -#include -#include -#include -#include #include #include +#include +#include +#include +#include namespace topi { using namespace tvm; @@ -37,13 +37,13 @@ using namespace tvm::te; namespace generic { /*! -* \brief Schedule an extern op followed by injective operations -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the op. -*/ + * \brief Schedule an extern op followed by injective operations + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the op. + */ inline Schedule schedule_extern(const Target& target, Array outs) { Array out_ops; for (auto t : outs) { diff --git a/topi/include/topi/generic/injective.h b/topi/include/topi/generic/injective.h index 7a5aff7..69962dc 100644 --- a/topi/include/topi/generic/injective.h +++ b/topi/include/topi/generic/injective.h @@ -24,11 +24,11 @@ #ifndef TOPI_GENERIC_INJECTIVE_H_ #define TOPI_GENERIC_INJECTIVE_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -57,7 +57,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out * * \return A schedule for the given ops. */ -inline Schedule schedule_injective(const Target &target, const Array& outs) { +inline Schedule schedule_injective(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 7569bb0..7fbe7eb 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -24,12 +24,12 @@ #ifndef TOPI_NN_H_ #define TOPI_NN_H_ -#include #include +#include #include +#include #include #include -#include #include #include @@ -62,43 +62,38 @@ tvm::PrimExpr Map(const tvm::Array& exprs, T op) { * \return A Tensor whose op member is the relu operation */ template -inline tvm::te::Tensor relu(const tvm::te::Tensor& t, - T threshold = static_cast(0), - std::string name = "T_relu", - std::string tag = kElementWise) { +inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast(0), + std::string name = "T_relu", std::string tag = kElementWise) { return tvm::te::compute( t->shape, [&](const tvm::Array& i) { auto threshold_const = tvm::tir::make_const(t->dtype, threshold); return tvm::max(t(i), threshold_const); }, - name, - tag); + name, tag); } /*! -* \brief Creates an operation that performs a leaky rectified linear unit -* -* \param t The input tensor -* \param alpha The slope for the small gradient when t < 0 -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the leaky relu operation -*/ -inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, - double alpha = 0.1, - std::string name = "T_leaky_relu", - std::string tag = kElementWise) { + * \brief Creates an operation that performs a leaky rectified linear unit + * + * \param t The input tensor + * \param alpha The slope for the small gradient when t < 0 + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the leaky relu operation + */ +inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, + std::string name = "T_leaky_relu", + std::string tag = kElementWise) { return tvm::te::compute( - t->shape, - [&](const tvm::Array& i) { - auto value = t(i); - auto calpha = tvm::tir::make_const(value.dtype(), alpha); - return tvm::tir::SelectNode::make(value > 0, value, value * calpha); - }, - name, - tag); + t->shape, + [&](const tvm::Array& i) { + auto value = t(i); + auto calpha = tvm::tir::make_const(value.dtype(), alpha); + return tvm::tir::SelectNode::make(value > 0, value, value * calpha); + }, + name, tag); } /*! @@ -112,27 +107,20 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, * * \return A Tensor whose op member is the parametric relu operation */ -inline tvm::te::Tensor prelu(const tvm::te::Tensor &x, - const tvm::te::Tensor &slope, - const int axis = 1, - std::string name = "T_prelu", - std::string tag = kBroadcast) { - CHECK((size_t)axis < x->shape.size()) << - "Wrong axis (" << axis << ")value. "; - CHECK(topi::detail::GetConstInt(slope->shape[0]) == - topi::detail::GetConstInt(x->shape[axis])) - << "Wrong slope shape received."; +inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope, + const int axis = 1, std::string name = "T_prelu", + std::string tag = kBroadcast) { + CHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. "; + CHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis])) + << "Wrong slope shape received."; - return tvm::te::compute(x->shape, - [&](const tvm::Array &indices) { - auto xval = x(indices); - return tvm::tir::SelectNode::make( - xval > 0, - xval, - xval * slope(indices[axis])); - }, - name, - tag); + return tvm::te::compute( + x->shape, + [&](const tvm::Array& indices) { + auto xval = x(indices); + return tvm::tir::SelectNode::make(xval > 0, xval, xval * slope(indices[axis])); + }, + name, tag); } /*! @@ -172,13 +160,10 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor &x, * * */ -inline tvm::te::Tensor pad(const tvm::te::Tensor& t, - const tvm::Array& pad_before, - tvm::Array pad_after = tvm::Array(), - PrimExpr pad_value = PrimExpr(), - std::string name = "T_pad", - std::string tag = kElementWise, - std::string pad_mode = "constant") { +inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array& pad_before, + tvm::Array pad_after = tvm::Array(), + PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", + std::string tag = kElementWise, std::string pad_mode = "constant") { if (pad_after.size() < pad_before.size()) { for (size_t i = pad_after.size(); i < pad_before.size(); ++i) { pad_after.push_back(pad_before[i]); @@ -190,10 +175,10 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, tvm::Array output_shape; tvm::Array pad_before_int32; tvm::Array pad_after_int32; - for (const auto &ele : pad_before) { + for (const auto& ele : pad_before) { pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); } - for (const auto &ele : pad_after) { + for (const auto& ele : pad_after) { pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); } for (size_t i = 0; i < t->shape.size(); ++i) { @@ -228,28 +213,23 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); } if (pad_mode == "edge") { - pad_idx.push_back(tvm::if_then_else( - ovars[i] < pad_before[i], - 0, - tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], - t->shape[i] - 1, - ovars[i] - pad_before[i]))); + pad_idx.push_back( + tvm::if_then_else(ovars[i] < pad_before[i], 0, + tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], + t->shape[i] - 1, ovars[i] - pad_before[i]))); } else if (pad_mode == "reflect") { - pad_idx.push_back(tvm::if_then_else( - ovars[i] < pad_before[i], - pad_before[i] - ovars[i], - tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], - t->shape[i] * 2 - ovars[i] + pad_before[i] - 2, - ovars[i] - pad_before[i]))); + pad_idx.push_back( + tvm::if_then_else(ovars[i] < pad_before[i], pad_before[i] - ovars[i], + tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], + t->shape[i] * 2 - ovars[i] + pad_before[i] - 2, + ovars[i] - pad_before[i]))); } } if (sel.size() != 0) { if (pad_mode == "constant") { - return tvm::if_then_else( - detail::Map(sel, tvm::tir::AndNode::make), t(indices), pad_value); + return tvm::if_then_else(detail::Map(sel, tvm::tir::AndNode::make), t(indices), pad_value); } else if (pad_mode == "edge" || pad_mode == "reflect") { - return tvm::if_then_else( - detail::Map(sel, tvm::tir::AndNode::make), t(indices), t(pad_idx)); + return tvm::if_then_else(detail::Map(sel, tvm::tir::AndNode::make), t(indices), t(pad_idx)); } } return t(indices); @@ -277,34 +257,27 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, * \return A Tensor whose op member is the 2-D convolution operation (NCHW * layout) */ -inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_conv2d_nchw", - std::string tag = kConv2dNCHW) { +inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, + std::string name = "T_conv2d_nchw", + std::string tag = kConv2dNCHW) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; tvm::Array output_shape{ - I->shape[0], // B - W->shape[0], // O - indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H - indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W + I->shape[0], // B + W->shape[0], // O + indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H + indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); - auto T = (pad_h == 0 && pad_w == 0) - ? I - : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); + auto T = + (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { - return tvm::sum( - T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), - {i, kh, kw}); + return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } @@ -328,14 +301,10 @@ inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, * \return A Tensor whose op member is the 2-D convolution operation * (HWCN layout) */ -inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_conv2d_hwcn", - std::string tag = kConv2dHWCN) { +inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, + std::string name = "T_conv2d_hwcn", + std::string tag = kConv2dHWCN) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; @@ -343,22 +312,19 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, tvm::Array output_shape{ indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W - I->shape[2], // B - W->shape[3] // O + I->shape[2], // B + W->shape[3] // O }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w}); auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { - return tvm::sum( - T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), - {i, kh, kw}); + return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } - /*! * \brief Creates an operation that performs a 2-D depthwise convolution with * an NCHW-layout @@ -379,67 +345,59 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, * \return A Tensor whose op member is the 2-D depthwise convolution operation * (NCHW layout) */ -inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_depthwise_conv2d_nchw", - std::string tag = kDepthwiseConv2dNCHW) { +inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, + int stride_w = 1, + std::string name = "T_depthwise_conv2d_nchw", + std::string tag = kDepthwiseConv2dNCHW) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; auto pCM = W->shape[1]; // channel_multiplier tvm::Array output_shape{ - I->shape[0], // B - W->shape[1], // O + I->shape[0], // B + W->shape[1], // O indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); - auto T = (pad_h == 0 && pad_w == 0) - ? I - : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); + auto T = + (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) * - W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), + W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } -inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_depthwise_conv2d_nhwc", - std::string tag = kDepthwiseConv2dNHWC) { +inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, + int stride_w = 1, + std::string name = "T_depthwise_conv2d_nhwc", + std::string tag = kDepthwiseConv2dNHWC) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[1]; auto pW = I->shape[2]; auto pCM = W->shape[1]; // channel_multiplier tvm::Array output_shape{ - I->shape[0], // B + I->shape[0], // B indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H - indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W - W->shape[3], // O + indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W + W->shape[3], // O }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); - auto T = (pad_h == 0 && pad_w == 0) - ? I - : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); + auto T = + (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) { return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) * - W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), + W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), {kh, kw, i}); }; return tvm::te::compute(output_shape, l, name, tag); @@ -465,22 +423,19 @@ inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, * \return A Tensor whose op member is the 2-D groupconvolution operation * (NCHW layout) */ -inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_group_conv2d_ngchw", - std::string tag = kGroupConv2d) { +inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, + int stride_w = 1, + std::string name = "T_group_conv2d_ngchw", + std::string tag = kGroupConv2d) { CHECK_EQ(5, I->shape.size()); CHECK_EQ(5, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; tvm::Array output_shape{ - I->shape[0], // B - I->shape[1], // G - W->shape[2], // O + I->shape[0], // B + I->shape[1], // G + W->shape[2], // O indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W }; @@ -497,9 +452,8 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, tvm::tir::Var o = args[2]; tvm::tir::Var h = args[3]; tvm::tir::Var w = args[4]; - return tvm::sum( - I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw), - {i, kh, kw}); + return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw), + {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } diff --git a/topi/include/topi/nn/batch_matmul.h b/topi/include/topi/nn/batch_matmul.h index 12075e6..80525c4 100644 --- a/topi/include/topi/nn/batch_matmul.h +++ b/topi/include/topi/nn/batch_matmul.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_BATCH_MATMUL_H_ #define TOPI_NN_BATCH_MATMUL_H_ -#include #include +#include #include @@ -35,15 +35,14 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Creates an operation that calculates matrix multiplication in batch. -* -* \param x Tensor with shape [batch, M, K] -* \param y Tensor with shape [batch, N, K] -* -* \return Tensor with shape [batch, M, N] -*/ -inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, - const tvm::te::Tensor& y) { + * \brief Creates an operation that calculates matrix multiplication in batch. + * + * \param x Tensor with shape [batch, M, K] + * \param y Tensor with shape [batch, N, K] + * + * \return Tensor with shape [batch, M, N] + */ +inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, const tvm::te::Tensor& y) { CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data"; CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data"; @@ -54,10 +53,8 @@ inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, auto k = tvm::te::reduce_axis(Range(0, K), "k"); auto result = tvm::te::compute( - { batch, M, N }, - [&](Var b, Var i, Var j) { - return tvm::sum(x(b, i, k) * y(b, j, k), { k }); - }, "tensor", "batch_matmul"); + {batch, M, N}, [&](Var b, Var i, Var j) { return tvm::sum(x(b, i, k) * y(b, j, k), {k}); }, + "tensor", "batch_matmul"); return result; } diff --git a/topi/include/topi/nn/bias_add.h b/topi/include/topi/nn/bias_add.h index 209c30c..18e95de 100644 --- a/topi/include/topi/nn/bias_add.h +++ b/topi/include/topi/nn/bias_add.h @@ -24,10 +24,10 @@ #ifndef TOPI_NN_BIAS_ADD_H_ #define TOPI_NN_BIAS_ADD_H_ -#include -#include #include +#include #include +#include #include @@ -35,16 +35,15 @@ namespace topi { namespace nn { /*! -* \brief Creates an operation that calculates data + bias -* -* \param data Tensor with shape [batch, in_dim] -* \param bias Tensor with shape [batch]. -* \param axis The axis to add the bias to. -* \return Tensor with shape [batch, in_dim] -*/ -inline tvm::te::Tensor bias_add(const tvm::te::Tensor& data, - const tvm::te::Tensor& bias, - int axis) { + * \brief Creates an operation that calculates data + bias + * + * \param data Tensor with shape [batch, in_dim] + * \param bias Tensor with shape [batch]. + * \param axis The axis to add the bias to. + * \return Tensor with shape [batch, in_dim] + */ +inline tvm::te::Tensor bias_add(const tvm::te::Tensor& data, const tvm::te::Tensor& bias, + int axis) { int data_ndim = data->shape.size(); if (axis < 0) { axis += data_ndim; diff --git a/topi/include/topi/nn/bnn.h b/topi/include/topi/nn/bnn.h index c69fc54..c0626cd 100644 --- a/topi/include/topi/nn/bnn.h +++ b/topi/include/topi/nn/bnn.h @@ -24,10 +24,10 @@ #ifndef TOPI_NN_BNN_H_ #define TOPI_NN_BNN_H_ -#include -#include -#include #include +#include +#include +#include #include @@ -37,71 +37,67 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Binarization and bit-packing along a certain axis. -* -* \param data N-D tensor, can be any layout -* \param axis The axis along which to do binarization and bit-packing. This axis -* must have a size equal to an integer multiple of 32. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return Output tensor with dtype uint32 -*/ -inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, - int axis, - std::string name = "PackedInput", - std::string tag = "binarize_pack") { + * \brief Binarization and bit-packing along a certain axis. + * + * \param data N-D tensor, can be any layout + * \param axis The axis along which to do binarization and bit-packing. This axis + * must have a size equal to an integer multiple of 32. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return Output tensor with dtype uint32 + */ +inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, + std::string name = "PackedInput", + std::string tag = "binarize_pack") { auto ishape = data->shape; CHECK_EQ(GetConstInt(ishape[axis]) % 32, 0) - << "binarize_pack: axis size must be a multiple of 32"; + << "binarize_pack: axis size must be a multiple of 32"; arith::Analyzer analyzer; auto n = ishape.size(); Array oshape; for (size_t i = 0; i < n; ++i) { - oshape.push_back(i == static_cast(axis) ? - analyzer.Simplify(indexdiv(ishape[i], 32)) : - ishape[i]); + oshape.push_back(i == static_cast(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32)) + : ishape[i]); } return tvm::te::compute( - oshape, - [&](const Array& indices) { - Array start_idx; - for (size_t i = 0; i < n; ++i) { - start_idx.push_back(i == static_cast(axis) ? - indices[i] * 32 : - static_cast(indices[i])); - } - auto packed = make_const(DataType::UInt(32), 0); - for (size_t j = 0; j < 32; ++j) { - Array idx; + oshape, + [&](const Array& indices) { + Array start_idx; for (size_t i = 0; i < n; ++i) { - idx.push_back(i == static_cast(axis) ? - start_idx[i] + static_cast(j) : - start_idx[i]); + start_idx.push_back(i == static_cast(axis) ? indices[i] * 32 + : static_cast(indices[i])); } - auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0); - packed = (packed | sign); - if (j == 31) { - return packed; + auto packed = make_const(DataType::UInt(32), 0); + for (size_t j = 0; j < 32; ++j) { + Array idx; + for (size_t i = 0; i < n; ++i) { + idx.push_back(i == static_cast(axis) ? start_idx[i] + static_cast(j) + : start_idx[i]); + } + auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0); + packed = (packed | sign); + if (j == 31) { + return packed; + } + packed = packed << 1; } - packed = packed << 1; - } - return packed; // never reached, but suppress compiler warning - }, name, tag); + return packed; // never reached, but suppress compiler warning + }, + name, tag); } /*! -* \brief Binary matrix multiplication using xor and bit-count -* -* \param data Tensor with shape [batch, in_dim], dtype is uint32 -* \param weight Tensor with shape [out_dim, in_dim], dtype is uint32 -* -* \return Tensor with shape [batch, out_dim], dtype is float32 -*/ -inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, - const tvm::te::Tensor& weight) { + * \brief Binary matrix multiplication using xor and bit-count + * + * \param data Tensor with shape [batch, in_dim], dtype is uint32 + * \param weight Tensor with shape [out_dim, in_dim], dtype is uint32 + * + * \return Tensor with shape [batch, out_dim], dtype is float32 + */ +inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) { CHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight"; CHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data"; @@ -113,16 +109,13 @@ inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, auto k = tvm::te::reduce_axis(Range(0, in_dim), "k"); auto matmul = tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return tvm::sum(popcount(data(i, k) ^ weight(j, k)), { k }); - }, "tensor", "binary_dense"); + {batch, out_dim}, + [&](Var i, Var j) { return tvm::sum(popcount(data(i, k) ^ weight(j, k)), {k}); }, "tensor", + "binary_dense"); return tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return 32 * in_dim - 2.0f * matmul(i, j); - }, "tensor", kElementWise); + {batch, out_dim}, [&](Var i, Var j) { return 32 * in_dim - 2.0f * matmul(i, j); }, "tensor", + kElementWise); } } // namespace nn diff --git a/topi/include/topi/nn/dense.h b/topi/include/topi/nn/dense.h index 57f071a..4ee36c2 100644 --- a/topi/include/topi/nn/dense.h +++ b/topi/include/topi/nn/dense.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_DENSE_H_ #define TOPI_NN_DENSE_H_ -#include #include +#include #include @@ -35,19 +35,17 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Creates an operation that calculates data * weight^T + bias -* -* \param data Tensor with shape [batch, in_dim] -* \param weight Tensor with shape [out_dim, in_dim] -* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() -* \param out_dtype Output data type. Used for mixed precision. -* -* \return Tensor with shape [batch, out_dim] -*/ -inline tvm::te::Tensor dense(const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { + * \brief Creates an operation that calculates data * weight^T + bias + * + * \param data Tensor with shape [batch, in_dim] + * \param weight Tensor with shape [out_dim, in_dim] + * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() + * \param out_dtype Output data type. Used for mixed precision. + * + * \return Tensor with shape [batch, out_dim] + */ +inline tvm::te::Tensor dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight, + const tvm::te::Tensor& bias, const DataType& out_dtype) { CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { @@ -60,18 +58,17 @@ inline tvm::te::Tensor dense(const tvm::te::Tensor& data, auto k = tvm::te::reduce_axis(Range(0, in_dim), "k"); auto matmul = tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return tvm::sum(tvm::cast(out_dtype, data(i, k)) * - tvm::cast(out_dtype, weight(j, k)), { k }); - }, "tensor", "dense"); + {batch, out_dim}, + [&](Var i, Var j) { + return tvm::sum(tvm::cast(out_dtype, data(i, k)) * tvm::cast(out_dtype, weight(j, k)), {k}); + }, + "tensor", "dense"); if (bias.defined()) { matmul = tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return matmul(i, j) + tvm::cast(out_dtype, bias(j)); - }, "tensor", kBroadcast); + {batch, out_dim}, + [&](Var i, Var j) { return matmul(i, j) + tvm::cast(out_dtype, bias(j)); }, "tensor", + kBroadcast); } return matmul; diff --git a/topi/include/topi/nn/dilate.h b/topi/include/topi/nn/dilate.h index 32ee139..0d3ab89 100644 --- a/topi/include/topi/nn/dilate.h +++ b/topi/include/topi/nn/dilate.h @@ -24,9 +24,9 @@ #ifndef TOPI_NN_DILATE_H_ #define TOPI_NN_DILATE_H_ -#include -#include #include +#include +#include #include @@ -36,13 +36,13 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Create a new expression of the logical and of all -* conditions in the arguments. -* -* \param args The arguments to find the logical conjunction of -* -* \return The logical conjunction expression -*/ + * \brief Create a new expression of the logical and of all + * conditions in the arguments. + * + * \param args The arguments to find the logical conjunction of + * + * \return The logical conjunction expression + */ PrimExpr all(Array args) { CHECK_GT(args.size(), 0) << "all requires at least one argument"; @@ -54,53 +54,50 @@ PrimExpr all(Array args) { } /*! -* \brief Dilate data with zeros -* -* \param x The input tensor, this can have any number of -* dimensions and any layout. -* \param strides Dilation stride for each dimension. Stride 1 -* means no dilation. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return The output tensor. -*/ -inline Tensor dilate(const Tensor& x, - Array strides, - std::string name = "tensor", + * \brief Dilate data with zeros + * + * \param x The input tensor, this can have any number of + * dimensions and any layout. + * \param strides Dilation stride for each dimension. Stride 1 + * means no dilation. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The output tensor. + */ +inline Tensor dilate(const Tensor& x, Array strides, std::string name = "tensor", std::string tag = kInjective) { auto n = x->shape.size(); - CHECK_EQ(n, strides.size()) - << "strides size (" << strides.size() - << ") must match dimension of x (" << n << ")"; + CHECK_EQ(n, strides.size()) << "strides size (" << strides.size() + << ") must match dimension of x (" << n << ")"; Array out_shape; arith::Analyzer analyzer; for (size_t i = 0; i < n; ++i) { - out_shape.push_back(analyzer.Simplify( - (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); + out_shape.push_back( + analyzer.Simplify((x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); } return tvm::te::compute( - out_shape, - [&](const Array& indices) { - Array not_zero; - Array index_tuple; - for (size_t i = 0; i < n; ++i) { - if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { - index_tuple.push_back(indices[i]); - } else { - index_tuple.push_back(indexdiv(indices[i], strides[i])); - not_zero.push_back((indexmod(indices[i], strides[i])) == 0); + out_shape, + [&](const Array& indices) { + Array not_zero; + Array index_tuple; + for (size_t i = 0; i < n; ++i) { + if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { + index_tuple.push_back(indices[i]); + } else { + index_tuple.push_back(indexdiv(indices[i], strides[i])); + not_zero.push_back((indexmod(indices[i], strides[i])) == 0); + } + } + if (not_zero.size() > 0) { + auto all_not_zero = all(not_zero); + return tvm::if_then_else(all_not_zero, x(index_tuple), make_const(x->dtype, 0)); } - } - if (not_zero.size() > 0) { - auto all_not_zero = all(not_zero); - return tvm::if_then_else( - all_not_zero, x(index_tuple), make_const(x->dtype, 0)); - } - return x(index_tuple); - }, name, tag); + return x(index_tuple); + }, + name, tag); } } // namespace nn diff --git a/topi/include/topi/nn/flatten.h b/topi/include/topi/nn/flatten.h index 81cef2e..1ac5de4 100644 --- a/topi/include/topi/nn/flatten.h +++ b/topi/include/topi/nn/flatten.h @@ -24,9 +24,9 @@ #ifndef TOPI_NN_FLATTEN_H_ #define TOPI_NN_FLATTEN_H_ -#include -#include #include +#include +#include #include #include @@ -37,25 +37,23 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. -* This requires the input tensor to have constant sized dimensions. -* -* \param x The input tensor. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A 2-D tensor. -*/ -inline Tensor flatten(const Tensor& x, - std::string name = "tensor", - std::string tag = kInjective) { + * \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. + * This requires the input tensor to have constant sized dimensions. + * + * \param x The input tensor. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A 2-D tensor. + */ +inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string tag = kInjective) { auto ishape = x->shape; PrimExpr dim = 1; for (size_t i = 1; i < ishape.size(); ++i) { dim = dim * ishape[i]; } - Array oshape({ ishape[0], dim }); + Array oshape({ishape[0], dim}); std::vector extra_shape; for (size_t i = 1; i < ishape.size(); ++i) { @@ -64,17 +62,19 @@ inline Tensor flatten(const Tensor& x, std::reverse(extra_shape.begin(), extra_shape.end()); return tvm::te::compute( - oshape, [&](Var i, Var j) { - PrimExpr idx = j; - std::vector index; - for (auto s : extra_shape) { - index.push_back(indexmod(idx, s)); - idx = indexdiv(idx, s); - } - index.push_back(i); - std::reverse(index.begin(), index.end()); - return x(index); - }, name, tag); + oshape, + [&](Var i, Var j) { + PrimExpr idx = j; + std::vector index; + for (auto s : extra_shape) { + index.push_back(indexmod(idx, s)); + idx = indexdiv(idx, s); + } + index.push_back(i); + std::reverse(index.begin(), index.end()); + return x(index); + }, + name, tag); } } // namespace nn diff --git a/topi/include/topi/nn/local_response_norm.h b/topi/include/topi/nn/local_response_norm.h index 14dec39..4e8dfd9 100644 --- a/topi/include/topi/nn/local_response_norm.h +++ b/topi/include/topi/nn/local_response_norm.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_LOCAL_RESPONSE_NORM_H_ #define TOPI_NN_LOCAL_RESPONSE_NORM_H_ -#include #include +#include #include @@ -35,60 +35,45 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Local response normalization inference operator -* -* \param data The input tensor. 4-D shape NCHW or NHWC -* \param size Integer to define normalisation window size -* \param axis Input data layout channel axis -* \param alpha Float scaling factor -* \param beta Exponent value -* \param bias Offset to avoid dividing by zero -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the Local response normalization operation -*/ -inline Tensor lrn(const Tensor& data, - int size, - int axis = 1, - float alpha = 0.0001, - float beta = 0.75, - float bias = 2, - std::string name = "tensor", + * \brief Local response normalization inference operator + * + * \param data The input tensor. 4-D shape NCHW or NHWC + * \param size Integer to define normalisation window size + * \param axis Input data layout channel axis + * \param alpha Float scaling factor + * \param beta Exponent value + * \param bias Offset to avoid dividing by zero + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the Local response normalization operation + */ +inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.0001, + float beta = 0.75, float bias = 2, std::string name = "tensor", std::string tag = kBroadcast) { CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; CHECK_EQ(size % 2, 1) << "size should be odd number"; CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; auto input_shape = data->shape; - Array pad_before{ 0, 0, 0, 0}; - Array pad_after{ 0, 0, 0, 0}; - pad_before.Set(axis, static_cast(size/2)); - pad_after.Set(axis, static_cast(size/2)); + Array pad_before{0, 0, 0, 0}; + Array pad_after{0, 0, 0, 0}; + pad_before.Set(axis, static_cast(size / 2)); + pad_after.Set(axis, static_cast(size / 2)); auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data"); auto rxs = tvm::te::reduce_axis(Range(0, size), "rxs"); Tensor sqr_sum; if (axis == 1) { - sqr_sum = tvm::te::compute(input_shape, - [&](Var i, Var l, Var j, Var k) { - return tvm::sum(pad_data(i, l + rxs, j, k) * - pad_data(i, l + rxs, j, k), - {rxs}); - }); + sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) { + return tvm::sum(pad_data(i, l + rxs, j, k) * pad_data(i, l + rxs, j, k), {rxs}); + }); } else if (axis == 3) { - sqr_sum = tvm::te::compute(input_shape, - [&](Var i, Var l, Var j, Var k) { - return tvm::sum(pad_data(i, l, j, k + rxs) * - pad_data(i, l, j, k + rxs), - {rxs}); - }); + sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) { + return tvm::sum(pad_data(i, l, j, k + rxs) * pad_data(i, l, j, k + rxs), {rxs}); + }); } - auto sqrt_sum_up = tvm::te::compute( - input_shape, - [&](Var i, Var j, Var k, Var l) { - return tvm::pow(bias + - (div(alpha * sqr_sum(i, j, k, l), size)), - beta); - }); + auto sqrt_sum_up = tvm::te::compute(input_shape, [&](Var i, Var j, Var k, Var l) { + return tvm::pow(bias + (div(alpha * sqr_sum(i, j, k, l), size)), beta); + }); return topi::divide(data, sqrt_sum_up); } } // namespace nn diff --git a/topi/include/topi/nn/mapping.h b/topi/include/topi/nn/mapping.h index 17d1404..d4a3a47 100644 --- a/topi/include/topi/nn/mapping.h +++ b/topi/include/topi/nn/mapping.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_MAPPING_H_ #define TOPI_NN_MAPPING_H_ -#include #include +#include #include @@ -35,49 +35,39 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Scale and shift with NCHW order -* -* \param x The input tensor. -* \param scale Scale tensor, 1-D of size channel -* \param shift Shift tensor, 1-D of size channel -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the scale shift operation -*/ -inline Tensor scale_shift_nchw(const Tensor& x, - const Tensor& scale, - const Tensor& shift, - std::string name = "ScaleShift", - std::string tag = kBroadcast) { + * \brief Scale and shift with NCHW order + * + * \param x The input tensor. + * \param scale Scale tensor, 1-D of size channel + * \param shift Shift tensor, 1-D of size channel + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the scale shift operation + */ +inline Tensor scale_shift_nchw(const Tensor& x, const Tensor& scale, const Tensor& shift, + std::string name = "ScaleShift", std::string tag = kBroadcast) { return tvm::te::compute( - x->shape, - [&](Var b, Var c, Var h, Var w) { - return x(b, c, h, w) * scale(c) + shift(w); - }, name, tag); + x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(w); }, + name, tag); } /*! -* \brief Scale and shift with NHWC order -* -* \param x The input tensor. -* \param scale Scale tensor, 1-D of size channel -* \param shift Shift tensor, 1-D of size channel -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the scale shift operation -*/ -inline Tensor scale_shift_nhwc(const Tensor& x, - const Tensor& scale, - const Tensor& shift, - std::string name = "ScaleShift", - std::string tag = kBroadcast) { + * \brief Scale and shift with NHWC order + * + * \param x The input tensor. + * \param scale Scale tensor, 1-D of size channel + * \param shift Shift tensor, 1-D of size channel + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the scale shift operation + */ +inline Tensor scale_shift_nhwc(const Tensor& x, const Tensor& scale, const Tensor& shift, + std::string name = "ScaleShift", std::string tag = kBroadcast) { return tvm::te::compute( - x->shape, - [&](Var b, Var h, Var w, Var c) { - return x(b, h, w, c) * scale(c) + shift(w); - }, name, tag); + x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(w); }, + name, tag); } } // namespace nn diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 324ecad..ffc4f98 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -45,31 +45,25 @@ enum PoolType : int { kMaxPool, }; - /*! -* \brief Perform pooling on height and width dimension of data. -* -* \param x The input tensor -* \param kernel_size Vector of two ints: {kernel_height, kernel_width} -* \param stride_size Vector of two ints: {stride_height, stride_width} -* \param padding_size Vector of two ints: {padding_height, padding_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param height_axis index of the height dimension -* \param width_axis index of the width dimension -* \param count_include_pad Whether include padding in the calculation -* -* \return The output tensor in same layout order -*/ -inline Tensor pool_impl(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const size_t height_axis, - const size_t width_axis, - bool count_include_pad) { + * \brief Perform pooling on height and width dimension of data. + * + * \param x The input tensor + * \param kernel_size Vector of two ints: {kernel_height, kernel_width} + * \param stride_size Vector of two ints: {stride_height, stride_width} + * \param padding_size Vector of two ints: {padding_height, padding_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param height_axis index of the height dimension + * \param width_axis index of the width dimension + * \param count_include_pad Whether include padding in the calculation + * + * \return The output tensor in same layout order + */ +inline Tensor pool_impl(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const size_t height_axis, + const size_t width_axis, bool count_include_pad) { CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; @@ -103,10 +97,10 @@ inline Tensor pool_impl(const Tensor& x, pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); arith::Analyzer analyzer; - auto out_height = analyzer.Simplify( - indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); - auto out_width = analyzer.Simplify( - indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); + auto out_height = + analyzer.Simplify(indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); + auto out_width = + analyzer.Simplify(indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); @@ -115,69 +109,72 @@ inline Tensor pool_impl(const Tensor& x, out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); - const int64_t *padding_h0 = as_const_int(pad_top); - const int64_t *padding_w0 = as_const_int(pad_left); - const int64_t *padding_h1 = as_const_int(pad_bottom); - const int64_t *padding_w1 = as_const_int(pad_right); + const int64_t* padding_h0 = as_const_int(pad_top); + const int64_t* padding_w0 = as_const_int(pad_left); + const int64_t* padding_h1 = as_const_int(pad_bottom); + const int64_t* padding_w1 = as_const_int(pad_right); const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) || ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); if (pool_type == kMaxPool) { - auto temp = do_pad ? pad( - x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - indices.Set(height_axis, output[height_axis] * stride_height + dheight); - indices.Set(width_axis, output[width_axis] * stride_width + dwidth); - return tvm::max(temp(indices), { dheight, dwidth }); - }, "tensor", "pool_max"); + auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + indices.Set(height_axis, output[height_axis] * stride_height + dheight); + indices.Set(width_axis, output[width_axis] * stride_width + dwidth); + return tvm::max(temp(indices), {dheight, dwidth}); + }, + "tensor", "pool_max"); } else if (pool_type == kAvgPool) { // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; // TVM compute for summing the pooling window. - auto pool_sum = tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - indices.Set(height_axis, output[height_axis] * stride_height + dheight); - indices.Set(width_axis, output[width_axis] * stride_width + dwidth); - return tvm::sum(temp(indices), { dheight, dwidth }); - }, "tensor", "pool_sum"); + auto pool_sum = tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + indices.Set(height_axis, output[height_axis] * stride_height + dheight); + indices.Set(width_axis, output[width_axis] * stride_width + dwidth); + return tvm::sum(temp(indices), {dheight, dwidth}); + }, + "tensor", "pool_sum"); // TVM compute for dividing the reduced window sum by kernel size. - return tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - if (count_include_pad) { - return div(pool_sum(indices), (kernel_height * kernel_width)); - } else { - PrimExpr h_start = output[height_axis] * stride_height - pad_top; - PrimExpr w_start = output[width_axis] * stride_width - pad_left; - PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); - PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); - h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); - w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); - PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), - make_const(DataType::DataType::Int(32), 1)); - return div(pool_sum(indices), divide_factor); - } - }, "tensor", kElementWise); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + return div(pool_sum(indices), (kernel_height * kernel_width)); + } else { + PrimExpr h_start = output[height_axis] * stride_height - pad_top; + PrimExpr w_start = output[width_axis] * stride_width - pad_left; + PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); + PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); + h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); + w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); + PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), + make_const(DataType::DataType::Int(32), 1)); + return div(pool_sum(indices), divide_factor); + } + }, + "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; } } -inline Tensor pool_grad_impl(const Tensor& out_grad, - const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, bool ceil_mode, - const size_t height_axis, const size_t width_axis, +inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, + const Array& kernel_size, const Array& stride_size, + const Array& padding_size, PoolType pool_type, + bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad) { CHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)"; CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; @@ -237,38 +234,35 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom); ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); - auto windowh = tvm::te::reduce_axis( - Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::te::reduce_axis( - Range(0, (kernel_width + stride_width - 1) / stride_width)); + auto windowh = + tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); + auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); auto argmax = MakeArgmaxReducer(); - auto pad_x = do_pad ? pad( - x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - - auto mp_argmax = - tvm::te::compute( - out_shape, - [&](const Array& inds) { - Array window_inds{inds.begin(), inds.end()}; - window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); - window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); - auto idx = detail::RavelIndex(window_inds, ravel_shape); - return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr); - }, - "maxpool_grad_argmax", kCommReduceIdx); + auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + + auto mp_argmax = tvm::te::compute( + out_shape, + [&](const Array& inds) { + Array window_inds{inds.begin(), inds.end()}; + window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); + window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); + auto idx = detail::RavelIndex(window_inds, ravel_shape); + return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr); + }, + "maxpool_grad_argmax", kCommReduceIdx); auto mp_inds = mp_argmax[0]; return tvm::te::compute( x->shape, [&](const Array& inds) { - Array pad_inds {inds.begin(), inds.end()}; + Array pad_inds{inds.begin(), inds.end()}; pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top); pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left); auto idx = detail::RavelIndex(pad_inds, ravel_shape); - Array out_idx {inds.begin(), inds.end()}; + Array out_idx{inds.begin(), inds.end()}; out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); @@ -280,19 +274,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, (pad_inds[width_axis] - kernel_width) / stride_width + 1); return tvm::sum( - tvm::if_then_else(tir::AndNode::make( - tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, - out_idx[width_axis] >= out_idx_lower_w), - mp_inds(out_idx) == idx), + tvm::if_then_else( + tir::AndNode::make(tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, + out_idx[width_axis] >= out_idx_lower_w), + mp_inds(out_idx) == idx), out_grad(out_idx), make_const(x->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_max"); } else if (pool_type == kAvgPool) { - auto windowh = tvm::te::reduce_axis( - Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::te::reduce_axis( - Range(0, (kernel_width + stride_width - 1) / stride_width)); + auto windowh = + tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); + auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); return tvm::te::compute( x->shape, [&](const Array& inds) { @@ -304,12 +297,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh)); out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); - PrimExpr out_idx_lower_h = tir::SelectNode::make( - pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), - (pad_h_idx - kernel_height) / stride_height + 1); - PrimExpr out_idx_lower_w = tir::SelectNode::make( - pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), - (pad_w_idx - kernel_width) / stride_width + 1); + PrimExpr out_idx_lower_h = + tir::SelectNode::make(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), + (pad_h_idx - kernel_height) / stride_height + 1); + PrimExpr out_idx_lower_w = + tir::SelectNode::make(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), + (pad_w_idx - kernel_width) / stride_width + 1); PrimExpr divide_factor; // number of pooled elements if (count_include_pad) { @@ -321,17 +314,16 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); h_start = tir::MaxNode::make(h_start, make_const(DataType::Int(32), 0)); w_start = tir::MaxNode::make(w_start, make_const(DataType::Int(32), 0)); - divide_factor = - tir::MaxNode::make((h_end - h_start) * (w_end - w_start), - make_const(DataType::Int(32), 1)); + divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), + make_const(DataType::Int(32), 1)); } - return tvm::sum(tvm::if_then_else( - tir::AndNode::make( - tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, - out_idx[height_axis] < out_height), - tir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w, - out_idx[width_axis] < out_width)), - out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), + return tvm::sum( + tvm::if_then_else( + tir::AndNode::make(tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, + out_idx[height_axis] < out_height), + tir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w, + out_idx[width_axis] < out_width)), + out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_avg"); @@ -341,15 +333,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, } } -inline bool find_depth_height_width(const std::string& layout, - int* depth_axis, - int* height_axis, +inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis, int* width_axis) { *depth_axis = -1, *height_axis = -1, *width_axis = -1; int curr_idx = 0; for (size_t i = 0; i < layout.size(); ++i) { - if ((layout[i] >= 'A' && layout[i] <= 'Z') || - (layout[i] >= 'a' && layout[i] <= 'z')) { + if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) { if (layout[i] == 'D') { if (*depth_axis != -1) return false; *depth_axis = curr_idx; @@ -370,21 +359,18 @@ inline bool find_depth_height_width(const std::string& layout, return true; } -inline bool find_height_width(const std::string& layout, - int* height_axis, - int* width_axis) { +inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) { int dummy; - CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false); + CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false); if (*height_axis != -1 && *width_axis != -1) { return true; } return false; } -inline bool find_width(const std::string& layout, - int* width_axis) { +inline bool find_width(const std::string& layout, int* width_axis) { int dummy; - CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false); + CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false); if (*width_axis != -1) { return true; } @@ -392,48 +378,42 @@ inline bool find_width(const std::string& layout, } /*! -* \brief Perform pooling on height and width dimension of data. -* It decides the height and width dimension according to the layout string, -* in which 'W' and 'H' means width and height respectively. -* Width and height dimension cannot be split. -* For example, NCHW, NCHW16c, etc. are valid for pool, -* while NCHW16w, NCHW16h are not. -* See \a layout for more information of the layout string convention. -* \param x The input tensor. -* \param kernel_size Vector of two ints: {kernel_height, kernel_width} -* \param stride_size Vector of two ints: {stride_height, stride_width} -* \param padding_size Vector of two ints: {padding_height, padding_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCHW16c can describe a 5-D tensor of -* [batch_size, channel, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `H` and `W`, one can pass `NCHWc` as well. -* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* -* -* \return The output tensor in the same layout -*/ -inline Tensor pool(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::string& layout = "NCHW", + * \brief Perform pooling on height and width dimension of data. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, etc. are valid for pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of two ints: {kernel_height, kernel_width} + * \param stride_size Vector of two ints: {stride_height, stride_width} + * \param padding_size Vector of two ints: {padding_height, padding_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `H` and `W`, one can pass `NCHWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; - CHECK(find_height_width(layout, &height_axis, &width_axis)) - << "Unsupported layout " << layout; - return pool_impl(x, kernel_size, stride_size, padding_size, - pool_type, ceil_mode, height_axis, width_axis, - count_include_pad); + CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; + return pool_impl(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, height_axis, + width_axis, count_include_pad); } /*! @@ -476,34 +456,27 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array& output_size, - PoolType pool_type, - const std::vector& axes) { + * \brief Perform adaptive pooling on N dimensional data + * + * \param x The input tensor + * \param output_size int vector of size in each dimension + * \param pool_type The type of pooling operator + * \param axes indices of each dimension + * + * \return The output tensor in same layout order + */ +inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_size, + PoolType pool_type, const std::vector& axes) { const auto n_dim = output_size.size(); CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension"; @@ -533,32 +506,41 @@ inline Tensor adaptive_pool_impl(const Tensor& x, }; if (pool_type == kMaxPool) { - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, true); - return tvm::max(x(indices), reduce_axes); // NOLINT(*) - }, "tensor", "adaptive_pool_max"); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, true); + return tvm::max(x(indices), reduce_axes); // NOLINT(*) + }, + "tensor", "adaptive_pool_max"); } else if (pool_type == kAvgPool) { - auto pool_sum = tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, true); - return tvm::sum(x(indices), reduce_axes); - }, "tensor", "adaptive_pool_sum"); - - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, false); - - PrimExpr divide_factor = tvm::cast(x->dtype, 1); - for (size_t i = 0; i < n_dim; ++i) { - divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); - } + auto pool_sum = tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, true); + return tvm::sum(x(indices), reduce_axes); + }, + "tensor", "adaptive_pool_sum"); - return div(pool_sum(indices), divide_factor); - }, "tensor", kElementWise); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, false); + + PrimExpr divide_factor = tvm::cast(x->dtype, 1); + for (size_t i = 0; i < n_dim; ++i) { + divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); + } + + return div(pool_sum(indices), divide_factor); + }, + "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; @@ -566,118 +548,107 @@ inline Tensor adaptive_pool_impl(const Tensor& x, } /*! -* \brief Adaptively perform pooling on height and width dimension of data. -* The pooling kernel and stride sizes are automatically chosen for desired output sizes. -* It decides the height and width dimension according to the layout string, -* in which 'W' and 'H' means width and height respectively. -* Width and height dimension cannot be split. -* For example, NCHW, NCHW16c, etc. are valid for pool, -* while NCHW16w, NCHW16h are not. -* See \a layout for more information of the layout string convention. -* -* \param x The input tensor -* \param output_size Vector of two ints: {output_height, output_width} -* \param pool_type The type of pooling operator -* \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCHW16c can describe a 5-D tensor of -* [batch_size, channel, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `H` and `W`, one can pass `NCHWc` as well. -* -* \return The output tensor in same layout order -*/ -inline Tensor adaptive_pool(const Tensor& x, - const Array& output_size, - PoolType pool_type, + * \brief Adaptively perform pooling on height and width dimension of data. + * The pooling kernel and stride sizes are automatically chosen for desired output sizes. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, etc. are valid for pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * + * \param x The input tensor + * \param output_size Vector of two ints: {output_height, output_width} + * \param pool_type The type of pooling operator + * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `H` and `W`, one can pass `NCHWc` as well. + * + * \return The output tensor in same layout order + */ +inline Tensor adaptive_pool(const Tensor& x, const Array& output_size, PoolType pool_type, const std::string& layout = "NCHW") { int height_axis = -1, width_axis = -1; - CHECK(find_height_width(layout, &height_axis, &width_axis)) - << "Unsupported layout " << layout; + CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis}); } /*! -* \brief Adaptively perform pooling on three dimensional data. -* See the two dimensional version above for details. -* \param x The input tensor -* \param output_size Vector of three ints: {output_depth, output_height, output_width} -* \param pool_type The type of pooling operator -* \param layout The input layout. The default is "NCDHW". -*/ -inline Tensor adaptive_pool3d(const Tensor& x, - const Array& output_size, - PoolType pool_type, - const std::string& layout = "NCDHW") { + * \brief Adaptively perform pooling on three dimensional data. + * See the two dimensional version above for details. + * \param x The input tensor + * \param output_size Vector of three ints: {output_depth, output_height, output_width} + * \param pool_type The type of pooling operator + * \param layout The input layout. The default is "NCDHW". + */ +inline Tensor adaptive_pool3d(const Tensor& x, const Array& output_size, + PoolType pool_type, const std::string& layout = "NCDHW") { int depth_axis = -1, height_axis = -1, width_axis = -1; CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) - << "Unsupported layout " << layout; + << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis}); } /*! -* \brief Perform global pooling on height and width dimension of data. -* It decides the height and width dimension according to the layout string, -* in which 'W' and 'H' means width and height respectively. -* Width and height dimension cannot be split. -* For example, NCHW, NCHW16c, ... are valid for global_pool, -* while NCHW16w, NCHW16h are not. -* See \a layout for more information of the layout string convention. -* -* \param x The input tensor represent as layout -* \param pool_type The type of pooling operator -* \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the sub-dimension. -* For example, `NCHW16c` can describe a 5-D tensor of -* [batch_size, channel, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of -* dimensions other than `H` and `W`, one can pass `NCHWc` as well. -* -* \return The output tensor in same layout with height and width dimension size of 1. -* e.g., for NCHW, the output shape will be [batch, channel, 1, 1] -*/ -inline Tensor global_pool(const Tensor& x, - PoolType pool_type, - const std::string& layout = "NCHW") { + * \brief Perform global pooling on height and width dimension of data. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, ... are valid for global_pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * + * \param x The input tensor represent as layout + * \param pool_type The type of pooling operator + * \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the sub-dimension. + * For example, `NCHW16c` can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of + * dimensions other than `H` and `W`, one can pass `NCHWc` as well. + * + * \return The output tensor in same layout with height and width dimension size of 1. + * e.g., for NCHW, the output shape will be [batch, channel, 1, 1] + */ +inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") { return adaptive_pool(x, Array{1, 1}, pool_type, layout); } /*! -* \brief Perform pooling on N-dimension of data. -* -* \param x The input tensor -* \param kernel_size Vector of N ints -* \param stride_size Vector of N ints -* \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ..., -* head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN] -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param axis Vector of indices for the N dimensions -* \param count_include_pad Whether include padding in the calculation -* -* \return The output tensor in same layout order -*/ -inline Tensor pool_impl_nd(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::vector& axis, + * \brief Perform pooling on N-dimension of data. + * + * \param x The input tensor + * \param kernel_size Vector of N ints + * \param stride_size Vector of N ints + * \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ..., + * head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN] + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param axis Vector of indices for the N dimensions + * \param count_include_pad Whether include padding in the calculation + * + * \return The output tensor in same layout order + */ +inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::vector& axis, bool count_include_pad) { int k_size = kernel_size.size(); int x_size = x->shape.size(); CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel"; CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of" - " kernel"; + " kernel"; CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel"; Array daxis; @@ -696,8 +667,8 @@ inline Tensor pool_impl_nd(const Tensor& x, stride[i] = cast(DataType::Int(32), stride_size[i]); pad_head[i] = cast(DataType::Int(32), padding_size[i]); pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]); - const int64_t *padding0 = as_const_int(pad_head[i]); - const int64_t *padding1 = as_const_int(pad_tail[i]); + const int64_t* padding0 = as_const_int(pad_head[i]); + const int64_t* padding1 = as_const_int(pad_tail[i]); do_pad = (do_pad) ? do_pad : ((padding0 && *padding0) || (padding1 && *padding1)); if (ceil_mode) { @@ -713,69 +684,76 @@ inline Tensor pool_impl_nd(const Tensor& x, arith::Analyzer analyzer; auto out_dim = analyzer.Simplify( - indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); + indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); out_shape.Set(ii, out_dim); } if (pool_type == kMaxPool) { - auto temp = do_pad ? pad( - x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - indices.Set(ii, output[ii] * stride[i] + daxis[i]); - } + auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices.Set(ii, output[ii] * stride[i] + daxis[i]); + } - return tvm::max(temp(indices), daxis); - }, "tensor", "pool_max"); + return tvm::max(temp(indices), daxis); + }, + "tensor", "pool_max"); } else if (pool_type == kAvgPool) { // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; // TVM compute for summing the pooling window. - auto pool_sum = tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - indices.Set(ii, output[ii] * stride[i] + daxis[i]); - } - return tvm::sum(temp(indices), daxis); - }, "tensor", "pool_sum"); + auto pool_sum = tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices.Set(ii, output[ii] * stride[i] + daxis[i]); + } + return tvm::sum(temp(indices), daxis); + }, + "tensor", "pool_sum"); // TVM compute for dividing the reduced window sum by kernel size. - return tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - if (count_include_pad) { - auto kernel_size = make_const(DataType::Int(32), 1); - for (int i = 0; i < k_size; i++) { - kernel_size *= kernel[i]; - } - return div(pool_sum(indices), kernel_size); - } else { - std::vector start(k_size); - std::vector end(k_size); - auto kernel_size = make_const(DataType::Int(32), 1); - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - start[i] = output[ii] * stride[i] - pad_head[i]; - end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); - start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); - kernel_size *= (end[i] - start[i]); - } - - PrimExpr divide_factor = tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); - return div(pool_sum(indices), divide_factor); - } - }, "tensor", kElementWise); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + kernel_size *= kernel[i]; + } + return div(pool_sum(indices), kernel_size); + } else { + std::vector start(k_size); + std::vector end(k_size); + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + start[i] = output[ii] * stride[i] - pad_head[i]; + end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); + start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); + kernel_size *= (end[i] - start[i]); + } + + PrimExpr divide_factor = + tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); + return div(pool_sum(indices), divide_factor); + } + }, + "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; @@ -783,94 +761,85 @@ inline Tensor pool_impl_nd(const Tensor& x, } /*! -* \brief Perform pooling on the width dimension of data. -* Width axis is determined by the layout string -* in which 'W' means width. -* Width dimension cannot be split. -* For example, NCW, NCW16c, etc. are valid for pool, -* while NCW16w is not. -* See \a layout for more information of the layout string convention. -* \param x The input tensor. -* \param kernel_size Vector of three ints: {kernel_width} -* \param stride_size Vector of three ints: {stride_width} -* \param padding_size Vector of six ints: {head_pad_width, tail_pad_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param layout The input layout. Pooling supports any layout as long as 'W' appears. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCW16c can describe a 4-D tensor of -* [batch_size, channel, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `W`, one can pass `NCWc` as well. -* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* -* -* \return The output tensor in the same layout -*/ -inline Tensor pool1d(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::string& layout = "NCW", + * \brief Perform pooling on the width dimension of data. + * Width axis is determined by the layout string + * in which 'W' means width. + * Width dimension cannot be split. + * For example, NCW, NCW16c, etc. are valid for pool, + * while NCW16w is not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of three ints: {kernel_width} + * \param stride_size Vector of three ints: {stride_width} + * \param padding_size Vector of six ints: {head_pad_width, tail_pad_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'W' appears. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCW16c can describe a 4-D tensor of + * [batch_size, channel, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `W`, one can pass `NCWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool1d(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW", bool count_include_pad = true) { int width_axis = -1; - CHECK(find_width(layout, &width_axis)) - << "Unsupported layout " << layout; + CHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; std::vector axis = {width_axis}; - return pool_impl_nd(x, kernel_size, stride_size, padding_size, - pool_type, ceil_mode, axis, count_include_pad); + return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis, + count_include_pad); } /*! -* \brief Perform pooling on depth, height and width dimension of data. -* It decides the depth, height and width dimension according to the layout string, -* in which 'D', 'W' and 'H' means depth, width and height respectively. -* Depth, Width and height dimension cannot be split. -* For example, NCDHW, NCDHW16c, etc. are valid for pool, -* while NCDHW16d, NCDHW16w or NCDHW16h are not. -* See \a layout for more information of the layout string convention. -* \param x The input tensor. -* \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width} -* \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width} -* \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width, -* tail_pad_depth, tail_pad_height, tail_pad_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCDHW16c can describe a 6-D tensor of -* [batch_size, channel, depth, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `D`, `H` and `W`, one can pass `NCDHWc` as well. -* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* -* -* \return The output tensor in the same layout -*/ -inline Tensor pool3d(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::string& layout = "NCDHW", + * \brief Perform pooling on depth, height and width dimension of data. + * It decides the depth, height and width dimension according to the layout string, + * in which 'D', 'W' and 'H' means depth, width and height respectively. + * Depth, Width and height dimension cannot be split. + * For example, NCDHW, NCDHW16c, etc. are valid for pool, + * while NCDHW16d, NCDHW16w or NCDHW16h are not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width} + * \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width} + * \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width, + * tail_pad_depth, tail_pad_height, tail_pad_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCDHW16c can describe a 6-D tensor of + * [batch_size, channel, depth, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `D`, `H` and `W`, one can pass `NCDHWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool3d(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW", bool count_include_pad = true) { int depth_axis = -1, height_axis = -1, width_axis = -1; CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) - << "Unsupported layout " << layout; + << "Unsupported layout " << layout; std::vector axis = {depth_axis, height_axis, width_axis}; - return pool_impl_nd(x, kernel_size, stride_size, padding_size, - pool_type, ceil_mode, axis, count_include_pad); + return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis, + count_include_pad); } } // namespace nn diff --git a/topi/include/topi/nn/softmax.h b/topi/include/topi/nn/softmax.h index dc76a9e..9ae9d6a 100644 --- a/topi/include/topi/nn/softmax.h +++ b/topi/include/topi/nn/softmax.h @@ -24,9 +24,9 @@ #ifndef TOPI_NN_SOFTMAX_H_ #define TOPI_NN_SOFTMAX_H_ -#include #include #include +#include #include #include @@ -37,18 +37,16 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Softmax activation -* -* \param x The input tensor. Can be any dimension -* \param axis The channel axis along which softmax is performed -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the softmax operation -*/ -inline Tensor softmax(const Tensor &x, - int axis = -1, - std::string name = "tensor", + * \brief Softmax activation + * + * \param x The input tensor. Can be any dimension + * \param axis The channel axis along which softmax is performed + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the softmax operation + */ +inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor", std::string tag = "softmax_output") { auto input_shape = x->shape; auto ndim = input_shape.size(); @@ -64,8 +62,7 @@ inline Tensor softmax(const Tensor &x, tvm::Map attrs; attrs.Set("axis", Integer(axis)); - auto insert_reduce_index = [axis, ndim](const Array &indices, - const IterVar &reduce_index) { + auto insert_reduce_index = [axis, ndim](const Array& indices, const IterVar& reduce_index) { Array eval_range; int arg_counter = 0; for (size_t i = 0; i < ndim; ++i) { @@ -77,61 +74,54 @@ inline Tensor softmax(const Tensor &x, return eval_range; }; - auto get_non_reduce_indices = [axis, ndim](const Array &indices) { + auto get_non_reduce_indices = [axis, ndim](const Array& indices) { Array non_reduce_indices; for (size_t i = 0; i < ndim; ++i) { - if (static_cast(i) != axis) - non_reduce_indices.push_back(indices[i]); + if (static_cast(i) != axis) non_reduce_indices.push_back(indices[i]); } return non_reduce_indices; }; - auto _compute_max = [&](const Array &indices) { + auto _compute_max = [&](const Array& indices) { auto eval_range = insert_reduce_index(indices, k1); return topi::MaxOp(x(eval_range), {k1}); }; - auto _compute_exp = [&](const Tensor &max_elem, - const Array &indices) { + auto _compute_exp = [&](const Tensor& max_elem, const Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return tvm::exp(x(indices) - max_elem(non_reduce_indices)); }; - auto _compute_expsum = [&](const Tensor &exp, - const Array &indices) { + auto _compute_expsum = [&](const Tensor& exp, const Array& indices) { auto eval_range = insert_reduce_index(indices, k2); return tvm::sum(exp(eval_range), {k2}); }; - auto _normalize = [&](const Tensor &exp, const Tensor &expsum, - const Array &indices) { + auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return exp(indices) / expsum(non_reduce_indices); }; auto max_elem = tvm::te::compute(reduced_shape, _compute_max); - auto exp = tvm::te::compute(input_shape, [&](const Array &indices) { - return _compute_exp(max_elem, indices); - }); - auto expsum = tvm::te::compute(reduced_shape, [&](const Array &indices) { - return _compute_expsum(exp, indices); - }); - return tvm::te::compute(input_shape, [&](const Array &indices) { - return _normalize(exp, expsum, indices); - }, name, tag, attrs); + auto exp = tvm::te::compute( + input_shape, [&](const Array& indices) { return _compute_exp(max_elem, indices); }); + auto expsum = tvm::te::compute( + reduced_shape, [&](const Array& indices) { return _compute_expsum(exp, indices); }); + return tvm::te::compute( + input_shape, [&](const Array& indices) { return _normalize(exp, expsum, indices); }, + name, tag, attrs); } /*! -* \brief Log softmax activation -* -* \param x The input tensor. 2-D where log softmax is performed along the second dimension -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the log softmax operation -*/ -inline Tensor log_softmax(const Tensor& x, - std::string name = "tensor", + * \brief Log softmax activation + * + * \param x The input tensor. 2-D where log softmax is performed along the second dimension + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the log softmax operation + */ +inline Tensor log_softmax(const Tensor& x, std::string name = "tensor", std::string tag = "log_softmax_output") { CHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input"; @@ -139,19 +129,16 @@ inline Tensor log_softmax(const Tensor& x, PrimExpr n = x->shape[1]; auto k = tvm::te::reduce_axis(Range(0, n), "k"); - auto max_elem = tvm::te::compute( - { m }, [&](Var i) { - return tvm::max(x(i, k), Array{ k }); }); + auto max_elem = + tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), Array{k}); }); k = tvm::te::reduce_axis(Range(0, n), "k"); - auto expsum = tvm::te::compute( - { m }, [&](Var i) { - return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), { k }); }); + auto expsum = + tvm::te::compute({m}, [&](Var i) { return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), {k}); }); return tvm::te::compute( - x->shape, [&](Var i, Var j) { - return x(i, j) - max_elem(i) - tvm::log(expsum(i)); - }, name, tag); + x->shape, [&](Var i, Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name, + tag); } } // namespace nn diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index 81c6963..c45bb50 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -24,18 +24,18 @@ #ifndef TOPI_REDUCTION_H_ #define TOPI_REDUCTION_H_ -#include #include +#include +#include #include #include #include -#include -#include +#include #include +#include #include #include -#include namespace topi { using namespace tvm; @@ -45,21 +45,21 @@ using namespace tvm::te; using FReduce = std::function& axis)>; /*! \brief The operation to use for CommReduceIdx */ -using FCommReduce = std::function< - Array(Array exprs, const Array& axis, PrimExpr* condition)>; +using FCommReduce = std::function(Array exprs, const Array& axis, + PrimExpr* condition)>; /*! -* \brief Convert a reduction axis which could be empty or have negative -* elements into a real axis with valid dimension indices. -* -* \param ndim Number of dimensions in the target. -* \param axis The axis parameter. -* -* \return A non-empty sorted array of valid dimension indices, with no duplicates. -* If the input axis is empty, the result will be an axis including all dimensions. -* If any input element is negative, it will be treated as an offset from the -* last dimension (same as python indexing rules). -*/ + * \brief Convert a reduction axis which could be empty or have negative + * elements into a real axis with valid dimension indices. + * + * \param ndim Number of dimensions in the target. + * \param axis The axis parameter. + * + * \return A non-empty sorted array of valid dimension indices, with no duplicates. + * If the input axis is empty, the result will be an axis including all dimensions. + * If any input element is negative, it will be treated as an offset from the + * last dimension (same as python indexing rules). + */ inline std::vector GetRealAxis(int ndim, const Array& axis) { std::vector real_axis; if (!axis.defined() || axis.size() == 0) { @@ -78,8 +78,7 @@ inline std::vector GetRealAxis(int ndim, const Array& axis) { real_axis.push_back(static_cast(val)); } std::sort(real_axis.begin(), real_axis.end()); - real_axis.resize( - std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin()); + real_axis.resize(std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin()); } return real_axis; } @@ -89,17 +88,14 @@ inline Array MakeReduceAxes(const std::vector& real_axis, const Te Array reduce_axes; for (auto i : real_axis) { std::string name = "k" + std::to_string(i); - reduce_axes.push_back( - tvm::te::reduce_axis(Range(0, data->shape[i]), name)); + reduce_axes.push_back(tvm::te::reduce_axis(Range(0, data->shape[i]), name)); } return reduce_axes; } /*! \brief Calculate the target shape for a reduce op */ -inline Array MakeReduceTargetShape(const std::vector& real_axis, - const Tensor& data, - bool keepdims, - bool atleast1d) { +inline Array MakeReduceTargetShape(const std::vector& real_axis, const Tensor& data, + bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); Array target_shape; if (keepdims) { @@ -137,9 +133,7 @@ inline Array MakeReduceTargetShape(const std::vector& real_axis, * * \return The result tensor. */ -inline Tensor DoCommReduce(const Tensor& data, - FReduce func, - const Array& target_shape, +inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array& target_shape, const std::vector& reduce_axes, const std::vector& squeeze_axes) { auto r_axes = MakeReduceAxes(reduce_axes, data); @@ -182,45 +176,39 @@ inline Tensor DoCommReduce(const Tensor& data, * * \return The result tensor. */ -inline Tensor CommReduce(const Tensor& data, - const Array& axis, - FReduce func, - bool keepdims, - bool atleast1d) { +inline Tensor CommReduce(const Tensor& data, const Array& axis, FReduce func, + bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); return DoCommReduce(data, func, target_shape, real_axis, - keepdims ? std::vector() : real_axis); + keepdims ? std::vector() : real_axis); } /*! -* \brief Create an index reduction operation. -* -* \param data The input tensor. -* \param axis The axes along which the reduction is performed. -* \param func The reduction function -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return The result tensor. -*/ -inline Tensor CommReduceIdx(const Tensor& data, - const Array& axis, - FCommReduce func, - bool keepdims, - bool atleast1d) { + * \brief Create an index reduction operation. + * + * \param data The input tensor. + * \param axis The axes along which the reduction is performed. + * \param func The reduction function + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return The result tensor. + */ +inline Tensor CommReduceIdx(const Tensor& data, const Array& axis, FCommReduce func, + bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); - auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data] - (const Array& indices) { + auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, + &data](const Array& indices) { Array eval_range; Array eval_indices; int arg_counter = 0; @@ -247,18 +235,16 @@ inline Tensor CommReduceIdx(const Tensor& data, ravel_shape.push_back(data->shape[i]); } auto idx = detail::RavelIndex(eval_indices, ravel_shape); - return func({ idx, data(eval_range) }, reduce_axes, nullptr); + return func({idx, data(eval_range)}, reduce_axes, nullptr); }; - auto temp_idx_val = tvm::te::compute(target_shape, compute, - data->op->name + "_red_temp", kCommReduceIdx); + auto temp_idx_val = + tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduceIdx); auto temp_idx = temp_idx_val[0]; auto temp_val = temp_idx_val[1]; return tvm::te::compute( - target_shape, - [&temp_idx](const Array& indices) { return temp_idx(indices); }, - data->op->name + "_red", - kCommReduceIdx); + target_shape, [&temp_idx](const Array& indices) { return temp_idx(indices); }, + data->op->name + "_red", kCommReduceIdx); } /*! \brief A combiner function for a reduction */ @@ -276,11 +262,10 @@ using FIdentity = std::function(std::vector types)>; * * \return A reducer function which creates a reduce expression over an axis. */ -inline FCommReduce MakeCommReducer(FCombine fcombine, - FIdentity fidentity, +inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name = "reduce") { - return [fcombine, fidentity, name] - (Array exprs, const Array& axis, PrimExpr* condition) { + return [fcombine, fidentity, name](Array exprs, const Array& axis, + PrimExpr* condition) { Array lhs, rhs; std::vector dtypes; @@ -299,16 +284,14 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, Array outputs; for (size_t i = 0; i < exprs.size(); ++i) { outputs.push_back( - tvm::tir::ReduceNode::make(combiner, exprs, axis, cond, static_cast(i))); + tvm::tir::ReduceNode::make(combiner, exprs, axis, cond, static_cast(i))); } return outputs; }; } /*! \brief Wrap tvm::min to ensure we get the correct overload */ -inline PrimExpr MinOp(PrimExpr source, Array axis) { - return tvm::min(source, axis); -} +inline PrimExpr MinOp(PrimExpr source, Array axis) { return tvm::min(source, axis); } /*! \brief Wrap tvm::max to ensure we get the correct overload */ inline PrimExpr MaxOp(PrimExpr source, Array axis) { @@ -321,21 +304,19 @@ inline PrimExpr ProdOp(PrimExpr source, Array axis) { } /*! -* \brief Creates an operation that sums array elements over a given axis -* -* \param data The input tensor -* \param axis The axis to sum over. If axis is empty, the operation will -* sum over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the sum operation -*/ -inline Tensor sum(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that sums array elements over a given axis + * + * \param data The input tensor + * \param axis The axis to sum over. If axis is empty, the operation will + * sum over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the sum operation + */ +inline Tensor sum(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::sum, keepdims, atleast1d); } @@ -347,8 +328,7 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { std::vector reduce_axes; std::vector squeeze_axes; - for (int i_ax = ishape.size() - 1, - o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) { + for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) { if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) { --o_ax; continue; @@ -369,106 +349,96 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { } /*! -* \brief Creates an operation that computes the logical AND of elements -* over a given axis -* -* \param data The input boolean tensor -* \param axis The axes to reduce. If axis is empty, the operation will -* perform logical AND over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the all operation -*/ -inline Tensor all(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that computes the logical AND of elements + * over a given axis + * + * \param data The input boolean tensor + * \param axis The axes to reduce. If axis is empty, the operation will + * perform logical AND over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the all operation + */ +inline Tensor all(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::all, keepdims, atleast1d); } /*! -* \brief Creates an operation that computes the logical OR of elements -* over a given axis -* -* \param data The input boolean tensor -* \param axis The axes to reduce. If axis is empty, the operation will -* perform logical OR over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the all operation -*/ -inline Tensor any(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that computes the logical OR of elements + * over a given axis + * + * \param data The input boolean tensor + * \param axis The axes to reduce. If axis is empty, the operation will + * perform logical OR over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the all operation + */ +inline Tensor any(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } /*! -* \brief Creates an operation that finds the minimum of elements over -* a given axis. -* -* \param data The input tensor -* \param axis The axis to find the minimum over. If axis is empty, the -* operation will find the minimum over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the min operation -*/ -inline Tensor min(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the minimum of elements over + * a given axis. + * + * \param data The input tensor + * \param axis The axis to find the minimum over. If axis is empty, the + * operation will find the minimum over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the min operation + */ +inline Tensor min(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MinOp, keepdims, atleast1d); } /*! -* \brief Creates an operation that finds the maximum of elements over -* a given axis. -* -* \param data The input tensor -* \param axis The axis to find the maximum over. If axis is empty, the -* operation will find the maximum over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the max operation -*/ -inline Tensor max(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the maximum of elements over + * a given axis. + * + * \param data The input tensor + * \param axis The axis to find the maximum over. If axis is empty, the + * operation will find the maximum over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the max operation + */ +inline Tensor max(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } /*! -* \brief Creates an operation that finds the indices of the minimum -* values over a given axis. -* -* \param data The input tensor -* \param axis The axis along which the argmin is performed. If axis is empty, -* the operation will find the minimum index over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the argmin operation -*/ -inline Tensor argmin(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the indices of the minimum + * values over a given axis. + * + * \param data The input tensor + * \param axis The axis along which the argmin is performed. If axis is empty, + * the operation will find the minimum index over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the argmin operation + */ +inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { auto fcombine = [](Array lhs, Array rhs) { Array result; @@ -479,7 +449,7 @@ inline Tensor argmin(const Tensor& data, auto fidentity = [](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::max_value(types[1])); // val + result.push_back(tvm::max_value(types[1])); // val return result; }; auto func = MakeCommReducer(fcombine, fidentity, "argmin"); @@ -496,50 +466,46 @@ inline FCommReduce MakeArgmaxReducer() { auto fidentity = [](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::min_value(types[1])); // val + result.push_back(tvm::min_value(types[1])); // val return result; }; return MakeCommReducer(fcombine, fidentity, "argmax"); } /*! -* \brief Creates an operation that finds the indices of the maximum -* values over a given axis. -* -* \param data The input tensor -* \param axis The axis along which the argmax is performed. If axis is empty, -* the operation will find the maximum index over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the argmax operation -*/ -inline Tensor argmax(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the indices of the maximum + * values over a given axis. + * + * \param data The input tensor + * \param axis The axis along which the argmax is performed. If axis is empty, + * the operation will find the maximum index over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the argmax operation + */ +inline Tensor argmax(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { auto reducer = MakeArgmaxReducer(); return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } /*! -* \brief Creates product operation over given axis. -* -* \param data The input tensor -* \param axis The axis to do product over. If axis is empty, the -* operation will do the product over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the prod operation -*/ -inline Tensor prod(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates product operation over given axis. + * + * \param data The input tensor + * \param axis The axis to do product over. If axis is empty, the + * operation will do the product over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the prod operation + */ +inline Tensor prod(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, ProdOp, keepdims, atleast1d); } diff --git a/topi/include/topi/rocm/dense.h b/topi/include/topi/rocm/dense.h index 629b34e..72f8ee6 100644 --- a/topi/include/topi/rocm/dense.h +++ b/topi/include/topi/rocm/dense.h @@ -24,14 +24,15 @@ #ifndef TOPI_ROCM_DENSE_H_ #define TOPI_ROCM_DENSE_H_ -#include -#include #include -#include "topi/detail/array_utils.h" -#include "topi/nn/dense.h" +#include +#include + #include "topi/contrib/rocblas.h" -#include "topi/generic/extern.h" #include "topi/cuda/dense.h" +#include "topi/detail/array_utils.h" +#include "topi/generic/extern.h" +#include "topi/nn/dense.h" namespace topi { using namespace tvm; @@ -39,21 +40,19 @@ using namespace tvm::te; namespace rocm { /*! -* \brief Implementation of dense for rocm backend -* -* \param target The target device -* \param data Tensor with shape [batch, in_dim] -* \param weight Tensor with shape [out_dim, in_dim] -* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() -* \param out_dtype Output data type. Used for mixed precision. -* -* \return Tensor with shape [batch, out_dim] -*/ -inline tvm::te::Tensor dense_rocm(const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { + * \brief Implementation of dense for rocm backend + * + * \param target The target device + * \param data Tensor with shape [batch, in_dim] + * \param weight Tensor with shape [out_dim, in_dim] + * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() + * \param out_dtype Output data type. Used for mixed precision. + * + * \return Tensor with shape [batch, out_dim] + */ +inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, + const DataType& out_dtype) { CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { @@ -68,10 +67,8 @@ inline tvm::te::Tensor dense_rocm(const Target& target, CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported."; auto mm = topi::contrib::rocblas_matmul(data, weight, false, true); if (bias.defined()) { - mm = tvm::te::compute({ batch, out_dim }, - [&](Var i, Var j) { - return mm(i, j) + bias(j); - }, "tensor", kBroadcast); + mm = tvm::te::compute( + {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast); } return mm; @@ -81,16 +78,15 @@ inline tvm::te::Tensor dense_rocm(const Target& target, } /*! -* \brief Create a rocm schedule for dense -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_dense(const Target &target, const Array& outs) { - if (target->target_name == "rocm" && - target->libs().count("rocblas")) { + * \brief Create a rocm schedule for dense + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_dense(const Target& target, const Array& outs) { + if (target->target_name == "rocm" && target->libs().count("rocblas")) { return topi::generic::schedule_extern(target, outs); } diff --git a/topi/include/topi/rocm/injective.h b/topi/include/topi/rocm/injective.h index f3a3f3b..e7415bf 100644 --- a/topi/include/topi/rocm/injective.h +++ b/topi/include/topi/rocm/injective.h @@ -24,10 +24,10 @@ #ifndef TOPI_ROCM_INJECTIVE_H_ #define TOPI_ROCM_INJECTIVE_H_ -#include #include -#include +#include #include +#include #include "topi/cuda/injective.h" @@ -57,7 +57,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out * * \return A schedule for the given ops. */ -inline Schedule schedule_injective(const Target &target, const Array& outs) { +inline Schedule schedule_injective(const Target& target, const Array& outs) { return topi::cuda::schedule_injective(target, outs); } diff --git a/topi/include/topi/rocm/normalization.h b/topi/include/topi/rocm/normalization.h index 303f4a8..8328683 100644 --- a/topi/include/topi/rocm/normalization.h +++ b/topi/include/topi/rocm/normalization.h @@ -24,22 +24,20 @@ #ifndef TOPI_ROCM_NORMALIZATION_H_ #define TOPI_ROCM_NORMALIZATION_H_ -#include -#include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::te; namespace rocm { /*! -* \brief Create a rocm schedule for LRN -* \param outs The output tensors. -* \return A schedule for the given ops. -*/ -inline Schedule schedule_lrn(const Array& outs) { - return topi::cuda::schedule_lrn(outs); -} + * \brief Create a rocm schedule for LRN + * \param outs The output tensors. + * \return A schedule for the given ops. + */ +inline Schedule schedule_lrn(const Array& outs) { return topi::cuda::schedule_lrn(outs); } } // namespace rocm } // namespace topi diff --git a/topi/include/topi/rocm/pooling.h b/topi/include/topi/rocm/pooling.h index 7d1f36f..0b68a0a 100644 --- a/topi/include/topi/rocm/pooling.h +++ b/topi/include/topi/rocm/pooling.h @@ -24,12 +24,12 @@ #ifndef TOPI_ROCM_POOLING_H_ #define TOPI_ROCM_POOLING_H_ -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include namespace topi { using namespace tvm; @@ -38,26 +38,26 @@ using namespace tvm::te; namespace rocm { /*! -* \brief Create a rocm schedule for pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_pool(const Target &target, const Array& outs) { + * \brief Create a rocm schedule for pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_pool(const Target& target, const Array& outs) { return topi::cuda::schedule_pool(target, outs); } /*! -* \brief Create a rocm schedule for global_pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_global_pool(const Target &target, const Array& outs) { + * \brief Create a rocm schedule for global_pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_global_pool(const Target& target, const Array& outs) { return topi::cuda::schedule_global_pool(target, outs); } diff --git a/topi/include/topi/rocm/reduction.h b/topi/include/topi/rocm/reduction.h index ea4b656..512bf20 100644 --- a/topi/include/topi/rocm/reduction.h +++ b/topi/include/topi/rocm/reduction.h @@ -24,10 +24,10 @@ #ifndef TOPI_ROCM_REDUCTION_H_ #define TOPI_ROCM_REDUCTION_H_ -#include #include -#include +#include #include +#include #include "topi/cuda/reduction.h" @@ -37,13 +37,13 @@ using namespace tvm::te; namespace rocm { /*! -* \brief Create a rocm schedule for a reduce operation. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a rocm schedule for a reduce operation. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ Schedule schedule_reduce(const Target& target, Array outs) { return topi::cuda::schedule_reduce(target, outs); } diff --git a/topi/include/topi/rocm/softmax.h b/topi/include/topi/rocm/softmax.h index 63a0304..de05c4c 100644 --- a/topi/include/topi/rocm/softmax.h +++ b/topi/include/topi/rocm/softmax.h @@ -24,10 +24,10 @@ #ifndef TOPI_ROCM_SOFTMAX_H_ #define TOPI_ROCM_SOFTMAX_H_ -#include #include -#include +#include #include +#include #include "topi/cuda/softmax.h" @@ -45,7 +45,7 @@ namespace rocm { * * \return A schedule for the given ops. */ -inline Schedule schedule_softmax(const Target &target, const Array& outs) { +inline Schedule schedule_softmax(const Target& target, const Array& outs) { return topi::cuda::schedule_softmax(target, outs); } diff --git a/topi/include/topi/tags.h b/topi/include/topi/tags.h index 8d353b9..1e9ec44 100644 --- a/topi/include/topi/tags.h +++ b/topi/include/topi/tags.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -43,16 +43,12 @@ constexpr auto kDepthwiseConv2dBackWeightNHWC = "depthwise_conv2d_back_weight_nh constexpr auto kGroupConv2d = "group_conv2d"; inline bool is_broadcast(std::string tag) { - return - tag.rfind(kElementWise, 0) == 0 || - tag.rfind(kBroadcast, 0) == 0; + return tag.rfind(kElementWise, 0) == 0 || tag.rfind(kBroadcast, 0) == 0; } inline bool is_injective(std::string tag) { - return - tag.rfind(kElementWise, 0) == 0 || - tag.rfind(kBroadcast, 0) == 0 || - tag.rfind(kInjective, 0) == 0; + return tag.rfind(kElementWise, 0) == 0 || tag.rfind(kBroadcast, 0) == 0 || + tag.rfind(kInjective, 0) == 0; } } // namespace topi diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 0609020..e21fc2a 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -24,19 +24,19 @@ #ifndef TOPI_TRANSFORM_H_ #define TOPI_TRANSFORM_H_ -#include -#include -#include -#include #include +#include #include +#include +#include +#include -#include -#include -#include #include +#include #include +#include #include +#include namespace topi { using namespace tvm; @@ -44,30 +44,25 @@ using namespace tvm::te; using namespace topi::detail; /*! -* \brief Creates an operation to insert new dimensions of length 1 -* -* \param x The input tensor -* \param axis The index of the first new dimension (allows negative -* indices as offsets from the last dimension) -* \param num_newaxis The number of new dimensions to insert -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the dim expansion operation -*/ -inline Tensor expand_dims(const Tensor& x, - int axis, - int num_newaxis = 1, - std::string name = "T_expand_dims", - std::string tag = kBroadcast) { + * \brief Creates an operation to insert new dimensions of length 1 + * + * \param x The input tensor + * \param axis The index of the first new dimension (allows negative + * indices as offsets from the last dimension) + * \param num_newaxis The number of new dimensions to insert + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the dim expansion operation + */ +inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, + std::string name = "T_expand_dims", std::string tag = kBroadcast) { int ndim = static_cast(x->shape.size()); CHECK(-ndim - 1 <= axis && axis <= ndim) - << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - CHECK(num_newaxis >= 0) - << "expand_dims only accepts `num_newaxis >= 0`" - << ", but got num_newaxis = " << num_newaxis; + << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; + CHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" + << ", but got num_newaxis = " << num_newaxis; if (axis < 0) { // Calculate offset from last dimension axis = ndim + axis + 1; @@ -84,32 +79,32 @@ inline Tensor expand_dims(const Tensor& x, } return compute( - new_shape, [&](const Array& indices) { - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); - } - for (size_t i = axis + num_newaxis; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + for (size_t i = axis + num_newaxis; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } + return x(idx); + }, + name, tag); } /*! -* \brief Permute the dimensions of an array -* -* \param x The input tensor -* \param axes The indices of the permutation. If this is empty, -* the dimensions will be reversed. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the transpose operation -*/ -inline Tensor transpose(const Tensor& x, - Array axes, - std::string name = "T_transpose", + * \brief Permute the dimensions of an array + * + * \param x The input tensor + * \param axes The indices of the permutation. If this is empty, + * the dimensions will be reversed. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the transpose operation + */ +inline Tensor transpose(const Tensor& x, Array axes, std::string name = "T_transpose", std::string tag = kInjective) { if (!axes.defined() || axes.size() == 0) { axes = Array(); @@ -127,11 +122,11 @@ inline Tensor transpose(const Tensor& x, axes.Set(i, new_axis); } CHECK((new_axis >= 0) && (new_axis < static_cast(x->shape.size()))) - << "axis=" << axis << " is invalid for the " - << static_cast(x->shape.size()) << "-dimensional input tensor"; + << "axis=" << axis << " is invalid for the " << static_cast(x->shape.size()) + << "-dimensional input tensor"; for (size_t j = 0; j < axes.size(); ++j) { - if (i !=j) { + if (i != j) { CHECK(new_axis != static_cast(axes[j]->value)) << "repeated axis in transpose"; } } @@ -139,33 +134,33 @@ inline Tensor transpose(const Tensor& x, } return compute( - new_shape, [&](const Array& indices) { - std::vector idx; - for (size_t i = 0; i < axes.size(); ++i) { - idx.push_back(1); - } - for (size_t i = 0; i < axes.size(); ++i) { - int axis = static_cast(axes[i]->value); - idx[axis] = indices[i]; - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + std::vector idx; + for (size_t i = 0; i < axes.size(); ++i) { + idx.push_back(1); + } + for (size_t i = 0; i < axes.size(); ++i) { + int axis = static_cast(axes[i]->value); + idx[axis] = indices[i]; + } + return x(idx); + }, + name, tag); } /*! -* \brief flip/reverse elements of an array in a particular axis -* -* \param x The input tensor -* \param axis The axis along which the tensors will be reveresed -* (allows negative indices) -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the reverse operation -*/ -inline Tensor flip(const Tensor& x, - int axis = 0, - std::string name = "T_flip", + * \brief flip/reverse elements of an array in a particular axis + * + * \param x The input tensor + * \param axis The axis along which the tensors will be reveresed + * (allows negative indices) + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reverse operation + */ +inline Tensor flip(const Tensor& x, int axis = 0, std::string name = "T_flip", std::string tag = kInjective) { size_t src_tensor_dim = x->shape.size(); int axis_inp = axis; @@ -175,42 +170,42 @@ inline Tensor flip(const Tensor& x, } CHECK((0 <= axis) && (axis < static_cast(x->shape.size()))) - << "axis=" << axis_inp << " is invalid for the " - << static_cast(x->shape.size()) << "-dimensional input tensor"; + << "axis=" << axis_inp << " is invalid for the " << static_cast(x->shape.size()) + << "-dimensional input tensor"; // Reverse the Input Tensor in the axis specified return compute( - x->shape, [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - if (i == static_cast(axis)) { - real_indices.push_back(x->shape[i] - indices[i] - 1); - } else { - real_indices.push_back(indices[i]); + x->shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + if (i == static_cast(axis)) { + real_indices.push_back(x->shape[i] - indices[i] - 1); + } else { + real_indices.push_back(indices[i]); + } } - } - return x(real_indices); - }, name, tag); + return x(real_indices); + }, + name, tag); } /*! -* \brief Reshape a tensor -* -* \param x The input tensor -* \param newshape The new shape -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the reshape operation -*/ -inline Tensor reshape(const Tensor& x, - Array newshape, - std::string name = "T_reshape", + * \brief Reshape a tensor + * + * \param x The input tensor + * \param newshape The new shape + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reshape operation + */ +inline Tensor reshape(const Tensor& x, Array newshape, std::string name = "T_reshape", std::string tag = kInjective) { auto x_shape = x->shape; Array target_shape; - for (const auto &ele : newshape) { + for (const auto& ele : newshape) { if (ele.as()) { target_shape.push_back(cast(DataType::Int(32), ele)); } else { @@ -219,16 +214,16 @@ inline Tensor reshape(const Tensor& x, } if (is_empty_shape(target_shape)) { - return compute(target_shape, - [&](const Array &indices) { return tvm::cast(x->dtype, 0); }, - name, tag); + return compute( + target_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); } else { return compute( - target_shape, [&](const Array& indices) { - return x(UnravelIndex( - RavelIndex(Array{indices.begin(), indices.end()}, target_shape), - x_shape)); - }, name, tag); + target_shape, + [&](const Array& indices) { + return x(UnravelIndex( + RavelIndex(Array{indices.begin(), indices.end()}, target_shape), x_shape)); + }, + name, tag); } } @@ -243,9 +238,7 @@ inline Tensor reshape(const Tensor& x, * \return A Tensor of coordinate arrays. */ -inline Tensor unravel_index(const Tensor& x, - const Tensor& shape, - std::string name = "T_unravel", +inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel", std::string tag = kInjective) { auto x_shape = x->shape; auto shape_shape = shape->shape; @@ -281,23 +274,20 @@ inline Tensor unravel_index(const Tensor& x, } /*! -* \brief Remove size 1 dimensions from the shape of a tensor. -* The removed dimensions must have a constant size of 1. -* -* \param x The input tensor -* \param axis Indices of the dimensions to remove. If this is empty, -* all entries with a constant size of 1 will be removed. + * \brief Remove size 1 dimensions from the shape of a tensor. + * The removed dimensions must have a constant size of 1. + * + * \param x The input tensor + * \param axis Indices of the dimensions to remove. If this is empty, + * all entries with a constant size of 1 will be removed. * \param atleast1d Whether the output need to be atleast1d. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the squeeze operation -*/ -inline Tensor squeeze(const Tensor& x, - Array axis, - bool atleast1d = false, - std::string name = "T_squeeze", - std::string tag = kInjective) { + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the squeeze operation + */ +inline Tensor squeeze(const Tensor& x, Array axis, bool atleast1d = false, + std::string name = "T_squeeze", std::string tag = kInjective) { auto ndim = x->shape.size(); std::vector axis_val; if (!axis.defined() || axis.size() == 0) { @@ -312,8 +302,7 @@ inline Tensor squeeze(const Tensor& x, if (val < 0) { val += static_cast(x->shape.size()); } - CHECK_EQ(GetConstInt(x->shape[val]), 1) << - "Dimension " << val << " must have size 1"; + CHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1"; axis_val.push_back(val); } } @@ -331,45 +320,42 @@ inline Tensor squeeze(const Tensor& x, } return compute( - out_shape, [&](const Array& indices) { - Array real_indices; - int flag = 0; - for (size_t i = 0; i < ndim; ++i) { - if (axis_set.count(static_cast(i)) == 0) { - real_indices.push_back(indices[i - flag]); - } else { - real_indices.push_back(0); - flag += 1; + out_shape, + [&](const Array& indices) { + Array real_indices; + int flag = 0; + for (size_t i = 0; i < ndim; ++i) { + if (axis_set.count(static_cast(i)) == 0) { + real_indices.push_back(indices[i - flag]); + } else { + real_indices.push_back(0); + flag += 1; + } } - } - return x(real_indices); - }, name, tag); + return x(real_indices); + }, + name, tag); } /*! -* \brief Join a sequence of tensors along an existing axis -* -* \param inputs The input tensors -* \param axis The axis along which the tensors will be joined -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the concatenate operation -*/ -inline Tensor concatenate(const Array& inputs, - int axis = 0, - std::string name = "T_concat", + * \brief Join a sequence of tensors along an existing axis + * + * \param inputs The input tensors + * \param axis The axis along which the tensors will be joined + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the concatenate operation + */ +inline Tensor concatenate(const Array& inputs, int axis = 0, std::string name = "T_concat", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); - CHECK(-ndim <= axis && axis < ndim) - << "concatenate only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; + CHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; if (axis < 0) { axis += ndim; } - CHECK_LT(axis, inputs[0]->shape.size()) << - "axis out of bounds"; + CHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; Array axis_sizes; for (auto t : inputs) { @@ -387,96 +373,87 @@ inline Tensor concatenate(const Array& inputs, } return compute( - out_shape, [&](const Array& indices) { - auto ret = inputs[0](indices); - auto ind = indices[axis]; - for (size_t i = 0; i < inputs.size() - 1; ++i) { - ind -= axis_sizes[i]; + out_shape, + [&](const Array& indices) { + auto ret = inputs[0](indices); + auto ind = indices[axis]; + for (size_t i = 0; i < inputs.size() - 1; ++i) { + ind -= axis_sizes[i]; + + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + idx.push_back(ind); + for (size_t i = axis + 1; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); + ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret); } - idx.push_back(ind); - for (size_t i = axis + 1; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - - ret = tvm::if_then_else(ind >= 0, - inputs[i + 1](idx), - ret); - } - return ret; - }, name, tag); + return ret; + }, + name, tag); } /*! -* \brief Join a sequence of tensors along a new axis. -* -* \param inputs The input tensors -* \param axis The axis along which the tensors will be stacked -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the stack operation -*/ -inline Tensor stack(const Array& inputs, - int axis = 0, - std::string name = "T_stack", + * \brief Join a sequence of tensors along a new axis. + * + * \param inputs The input tensors + * \param axis The axis along which the tensors will be stacked + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the stack operation + */ +inline Tensor stack(const Array& inputs, int axis = 0, std::string name = "T_stack", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); CHECK(-ndim - 1 <= axis && axis <= ndim) - << "stack only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; + << "stack only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; if (axis < 0) { axis += ndim + 1; } - CHECK_LT(axis, inputs[0]->shape.size() + 1) << - "axis out of bounds"; + CHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds"; const int stack_size = static_cast(inputs.size()); Array out_shape; - for (size_t i = 0; i < static_cast(axis); ++i) - out_shape.push_back(inputs[0]->shape[i]); + for (size_t i = 0; i < static_cast(axis); ++i) out_shape.push_back(inputs[0]->shape[i]); out_shape.push_back(stack_size); for (size_t i = static_cast(axis); i < static_cast(ndim); ++i) out_shape.push_back(inputs[0]->shape[i]); return compute( - out_shape, [&](const Array& indices) { - Array idx; - for (size_t i = 0; i < indices.size(); ++i) - if (i != static_cast(axis)) - idx.push_back(indices[i]); - auto ind = indices[axis]; - auto ret = inputs[0](idx); - for (int i = 0; i < static_cast(inputs.size() - 1); ++i) { - ret = tvm::if_then_else(ind == i + 1, - inputs[i + 1](idx), - ret); - } - return ret; - }, name, tag); + out_shape, + [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < indices.size(); ++i) + if (i != static_cast(axis)) idx.push_back(indices[i]); + auto ind = indices[axis]; + auto ret = inputs[0](idx); + for (int i = 0; i < static_cast(inputs.size() - 1); ++i) { + ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret); + } + return ret; + }, + name, tag); } /*! -* \brief Split a tensor into multiple sub-tensors -* -* \param x The input tensor -* \param split_indices The indices to split the input at. This must be in ascending -* order. -* \param axis The axis to split along. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the split operation -*/ -inline Array split(const Tensor& x, - Array split_indices, - int axis, - std::string name = "T_split", - std::string tag = kInjective) { + * \brief Split a tensor into multiple sub-tensors + * + * \param x The input tensor + * \param split_indices The indices to split the input at. This must be in ascending + * order. + * \param axis The axis to split along. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Array split(const Tensor& x, Array split_indices, int axis, + std::string name = "T_split", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -488,12 +465,11 @@ inline Array split(const Tensor& x, for (Integer idx : split_indices) { int val = static_cast(idx->value); - CHECK_GT(val, begin_ids.back()) - << "split_indices must be sorted"; + CHECK_GT(val, begin_ids.back()) << "split_indices must be sorted"; begin_ids.push_back(val); } - Array< Array > out_shapes; + Array > out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { int out_axis_size; if (i == begin_ids.size() - 1) { @@ -516,9 +492,9 @@ inline Array split(const Tensor& x, Array result; for (size_t i = 0; i < begin_ids.size(); ++i) { - result.push_back( - compute( - out_shapes[i], [&](const Array& indices) { + result.push_back(compute( + out_shapes[i], + [&](const Array& indices) { auto begin = begin_ids[i]; Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { @@ -530,30 +506,28 @@ inline Array split(const Tensor& x, } return x(real_indices); - }, name, tag)); + }, + name, tag)); } return result; } /*! -* \brief strided_slice of a tensor -* -* \param x The input tensor -* \param begin The indices to begin with in the slicing -* \param end Indicies indicating end of the slice -* \param strides Specifies the stride values, it can be negative -* in that case, the input tensor will be reversed in that particular axis -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the split operation -*/ -inline Tensor strided_slice(const Tensor& x, - const Array& begin, - const Array& end, - const Array& strides, - std::string name = "T_strided_slice", + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, + const Array& strides, std::string name = "T_strided_slice", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); // Setup the ranges. @@ -615,43 +589,43 @@ inline Tensor strided_slice(const Tensor& x, int64_t end_i = index_canonicalization(end_vec[i]); int interval = std::abs(end_i - begin_i); - int slice_size = static_cast((interval - + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); + int slice_size = + static_cast((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); CHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) - << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] - << "] is invalid for axis=" << i; + << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] + << "] is invalid for axis=" << i; begin_expr.push_back(make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back(make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), - stride_vec[i])); + strides_expr.push_back( + make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i])); out_shape.push_back(slice_size); } return compute( - out_shape, [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return x(real_indices); - }, name, tag); + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); + } + return x(real_indices); + }, + name, tag); } /*! -* \brief Split a tensor into a number of sub-tensors -* -* \param x The input tensor -* \param num_sections The number of sections to split the tensor into. -* this must be an integer factor of the size of the axis being split. -* \param axis The axis to split along. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the split operation -*/ -inline Array split_sections(const Tensor& x, - int num_sections, - int axis, + * \brief Split a tensor into a number of sub-tensors + * + * \param x The input tensor + * \param num_sections The number of sections to split the tensor into. + * this must be an integer factor of the size of the axis being split. + * \param axis The axis to split along. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Array split_sections(const Tensor& x, int num_sections, int axis, std::string name = "T_split_sections", std::string tag = kInjective) { if (axis < 0) { @@ -663,8 +637,8 @@ inline Array split_sections(const Tensor& x, CHECK_GT(num_sections, 0) << "Slice count must be > 0"; CHECK_EQ(src_axis_size % num_sections, 0) - << "num_sections must be an integer factor of the size of axis " << axis - << " (" << src_axis_size << ")"; + << "num_sections must be an integer factor of the size of axis " << axis << " (" + << src_axis_size << ")"; Array split_indices; auto seg_size = src_axis_size / num_sections; @@ -679,22 +653,19 @@ inline Array split_sections(const Tensor& x, } /*! -* \brief Take elements from an flattened input array when axis is None. -* -* \param a The source array. -* \param indices The indices of the values to extract. -* \param mode The mode of the operation. -* \param name The name of the operation. -* \param mode The mode of to handle out of bound indices. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the take operation -*/ -inline Tensor take(const Tensor& a, - const Tensor& indices, - std::string mode = "clip", - std::string name = "T_take", - std::string tag = kInjective) { + * \brief Take elements from an flattened input array when axis is None. + * + * \param a The source array. + * \param indices The indices of the values to extract. + * \param mode The mode of the operation. + * \param name The name of the operation. + * \param mode The mode of to handle out of bound indices. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the take operation + */ +inline Tensor take(const Tensor& a, const Tensor& indices, std::string mode = "clip", + std::string name = "T_take", std::string tag = kInjective) { Array a_shape = a->shape; Array out_shape = indices->shape; PrimExpr a_size = 1; @@ -704,44 +675,44 @@ inline Tensor take(const Tensor& a, if (mode == "clip") { return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); return a(UnravelIndex(idx, a_shape)); - }, name, tag); + }, + name, tag); } else if (mode == "fast") { LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " "Make sure input indices are in bound"; return compute( - out_shape, [&](const Array& out_index) { - return a(UnravelIndex(indices(out_index), a_shape)); - }, name, tag); + out_shape, + [&](const Array& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); }, + name, tag); } else { // mode == "wrap" return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size); return a(UnravelIndex(idx, a_shape)); - }, name, tag); + }, + name, tag); } } - /*! -* \brief Mask the out-of-boundary elements of each sequence. -* -* \param data The source array. -* \param valid_length The real length of each sequence. -* \param mask_value The masking value. -* \param axis The axis of the temporal dimension of the sequence -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the sequence_mask operation -*/ -inline Tensor sequence_mask(const Tensor& data, - const Tensor& valid_length, - double mask_value, - int axis, - std::string name = "T_sequence_mask", + * \brief Mask the out-of-boundary elements of each sequence. + * + * \param data The source array. + * \param valid_length The real length of each sequence. + * \param mask_value The masking value. + * \param axis The axis of the temporal dimension of the sequence + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the sequence_mask operation + */ +inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value, + int axis, std::string name = "T_sequence_mask", std::string tag = kInjective) { CHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1"; CHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,)."; @@ -749,38 +720,36 @@ inline Tensor sequence_mask(const Tensor& data, auto batch_dim = data->shape[1 - axis]; Array out_shape = data->shape; Tensor out = compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array len_index; auto tid = out_index[axis]; auto bid = out_index[1 - axis]; len_index.push_back(bid); - PrimExpr ret = tvm::if_then_else( - tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), - tvm::tir::make_const(data->dtype, mask_value), data(out_index)); + PrimExpr ret = + tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), + tvm::tir::make_const(data->dtype, mask_value), data(out_index)); return ret; - }, name, tag); + }, + name, tag); return out; } /*! -* \brief Take elements from an array along an axis. -* -* \param a The source array. -* \param indices The indices of the values to extract. -* \param axis The axis over which to select values. By default, -* the flattened input array is used. -* \param mode The mode for handling out of bound indices. -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the take operation -*/ -inline Tensor take(const Tensor& a, - const Tensor& indices, - int axis, - std::string mode = "clip", - std::string name = "T_take", - std::string tag = kInjective) { + * \brief Take elements from an array along an axis. + * + * \param a The source array. + * \param indices The indices of the values to extract. + * \param axis The axis over which to select values. By default, + * the flattened input array is used. + * \param mode The mode for handling out of bound indices. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the take operation + */ +inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string mode = "clip", + std::string name = "T_take", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(a->shape.size()); } @@ -801,30 +770,32 @@ inline Tensor take(const Tensor& a, } if (mode == "clip") { return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array indices_position; - for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - auto idx = tvm::min(tvm::max(0, indices(indices_position)), - axis_dim - 1); + auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); real_indices.push_back(idx); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); } return a(real_indices); - }, name, tag); + }, + name, tag); } else if (mode == "fast") { LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " "Make sure input indices are in bound"; return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array indices_position; - for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } Array real_indices; @@ -836,12 +807,14 @@ inline Tensor take(const Tensor& a, real_indices.push_back(out_index[j]); } return a(real_indices); - }, name, tag); + }, + name, tag); } else { // mode == "wrap" return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array indices_position; - for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } Array real_indices; @@ -854,82 +827,78 @@ inline Tensor take(const Tensor& a, real_indices.push_back(out_index[j]); } return a(real_indices); - }, name, tag); + }, + name, tag); } } /*! -* \brief Return the elements, either from x or y, depending on the condition. -* -* \param condition The condition array. -* \param x First array to be selected. -* \param y Second array to be selected. -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor selected from x or y depending on condition. -*/ -inline Tensor where(const Tensor& condition, - const Tensor& x, - const Tensor& y, - std::string name = "T_where", - std::string tag = kBroadcast) { + * \brief Return the elements, either from x or y, depending on the condition. + * + * \param condition The condition array. + * \param x First array to be selected. + * \param y Second array to be selected. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor selected from x or y depending on condition. + */ +inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, + std::string name = "T_where", std::string tag = kBroadcast) { CHECK_EQ(x->shape.size(), y->shape.size()) - << "x and y must have the same shape.Got different number of dimension: " - << x->shape.size() << " vs " << y->shape.size(); - CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " - << x->dtype << " vs " << y->dtype; + << "x and y must have the same shape.Got different number of dimension: " << x->shape.size() + << " vs " << y->shape.size(); + CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " + << y->dtype; Array oshape = x->shape; Tensor out; if (condition->shape.size() != 1) { CHECK_EQ(condition->shape.size(), x->shape.size()) - << "condition array must be either have the same shape as x or to be a " - "1-D array.Got different number of dimension: " - << condition->shape.size() << " vs " << x->shape.size(); + << "condition array must be either have the same shape as x or to be a " + "1-D array.Got different number of dimension: " + << condition->shape.size() << " vs " << x->shape.size(); out = compute( - oshape, [&](const Array& indices) { - return tvm::tir::SelectNode::make(condition(indices) != 0, x(indices), y(indices)); - }, name, tag); + oshape, + [&](const Array& indices) { + return tvm::tir::SelectNode::make(condition(indices) != 0, x(indices), y(indices)); + }, + name, tag); } else { CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0])) - << "If condition is 1-D, the first dimension must be the same as x: " - << condition->shape[0] << " vs " << x->shape[0]; + << "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0] + << " vs " << x->shape[0]; out = compute( - oshape, [&](const Array& indices) { - Array condition_idx{indices[0]}; - return tvm::tir::SelectNode::make(condition(condition_idx) != 0, - x(indices), y(indices)); - }, name, tag); + oshape, + [&](const Array& indices) { + Array condition_idx{indices[0]}; + return tvm::tir::SelectNode::make(condition(condition_idx) != 0, x(indices), y(indices)); + }, + name, tag); } return out; } /*! -* \brief Creates an operation to repeat elements of an array -* -* \param x The input tensor -* \param repeats The number of repetitions for each element -* \param axis The axis along which to repeat values (allows -* negative indices as offsets from the last dimension) -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the repeat operation -*/ -inline Tensor repeat(const Tensor& x, - int repeats, - int axis, - std::string name = "T_repeat", + * \brief Creates an operation to repeat elements of an array + * + * \param x The input tensor + * \param repeats The number of repetitions for each element + * \param axis The axis along which to repeat values (allows + * negative indices as offsets from the last dimension) + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the repeat operation + */ +inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat", std::string tag = kBroadcast) { int ndim = static_cast(x->shape.size()); CHECK(-ndim - 1 <= axis && axis <= ndim) - << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - CHECK(repeats >= 1) - << "repeat only accepts `repeats >= 1`" - << ", but got repeats = " << repeats; + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; + CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; if (axis < 0) { // Calculate offset from last dimension axis += ndim; @@ -944,32 +913,32 @@ inline Tensor repeat(const Tensor& x, } return compute( - new_shape, [&](const Array& indices) { - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); - } - idx.push_back(indexdiv(indices[axis], repeats)); - for (size_t i = axis + 1; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + idx.push_back(indexdiv(indices[axis], repeats)); + for (size_t i = axis + 1; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } + return x(idx); + }, + name, tag); } /*! -* \brief Creates an operation to tile elements of an array -* -* \param x The input tensor -* \param reps The number of times for repeating the tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the tile operation -*/ -inline Tensor tile(const Tensor& x, - Array reps, - std::string name = "T_tile", + * \brief Creates an operation to tile elements of an array + * + * \param x The input tensor + * \param reps The number of times for repeating the tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the tile operation + */ +inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_tile", std::string tag = kBroadcast) { size_t ndim = x->shape.size(); size_t rdim = reps.size(); @@ -983,56 +952,47 @@ inline Tensor tile(const Tensor& x, reps_shape.push_back(reps[i]); } } else if (ndim > rdim) { - for (size_t i = 0; i < ndim; ++i) - data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < (ndim - rdim); ++i) - reps_shape.push_back(1); - for (size_t i = 0; i < rdim; ++i) - reps_shape.push_back(reps[i]); + for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); + for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); } else { - for (size_t i = 0; i < (rdim - ndim); ++i) - data_shape.push_back(1); - for (size_t i = 0; i < ndim; ++i) - data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < rdim; ++i) - reps_shape.push_back(reps[i]); + for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1); + for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); } - for (size_t i = 0; i < tdim; ++i) - new_shape.push_back(data_shape[i] * reps_shape[i]); + for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]); if (is_empty_shape(new_shape)) { - return compute(new_shape, - [&](const Array& indices) { return tvm::cast(x->dtype, 0);}, - name, tag); + return compute( + new_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); } else { return compute( - new_shape, [&](const Array& indices) { - Array idx; - if (ndim >= rdim) { - for (size_t i = 0; i < ndim; ++i) - idx.push_back(indexmod(indices[i], x->shape[i])); - } else { - for (size_t i = 0; i < ndim; ++i) - idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + Array idx; + if (ndim >= rdim) { + for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i])); + } else { + for (size_t i = 0; i < ndim; ++i) + idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); + } + return x(idx); + }, + name, tag); } } /*! -* \brief Gather elements from a n-dimension array. -* -* \param data The source array. -* \param indices The indices of the values to extract. -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the gather_nd operation -*/ -inline Tensor gather_nd(const Tensor& data, - const Tensor& indices, - std::string name = "T_gather_nd", + * \brief Gather elements from a n-dimension array. + * + * \param data The source array. + * \param indices The indices of the values to extract. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the gather_nd operation + */ +inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string name = "T_gather_nd", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); @@ -1051,27 +1011,28 @@ inline Tensor gather_nd(const Tensor& data, out_shape.push_back(make_const(DataType::Int(32), 1)); } return compute( - out_shape, [&](const Array& out_index) { - Array indices_position; - indices_position.push_back(0); - for (size_t i = 0; i < ndim_i - 1; ++i) { - indices_position.push_back(out_index[i]); - } - Array real_indices; - for (size_t i = 0; i < indices_dim0; ++i) { - indices_position.Set(0, make_const(DataType::Int(32), i)); - if (indices->dtype.is_int()) { - real_indices.push_back(indices(indices_position)); - } else { - real_indices.push_back( - tvm::cast(tvm::DataType::Int(32), indices(indices_position))); - } - } - for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { - real_indices.push_back(out_index[i]); + out_shape, + [&](const Array& out_index) { + Array indices_position; + indices_position.push_back(0); + for (size_t i = 0; i < ndim_i - 1; ++i) { + indices_position.push_back(out_index[i]); + } + Array real_indices; + for (size_t i = 0; i < indices_dim0; ++i) { + indices_position.Set(0, make_const(DataType::Int(32), i)); + if (indices->dtype.is_int()) { + real_indices.push_back(indices(indices_position)); + } else { + real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position))); } - return data(real_indices); - }, name, tag); + } + for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { + real_indices.push_back(out_index[i]); + } + return data(real_indices); + }, + name, tag); } /*! @@ -1089,18 +1050,13 @@ inline Tensor gather_nd(const Tensor& data, * * \return A Tensor whose op member is the matmul operation */ -inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, - const tvm::te::Tensor& B, - bool trans_a = false, - bool trans_b = false, - std::string name = "T_matmul", - std::string tag = kMatMul) { - tvm::Array output_shape{A->shape[trans_a ? 1 : 0], - B->shape[trans_b ? 0 : 1]}; +inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B, + bool trans_a = false, bool trans_b = false, + std::string name = "T_matmul", std::string tag = kMatMul) { + tvm::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); auto l = [&](tvm::tir::Var i, tvm::tir::Var j) { - return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), - {k}); + return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k}); }; return tvm::te::compute(output_shape, l, name, tag); } @@ -1116,45 +1072,35 @@ inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, * * \return A Tensor computing the result */ -inline Tensor tensordot(const Tensor& A, - const tvm::te::Tensor& B, - int axes = 2, - std::string name = "T_tensordot", - std::string tag = kMatMul) { +inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, + std::string name = "T_tensordot", std::string tag = kMatMul) { CHECK_GE(A->shape.size(), axes); CHECK_GE(B->shape.size(), axes); Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); - for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) - output_shape.push_back(*it); + for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it); Array iter_vars; for (int i = 0; i < axes; ++i) iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i))); - auto func = - [&A, &B, &iter_vars, axes] - (const Array& input_indices) { - Array A_indices( - input_indices.begin(), - input_indices.begin() + (A->shape.size() - axes)); - for (auto& v : iter_vars) - A_indices.push_back(v); - - Array B_indices; - for (auto& v : iter_vars) - B_indices.push_back(v); - - auto it = input_indices.begin() + (A->shape.size() - axes); - for (; it != input_indices.end(); ++it) - B_indices.push_back(*it); - - // Some passes don't like reductions with empty axis, so avoid it here - if (iter_vars.empty()) - return A(A_indices) * B(B_indices); - else - return sum(A(A_indices) * B(B_indices), iter_vars); - }; + auto func = [&A, &B, &iter_vars, axes](const Array& input_indices) { + Array A_indices(input_indices.begin(), + input_indices.begin() + (A->shape.size() - axes)); + for (auto& v : iter_vars) A_indices.push_back(v); + + Array B_indices; + for (auto& v : iter_vars) B_indices.push_back(v); + + auto it = input_indices.begin() + (A->shape.size() - axes); + for (; it != input_indices.end(); ++it) B_indices.push_back(*it); + + // Some passes don't like reductions with empty axis, so avoid it here + if (iter_vars.empty()) + return A(A_indices) * B(B_indices); + else + return sum(A(A_indices) * B(B_indices), iter_vars); + }; return compute(output_shape, func, name, tag); } @@ -1171,11 +1117,8 @@ inline Tensor tensordot(const Tensor& A, * * \return A Tensor computing the result */ -inline Tensor tensordot(const Tensor& A, - const tvm::te::Tensor& B, - Array A_axes, - Array B_axes, - std::string name = "T_tensordot", +inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array A_axes, + Array B_axes, std::string name = "T_tensordot", std::string tag = kMatMul) { CHECK_EQ(A_axes.size(), B_axes.size()); @@ -1191,47 +1134,42 @@ inline Tensor tensordot(const Tensor& A, output_shape.push_back(B->shape[i]); Array iter_vars; - for (unsigned i = 0; i < B_axes_val.size(); ++i) - iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i))); - - auto func = - [&A, &B, &iter_vars, A_axes_val, B_axes_val] - (const Array& input_indices) { - int idx_input = 0; - Array A_indices; - for (unsigned i = 0; i < A->shape.size(); ++i) { - auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); - if (axes_pos == A_axes_val.end()) - A_indices.push_back(input_indices[idx_input++]); - else - A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]); - } + for (unsigned i = 0; i < B_axes_val.size(); ++i) + iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i))); + + auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array& input_indices) { + int idx_input = 0; + Array A_indices; + for (unsigned i = 0; i < A->shape.size(); ++i) { + auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); + if (axes_pos == A_axes_val.end()) + A_indices.push_back(input_indices[idx_input++]); + else + A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]); + } - Array B_indices; - for (unsigned i = 0; i < B->shape.size(); ++i) { - auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); - if (axes_pos == B_axes_val.end()) - B_indices.push_back(input_indices[idx_input++]); - else - B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]); - } - return sum(A(A_indices) * B(B_indices), iter_vars); - }; + Array B_indices; + for (unsigned i = 0; i < B->shape.size(); ++i) { + auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); + if (axes_pos == B_axes_val.end()) + B_indices.push_back(input_indices[idx_input++]); + else + B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]); + } + return sum(A(A_indices) * B(B_indices), iter_vars); + }; return compute(output_shape, func, name, tag); } -inline Tensor arange(const PrimExpr& start, - const PrimExpr& stop, - const PrimExpr& step, - DataType dtype, - std::string name = "T_arange", - std::string tag = kInjective) { - PrimExpr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil( - tvm::cast(tvm::DataType::Float(32), stop - start) / step)); +inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step, + DataType dtype, std::string name = "T_arange", std::string tag = kInjective) { + PrimExpr num_elem = tvm::cast( + tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step)); Array shape; - return compute({num_elem}, [&](const Array& indices) { - return tvm::cast(dtype, start + step * indices[0]); - }, name, tag); + return compute( + {num_elem}, + [&](const Array& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name, + tag); } /*! @@ -1243,8 +1181,7 @@ inline Tensor arange(const PrimExpr& start, * \param tag output tensor tag. * \return A tensor with shape in \p dst_layout */ -inline Tensor layout_transform(const Tensor& src, - const std::string& src_layout, +inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, const std::string& dst_layout, const std::string name = "T_layout_trans", const std::string tag = kInjective) { @@ -1256,20 +1193,21 @@ inline Tensor layout_transform(const Tensor& src, } CHECK(src_layout_struct.defined() && dst_layout_struct.defined()) - << "cannot convert from/to undefined layout"; + << "cannot convert from/to undefined layout"; auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct); - CHECK(layout_converter.defined()) - << "cannot convert from " << src_layout << " to " << dst_layout; + CHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout; Array dst_shape = layout_converter.ForwardShape(src->shape); return compute( - dst_shape, [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); - return src(src_indices); - }, name, tag); + dst_shape, + [&](const Array& dst_indices) { + Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); + return src(src_indices); + }, + name, tag); } /*! @@ -1280,20 +1218,21 @@ inline Tensor layout_transform(const Tensor& src, * \param tag output tensor tag. * \return Tensor of input shape. */ -inline Tensor shape(const Tensor& src, - DataType dtype, - const std::string name = "T_shape", +inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape", const std::string tag = kInjective) { int ndim = static_cast(src->shape.size()); Array out_shape{ndim}; - return compute(out_shape, [&](const Array& indices) { - auto idx = indices[0]; - PrimExpr ret = 0; - for (int i = 0; i < ndim; ++i) { - ret = tvm::if_then_else(idx == i, src->shape[i], ret); - } - return tvm::cast(dtype, ret); - }, name, tag); + return compute( + out_shape, + [&](const Array& indices) { + auto idx = indices[0]; + PrimExpr ret = 0; + for (int i = 0; i < ndim; ++i) { + ret = tvm::if_then_else(idx == i, src->shape[i], ret); + } + return tvm::cast(dtype, ret); + }, + name, tag); } /*! @@ -1304,19 +1243,21 @@ inline Tensor shape(const Tensor& src, * \param tag output tensor tag. * \return Tensor of input shape. */ -inline Tensor ndarray_size(const Tensor& src, - const DataType& dtype, +inline Tensor ndarray_size(const Tensor& src, const DataType& dtype, const std::string& name = "ndarray_size", const std::string& tag = kInjective) { int ndim = static_cast(src->shape.size()); Array out_ndarray_size = {1}; - return compute(out_ndarray_size, [&](const Array& indices) { - PrimExpr ret = 1; - for (int i = 0; i < ndim; ++i) { - ret *= src->shape[i]; - } - return tvm::cast(dtype, ret); - }, name, tag); + return compute( + out_ndarray_size, + [&](const Array& indices) { + PrimExpr ret = 1; + for (int i = 0; i < ndim; ++i) { + ret *= src->shape[i]; + } + return tvm::cast(dtype, ret); + }, + name, tag); } /*! @@ -1332,14 +1273,9 @@ inline Tensor ndarray_size(const Tensor& src, * \param tag output tensor tag. * \return one-hot tensor. */ -inline Tensor one_hot(const Tensor& indices, - const PrimExpr on_value, - const PrimExpr off_value, - int depth, - int axis, - const DataType& dtype, - const std::string name = "T_one_hot", - const std::string tag = kInjective) { +inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, + int depth, int axis, const DataType& dtype, + const std::string name = "T_one_hot", const std::string tag = kInjective) { Array oshape; int ndim = indices->shape.size() + 1; int indices_index = 0; @@ -1354,19 +1290,23 @@ inline Tensor one_hot(const Tensor& indices, PrimExpr on_value_cast = cast(dtype, on_value); PrimExpr off_value_cast = cast(dtype, off_value); - return compute(oshape, [&](const Array& iter_vars) { - Array indices_indices; - for (size_t i = 0; i < iter_vars.size(); i++) { - if (static_cast(i) == true_axis) { - continue; - } + return compute( + oshape, + [&](const Array& iter_vars) { + Array indices_indices; + for (size_t i = 0; i < iter_vars.size(); i++) { + if (static_cast(i) == true_axis) { + continue; + } - indices_indices.push_back(iter_vars[i]); - } + indices_indices.push_back(iter_vars[i]); + } - auto idx = iter_vars[true_axis]; - return tir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, off_value_cast); - }, name, tag); + auto idx = iter_vars[true_axis]; + return tir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, + off_value_cast); + }, + name, tag); } } // namespace topi diff --git a/topi/include/topi/vision/reorg.h b/topi/include/topi/vision/reorg.h index 06931e4..5bd79f6 100644 --- a/topi/include/topi/vision/reorg.h +++ b/topi/include/topi/vision/reorg.h @@ -24,11 +24,11 @@ #ifndef TOPI_VISION_REORG_H_ #define TOPI_VISION_REORG_H_ -#include #include #include #include #include +#include #include #include @@ -39,18 +39,16 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Reorg operation -* -* \param data The input tensor. Can be any dimension -* \param stride The input integer used as stride in reorg operation -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the reorg operation -*/ -inline Tensor reorg(const Tensor &data, - int stride = 1, - std::string name = "tensor", + * \brief Reorg operation + * + * \param data The input tensor. Can be any dimension + * \param stride The input integer used as stride in reorg operation + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reorg operation + */ +inline Tensor reorg(const Tensor& data, int stride = 1, std::string name = "tensor", std::string tag = "reorg_output") { auto input_shape = data->shape; @@ -60,15 +58,14 @@ inline Tensor reorg(const Tensor &data, int w_in = GetConstInt(input_shape[3]); int out_c = c_in / (stride * stride); - auto out = tvm::te::compute(input_shape, - [&](Var b, Var k, Var j, Var i) { - return data(b * stride * stride, - indexmod(k, out_c) * stride * stride, - (j*stride + indexdiv(indexdiv(k, out_c), stride)) * stride, - (i*stride + indexmod(indexdiv(k, out_c), stride))); - }, - name, - tag); + auto out = tvm::te::compute( + input_shape, + [&](Var b, Var k, Var j, Var i) { + return data(b * stride * stride, indexmod(k, out_c) * stride * stride, + (j * stride + indexdiv(indexdiv(k, out_c), stride)) * stride, + (i * stride + indexmod(indexdiv(k, out_c), stride))); + }, + name, tag); out_c = c_in * stride * stride; int out_h = h_in / stride; diff --git a/topi/include/topi/x86/bnn.h b/topi/include/topi/x86/bnn.h index 53b7a8e..a59d30d 100644 --- a/topi/include/topi/x86/bnn.h +++ b/topi/include/topi/x86/bnn.h @@ -24,10 +24,10 @@ #ifndef TOPI_X86_BNN_H_ #define TOPI_X86_BNN_H_ -#include #include -#include +#include #include +#include namespace topi { using namespace tvm; @@ -35,14 +35,14 @@ using namespace tvm::te; namespace x86 { /*! -* \brief Create a generic schedule for binarize_pack -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_binarize_pack(const Target &target, const Array& outs) { + * \brief Create a generic schedule for binarize_pack + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_binarize_pack(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -67,14 +67,14 @@ inline Schedule schedule_binarize_pack(const Target &target, const Array } /*! -* \brief Create a generic schedule for binary_dense -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_binary_dense(const Target &target, const Array& outs) { + * \brief Create a generic schedule for binary_dense + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_binary_dense(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/x86/default.h b/topi/include/topi/x86/default.h index 9b6efa5..0733781 100644 --- a/topi/include/topi/x86/default.h +++ b/topi/include/topi/x86/default.h @@ -24,11 +24,11 @@ #ifndef TOPI_X86_DEFAULT_H_ #define TOPI_X86_DEFAULT_H_ -#include #include +#include +#include #include #include -#include namespace topi { using namespace tvm; @@ -36,16 +36,15 @@ using namespace tvm::te; namespace x86 { /*! -* \brief Helper to create a default x86 schedule for the given ops. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* \param auto_inline Whether to apply the auto inline step. -* -* \return A schedule for the given ops. -*/ -inline Schedule MakeDefaultSchedule(const Target &target, - const Array& outs, + * \brief Helper to create a default x86 schedule for the given ops. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * \param auto_inline Whether to apply the auto inline step. + * + * \return A schedule for the given ops. + */ +inline Schedule MakeDefaultSchedule(const Target& target, const Array& outs, bool auto_inline) { Array out_ops; for (auto t : outs) { @@ -66,7 +65,7 @@ inline Schedule MakeDefaultSchedule(const Target &target, if (axis.size() == 4) { auto n = axis[0]; auto c = axis[1]; - auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h + auto fused = detail::Fuse(s[x], {n, c}); // for nhwc layout, fuse n and h s[x].parallel(fused); } else { s[x].parallel(axis[0]); @@ -76,26 +75,26 @@ inline Schedule MakeDefaultSchedule(const Target &target, } /*! -* \brief Create a default x86 schedule for the given ops. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule default_schedule(const Target &target, const Array& outs) { + * \brief Create a default x86 schedule for the given ops. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule default_schedule(const Target& target, const Array& outs) { return MakeDefaultSchedule(target, outs, false); } /*! -* \brief Create a default x86 schedule for the given ops, with auto inline -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule default_schedule_auto_inline(const Target &target, const Array& outs) { + * \brief Create a default x86 schedule for the given ops, with auto inline + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule default_schedule_auto_inline(const Target& target, const Array& outs) { return MakeDefaultSchedule(target, outs, true); } diff --git a/topi/include/topi/x86/injective.h b/topi/include/topi/x86/injective.h index 182140d..069a971 100644 --- a/topi/include/topi/x86/injective.h +++ b/topi/include/topi/x86/injective.h @@ -24,10 +24,10 @@ #ifndef TOPI_X86_INJECTIVE_H_ #define TOPI_X86_INJECTIVE_H_ -#include #include -#include +#include #include +#include namespace topi { using namespace tvm; @@ -48,7 +48,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out if (axis.size() == 4) { auto n = axis[0]; auto c = axis[1]; - auto fused = detail::Fuse(sch[out], { n, c }); // for nhwc layout, fuse n and h + auto fused = detail::Fuse(sch[out], {n, c}); // for nhwc layout, fuse n and h sch[out].parallel(fused); } else { sch[out].parallel(axis[0]); @@ -57,14 +57,14 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out } /*! -* \brief Create an x86 schedule for the given injective ops. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_injective(const Target &target, const Array& outs) { + * \brief Create an x86 schedule for the given injective ops. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_injective(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/src/broadcast.cc b/topi/src/broadcast.cc index b147545..e13c09e 100644 --- a/topi/src/broadcast.cc +++ b/topi/src/broadcast.cc @@ -18,39 +18,33 @@ */ /*! -* \brief Registration of broadcast operators -* \file broadcast.cc -*/ -#include -#include - + * \brief Registration of broadcast operators + * \file broadcast.cc + */ #include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ - TVM_REGISTER_GLOBAL(OpName) \ - .set_body([](TVMArgs args, TVMRetValue *rv) { \ - bool lhs_is_tensor = args[0].IsObjectRef(); \ - bool rhs_is_tensor = args[1].IsObjectRef(); \ - if (lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::te::Tensor(), \ - args[1].operator tvm::te::Tensor()); \ - } else if (!lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::PrimExpr(), \ - args[1].operator tvm::te::Tensor()); \ - } else if (lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::te::Tensor(), \ - args[1].operator tvm::PrimExpr()); \ - } else if (!lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::PrimExpr(), \ - args[1].operator tvm::PrimExpr()); \ - } \ - }); \ +#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ + TVM_REGISTER_GLOBAL(OpName).set_body([](TVMArgs args, TVMRetValue* rv) { \ + bool lhs_is_tensor = args[0].IsObjectRef(); \ + bool rhs_is_tensor = args[1].IsObjectRef(); \ + if (lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::te::Tensor()); \ + } else if (!lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::te::Tensor()); \ + } else if (lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::PrimExpr()); \ + } else if (!lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::PrimExpr()); \ + } \ + }); TOPI_REGISTER_BCAST_OP("topi.add", topi::add); TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract); @@ -77,9 +71,8 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); -TVM_REGISTER_GLOBAL("topi.broadcast_to") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = broadcast_to(args[0], args[1]); - }); +}); } // namespace topi diff --git a/topi/src/elemwise.cc b/topi/src/elemwise.cc index 2c59994..10ac8f8 100644 --- a/topi/src/elemwise.cc +++ b/topi/src/elemwise.cc @@ -18,187 +18,140 @@ */ /*! -* \brief Registration of elemwise operators -* \file elemwise.cc -*/ + * \brief Registration of elemwise operators + * \file elemwise.cc + */ +#include #include #include -#include - namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.acos") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.acos").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = acos(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.acosh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.acosh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = acosh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.asin") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.asin").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = asin(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.asinh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.asinh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = asinh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.atanh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.atanh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = atanh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.exp") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = exp(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.exp").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = exp(args[0]); }); -TVM_REGISTER_GLOBAL("topi.fast_exp") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.fast_exp").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = fast_exp(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.erf") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = erf(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.erf").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = erf(args[0]); }); -TVM_REGISTER_GLOBAL("topi.fast_erf") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.fast_erf").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = fast_erf(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.tan") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = tan(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.tan").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tan(args[0]); }); -TVM_REGISTER_GLOBAL("topi.cos") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = cos(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.cos").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cos(args[0]); }); -TVM_REGISTER_GLOBAL("topi.cosh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cosh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cosh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sin") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = sin(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.sin").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sin(args[0]); }); -TVM_REGISTER_GLOBAL("topi.sinh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sinh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sinh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.tanh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.tanh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tanh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.fast_tanh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.fast_tanh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = fast_tanh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.atan") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.atan").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = atan(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sigmoid") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sigmoid").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sigmoid(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sqrt") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sqrt").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sqrt(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rsqrt") -.set_body([](TVMArgs args, TVMRetValue *rv) { -*rv = rsqrt(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.rsqrt").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = rsqrt(args[0]); +}); -TVM_REGISTER_GLOBAL("topi.log") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = log(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.log").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = log(args[0]); }); -TVM_REGISTER_GLOBAL("topi.log2") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.log2").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = log2(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.log10") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.log10").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = log10(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.identity") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.identity").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = identity(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.negative") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.negative").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = negative(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.clip") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.clip").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = clip(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cast") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cast").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cast(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.reinterpret") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.reinterpret").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = reinterpret(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.elemwise_sum") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.elemwise_sum").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = elemwise_sum(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sign") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sign").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sign(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.full") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.full").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = full(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.full_like") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.full_like").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = full_like(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.logical_not") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.logical_not").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = logical_not(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.bitwise_not") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.bitwise_not").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = bitwise_not(args[0]); - }); +}); } // namespace topi diff --git a/topi/src/nn.cc b/topi/src/nn.cc index 77b208d..3ec4778 100644 --- a/topi/src/nn.cc +++ b/topi/src/nn.cc @@ -18,23 +18,22 @@ */ /*! -* \brief Registration of NN operators -* \file nn.cc -*/ -#include -#include - + * \brief Registration of NN operators + * \file nn.cc + */ #include +#include #include #include #include #include #include +#include #include #include #include -#include -#include +#include +#include namespace topi { @@ -42,144 +41,113 @@ using namespace tvm; using namespace tvm::runtime; /* Ops from nn.h */ -TVM_REGISTER_GLOBAL("topi.nn.relu") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.relu").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = relu(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.leaky_relu") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.leaky_relu").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = leaky_relu(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.prelu") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.prelu").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = prelu(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.pad") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pad").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = pad(args[0], args[1], args[2], args[3]); - }); +}); /* Ops from nn/dense.h */ -TVM_REGISTER_GLOBAL("topi.nn.dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dense(args[0], args[1], args[2], args[3]); - }); +}); /* Ops from nn/bias_add.h */ -TVM_REGISTER_GLOBAL("topi.nn.bias_add") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::bias_add(args[0], args[1], args[2]); - }); +}); /* Ops from nn/batch_matmul.h */ -TVM_REGISTER_GLOBAL("topi.nn.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::batch_matmul(args[0], args[1]); - }); +}); /* Ops from nn/dilate.h */ -TVM_REGISTER_GLOBAL("topi.nn.dilate") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dilate(args[0], args[1]); - }); +}); /* Ops from nn/flatten.h */ -TVM_REGISTER_GLOBAL("topi.nn.flatten") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.flatten").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::flatten(args[0]); - }); +}); /* Ops from nn/mapping.h */ -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::scale_shift_nchw(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::scale_shift_nhwc(args[0], args[1], args[2]); - }); +}); /* Ops from nn/pooling.h */ -TVM_REGISTER_GLOBAL("topi.nn.pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); + static_cast(static_cast(args[4])), args[5], args[6], args[7]); +}); -TVM_REGISTER_GLOBAL("topi.nn.pool_grad") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4], - static_cast(static_cast(args[5])), - args[6], args[7], args[8]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::global_pool(args[0], - static_cast(static_cast(args[1])), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::adaptive_pool(args[0], args[1], - static_cast(static_cast(args[2])), + static_cast(static_cast(args[5])), args[6], args[7], + args[8]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.global_pool").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::global_pool(args[0], static_cast(static_cast(args[1])), args[2]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::adaptive_pool(args[0], args[1], static_cast(static_cast(args[2])), args[3]); }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::adaptive_pool3d(args[0], args[1], - static_cast(static_cast(args[2])), +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::adaptive_pool3d(args[0], args[1], static_cast(static_cast(args[2])), args[3]); }); -TVM_REGISTER_GLOBAL("topi.nn.pool1d") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool1d").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool1d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); + static_cast(static_cast(args[4])), args[5], args[6], args[7]); +}); -TVM_REGISTER_GLOBAL("topi.nn.pool3d") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool3d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); + static_cast(static_cast(args[4])), args[5], args[6], args[7]); +}); /* Ops from nn/softmax.h */ -TVM_REGISTER_GLOBAL("topi.nn.softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::softmax(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.log_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.log_softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::log_softmax(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::lrn(args[0], args[1], args[2], - static_cast(args[3]), - static_cast(args[4]), - static_cast(args[5])); - }); +TVM_REGISTER_GLOBAL("topi.nn.lrn").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::lrn(args[0], args[1], args[2], static_cast(args[3]), + static_cast(args[4]), static_cast(args[5])); +}); /* Ops from nn/bnn.h */ -TVM_REGISTER_GLOBAL("topi.nn.binarize_pack") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::binarize_pack(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.binary_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::binary_dense(args[0], args[1]); - }); +}); } // namespace topi diff --git a/topi/src/reduction.cc b/topi/src/reduction.cc index e1fdada..b981495 100644 --- a/topi/src/reduction.cc +++ b/topi/src/reduction.cc @@ -18,58 +18,49 @@ */ /*! -* \brief Registration of reduction operators -* \file reduction.cc -*/ -#include -#include - + * \brief Registration of reduction operators + * \file reduction.cc + */ #include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.sum") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sum").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::sum(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.min") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.min").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::min(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.max") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.max").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::max(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.argmin") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.argmin").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.argmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.argmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.prod") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.prod").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.all") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.all").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.any") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); } // namespace topi diff --git a/topi/src/schedule.cc b/topi/src/schedule.cc index 936f390..b974aca 100644 --- a/topi/src/schedule.cc +++ b/topi/src/schedule.cc @@ -18,212 +18,181 @@ */ /*! -* \brief Registration of TVM schedules -* \file schedule.cc -*/ + * \brief Registration of TVM schedules + * \file schedule.cc + */ #define TOPI_REDUCE_ATLEAST1D 0 -#include -#include -#include -#include -#include - -#include -#include -#include - #include #include +#include #include #include #include -#include - -#include -#include -#include - +#include +#include +#include +#include #include #include +#include #include #include #include -#include - -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.TEST_create_target") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.TEST_create_target").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tvm::Target::Create(args[0]); - }); +}); /* Generic schedules */ -TVM_REGISTER_GLOBAL("topi.generic.default_schedule") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.generic.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[2]) { *rv = topi::generic::default_schedule_auto_inline(args[0], args[1]); } else { *rv = topi::generic::default_schedule(args[0], args[1]); } - }); +}); -TVM_REGISTER_GLOBAL("topi.generic.schedule_extern") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.generic.schedule_extern").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::generic::schedule_extern(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.generic.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.generic.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::generic::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); + }); /* x86 schedules */ -TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::x86::schedule_binarize_pack(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::x86::schedule_binary_dense(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.x86.default_schedule") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[2]) { *rv = topi::x86::default_schedule_auto_inline(args[0], args[1]); } else { *rv = topi::x86::default_schedule(args[0], args[1]); } - }); +}); -TVM_REGISTER_GLOBAL("topi.x86.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::x86::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); + }); /* ROCm schedules */ -TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_dense(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]); + }); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_global_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_reduce(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_softmax(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_lrn(args[0]); - }); +}); /* CUDA schedules */ -TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_dense(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); + }); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_global_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_reduce(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_softmax(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_lrn(args[0]); - }); +}); /* Utility functions */ -TVM_REGISTER_GLOBAL("topi.util.is_empty_shape") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.util.is_empty_shape").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::detail::is_empty_shape(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]); - }); +}); /*! \brief Builder function for instantiating schedules. */ -using FTVMScheduleBuilder = std::function< - tvm::te::Schedule(const tvm::Target& target, const tvm::Array& outs)>; +using FTVMScheduleBuilder = std::function& outs)>; /*! * \brief Helper function for registering generic functions matching the @@ -242,7 +211,7 @@ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { if (argNodeRef->type_index() == outs->type_index()) { outs = args[0]; } else { - outs = Array { args[0] }; + outs = Array{args[0]}; } *ret = builder(target, outs); @@ -250,49 +219,49 @@ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { } TVM_REGISTER_GENERIC_FUNC(schedule_injective) -.set_default(WrapSchedule(topi::generic::schedule_injective)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective)); + .set_default(WrapSchedule(topi::generic::schedule_injective)) + .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_injective)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_injective)); TVM_REGISTER_GENERIC_FUNC(schedule_softmax) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_softmax)); TVM_REGISTER_GENERIC_FUNC(schedule_dense) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense)) -.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_dense)) + .register_func({"rocm"}, WrapSchedule(topi::rocm::schedule_dense)); TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul) -.set_default(WrapSchedule(topi::generic::default_schedule)); + .set_default(WrapSchedule(topi::generic::default_schedule)); TVM_REGISTER_GENERIC_FUNC(schedule_pool) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_pool)); TVM_REGISTER_GENERIC_FUNC(schedule_global_pool) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_global_pool)); TVM_REGISTER_GENERIC_FUNC(schedule_reduce) -.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce)); + .set_default(WrapSchedule(topi::generic::default_schedule_auto_inline)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule_auto_inline)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_reduce)); TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binarize_pack)); TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binary_dense)); /*! \brief Builder function for instantiating schedules from existing schedules. */ -using FTVMScheduleFromExistingBuilder = std::function< - tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>; +using FTVMScheduleFromExistingBuilder = + std::function; /*! * \brief Helper function for registering generic functions matching the @@ -304,33 +273,30 @@ using FTVMScheduleFromExistingBuilder = std::function< * \return The wrapped schedule builder */ inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) { - return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { - *ret = builder(args[0], args[1]); - }); + return PackedFunc( + [builder](TVMArgs args, TVMRetValue* ret) { *ret = builder(args[0], args[1]); }); } TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing) -.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) -.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) -.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting( - topi::cuda::schedule_injective_from_existing)); + .set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) + .register_func({"cpu"}, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) + .register_func({"cuda", "gpu"}, + WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing)); /*! \brief Builder function for instantiating dense ops. */ -using FTVMDenseOpBuilder = std::function; +using FTVMDenseOpBuilder = std::function; /*! -* \brief Helper function for registering dense ops matching the -* FTVMDenseOpBuilder signature. The op builder function is wrapped -* with a PackedFunc suitable for passing to a tvm::GenericFunc. -* -* \param builder The op builder to wrap. -* -* \return The wrapped op builder -*/ + * \brief Helper function for registering dense ops matching the + * FTVMDenseOpBuilder signature. The op builder function is wrapped + * with a PackedFunc suitable for passing to a tvm::GenericFunc. + * + * \param builder The op builder to wrap. + * + * \return The wrapped op builder + */ inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { auto target = Target::Current(false); @@ -344,14 +310,12 @@ inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { } TVM_REGISTER_GENERIC_FUNC(dense) -.set_default(WrapDenseOp([](const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { - return topi::nn::dense(data, weight, bias, out_dtype); -})) -.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda)) -.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm)); + .set_default(WrapDenseOp([](const Target& target, const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, + const DataType& out_dtype) { + return topi::nn::dense(data, weight, bias, out_dtype); + })) + .register_func({"cuda", "gpu"}, WrapDenseOp(topi::cuda::dense_cuda)) + .register_func({"rocm"}, WrapDenseOp(topi::rocm::dense_rocm)); } // namespace topi diff --git a/topi/src/transform.cc b/topi/src/transform.cc index 4f0d4f8..fa27b99 100644 --- a/topi/src/transform.cc +++ b/topi/src/transform.cc @@ -18,67 +18,56 @@ */ /*! -* \brief Registration of transform operators -* \file transform.cc -*/ -#include -#include - + * \brief Registration of transform operators + * \file transform.cc + */ #include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.expand_dims") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.expand_dims").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = expand_dims(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.transpose") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.transpose").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = transpose(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.flip") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.flip").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = flip(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.reshape") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.reshape").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = reshape(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.squeeze") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.squeeze").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = squeeze(args[0], ArrayOrInt(args[1])); - }); +}); -TVM_REGISTER_GLOBAL("topi.concatenate") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.concatenate").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = concatenate(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.stack") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.stack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = stack(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.shape") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.shape").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = shape(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.ndarray_size") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = ndarray_size(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.split") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { *rv = split_sections(args[0], args[1], args[2]); } else { @@ -86,13 +75,11 @@ TVM_REGISTER_GLOBAL("topi.split") } }); -TVM_REGISTER_GLOBAL("topi.layout_transform") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = layout_transform(args[0], args[1], args[2]); }); -TVM_REGISTER_GLOBAL("topi.take") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) { if (args.size() == 3) { std::string mode = args[2]; *rv = take(args[0], args[1], mode); @@ -101,56 +88,55 @@ TVM_REGISTER_GLOBAL("topi.take") std::string mode = args[3]; *rv = take(args[0], args[1], axis, mode); } - }); +}); -TVM_REGISTER_GLOBAL("topi.sequence_mask") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sequence_mask").set_body([](TVMArgs args, TVMRetValue* rv) { double pad_val = args[2]; int axis = args[3]; *rv = sequence_mask(args[0], args[1], pad_val, axis); }); -TVM_REGISTER_GLOBAL("topi.where") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.where").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = where(args[0], args[1], args[2]); }); -TVM_REGISTER_GLOBAL("topi.arange") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.arange").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = arange(args[0], args[1], args[2], args[3]); }); -TVM_REGISTER_GLOBAL("topi.repeat") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.repeat").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = repeat(args[0], args[1], args[2]); }); -TVM_REGISTER_GLOBAL("topi.tile") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.tile").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tile(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.gather_nd") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = gather_nd(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.unravel_index") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = unravel_index(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.matmul") -.set_body([](TVMArgs args, TVMRetValue *rv) { - switch ( args.size() ) { - case 2: *rv = matmul(args[0], args[1]); break; - case 3: *rv = matmul(args[0], args[1], args[2]); break; - case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break; - default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; - }}); - -TVM_REGISTER_GLOBAL("topi.tensordot") -.set_body([](TVMArgs args, TVMRetValue *rv) { +}); + +TVM_REGISTER_GLOBAL("topi.matmul").set_body([](TVMArgs args, TVMRetValue* rv) { + switch (args.size()) { + case 2: + *rv = matmul(args[0], args[1]); + break; + case 3: + *rv = matmul(args[0], args[1], args[2]); + break; + case 4: + *rv = matmul(args[0], args[1], args[2], args[3]); + break; + default: + CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; + } +}); + +TVM_REGISTER_GLOBAL("topi.tensordot").set_body([](TVMArgs args, TVMRetValue* rv) { if (args.size() == 2) { *rv = tensordot(args[0], args[1]); } else if (args.size() == 3) { @@ -159,19 +145,17 @@ TVM_REGISTER_GLOBAL("topi.tensordot") Array axes = args[3]; *rv = tensordot(args[0], args[1], args[2], axes); } - }); +}); -TVM_REGISTER_GLOBAL("topi.strided_slice") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = strided_slice(args[0], args[1], args[2], args[3]); - }); +}); -TVM_REGISTER_GLOBAL("topi.one_hot") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { int depth = args[3]; int axis = args[4]; DataType dtype = args[5]; *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype); - }); +}); } // namespace topi diff --git a/topi/src/vision.cc b/topi/src/vision.cc index 1a4884e..0485177 100644 --- a/topi/src/vision.cc +++ b/topi/src/vision.cc @@ -18,22 +18,20 @@ */ /*! -* \brief Registration of vision operators -* \file vision.cc -*/ + * \brief Registration of vision operators + * \file vision.cc + */ +#include #include #include -#include - namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.vision.reorg") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.vision.reorg").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = vision::reorg(args[0], args[1]); - }); +}); } // namespace topi diff --git a/vta/runtime/device_api.cc b/vta/runtime/device_api.cc index 047a6fd..298403c 100644 --- a/vta/runtime/device_api.cc +++ b/vta/runtime/device_api.cc @@ -22,12 +22,11 @@ * \brief TVM device API for VTA */ -#include #include +#include -#include "runtime.h" #include "../../src/runtime/workspace_pool.h" - +#include "runtime.h" namespace tvm { namespace runtime { @@ -42,25 +41,14 @@ class VTADeviceAPI final : public DeviceAPI { } } - void* AllocDataSpace(TVMContext ctx, - size_t size, - size_t alignment, - DLDataType type_hint) final { + void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) final { return VTABufferAlloc(size); } - void FreeDataSpace(TVMContext ctx, void* ptr) final { - VTABufferFree(ptr); - } + void FreeDataSpace(TVMContext ctx, void* ptr) final { VTABufferFree(ptr); } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { int kind_mask = 0; if (ctx_from.device_type != kDLCPU) { @@ -69,33 +57,27 @@ class VTADeviceAPI final : public DeviceAPI { if (ctx_to.device_type != kDLCPU) { kind_mask |= 1; } - VTABufferCopy(from, from_offset, - to, to_offset, - size, kind_mask); + VTABufferCopy(from, from_offset, to, to_offset, size, kind_mask); } - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {} void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; void FreeWorkspace(TVMContext ctx, void* data) final; static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } }; struct VTAWorkspacePool : public WorkspacePool { - VTAWorkspacePool() : - WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {} + VTAWorkspacePool() : WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {} }; void* VTADeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { - return dmlc::ThreadLocalStore::Get() - ->AllocWorkspace(ctx, size); + return dmlc::ThreadLocalStore::Get()->AllocWorkspace(ctx, size); } void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { @@ -104,10 +86,10 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { // Register device api with override. static TVM_ATTRIBUTE_UNUSED auto& __register_dev__ = -::tvm::runtime::Registry::Register("device_api.ext_dev", true) -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = VTADeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); + ::tvm::runtime::Registry::Register("device_api.ext_dev", true) + .set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = VTADeviceAPI::Global().get(); + *rv = static_cast(ptr); + }); } // namespace runtime } // namespace tvm diff --git a/vta/runtime/runtime.cc b/vta/runtime/runtime.cc index 0e48e16..49fe9c5 100644 --- a/vta/runtime/runtime.cc +++ b/vta/runtime/runtime.cc @@ -24,24 +24,23 @@ * The runtime depends on specific instruction * stream spec as specified in hw_spec.h */ -#include -#include +#include "runtime.h" + #include #include +#include +#include #include #include #include -#include #include - -#include "runtime.h" +#include namespace vta { // Avoid bad configurations. -static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8, - "VTA_UOP_WIDTH do not match VTAUop size"); +static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8, "VTA_UOP_WIDTH do not match VTAUop size"); /*! \brief Enable coherent access of data buffers between VTA and CPU */ static const bool kBufferCoherent = VTA_COHERENT_ACCESSES; @@ -53,13 +52,9 @@ static const bool kAlwaysCache = true; */ struct DataBuffer { /*! \return Virtual address of the data. */ - void* virt_addr() const { - return data_; - } + void* virt_addr() const { return data_; } /*! \return Physical address of the data. */ - vta_phy_addr_t phy_addr() const { - return phy_addr_; - } + vta_phy_addr_t phy_addr() const { return phy_addr_; } /*! * \brief Invalidate the cache of given location in data buffer. * \param offset The offset to the data. @@ -67,9 +62,7 @@ struct DataBuffer { */ void InvalidateCache(size_t offset, size_t size) { if (!kBufferCoherent && kAlwaysCache) { - VTAInvalidateCache(reinterpret_cast(data_) + offset, - phy_addr_ + offset, - size); + VTAInvalidateCache(reinterpret_cast(data_) + offset, phy_addr_ + offset, size); } } /*! @@ -79,16 +72,14 @@ struct DataBuffer { */ void FlushCache(size_t offset, size_t size) { if (!kBufferCoherent && kAlwaysCache) { - VTAFlushCache(reinterpret_cast(data_) + offset, - phy_addr_ + offset, - size); + VTAFlushCache(reinterpret_cast(data_) + offset, phy_addr_ + offset, size); } } /*! * \brief Performs a copy operation from host memory to buffer allocated with VTAMemAlloc. - * \param dst The desination buffer in FPGA-accessible memory. Has to be allocated with VTAMemAlloc(). - * \param src The source buffer in host memory. - * \param size Size of the region in Bytes. + * \param dst The desination buffer in FPGA-accessible memory. Has to be allocated with + * VTAMemAlloc(). \param src The source buffer in host memory. \param size Size of the region in + * Bytes. */ void MemCopyFromHost(void* dst, const void* src, size_t size) { VTAMemCopyFromHost(dst, src, size); @@ -99,9 +90,7 @@ struct DataBuffer { * \param src The source buffer in FPGA-accessible memory. Has to be allocated with VTAMemAlloc(). * \param size Size of the region in Bytes. */ - void MemCopyToHost(void* dst, const void* src, size_t size) { - VTAMemCopyToHost(dst, src, size); - } + void MemCopyToHost(void* dst, const void* src, size_t size) { VTAMemCopyToHost(dst, src, size); } /*! * \brief Allocate a buffer of a given size. * \param size The size of the buffer. @@ -128,8 +117,7 @@ struct DataBuffer { * \return The corresponding data buffer header. */ static DataBuffer* FromHandle(const void* buffer) { - return const_cast( - reinterpret_cast(buffer)); + return const_cast(reinterpret_cast(buffer)); } private: @@ -157,9 +145,7 @@ class UopKernel { * \param signature The pointer to signature. * \param nbytes Number of bytes. */ - UopKernel(const char* signature, int nbytes) - : signature_(signature, signature + nbytes) { - } + UopKernel(const char* signature, int nbytes) : signature_(signature, signature + nbytes) {} /*! * \brief Verify if the signature is correct. * \param signature Signature ptr. @@ -170,21 +156,13 @@ class UopKernel { return memcmp(signature, signature_.data(), nbytes) == 0; } /*! \return Whether the kernel is cached in SRAM. */ - bool cached() const { - return sram_begin_ != sram_end_; - } + bool cached() const { return sram_begin_ != sram_end_; } /*! \return The length of the micro op sequence. */ - size_t size() const { - return seq_.size(); - } + size_t size() const { return seq_.size(); } /*! \return The micro-op data. */ - const VTAUop* data() const { - return seq_.data(); - } + const VTAUop* data() const { return seq_.data(); } /*! \return The loop structure. */ - const std::vector& loop() const { - return loop_; - } + const std::vector& loop() const { return loop_; } /*! * \brief Declare loop start. * \param extent The loop extent. @@ -192,9 +170,7 @@ class UopKernel { * \param src_factor Loop factor of input index * \param wgt_factor Loop factor of weight index. */ - void PushLoopBegin(uint32_t extent, - uint32_t dst_factor, - uint32_t src_factor, + void PushLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor, uint32_t wgt_factor) { LoopEntry le; le.extent = extent; @@ -209,9 +185,7 @@ class UopKernel { /*! * \brief Declare loop end. */ - void PushLoopEnd() { - --loop_ptr_; - } + void PushLoopEnd() { --loop_ptr_; } /*! * \brief Push micro op into kernel. * \param mode Set to GEMM mode if set to 0, ALU mode is set to 1. @@ -223,14 +197,8 @@ class UopKernel { * \param use_imm Use immediate in ALU mode if set to true. * \param imm_val Immediate value in ALU mode. */ - void Push(uint32_t mode, - uint32_t reset_out, - uint32_t dst_index, - uint32_t src_index, - uint32_t wgt_index, - uint32_t opcode, - uint32_t use_imm, - int32_t imm_val) { + void Push(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index, + uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) { // The loop nest structure VerifyDep(dst_index); VTAUop op; @@ -268,10 +236,7 @@ class UopKernel { uint32_t size = seq_.size(); printf("There are %u uops\n", size); for (uint32_t i = 0; i < size; ++i) { - printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n", - i, - seq_[i].dst_idx, - seq_[i].src_idx, + printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n", i, seq_[i].dst_idx, seq_[i].src_idx, seq_[i].wgt_idx); } printf("\n"); @@ -294,7 +259,7 @@ class UopKernel { } } // The uop buffer - template + template friend class UopQueue; friend class CommandQueue; // SRAM location if begin != end @@ -322,26 +287,21 @@ class BaseQueue { } } /*! \return Content of DRAM buffer. */ - char* dram_buffer() const { - return dram_buffer_; - } + char* dram_buffer() const { return dram_buffer_; } /*! \return Physical address of DRAM. */ vta_phy_addr_t dram_phy_addr() const { CHECK(fpga_buff_phy_); return fpga_buff_phy_; } /*! \return Whether there is pending information. */ - bool pending() const { - return sram_begin_ != sram_end_; - } + bool pending() const { return sram_begin_ != sram_end_; } /*! \brief Initialize the space of the buffer. */ void InitSpace(uint32_t elem_bytes, uint32_t max_bytes, bool coherent, bool always_cache) { coherent_ = coherent; always_cache_ = always_cache; elem_bytes_ = elem_bytes; // Allocate buffer ahead of time - fpga_buff_ = static_cast(VTAMemAlloc( - max_bytes, coherent_ || always_cache_)); + fpga_buff_ = static_cast(VTAMemAlloc(max_bytes, coherent_ || always_cache_)); CHECK(fpga_buff_ != nullptr); fpga_buff_phy_ = VTAMemGetPhyAddr(fpga_buff_); } @@ -379,14 +339,12 @@ class BaseQueue { /*! * \brief Micro op buffer that manages the micro op cache. */ -template +template class UopQueue : public BaseQueue { public: - void InitSpace() { - BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache); - } + void InitSpace() { BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache); } // Push data to the queue - template + template void Push(UopKernel* kernel, FAutoSync fautosync) { // if the micro-op is cached in VTA SRAM, skip if (kernel->cached()) return; @@ -460,9 +418,7 @@ class UopQueue : public BaseQueue { cache_idx_ = 0; BaseQueue::Reset(); } - void AutoReadBarrier() { - ReadBarrier(); - } + void AutoReadBarrier() { ReadBarrier(); } /*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */ void ReadBarrier() { CHECK(fpga_buff_ != nullptr); @@ -477,18 +433,14 @@ class UopQueue : public BaseQueue { uint32_t offset = 0; for (uint32_t i = 0; i < cache_.size(); ++i) { uint32_t ksize = cache_[i]->size() * kElemBytes; - VTAMemCopyFromHost(static_cast(fpga_buff_) + offset, - cache_[i]->data(), - ksize); + VTAMemCopyFromHost(static_cast(fpga_buff_) + offset, cache_[i]->data(), ksize); // Update offset offset += ksize; } // Flush if we're using a shared memory system // and if interface is non-coherent if (!coherent_ && always_cache_) { - VTAFlushCache(fpga_buff_, - fpga_buff_phy_, - offset); + VTAFlushCache(fpga_buff_, fpga_buff_phy_, offset); } } @@ -507,8 +459,7 @@ class UopQueue : public BaseQueue { class UopKernelMap { public: // Simple hash map - UopKernel** Get(void* signature, - int nbytes) { + UopKernel** Get(void* signature, int nbytes) { uint32_t key = 0; CHECK(nbytes == 0 || nbytes == sizeof(int)); if (nbytes == sizeof(int)) { @@ -526,15 +477,10 @@ class UopKernelMap { std::vector kmap_; }; -enum PipelineStage : int { - kNoneStage = 0, - kLoadStage = 1, - kComputeStage = 2, - kStoreStage = 3 -}; +enum PipelineStage : int { kNoneStage = 0, kLoadStage = 1, kComputeStage = 2, kStoreStage = 3 }; // Instruction Queue -template +template class InsnQueue : public BaseQueue { public: /*! \brief Initialize the space. */ @@ -545,13 +491,9 @@ class InsnQueue : public BaseQueue { std::fill(pending_pop_next_, pending_pop_next_ + 4, 0); } /*! \return The data pointer. */ - VTAGenericInsn* data() { - return dram_buffer_.data(); - } + VTAGenericInsn* data() { return dram_buffer_.data(); } /*! \return Number of instructions. */ - uint32_t count() { - return dram_buffer_.size(); - } + uint32_t count() { return dram_buffer_.size(); } // Insert dependency push of load void DepPop(int from, int to) { // NOTE: This instruction executes on queue[to] @@ -579,10 +521,12 @@ class InsnQueue : public BaseQueue { if (GetPipelineStage(mptr) == from) { if (from < to && !mptr->push_next_dep) { // push(LD->C) or push(C->ST) - mptr->push_next_dep = true; return; + mptr->push_next_dep = true; + return; } else if (from > to && !mptr->push_prev_dep) { // push(C->LD) or push(ST->C) - mptr->push_prev_dep = true; return; + mptr->push_prev_dep = true; + return; } } } @@ -595,25 +539,15 @@ class InsnQueue : public BaseQueue { } } // Create a new instruction for a GEMM stage - VTAGemInsn* CreateGemInsn() { - return reinterpret_cast( - Create(kComputeStage)); - } + VTAGemInsn* CreateGemInsn() { return reinterpret_cast(Create(kComputeStage)); } // Create a new instruction for a ALU stage - VTAAluInsn* CreateAluInsn() { - return reinterpret_cast( - Create(kComputeStage)); - } + VTAAluInsn* CreateAluInsn() { return reinterpret_cast(Create(kComputeStage)); } // Create a new instruction for a memory stage VTAMemInsn* CreateMemInsn(int memory_type) { - return reinterpret_cast( - Create(GetMemPipelineStage(memory_type))); + return reinterpret_cast(Create(GetMemPipelineStage(memory_type))); } // create a new instruction for a store stage - VTAMemInsn* CreateStoreInsn() { - return reinterpret_cast( - Create(kStoreStage)); - } + VTAMemInsn* CreateStoreInsn() { return reinterpret_cast(Create(kStoreStage)); } // Rewrite instruction stream to force serial execution void RewriteForceSerial() { int insn_count = count(); @@ -663,7 +597,7 @@ class InsnQueue : public BaseQueue { } CommitPendingPop(kComputeStage); } else { - pending_pop_next_[kComputeStage] = 0; + pending_pop_next_[kComputeStage] = 0; } DepPush(kComputeStage, kLoadStage); DepPop(kLoadStage, kComputeStage); @@ -676,30 +610,30 @@ class InsnQueue : public BaseQueue { } // Helper function: Get Opcode string const char* getOpcodeString(int opcode, bool use_imm) { - // The string name - if (opcode == VTA_ALU_OPCODE_MIN) { - if (use_imm) { - return "min imm"; - } else { - return "min"; - } - } else if (opcode == VTA_ALU_OPCODE_MAX) { - if (use_imm) { - return "max imm"; - } else { - return "max"; - } - } else if (opcode == VTA_ALU_OPCODE_ADD) { - if (use_imm) { - return "add imm"; - } else { - return "add"; - } - } else if (opcode == VTA_ALU_OPCODE_SHR) { - return "shr"; + // The string name + if (opcode == VTA_ALU_OPCODE_MIN) { + if (use_imm) { + return "min imm"; + } else { + return "min"; } + } else if (opcode == VTA_ALU_OPCODE_MAX) { + if (use_imm) { + return "max imm"; + } else { + return "max"; + } + } else if (opcode == VTA_ALU_OPCODE_ADD) { + if (use_imm) { + return "add imm"; + } else { + return "add"; + } + } else if (opcode == VTA_ALU_OPCODE_SHR) { + return "shr"; + } - return "unknown op"; + return "unknown op"; } // Dump instructions in the queue void DumpInsn() { @@ -728,10 +662,8 @@ class InsnQueue : public BaseQueue { printf("NOP-MEMORY-STAGE\n"); } printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); // Count status in queues if (c.mem.opcode == VTA_OPCODE_STORE) { CHECK(c.mem.pop_next_dep == false); @@ -739,8 +671,7 @@ class InsnQueue : public BaseQueue { if (c.mem.pop_prev_dep) g2s_queue--; if (c.mem.push_prev_dep) s2g_queue++; } else if (c.mem.opcode == VTA_OPCODE_LOAD && - (c.mem.memory_type == VTA_MEM_ID_INP || - c.mem.memory_type == VTA_MEM_ID_WGT) ) { + (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) { CHECK(c.mem.pop_prev_dep == false); CHECK(c.mem.push_prev_dep == false); if (c.mem.pop_next_dep) g2l_queue--; @@ -767,65 +698,44 @@ class InsnQueue : public BaseQueue { printf("STORE:\n"); } printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); - printf("\tDRAM: 0x%08x, SRAM:0x%04x\n", - static_cast(c.mem.dram_base), + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); + printf("\tDRAM: 0x%08x, SRAM:0x%04x\n", static_cast(c.mem.dram_base), static_cast(c.mem.sram_base)); - printf("\ty: size=%d, pad=[%d, %d]\n", - static_cast(c.mem.y_size), - static_cast(c.mem.y_pad_0), - static_cast(c.mem.y_pad_1)); - printf("\tx: size=%d, stride=%d, pad=[%d, %d]\n", - static_cast(c.mem.x_size), - static_cast(c.mem.x_stride), - static_cast(c.mem.x_pad_0), + printf("\ty: size=%d, pad=[%d, %d]\n", static_cast(c.mem.y_size), + static_cast(c.mem.y_pad_0), static_cast(c.mem.y_pad_1)); + printf("\tx: size=%d, stride=%d, pad=[%d, %d]\n", static_cast(c.mem.x_size), + static_cast(c.mem.x_stride), static_cast(c.mem.x_pad_0), static_cast(c.mem.x_pad_1)); } else if (c.mem.opcode == VTA_OPCODE_GEMM) { // Print instruction field information printf("GEMM\n"); printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); printf("\treset_out: %d\n", static_cast(c.gemm.reset_reg)); - printf("\trange (%d, %d)\n", - static_cast(c.gemm.uop_bgn), + printf("\trange (%d, %d)\n", static_cast(c.gemm.uop_bgn), static_cast(c.gemm.uop_end)); printf("\touter loop - iter: %d, wgt: %d, inp: %d, acc: %d\n", - static_cast(c.gemm.iter_out), - static_cast(c.gemm.wgt_factor_out), - static_cast(c.gemm.src_factor_out), - static_cast(c.gemm.dst_factor_out)); + static_cast(c.gemm.iter_out), static_cast(c.gemm.wgt_factor_out), + static_cast(c.gemm.src_factor_out), static_cast(c.gemm.dst_factor_out)); printf("\tinner loop - iter: %d, wgt: %d, inp: %d, acc: %d\n", - static_cast(c.gemm.iter_in), - static_cast(c.gemm.wgt_factor_in), - static_cast(c.gemm.src_factor_in), - static_cast(c.gemm.dst_factor_in)); + static_cast(c.gemm.iter_in), static_cast(c.gemm.wgt_factor_in), + static_cast(c.gemm.src_factor_in), static_cast(c.gemm.dst_factor_in)); } else if (c.mem.opcode == VTA_OPCODE_ALU) { // Print instruction field information printf("ALU - %s\n", getOpcodeString(c.alu.alu_opcode, c.alu.use_imm)); printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); printf("\treset_out: %d\n", static_cast(c.alu.reset_reg)); - printf("\trange (%d, %d)\n", - static_cast(c.alu.uop_bgn), + printf("\trange (%d, %d)\n", static_cast(c.alu.uop_bgn), static_cast(c.alu.uop_end)); - printf("\touter loop - iter: %d, dst: %d, src: %d\n", - static_cast(c.alu.iter_out), - static_cast(c.alu.dst_factor_out), - static_cast(c.alu.src_factor_out)); - printf("\tinner loop - iter: %d, dst: %d, src: %d\n", - static_cast(c.alu.iter_in), - static_cast(c.alu.dst_factor_in), - static_cast(c.alu.src_factor_in)); + printf("\touter loop - iter: %d, dst: %d, src: %d\n", static_cast(c.alu.iter_out), + static_cast(c.alu.dst_factor_out), static_cast(c.alu.src_factor_out)); + printf("\tinner loop - iter: %d, dst: %d, src: %d\n", static_cast(c.alu.iter_in), + static_cast(c.alu.dst_factor_in), static_cast(c.alu.src_factor_in)); } else if (c.mem.opcode == VTA_OPCODE_FINISH) { printf("FINISH\n"); } @@ -833,25 +743,23 @@ class InsnQueue : public BaseQueue { // Count status in queues if (c.mem.opcode == VTA_OPCODE_LOAD || c.mem.opcode == VTA_OPCODE_STORE) { if (c.mem.opcode == VTA_OPCODE_STORE) { - CHECK(c.mem.pop_next_dep == false); - CHECK(c.mem.push_next_dep == false); - if (c.mem.pop_prev_dep) g2s_queue--; - if (c.mem.push_prev_dep) s2g_queue++; + CHECK(c.mem.pop_next_dep == false); + CHECK(c.mem.push_next_dep == false); + if (c.mem.pop_prev_dep) g2s_queue--; + if (c.mem.push_prev_dep) s2g_queue++; } else if (c.mem.opcode == VTA_OPCODE_LOAD && - (c.mem.memory_type == VTA_MEM_ID_INP || - c.mem.memory_type == VTA_MEM_ID_WGT) ) { - CHECK(c.mem.pop_prev_dep == false); - CHECK(c.mem.push_prev_dep == false); - if (c.mem.pop_next_dep) g2l_queue--; - if (c.mem.push_next_dep) l2g_queue++; + (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) { + CHECK(c.mem.pop_prev_dep == false); + CHECK(c.mem.push_prev_dep == false); + if (c.mem.pop_next_dep) g2l_queue--; + if (c.mem.push_next_dep) l2g_queue++; } else { - if (c.mem.pop_prev_dep) l2g_queue--; - if (c.mem.push_prev_dep) g2l_queue++; - if (c.mem.pop_next_dep) s2g_queue--; - if (c.mem.push_next_dep) g2s_queue++; + if (c.mem.pop_prev_dep) l2g_queue--; + if (c.mem.push_prev_dep) g2l_queue++; + if (c.mem.pop_next_dep) s2g_queue--; + if (c.mem.push_next_dep) g2s_queue++; } - } else if (c.mem.opcode == VTA_OPCODE_GEMM || - c.mem.opcode == VTA_OPCODE_ALU) { + } else if (c.mem.opcode == VTA_OPCODE_GEMM || c.mem.opcode == VTA_OPCODE_ALU) { // Print instruction field information if (c.gemm.pop_prev_dep) l2g_queue--; if (c.gemm.push_prev_dep) g2l_queue++; @@ -867,11 +775,8 @@ class InsnQueue : public BaseQueue { // Handle the LD<->compute queue // NOTE: pop executes on target(stage) CHECK(stage > 0 && stage < 4); - if (pending_pop_prev_[stage] || - pending_pop_next_[stage]) { - PushNoop(stage, false, false, - pending_pop_prev_[stage], - pending_pop_next_[stage]); + if (pending_pop_prev_[stage] || pending_pop_next_[stage]) { + PushNoop(stage, false, false, pending_pop_prev_[stage], pending_pop_next_[stage]); pending_pop_prev_[stage] = 0; pending_pop_next_[stage] = 0; } @@ -888,9 +793,7 @@ class InsnQueue : public BaseQueue { } return false; } - void AutoReadBarrier() { - ReadBarrier(); - } + void AutoReadBarrier() { ReadBarrier(); } /*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */ void ReadBarrier() { CHECK(fpga_buff_ != nullptr); @@ -898,15 +801,11 @@ class InsnQueue : public BaseQueue { uint32_t buff_size = dram_buffer_.size() * elem_bytes_; CHECK(buff_size <= kMaxBytes); // Copy contents of DRAM buffer to FPGA buff - VTAMemCopyFromHost(fpga_buff_, - dram_buffer_.data(), - buff_size); + VTAMemCopyFromHost(fpga_buff_, dram_buffer_.data(), buff_size); // Flush if we're using a shared memory system // and if interface is non-coherent if (!coherent_ && always_cache_) { - VTAFlushCache(fpga_buff_, - fpga_buff_phy_, - buff_size); + VTAFlushCache(fpga_buff_, fpga_buff_phy_, buff_size); } } @@ -957,15 +856,14 @@ class InsnQueue : public BaseQueue { // Get stage of memory and computation static PipelineStage GetPipelineStageAll(VTAMemInsn* insn) { - PipelineStage stage = GetPipelineStage(insn); - if (stage != kNoneStage) return stage; - return GetMemPipelineStage(insn->memory_type); + PipelineStage stage = GetPipelineStage(insn); + if (stage != kNoneStage) return stage; + return GetMemPipelineStage(insn->memory_type); } // Push no-op - void PushNoop(int stage, - bool push_prev_dep, bool push_next_dep, - bool pop_prev_dep, bool pop_next_dep) { + void PushNoop(int stage, bool push_prev_dep, bool push_next_dep, bool pop_prev_dep, + bool pop_next_dep) { VTAMemInsn* insn = reinterpret_cast(NextInsn()); insn->opcode = (stage == kStoreStage ? VTA_OPCODE_STORE : VTA_OPCODE_LOAD); insn->push_prev_dep = push_prev_dep; @@ -997,9 +895,7 @@ class InsnQueue : public BaseQueue { */ class CommandQueue { public: - CommandQueue() { - this->InitSpace(); - } + CommandQueue() { this->InitSpace(); } void InitSpace() { uop_queue_.InitSpace(); insn_queue_.InitSpace(); @@ -1007,31 +903,29 @@ class CommandQueue { CHECK(device_ != nullptr); } - ~CommandQueue() { - VTADeviceFree(device_); - } + ~CommandQueue() { VTADeviceFree(device_); } uint32_t GetElemBytes(uint32_t memory_id) { uint32_t elem_bytes = 0; switch (memory_id) { case VTA_MEM_ID_UOP: - elem_bytes = VTA_UOP_ELEM_BYTES; - break; + elem_bytes = VTA_UOP_ELEM_BYTES; + break; case VTA_MEM_ID_INP: - elem_bytes = VTA_INP_ELEM_BYTES; - break; + elem_bytes = VTA_INP_ELEM_BYTES; + break; case VTA_MEM_ID_WGT: - elem_bytes = VTA_WGT_ELEM_BYTES; - break; + elem_bytes = VTA_WGT_ELEM_BYTES; + break; case VTA_MEM_ID_ACC: - elem_bytes = VTA_ACC_ELEM_BYTES; - break; + elem_bytes = VTA_ACC_ELEM_BYTES; + break; case VTA_MEM_ID_OUT: - elem_bytes = VTA_OUT_ELEM_BYTES; - break; + elem_bytes = VTA_OUT_ELEM_BYTES; + break; default: - LOG(FATAL) << "Memory id not recognized:" << memory_id; - break; + LOG(FATAL) << "Memory id not recognized:" << memory_id; + break; } /* * elements size should not larger than VTA_PAGE_BYTES. @@ -1041,16 +935,9 @@ class CommandQueue { return elem_bytes; } - void LoadBuffer2D(void* src_dram_addr, - uint32_t src_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride, - uint32_t x_pad_before, - uint32_t y_pad_before, - uint32_t x_pad_after, - uint32_t y_pad_after, - uint32_t dst_sram_index, + void LoadBuffer2D(void* src_dram_addr, uint32_t src_elem_offset, uint32_t x_size, uint32_t y_size, + uint32_t x_stride, uint32_t x_pad_before, uint32_t y_pad_before, + uint32_t x_pad_after, uint32_t y_pad_after, uint32_t dst_sram_index, uint32_t dst_memory_type) { VTAMemInsn* insn = insn_queue_.CreateMemInsn(dst_memory_type); insn->opcode = VTA_OPCODE_LOAD; @@ -1068,12 +955,8 @@ class CommandQueue { this->CheckInsnOverFlow(); } - void StoreBuffer2D(uint32_t src_sram_index, - uint32_t src_memory_type, - void* dst_dram_addr, - uint32_t dst_elem_offset, - uint32_t x_size, - uint32_t y_size, + void StoreBuffer2D(uint32_t src_sram_index, uint32_t src_memory_type, void* dst_dram_addr, + uint32_t dst_elem_offset, uint32_t x_size, uint32_t y_size, uint32_t x_stride) { VTAMemInsn* insn = insn_queue_.CreateStoreInsn(); insn->opcode = VTA_OPCODE_STORE; @@ -1091,27 +974,21 @@ class CommandQueue { this->CheckInsnOverFlow(); } - void DepPush(int from_qid, int to_qid) { - insn_queue_.DepPush(from_qid, to_qid); - } + void DepPush(int from_qid, int to_qid) { insn_queue_.DepPush(from_qid, to_qid); } - void DepPop(int from_qid, int to_qid) { - insn_queue_.DepPop(from_qid, to_qid); - } + void DepPop(int from_qid, int to_qid) { insn_queue_.DepPop(from_qid, to_qid); } void ReadBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { if (!(debug_flag_ & VTA_DEBUG_SKIP_READ_BARRIER)) { uint32_t elem_bytes = (elem_bits + 8 - 1) / 8; - DataBuffer::FromHandle(buffer)->FlushCache( - elem_bytes * start, elem_bytes * extent); + DataBuffer::FromHandle(buffer)->FlushCache(elem_bytes * start, elem_bytes * extent); } } void WriteBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { if (!(debug_flag_ & VTA_DEBUG_SKIP_WRITE_BARRIER)) { uint32_t elem_bytes = (elem_bits + 8 - 1) / 8; - DataBuffer::FromHandle(buffer)->InvalidateCache( - elem_bytes * start, elem_bytes * extent); + DataBuffer::FromHandle(buffer)->InvalidateCache(elem_bytes * start, elem_bytes * extent); } } @@ -1141,16 +1018,13 @@ class CommandQueue { insn_queue_.DumpInsn(); } // Make sure that the last instruction is a finish instruction - CHECK(reinterpret_cast( - insn_queue_.data())[insn_queue_.count()-1].opcode == VTA_OPCODE_FINISH); + CHECK(reinterpret_cast(insn_queue_.data())[insn_queue_.count() - 1].opcode == + VTA_OPCODE_FINISH); // Make sure that we don't exceed contiguous physical memory limits CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER); - int timeout = VTADeviceRun( - device_, - insn_queue_.dram_phy_addr(), - insn_queue_.count(), - wait_cycles); + int timeout = + VTADeviceRun(device_, insn_queue_.dram_phy_addr(), insn_queue_.count(), wait_cycles); CHECK_EQ(timeout, 0); // Reset buffers uop_queue_.Reset(); @@ -1164,14 +1038,9 @@ class CommandQueue { } // Set debug flag - void SetDebugFlag(int debug_flag) { - debug_flag_ = debug_flag; - } + void SetDebugFlag(int debug_flag) { debug_flag_ = debug_flag; } - void PushGEMMOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { + void PushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { UopKernelMap** uptr = reinterpret_cast(uop_handle); if (uptr[0] == nullptr) { uptr[0] = new UopKernelMap(); @@ -1190,10 +1059,7 @@ class CommandQueue { this->CheckInsnOverFlow(); } - void PushALUUop(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { + void PushALUUop(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { UopKernelMap** uptr = reinterpret_cast(uop_handle); if (uptr[0] == nullptr) { uptr[0] = new UopKernelMap(); @@ -1213,23 +1079,19 @@ class CommandQueue { } static std::shared_ptr& ThreadLocal() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); if (inst == nullptr) { inst = std::make_shared(); } return inst; } - static void Shutdown() { - ThreadLocal().reset(); - } + static void Shutdown() { ThreadLocal().reset(); } private: // Push GEMM uop to the command buffer void PushGEMMOp(UopKernel* kernel) { - uop_queue_.Push(kernel, - [this]() { this->AutoSync(); }); + uop_queue_.Push(kernel, [this]() { this->AutoSync(); }); if (uop_queue_.pending()) { VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP); insn->opcode = VTA_OPCODE_LOAD; @@ -1240,7 +1102,7 @@ class CommandQueue { insn->reset_reg = kernel->reset_out_; insn->uop_bgn = kernel->sram_begin_; insn->uop_end = kernel->sram_end_; - const std::vector &loop = kernel->loop(); + const std::vector& loop = kernel->loop(); if (loop.size() > 0) { insn->iter_out = loop[0].extent; insn->wgt_factor_out = loop[0].wgt_factor; @@ -1267,8 +1129,7 @@ class CommandQueue { // Push ALU uop to the command buffer void PushALUUop(UopKernel* kernel) { - uop_queue_.Push(kernel, - [this]() { this->AutoSync(); }); + uop_queue_.Push(kernel, [this]() { this->AutoSync(); }); if (uop_queue_.pending()) { VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP); insn->opcode = VTA_OPCODE_LOAD; @@ -1282,7 +1143,7 @@ class CommandQueue { insn->alu_opcode = kernel->opcode_; insn->use_imm = kernel->use_imm_; insn->imm = kernel->imm_val_; - const std::vector &loop = kernel->loop(); + const std::vector& loop = kernel->loop(); if (loop.size() == 0) { insn->iter_out = 1; insn->dst_factor_out = 0; @@ -1315,9 +1176,7 @@ class CommandQueue { } } // Auto sync when instruction overflow - void AutoSync() { - this->Synchronize(1 << 31); - } + void AutoSync() { this->Synchronize(1 << 31); } // Internal debug flag int debug_flag_{0}; @@ -1333,19 +1192,11 @@ class CommandQueue { } // namespace vta -void* VTABufferAlloc(size_t size) { - return vta::DataBuffer::Alloc(size); -} +void* VTABufferAlloc(size_t size) { return vta::DataBuffer::Alloc(size); } -void VTABufferFree(void* buffer) { - vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); -} +void VTABufferFree(void* buffer) { vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); } -void VTABufferCopy(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, +void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, int kind_mask) { vta::DataBuffer* from_buffer = nullptr; vta::DataBuffer* to_buffer = nullptr; @@ -1363,143 +1214,87 @@ void VTABufferCopy(const void* from, // This is an FPGA to host mem transfer from_buffer->InvalidateCache(from_offset, size); from_buffer->MemCopyToHost(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); + static_cast(from) + from_offset, size); } else if (to_buffer) { // This is a host to FPGA mem transfer to_buffer->MemCopyFromHost(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); + static_cast(from) + from_offset, size); to_buffer->FlushCache(to_offset, size); } } -VTACommandHandle VTATLSCommandHandle() { - return vta::CommandQueue::ThreadLocal().get(); -} +VTACommandHandle VTATLSCommandHandle() { return vta::CommandQueue::ThreadLocal().get(); } -void VTARuntimeShutdown() { - vta::CommandQueue::Shutdown(); -} +void VTARuntimeShutdown() { vta::CommandQueue::Shutdown(); } void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) { - static_cast(cmd)-> - SetDebugFlag(debug_flag); + static_cast(cmd)->SetDebugFlag(debug_flag); } void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer) { return vta::DataBuffer::FromHandle(buffer)->virt_addr(); } -void VTAWriteBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { - static_cast(cmd)-> - WriteBarrier(buffer, elem_bits, start, extent); + static_cast(cmd)->WriteBarrier(buffer, elem_bits, start, extent); } -void VTAReadBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +void VTAReadBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { - static_cast(cmd)-> - ReadBarrier(buffer, elem_bits, start, extent); + static_cast(cmd)->ReadBarrier(buffer, elem_bits, start, extent); } -void VTALoadBuffer2D(VTACommandHandle cmd, - void* src_dram_addr, - uint32_t src_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride, - uint32_t x_pad_before, - uint32_t y_pad_before, - uint32_t x_pad_after, - uint32_t y_pad_after, - uint32_t dst_sram_index, - uint32_t dst_memory_type) { - static_cast(cmd)-> - LoadBuffer2D(src_dram_addr, src_elem_offset, - x_size, y_size, x_stride, - x_pad_before, y_pad_before, - x_pad_after, y_pad_after, - dst_sram_index, dst_memory_type); +void VTALoadBuffer2D(VTACommandHandle cmd, void* src_dram_addr, uint32_t src_elem_offset, + uint32_t x_size, uint32_t y_size, uint32_t x_stride, uint32_t x_pad_before, + uint32_t y_pad_before, uint32_t x_pad_after, uint32_t y_pad_after, + uint32_t dst_sram_index, uint32_t dst_memory_type) { + static_cast(cmd)->LoadBuffer2D( + src_dram_addr, src_elem_offset, x_size, y_size, x_stride, x_pad_before, y_pad_before, + x_pad_after, y_pad_after, dst_sram_index, dst_memory_type); } -void VTAStoreBuffer2D(VTACommandHandle cmd, - uint32_t src_sram_index, - uint32_t src_memory_type, - void* dst_dram_addr, - uint32_t dst_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride) { - static_cast(cmd)-> - StoreBuffer2D(src_sram_index, src_memory_type, - dst_dram_addr, dst_elem_offset, - x_size, y_size, x_stride); +void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t src_sram_index, uint32_t src_memory_type, + void* dst_dram_addr, uint32_t dst_elem_offset, uint32_t x_size, + uint32_t y_size, uint32_t x_stride) { + static_cast(cmd)->StoreBuffer2D( + src_sram_index, src_memory_type, dst_dram_addr, dst_elem_offset, x_size, y_size, x_stride); } -void VTAUopPush(uint32_t mode, - uint32_t reset_out, - uint32_t dst_index, - uint32_t src_index, - uint32_t wgt_index, - uint32_t opcode, - uint32_t use_imm, - int32_t imm_val) { - vta::CommandQueue::ThreadLocal()->record_kernel() - ->Push(mode, reset_out, dst_index, src_index, - wgt_index, opcode, use_imm, imm_val); +void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index, + uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) { + vta::CommandQueue::ThreadLocal()->record_kernel()->Push(mode, reset_out, dst_index, src_index, + wgt_index, opcode, use_imm, imm_val); } -void VTAUopLoopBegin(uint32_t extent, - uint32_t dst_factor, - uint32_t src_factor, +void VTAUopLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor, uint32_t wgt_factor) { - vta::CommandQueue::ThreadLocal()->record_kernel() - ->PushLoopBegin(extent, dst_factor, src_factor, wgt_factor); + vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopBegin(extent, dst_factor, src_factor, + wgt_factor); } -void VTAUopLoopEnd() { - vta::CommandQueue::ThreadLocal()->record_kernel() - ->PushLoopEnd(); -} +void VTAUopLoopEnd() { vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopEnd(); } -int VTAPushGEMMOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { - vta::CommandQueue::ThreadLocal()-> - PushGEMMOp(uop_handle, finit, signature, nbytes); +int VTAPushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { + vta::CommandQueue::ThreadLocal()->PushGEMMOp(uop_handle, finit, signature, nbytes); return 0; } -int VTAPushALUOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { - vta::CommandQueue::ThreadLocal()-> - PushALUUop(uop_handle, finit, signature, nbytes); +int VTAPushALUOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { + vta::CommandQueue::ThreadLocal()->PushALUUop(uop_handle, finit, signature, nbytes); return 0; } int VTADepPush(VTACommandHandle cmd, int from_qid, int to_qid) { - static_cast(cmd)-> - DepPush(from_qid, to_qid); + static_cast(cmd)->DepPush(from_qid, to_qid); return 0; } int VTADepPop(VTACommandHandle cmd, int from_qid, int to_qid) { - static_cast(cmd)-> - DepPop(from_qid, to_qid); + static_cast(cmd)->DepPop(from_qid, to_qid); return 0; } void VTASynchronize(VTACommandHandle cmd, uint32_t wait_cycles) { - static_cast(cmd)-> - Synchronize(wait_cycles); + static_cast(cmd)->Synchronize(wait_cycles); } diff --git a/vta/runtime/runtime.h b/vta/runtime/runtime.h index bb16d3a..24ebb8e 100644 --- a/vta/runtime/runtime.h +++ b/vta/runtime/runtime.h @@ -64,12 +64,8 @@ TVM_DLL void VTABufferFree(void* buffer); * \param size Size of copy. * \param kind_mask The memory copy kind. */ -TVM_DLL void VTABufferCopy(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - int kind_mask); +TVM_DLL void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t size, int kind_mask); /*! \brief VTA command handle */ typedef void* VTACommandHandle; @@ -99,10 +95,7 @@ TVM_DLL void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer); * \param start The start of the region (in elements). * \param extent The end of the region (in elements). */ -TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent); /*! @@ -113,10 +106,7 @@ TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, * \param start The start of the region (in elements). * \param extent The end of the region (in elements). */ -TVM_DLL void VTAReadBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +TVM_DLL void VTAReadBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent); /*! @@ -142,17 +132,10 @@ TVM_DLL void VTASetDebugMode(VTACommandHandle cmd, int debug_flag); * \param dst_sram_index Destination SRAM index. * \param dst_memory_type Destination memory type. */ -TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, - void* src_dram_addr, - uint32_t src_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride, - uint32_t x_pad_before, - uint32_t y_pad_before, - uint32_t x_pad_after, - uint32_t y_pad_after, - uint32_t dst_sram_index, +TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, void* src_dram_addr, uint32_t src_elem_offset, + uint32_t x_size, uint32_t y_size, uint32_t x_stride, + uint32_t x_pad_before, uint32_t y_pad_before, uint32_t x_pad_after, + uint32_t y_pad_after, uint32_t dst_sram_index, uint32_t dst_memory_type); /*! @@ -167,13 +150,9 @@ TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, * \param y_size The number of rows. * \param x_stride The x axis stride. */ -TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, - uint32_t src_sram_index, - uint32_t src_memory_type, - void* dst_dram_addr, - uint32_t dst_elem_offset, - uint32_t x_size, - uint32_t y_size, +TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t src_sram_index, + uint32_t src_memory_type, void* dst_dram_addr, + uint32_t dst_elem_offset, uint32_t x_size, uint32_t y_size, uint32_t x_stride); /*! @@ -207,14 +186,8 @@ TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, * \param use_imm Use immediate in ALU mode if set to true. * \param imm_val Immediate value in ALU mode. */ -TVM_DLL void VTAUopPush(uint32_t mode, - uint32_t reset_out, - uint32_t dst_index, - uint32_t src_index, - uint32_t wgt_index, - uint32_t opcode, - uint32_t use_imm, - int32_t imm_val); +TVM_DLL void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index, + uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val); /*! * \brief Mark start of a micro op loop. @@ -223,9 +196,7 @@ TVM_DLL void VTAUopPush(uint32_t mode, * \param src_factor The input factor. * \param wgt_factor The weight factor. */ -TVM_DLL void VTAUopLoopBegin(uint32_t extent, - uint32_t dst_factor, - uint32_t src_factor, +TVM_DLL void VTAUopLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor, uint32_t wgt_factor); /*! @@ -241,10 +212,7 @@ TVM_DLL void VTAUopLoopEnd(); * \param nbytes Number of bytes to in the closure arguments. * \return 0 if success. */ -TVM_DLL int VTAPushGEMMOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes); +TVM_DLL int VTAPushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes); /*! * \brief Push ALU uop kernel into the command handle. @@ -254,10 +222,7 @@ TVM_DLL int VTAPushGEMMOp(void** uop_handle, * \param nbytes Number of bytes to in the closure arguments. * \return 0 if success. */ -TVM_DLL int VTAPushALUOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes); +TVM_DLL int VTAPushALUOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes); /*! * \brief Push dependence token. diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 9ea65d0..6abd122 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -31,12 +31,12 @@ #define DMLC_LOG_NODATE 1 #define DMLC_LOG_FATAL_THROW 0 - #include -#include -#include #include #include +#include +#include + #include "../../src/runtime/rpc/rpc_local_session.h" extern "C" { @@ -61,8 +61,7 @@ TVM_DLL void TVMWasmFreeSpace(void* data); * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer 3A * \return 0 if success. */ -TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, - TVMFunctionHandle *out); +TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out); // --- APIs to be implemented by the frontend. --- /*! @@ -75,10 +74,7 @@ TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, * \param resource_handle The handle additional resouce handle from fron-end. * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. */ -extern int TVMWasmPackedCFunc(TVMValue* args, - int* type_codes, - int num_args, - TVMRetValueHandle ret, +extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, void* resource_handle); /*! @@ -88,24 +84,18 @@ extern int TVMWasmPackedCFunc(TVMValue* args, extern void TVMWasmPackedCFuncFinalizer(void* resource_handle); } // extern "C" - void* TVMWasmAllocSpace(int size) { int num_count = (size + 7) / 8; return new int64_t[num_count]; } -void TVMWasmFreeSpace(void* arr) { - delete[] static_cast(arr); -} +void TVMWasmFreeSpace(void* arr) { delete[] static_cast(arr); } -int TVMWasmFuncCreateFromCFunc(void* resource_handle, - TVMFunctionHandle *out) { - return TVMFuncCreateFromCFunc( - TVMWasmPackedCFunc, resource_handle, - TVMWasmPackedCFuncFinalizer, out); +int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) { + return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer, + out); } - namespace tvm { namespace runtime { @@ -113,8 +103,7 @@ namespace runtime { // functions in the JS runtime. class AsyncLocalSession : public LocalSession { public: - AsyncLocalSession() { - } + AsyncLocalSession() {} PackedFuncHandle GetFunction(const std::string& name) final { if (name == "runtime.RPCTimeEvaluator") { @@ -122,7 +111,7 @@ class AsyncLocalSession : public LocalSession { } else if (auto* fp = tvm::runtime::Registry::Get(name)) { // return raw handle because the remote need to explicitly manage it. return new PackedFunc(*fp); - } else if(auto* fp = tvm::runtime::Registry::Get("__async." + name)) { + } else if (auto* fp = tvm::runtime::Registry::Get("__async." + name)) { auto* rptr = new PackedFunc(*fp); async_func_set_.insert(rptr); return rptr; @@ -143,20 +132,16 @@ class AsyncLocalSession : public LocalSession { } } - void AsyncCallFunc(PackedFuncHandle func, - const TVMValue* arg_values, - const int* arg_type_codes, - int num_args, - FAsyncCallback callback) final { + void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, + int num_args, FAsyncCallback callback) final { auto it = async_func_set_.find(func); if (it != async_func_set_.end()) { PackedFunc packed_callback([callback, this](TVMArgs args, TVMRetValue*) { int code = args[0]; TVMRetValue rv; rv = args[1]; - this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { - callback(RPCCode::kReturn, encoded_args); - }); + this->EncodeReturn(std::move(rv), + [&](TVMArgs encoded_args) { callback(RPCCode::kReturn, encoded_args); }); }); TVMRetValue temp; @@ -175,8 +160,8 @@ class AsyncLocalSession : public LocalSession { // special handle time evaluator. try { TVMArgs args(arg_values, arg_type_codes, num_args); - PackedFunc retfunc = this->GetTimeEvaluator( - args[0], args[1], args[2], args[3], args[4], args[5], args[6]); + PackedFunc retfunc = + this->GetTimeEvaluator(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); TVMRetValue rv; rv = retfunc; this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { @@ -192,53 +177,39 @@ class AsyncLocalSession : public LocalSession { } } - void AsyncCopyToRemote(void* local_from, - size_t local_from_offset, - void* remote_to, - size_t remote_to_offset, - size_t nbytes, - TVMContext remote_ctx_to, - DLDataType type_hint, - FAsyncCallback on_complete) final { + void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint, FAsyncCallback on_complete) final { TVMContext cpu_ctx; cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; try { - this->GetDeviceAPI(remote_ctx_to)->CopyDataFromTo( - local_from, local_from_offset, - remote_to, remote_to_offset, - nbytes, cpu_ctx, remote_ctx_to, type_hint, nullptr); + this->GetDeviceAPI(remote_ctx_to) + ->CopyDataFromTo(local_from, local_from_offset, remote_to, remote_to_offset, nbytes, + cpu_ctx, remote_ctx_to, type_hint, nullptr); this->AsyncStreamWait(remote_ctx_to, nullptr, on_complete); } catch (const std::runtime_error& e) { this->SendException(on_complete, e.what()); } } - void AsyncCopyFromRemote(void* remote_from, - size_t remote_from_offset, - void* local_to, - size_t local_to_offset, - size_t nbytes, - TVMContext remote_ctx_from, - DLDataType type_hint, - FAsyncCallback on_complete) final { + void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from, + DLDataType type_hint, FAsyncCallback on_complete) final { TVMContext cpu_ctx; cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; try { - this->GetDeviceAPI(remote_ctx_from)->CopyDataFromTo( - remote_from, remote_from_offset, - local_to, local_to_offset, - nbytes, remote_ctx_from, cpu_ctx, type_hint, nullptr); + this->GetDeviceAPI(remote_ctx_from) + ->CopyDataFromTo(remote_from, remote_from_offset, local_to, local_to_offset, nbytes, + remote_ctx_from, cpu_ctx, type_hint, nullptr); this->AsyncStreamWait(remote_ctx_from, nullptr, on_complete); } catch (const std::runtime_error& e) { this->SendException(on_complete, e.what()); } } - void AsyncStreamWait(TVMContext ctx, - TVMStreamHandle stream, - FAsyncCallback on_complete) final { + void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_complete) final { if (ctx.device_type == kDLCPU) { TVMValue value; int32_t tcode = kTVMNullptr; @@ -259,9 +230,7 @@ class AsyncLocalSession : public LocalSession { } } - bool IsAsync() const final { - return true; - } + bool IsAsync() const final { return true; } private: std::unordered_set async_func_set_; @@ -269,13 +238,8 @@ class AsyncLocalSession : public LocalSession { const PackedFunc* async_wait_{nullptr}; // time evaluator - PackedFunc GetTimeEvaluator(Optional opt_mod, - std::string name, - int device_type, - int device_id, - int number, - int repeat, - int min_repeat_ms) { + PackedFunc GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, + int device_id, int number, int repeat, int min_repeat_ms) { TVMContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; @@ -283,24 +247,18 @@ class AsyncLocalSession : public LocalSession { if (opt_mod.defined()) { Module m = opt_mod.value(); std::string tkey = m->type_key(); - return WrapWasmTimeEvaluator( - m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + return WrapWasmTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); } else { auto* pf = runtime::Registry::Get(name); CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; - return WrapWasmTimeEvaluator( - *pf, ctx, number, repeat, min_repeat_ms); + return WrapWasmTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms); } } // time evaluator - PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, - TVMContext ctx, - int number, - int repeat, + PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, int min_repeat_ms) { - auto ftimer = [pf, ctx, number, repeat, min_repeat_ms]( - TVMArgs args, TVMRetValue *rv) { + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) { // the function is a async function. PackedFunc on_complete = args[args.size() - 1]; // keep argument alive in finvoke so that they @@ -317,15 +275,14 @@ class AsyncLocalSession : public LocalSession { }; auto* time_exec = runtime::Registry::Get("__async.wasm.TimeExecution"); CHECK(time_exec != nullptr) << "Cannot find wasm.GetTimer in the global function"; - (*time_exec)(TypedPackedFunc(finvoke), - ctx, number, repeat, min_repeat_ms, on_complete); + (*time_exec)(TypedPackedFunc(finvoke), ctx, number, repeat, min_repeat_ms, + on_complete); }; return PackedFunc(ftimer); } }; -TVM_REGISTER_GLOBAL("wasm.LocalSession") -.set_body_typed([]() { +TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() { return CreateRPCSessionModule(std::make_shared()); }); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 6ff652c..a67b4c3 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -29,64 +29,51 @@ #define DMLC_LOG_NODATE 1 #define DMLC_LOG_FATAL_THROW 0 -#include #include +#include #include "src/runtime/c_runtime_api.cc" #include "src/runtime/cpu_device_api.cc" -#include "src/runtime/workspace_pool.cc" +#include "src/runtime/file_util.cc" +#include "src/runtime/graph/graph_runtime.cc" #include "src/runtime/library_module.cc" -#include "src/runtime/system_library.cc" - #include "src/runtime/module.cc" #include "src/runtime/ndarray.cc" #include "src/runtime/object.cc" #include "src/runtime/registry.cc" -#include "src/runtime/file_util.cc" -#include "src/runtime/graph/graph_runtime.cc" -#include "src/runtime/rpc/rpc_session.cc" +#include "src/runtime/rpc/rpc_channel.cc" #include "src/runtime/rpc/rpc_endpoint.cc" #include "src/runtime/rpc/rpc_event_impl.cc" -#include "src/runtime/rpc/rpc_channel.cc" #include "src/runtime/rpc/rpc_local_session.cc" #include "src/runtime/rpc/rpc_module.cc" - +#include "src/runtime/rpc/rpc_session.cc" +#include "src/runtime/system_library.cc" +#include "src/runtime/workspace_pool.cc" // --- Implementations of backend and wasm runtime API. --- -int TVMBackendParallelLaunch(FTVMParallelLambda flambda, - void* cdata, - int num_task) { +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { TVMParallelGroupEnv env; env.num_task = 1; flambda(0, &env, cdata); return 0; } -int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { - return 0; -} +int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { return 0; } // --- Environment PackedFuncs for testing --- -namespace tvm { +namespace tvm { namespace runtime { -TVM_REGISTER_GLOBAL("testing.echo") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0]; }); -TVM_REGISTER_GLOBAL("testing.add_one") -.set_body_typed([](int x) { - return x + 1; -}); +TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; }); -TVM_REGISTER_GLOBAL("testing.wrap_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PackedFunc pf = args[0]; - *ret = runtime::TypedPackedFunc([pf](){ - pf(); - }); - }); +TVM_REGISTER_GLOBAL("testing.wrap_callback").set_body([](TVMArgs args, TVMRetValue* ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf]() { pf(); }); +}); } // namespace runtime } // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 537ab18..7f0b0d9 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -31,12 +31,13 @@ #include #include +#include #include #include -#include + #include "../../src/runtime/meta_data.h" -#include "../../src/runtime/workspace_pool.h" #include "../../src/runtime/vulkan/vulkan_shader.h" +#include "../../src/runtime/workspace_pool.h" namespace tvm { namespace runtime { @@ -52,7 +53,6 @@ class WebGPUThreadEntry { static WebGPUThreadEntry* ThreadLocal(); }; - // All the implementations are redirectly to the JS side. class WebGPUDeviceAPI : public DeviceAPI { public: @@ -67,32 +67,23 @@ class WebGPUDeviceAPI : public DeviceAPI { copy_within_gpu_ = getter("deviceCopyWithinGPU"); } - void SetDevice(TVMContext ctx) final { - } + void SetDevice(TVMContext ctx) final {} void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { if (kind == kExist) { *rv = 1; } } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { - double ptr_number = alloc_space_(nbytes); return reinterpret_cast(static_cast(ptr_number)); } - void FreeDataSpace(TVMContext ctx, void* ptr) final { - return free_space_(ptr); - } + void FreeDataSpace(TVMContext ctx, void* ptr) final { return free_space_(ptr); } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, size_t to_offset, size_t size, - TVMContext ctx_from, - TVMContext ctx_to, DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { if (static_cast(ctx_from.device_type) == kDLWebGPU && static_cast(ctx_to.device_type) == kDLWebGPU) { @@ -126,9 +117,7 @@ class WebGPUDeviceAPI : public DeviceAPI { return; } - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - LOG(FATAL) << "Not implemented"; - } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } void SetStream(TVMContext ctx, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; @@ -144,8 +133,7 @@ class WebGPUDeviceAPI : public DeviceAPI { } static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } @@ -155,27 +143,22 @@ class WebGPUDeviceAPI : public DeviceAPI { TypedPackedFunc free_space_; TypedPackedFunc copy_to_gpu_; TypedPackedFunc copy_from_gpu_; - TypedPackedFunc copy_within_gpu_; + TypedPackedFunc + copy_within_gpu_; }; - typedef dmlc::ThreadLocalStore WebGPUThreadStore; WebGPUThreadEntry::WebGPUThreadEntry() - : pool(static_cast(kDLWebGPU), WebGPUDeviceAPI::Global()) { -} - -WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { - return WebGPUThreadStore::Get(); -} + : pool(static_cast(kDLWebGPU), WebGPUDeviceAPI::Global()) {} +WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore::Get(); } class WebGPUModuleNode final : public runtime::ModuleNode { public: explicit WebGPUModuleNode(std::unordered_map smap, - std::unordered_map fmap, - std::string source) + std::unordered_map fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) { auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUCreateShader"); CHECK(fp != nullptr); @@ -184,8 +167,7 @@ class WebGPUModuleNode final : public runtime::ModuleNode { const char* type_key() const final { return "webgpu"; } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { auto it = smap_.find(name); if (it != smap_.end()) { FunctionInfo info = fmap_.at(name); @@ -206,9 +188,7 @@ class WebGPUModuleNode final : public runtime::ModuleNode { LOG(FATAL) << "Not implemented"; } - void SaveToBinary(dmlc::Stream* stream) final { - LOG(FATAL) << "Not implemented"; - } + void SaveToBinary(dmlc::Stream* stream) final { LOG(FATAL) << "Not implemented"; } std::string GetSource(const std::string& format) final { // can only return source code. @@ -226,7 +206,6 @@ class WebGPUModuleNode final : public runtime::ModuleNode { TypedPackedFunc create_shader_; }; - Module WebGPUModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); std::unordered_map smap; @@ -240,11 +219,9 @@ Module WebGPUModuleLoadBinary(void* strm) { } // for now webgpu is hosted via a vulkan module. -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan") -.set_body_typed(WebGPUModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary); -TVM_REGISTER_GLOBAL("device_api.webgpu") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = WebGPUDeviceAPI::Global().get(); *rv = static_cast(ptr); }); -- 2.7.4