const ::internal::tflite::operand::Index k_index{node.param().k_index};
// Currently, we only support the vector input.
- assert(_ctx.at(inputData_index).shape().rank() == 1);
+ assert(_ctx.at(inputData_index).shape().rank() == 1 ||
+ _ctx.at(inputData_index).shape().rank() == 2);
- const auto outputValues_shape = _ctx.at(outputValues_index).shape().asVector();
- const auto outputIndices_shape = _ctx.at(outputIndices_index).shape().asVector();
- const auto inputData_shape = _ctx.at(inputData_index).shape().asVector();
+ const auto outputValues_shape = _ctx.at(outputValues_index).shape().asTensor();
+ const auto outputIndices_shape = _ctx.at(outputIndices_index).shape().asTensor();
+ const auto inputData_shape = _ctx.at(inputData_index).shape().asTensor();
const int32_t k = _ctx.at(k_index).asScalar<int32_t>();
// Set shape constraints