[static runtime] port c2 argmin kernel (#63632)
authorAnsha Yu <ansha@fb.com>
Fri, 27 Aug 2021 06:17:42 +0000 (23:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 06:19:19 +0000 (23:19 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63632

Local benchmarking with 1 input repeated 10k iter on 290331537_4 local net. Reduces argmin runtime by about 80% and and local net execution by about ~0.71-0.77ms.

Before:
```
I0826 17:25:53.972786 1104614 PyTorchPredictorBenchLib.cpp:313] PyTorch run finished. Milliseconds per iter: 7.37599. Iters per second: 135.57
```
```
Static runtime ms per iter: 8.22086. Iters per second: 121.642
Time per node type:
        4.13527 ms.    50.9157%. fb::sigrid_transforms_torch_bind (1 nodes, out variant)
       0.868506 ms.    10.6935%. aten::argmin (1 nodes, out variant)
...
```

After:
```
I0826 17:17:54.165174 1064079 PyTorchPredictorBenchLib.cpp:313] PyTorch run finished. Milliseconds per iter: 6.66724. Iters per second: 149.987
```
```
Static runtime ms per iter: 7.68172. Iters per second: 130.179
Time per node type:
         4.1452 ms.    54.0612%. fb::sigrid_transforms_torch_bind (1 nodes, out variant)
       0.656778 ms.    8.56562%. fb::quantized_linear (8 nodes)
       0.488229 ms.    6.36741%. static_runtime::to_copy (827 nodes, out variant)
       0.372678 ms.    4.86042%. aten::argmin (1 nodes, out variant)
...Time per node type:
        3.39387 ms.    53.5467%. fb::sigrid_transforms_torch_bind (1 nodes, out variant)
       0.636216 ms.    10.0379%. fb::quantized_linear (8 nodes, out variant)
       0.410535 ms.    6.47721%. fb::clip_ranges_to_gather_to_offsets (304 nodes, out variant)
       0.212721 ms.     3.3562%. fb::clip_ranges_gather_sigrid_hash_precompute_v3 (157 nodes, out variant)
       0.173736 ms.    2.74111%. aten::matmul (1 nodes, out variant)
       0.150514 ms.    2.37474%. aten::argmin (1 nodes, out variant)
```
P447422384

Test Plan:
Test with local replayer sending traffic to `ansha_perf_test_0819.test`, and compare outputs to jit interpreter.

Start compute tier:
```
RUN_UUID=ansha_perf_test_0819.test.storage JOB_EXPIRE_TIME=864000 MODEL_ID=290331537_4 PREDICTOR_TAG= PREDICTOR_VERSION=405 PREDICTOR_TYPE=CPU ADDITIONAL_FLAGS="--enable_disagg_file_split=true --enable_adx=false --load_remote_file_locally=true --pytorch_predictor_static_runtime_whitelist_by_id=290331537" GFLAGS_CONFIG_PATH=sigrid/predictor/gflags/predictor_gflags_ads_perf_cpu_pyper SMC_TIER_NAME=sigrid.predictor.perf.ansha_per_test_0819.test.storage CLUSTER=tsp_rva ENTITLEMENT_NAME=ads_ranking_infra_test_t6 PREDICTOR_LOCAL_DIRECTORY= ICET_CONFIG_PATH= NNPI_COMPILATION_CONFIG_FILE= NUM_TASKS=1 NNPI_NUM_WORKERS=0 tw job start /data/users/ansha/fbsource/fbcode/tupperware/config/admarket/sigrid/predictor/predictor_perf_canary.tw
```

Start nnpi tier:
```
RUN_UUID=ansha_perf_test_0819.test JOB_EXPIRE_TIME=247200 MODEL_ID=290331537_4 PREDICTOR_TAG= PREDICTOR_VERSION=343 PREDICTOR_TYPE=NNPI_TWSHARED ADDITIONAL_FLAGS="--torch_glow_min_fusion_group_size=30 --pytorch_storage_tier_replayer_sr_connection_options=overall_timeout:1000000,processing_timeout:1000000 --predictor_storage_smc_tier=sigrid.predictor.perf.ansha_perf_test_0819.test.storage --pytorch_predictor_static_runtime_whitelist_by_id=290331537" GFLAGS_CONFIG_PATH=sigrid/predictor/gflags/predictor_gflags_ads_perf_glow_nnpi_pyper_v1 SMC_TIER_NAME=sigrid.predictor.perf.ansha_perf_test_0819.test CLUSTER=tsp_rva ENTITLEMENT_NAME=ads_ranking_infra_test_t17 PREDICTOR_LOCAL_DIRECTORY= ICET_CONFIG_PATH= NNPI_COMPILATION_CONFIG_FILE= NUM_TASKS=1 NNPI_NUM_WORKERS=0 tw job start /data/users/ansha/fbsource/fbcode/tupperware/config/admarket/sigrid/predictor/predictor_perf_canary.tw
```

```buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- StaticRuntime.IndividualOps_Argmin --print-passing-details```

Compared outputs to jit interpreter to check for no differences greater than 1e-3 (with nnc on) https://www.internalfb.com/intern/diff/view-version/136824794/

Reviewed By: hlu1

Differential Revision: D30445635

fbshipit-source-id: 048de8867ac72f764132295d1ebfa843cde2fa27

torch/csrc/jit/runtime/static/ops.cpp

index 4d34ed9..484c4b0 100644 (file)
@@ -9,6 +9,7 @@
 #include <ATen/native/Fill.h>
 #include <ATen/native/IndexingUtils.h>
 #include <ATen/native/Resize.h>
+#include <ATen/native/SharedReduceOps.h>
 #include <ATen/native/TensorAdvancedIndexing.h>
 #include <ATen/native/layer_norm.h>
 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
@@ -178,6 +179,94 @@ Tensor& linear_out(
   return output;
 }
 
+Tensor& c2_argmin_out(
+    Tensor& output,
+    const Tensor& input,
+    const int64_t dim,
+    const bool keepdim) {
+  const auto ndim = input.dim();
+  int64_t dim_ = maybe_wrap_dim(dim, ndim);
+  TORCH_CHECK(dim_ >= 0 && dim_ < ndim);
+
+  const auto in_dims = input.sizes();
+
+  c10::SmallVector<int64_t, 5> out_dims;
+  out_dims.reserve(ndim);
+  int prev_size = 1;
+  int next_size = 1;
+  for (int i = 0; i < dim_; ++i) {
+    out_dims.push_back(in_dims[i]);
+    prev_size *= in_dims[i];
+  }
+  if (keepdim) {
+    out_dims.push_back(1);
+  }
+  for (auto i = dim_ + 1; i < ndim; ++i) {
+    out_dims.push_back(in_dims[i]);
+    next_size *= in_dims[i];
+  }
+  at::native::resize_(output, out_dims, c10::nullopt);
+
+  const auto n = in_dims[dim_];
+
+  if (next_size == 1) {
+    AT_DISPATCH_ALL_TYPES_AND2(
+        kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() {
+          const auto in_ptr = input.data_ptr<scalar_t>();
+          const auto out_ptr = output.data_ptr<int64_t>();
+          // input is a [prev_size, n] tensor.
+          // output is a [prev_size,] tensor.
+          // Thus, access is contiguous/coalesced.
+          for (int i = 0; i < prev_size; ++i) {
+            auto v = std::min_element(
+                in_ptr + i * n,
+                in_ptr + (i + 1) * n,
+                [](scalar_t a, scalar_t b) {
+                  // if a is nan, then a is *less* than b with LessOrNan
+                  // semantics
+                  if (at::_isnan(a)) {
+                    return true;
+                  }
+                  // if a is not nan and b is nan, then a is not less than b
+                  // with LessOrNan semantics otherwise, act normally. If `b` is
+                  // NaN then a < b will always return false, so this is
+                  // equivalent to the first snippet.
+                  return a < b;
+                });
+            out_ptr[i] = std::distance(in_ptr + i * n, v);
+          }
+        });
+  } else {
+    AT_DISPATCH_ALL_TYPES_AND2(
+        kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() {
+          const auto less_or_nan = native::detail::LessOrNan<scalar_t>{};
+
+          const auto in_ptr = input.data_ptr<scalar_t>();
+          const auto out_ptr = output.data_ptr<int64_t>();
+
+          std::memset(out_ptr, 0, prev_size * next_size * sizeof(int64_t));
+
+          for (int i = 0; i < prev_size; ++i) {
+            const scalar_t* cur_in_ptr = in_ptr + i * n * next_size + next_size;
+            for (int k = 1; k < n; ++k) {
+              for (int j = 0; j < next_size; ++j) {
+                int64_t* cur_out_ptr = out_ptr + i * next_size + j;
+                if (less_or_nan(
+                        *cur_in_ptr,
+                        in_ptr
+                            [i * n * next_size + *cur_out_ptr * next_size + j],
+                        *cur_out_ptr,
+                        k)) {
+                  *cur_out_ptr = k;
+                }
+                ++cur_in_ptr;
+              }
+            }
+          }
+        });
+  }
+  return output;
+}
 } // namespace native
 } // namespace at
 
@@ -1209,6 +1298,10 @@ REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator {
     } else {
       auto& out_t = p_node->Output(0).toTensor();
       fastResizeToZero(out_t);
+      if (in0_t.is_contiguous() && dim.has_value()) {
+        at::native::c2_argmin_out(out_t, in0_t, dim.value(), keepdim);
+        return;
+      }
       at::cpu::argmin_out(out_t, in0_t, dim, keepdim);
     }
   };
@@ -1533,6 +1626,5 @@ REGISTER_OPERATOR_FUNCTOR(
         }
       };
     });
-
 } // namespace jit
 } // namespace torch