From 3f1c8094707f695cf0cf51c795b18093e0a3ab86 Mon Sep 17 00:00:00 2001 From: Ansha Yu Date: Thu, 26 Aug 2021 23:17:42 -0700 Subject: [PATCH] [static runtime] port c2 argmin kernel (#63632) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63632 Local benchmarking with 1 input repeated 10k iter on 290331537_4 local net. Reduces argmin runtime by about 80% and and local net execution by about ~0.71-0.77ms. Before: ``` I0826 17:25:53.972786 1104614 PyTorchPredictorBenchLib.cpp:313] PyTorch run finished. Milliseconds per iter: 7.37599. Iters per second: 135.57 ``` ``` Static runtime ms per iter: 8.22086. Iters per second: 121.642 Time per node type: 4.13527 ms. 50.9157%. fb::sigrid_transforms_torch_bind (1 nodes, out variant) 0.868506 ms. 10.6935%. aten::argmin (1 nodes, out variant) ... ``` After: ``` I0826 17:17:54.165174 1064079 PyTorchPredictorBenchLib.cpp:313] PyTorch run finished. Milliseconds per iter: 6.66724. Iters per second: 149.987 ``` ``` Static runtime ms per iter: 7.68172. Iters per second: 130.179 Time per node type: 4.1452 ms. 54.0612%. fb::sigrid_transforms_torch_bind (1 nodes, out variant) 0.656778 ms. 8.56562%. fb::quantized_linear (8 nodes) 0.488229 ms. 6.36741%. static_runtime::to_copy (827 nodes, out variant) 0.372678 ms. 4.86042%. aten::argmin (1 nodes, out variant) ...Time per node type: 3.39387 ms. 53.5467%. fb::sigrid_transforms_torch_bind (1 nodes, out variant) 0.636216 ms. 10.0379%. fb::quantized_linear (8 nodes, out variant) 0.410535 ms. 6.47721%. fb::clip_ranges_to_gather_to_offsets (304 nodes, out variant) 0.212721 ms. 3.3562%. fb::clip_ranges_gather_sigrid_hash_precompute_v3 (157 nodes, out variant) 0.173736 ms. 2.74111%. aten::matmul (1 nodes, out variant) 0.150514 ms. 2.37474%. aten::argmin (1 nodes, out variant) ``` P447422384 Test Plan: Test with local replayer sending traffic to `ansha_perf_test_0819.test`, and compare outputs to jit interpreter. Start compute tier: ``` RUN_UUID=ansha_perf_test_0819.test.storage JOB_EXPIRE_TIME=864000 MODEL_ID=290331537_4 PREDICTOR_TAG= PREDICTOR_VERSION=405 PREDICTOR_TYPE=CPU ADDITIONAL_FLAGS="--enable_disagg_file_split=true --enable_adx=false --load_remote_file_locally=true --pytorch_predictor_static_runtime_whitelist_by_id=290331537" GFLAGS_CONFIG_PATH=sigrid/predictor/gflags/predictor_gflags_ads_perf_cpu_pyper SMC_TIER_NAME=sigrid.predictor.perf.ansha_per_test_0819.test.storage CLUSTER=tsp_rva ENTITLEMENT_NAME=ads_ranking_infra_test_t6 PREDICTOR_LOCAL_DIRECTORY= ICET_CONFIG_PATH= NNPI_COMPILATION_CONFIG_FILE= NUM_TASKS=1 NNPI_NUM_WORKERS=0 tw job start /data/users/ansha/fbsource/fbcode/tupperware/config/admarket/sigrid/predictor/predictor_perf_canary.tw ``` Start nnpi tier: ``` RUN_UUID=ansha_perf_test_0819.test JOB_EXPIRE_TIME=247200 MODEL_ID=290331537_4 PREDICTOR_TAG= PREDICTOR_VERSION=343 PREDICTOR_TYPE=NNPI_TWSHARED ADDITIONAL_FLAGS="--torch_glow_min_fusion_group_size=30 --pytorch_storage_tier_replayer_sr_connection_options=overall_timeout:1000000,processing_timeout:1000000 --predictor_storage_smc_tier=sigrid.predictor.perf.ansha_perf_test_0819.test.storage --pytorch_predictor_static_runtime_whitelist_by_id=290331537" GFLAGS_CONFIG_PATH=sigrid/predictor/gflags/predictor_gflags_ads_perf_glow_nnpi_pyper_v1 SMC_TIER_NAME=sigrid.predictor.perf.ansha_perf_test_0819.test CLUSTER=tsp_rva ENTITLEMENT_NAME=ads_ranking_infra_test_t17 PREDICTOR_LOCAL_DIRECTORY= ICET_CONFIG_PATH= NNPI_COMPILATION_CONFIG_FILE= NUM_TASKS=1 NNPI_NUM_WORKERS=0 tw job start /data/users/ansha/fbsource/fbcode/tupperware/config/admarket/sigrid/predictor/predictor_perf_canary.tw ``` ```buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- StaticRuntime.IndividualOps_Argmin --print-passing-details``` Compared outputs to jit interpreter to check for no differences greater than 1e-3 (with nnc on) https://www.internalfb.com/intern/diff/view-version/136824794/ Reviewed By: hlu1 Differential Revision: D30445635 fbshipit-source-id: 048de8867ac72f764132295d1ebfa843cde2fa27 --- torch/csrc/jit/runtime/static/ops.cpp | 94 ++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 4d34ed9..484c4b0 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -178,6 +179,94 @@ Tensor& linear_out( return output; } +Tensor& c2_argmin_out( + Tensor& output, + const Tensor& input, + const int64_t dim, + const bool keepdim) { + const auto ndim = input.dim(); + int64_t dim_ = maybe_wrap_dim(dim, ndim); + TORCH_CHECK(dim_ >= 0 && dim_ < ndim); + + const auto in_dims = input.sizes(); + + c10::SmallVector out_dims; + out_dims.reserve(ndim); + int prev_size = 1; + int next_size = 1; + for (int i = 0; i < dim_; ++i) { + out_dims.push_back(in_dims[i]); + prev_size *= in_dims[i]; + } + if (keepdim) { + out_dims.push_back(1); + } + for (auto i = dim_ + 1; i < ndim; ++i) { + out_dims.push_back(in_dims[i]); + next_size *= in_dims[i]; + } + at::native::resize_(output, out_dims, c10::nullopt); + + const auto n = in_dims[dim_]; + + if (next_size == 1) { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() { + const auto in_ptr = input.data_ptr(); + const auto out_ptr = output.data_ptr(); + // input is a [prev_size, n] tensor. + // output is a [prev_size,] tensor. + // Thus, access is contiguous/coalesced. + for (int i = 0; i < prev_size; ++i) { + auto v = std::min_element( + in_ptr + i * n, + in_ptr + (i + 1) * n, + [](scalar_t a, scalar_t b) { + // if a is nan, then a is *less* than b with LessOrNan + // semantics + if (at::_isnan(a)) { + return true; + } + // if a is not nan and b is nan, then a is not less than b + // with LessOrNan semantics otherwise, act normally. If `b` is + // NaN then a < b will always return false, so this is + // equivalent to the first snippet. + return a < b; + }); + out_ptr[i] = std::distance(in_ptr + i * n, v); + } + }); + } else { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() { + const auto less_or_nan = native::detail::LessOrNan{}; + + const auto in_ptr = input.data_ptr(); + const auto out_ptr = output.data_ptr(); + + std::memset(out_ptr, 0, prev_size * next_size * sizeof(int64_t)); + + for (int i = 0; i < prev_size; ++i) { + const scalar_t* cur_in_ptr = in_ptr + i * n * next_size + next_size; + for (int k = 1; k < n; ++k) { + for (int j = 0; j < next_size; ++j) { + int64_t* cur_out_ptr = out_ptr + i * next_size + j; + if (less_or_nan( + *cur_in_ptr, + in_ptr + [i * n * next_size + *cur_out_ptr * next_size + j], + *cur_out_ptr, + k)) { + *cur_out_ptr = k; + } + ++cur_in_ptr; + } + } + } + }); + } + return output; +} } // namespace native } // namespace at @@ -1209,6 +1298,10 @@ REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator { } else { auto& out_t = p_node->Output(0).toTensor(); fastResizeToZero(out_t); + if (in0_t.is_contiguous() && dim.has_value()) { + at::native::c2_argmin_out(out_t, in0_t, dim.value(), keepdim); + return; + } at::cpu::argmin_out(out_t, in0_t, dim, keepdim); } }; @@ -1533,6 +1626,5 @@ REGISTER_OPERATOR_FUNCTOR( } }; }); - } // namespace jit } // namespace torch -- 2.7.4