#include <ATen/native/Fill.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/Resize.h>
+#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/layer_norm.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
return output;
}
+Tensor& c2_argmin_out(
+ Tensor& output,
+ const Tensor& input,
+ const int64_t dim,
+ const bool keepdim) {
+ const auto ndim = input.dim();
+ int64_t dim_ = maybe_wrap_dim(dim, ndim);
+ TORCH_CHECK(dim_ >= 0 && dim_ < ndim);
+
+ const auto in_dims = input.sizes();
+
+ c10::SmallVector<int64_t, 5> out_dims;
+ out_dims.reserve(ndim);
+ int prev_size = 1;
+ int next_size = 1;
+ for (int i = 0; i < dim_; ++i) {
+ out_dims.push_back(in_dims[i]);
+ prev_size *= in_dims[i];
+ }
+ if (keepdim) {
+ out_dims.push_back(1);
+ }
+ for (auto i = dim_ + 1; i < ndim; ++i) {
+ out_dims.push_back(in_dims[i]);
+ next_size *= in_dims[i];
+ }
+ at::native::resize_(output, out_dims, c10::nullopt);
+
+ const auto n = in_dims[dim_];
+
+ if (next_size == 1) {
+ AT_DISPATCH_ALL_TYPES_AND2(
+ kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() {
+ const auto in_ptr = input.data_ptr<scalar_t>();
+ const auto out_ptr = output.data_ptr<int64_t>();
+ // input is a [prev_size, n] tensor.
+ // output is a [prev_size,] tensor.
+ // Thus, access is contiguous/coalesced.
+ for (int i = 0; i < prev_size; ++i) {
+ auto v = std::min_element(
+ in_ptr + i * n,
+ in_ptr + (i + 1) * n,
+ [](scalar_t a, scalar_t b) {
+ // if a is nan, then a is *less* than b with LessOrNan
+ // semantics
+ if (at::_isnan(a)) {
+ return true;
+ }
+ // if a is not nan and b is nan, then a is not less than b
+ // with LessOrNan semantics otherwise, act normally. If `b` is
+ // NaN then a < b will always return false, so this is
+ // equivalent to the first snippet.
+ return a < b;
+ });
+ out_ptr[i] = std::distance(in_ptr + i * n, v);
+ }
+ });
+ } else {
+ AT_DISPATCH_ALL_TYPES_AND2(
+ kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() {
+ const auto less_or_nan = native::detail::LessOrNan<scalar_t>{};
+
+ const auto in_ptr = input.data_ptr<scalar_t>();
+ const auto out_ptr = output.data_ptr<int64_t>();
+
+ std::memset(out_ptr, 0, prev_size * next_size * sizeof(int64_t));
+
+ for (int i = 0; i < prev_size; ++i) {
+ const scalar_t* cur_in_ptr = in_ptr + i * n * next_size + next_size;
+ for (int k = 1; k < n; ++k) {
+ for (int j = 0; j < next_size; ++j) {
+ int64_t* cur_out_ptr = out_ptr + i * next_size + j;
+ if (less_or_nan(
+ *cur_in_ptr,
+ in_ptr
+ [i * n * next_size + *cur_out_ptr * next_size + j],
+ *cur_out_ptr,
+ k)) {
+ *cur_out_ptr = k;
+ }
+ ++cur_in_ptr;
+ }
+ }
+ }
+ });
+ }
+ return output;
+}
} // namespace native
} // namespace at
} else {
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
+ if (in0_t.is_contiguous() && dim.has_value()) {
+ at::native::c2_argmin_out(out_t, in0_t, dim.value(), keepdim);
+ return;
+ }
at::cpu::argmin_out(out_t, in0_t, dim, keepdim);
}
};
}
};
});
-
} // namespace jit
} // namespace torch