From ab5311c40944ef4e4de826f7c0927a05535a1668 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=98=A4=ED=98=95=EC=84=9D/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Tue, 2 Oct 2018 15:53:20 +0900 Subject: [PATCH] Fix custom op's output resize (#2873) Fix custom op TensorFlowMax & TensorFlowSum implementation Output resize policy: check dimension info Signed-off-by: Hyeongseok Oh --- libs/support/tflite/src/kernels/TensorFlowMax.cpp | 45 ++++++++++++++++------- libs/support/tflite/src/kernels/TensorFlowSum.cpp | 45 ++++++++++++++++------- 2 files changed, 62 insertions(+), 28 deletions(-) diff --git a/libs/support/tflite/src/kernels/TensorFlowMax.cpp b/libs/support/tflite/src/kernels/TensorFlowMax.cpp index abc6fda..4280b88 100644 --- a/libs/support/tflite/src/kernels/TensorFlowMax.cpp +++ b/libs/support/tflite/src/kernels/TensorFlowMax.cpp @@ -70,7 +70,7 @@ TfLiteStatus ResizeTempAxis(TfLiteContext *context, TensorFlowMaxOp *op_context, TfLiteStatus ResizeOutputTensor(TfLiteContext *context, TensorFlowMaxOp *op_context) { size_t num_axis = tflite::NumElements(op_context->axis); - const TfLiteIntArray *input_dims = op_context->input->dims; + TfLiteIntArray *input_dims = op_context->input->dims; int input_num_dims = tflite::NumDimensions(op_context->input); const int *axis = op_context->axis->data.i32; @@ -100,26 +100,43 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext *context, TensorFlowMaxOp *op_cont } } // Determines output dimensions. - TfLiteIntArray *output_dims = TfLiteIntArrayCreate(input_num_dims - num_reduce_axis); - int num_skip_axis = 0; - for (int idx = 0; idx < input_num_dims; ++idx) + int output_num_dims = tflite::NumDimensions(op_context->output); + TF_LITE_ENSURE(context, (input_num_dims == output_num_dims) || + (input_num_dims - num_reduce_axis == output_num_dims)); + + if (input_num_dims == output_num_dims) { - bool is_axis = false; + TfLiteIntArray *output_dims = TfLiteIntArrayCopy(input_dims); for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { - if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) - { - ++num_skip_axis; - is_axis = true; - break; - } + int current = axis[axis_idx]; + output_dims->data[current] = 1; } - if (!is_axis) + return context->ResizeTensor(context, op_context->output, output_dims); + } + else + { + TfLiteIntArray *output_dims = TfLiteIntArrayCreate(output_num_dims); + int num_skip_axis = 0; + for (int idx = 0; idx < input_num_dims; ++idx) { - output_dims->data[idx - num_skip_axis] = input_dims->data[idx]; + bool is_axis = false; + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) + { + if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) + { + ++num_skip_axis; + is_axis = true; + break; + } + } + if (!is_axis) + { + output_dims->data[idx - num_skip_axis] = input_dims->data[idx]; + } } + return context->ResizeTensor(context, op_context->output, output_dims); } - return context->ResizeTensor(context, op_context->output, output_dims); } } diff --git a/libs/support/tflite/src/kernels/TensorFlowSum.cpp b/libs/support/tflite/src/kernels/TensorFlowSum.cpp index 632981e..c7a3803 100644 --- a/libs/support/tflite/src/kernels/TensorFlowSum.cpp +++ b/libs/support/tflite/src/kernels/TensorFlowSum.cpp @@ -70,7 +70,7 @@ TfLiteStatus ResizeTempAxis(TfLiteContext *context, TensorFlowSumOp *op_context, TfLiteStatus ResizeOutputTensor(TfLiteContext *context, TensorFlowSumOp *op_context) { size_t num_axis = tflite::NumElements(op_context->axis); - const TfLiteIntArray *input_dims = op_context->input->dims; + TfLiteIntArray *input_dims = op_context->input->dims; int input_num_dims = tflite::NumDimensions(op_context->input); const int *axis = op_context->axis->data.i32; @@ -100,26 +100,43 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext *context, TensorFlowSumOp *op_cont } } // Determines output dimensions. - TfLiteIntArray *output_dims = TfLiteIntArrayCreate(input_num_dims - num_reduce_axis); - int num_skip_axis = 0; - for (int idx = 0; idx < input_num_dims; ++idx) + int output_num_dims = tflite::NumDimensions(op_context->output); + TF_LITE_ENSURE(context, (input_num_dims == output_num_dims) || + (input_num_dims - num_reduce_axis == output_num_dims)); + + if (input_num_dims == output_num_dims) { - bool is_axis = false; + TfLiteIntArray *output_dims = TfLiteIntArrayCopy(input_dims); for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { - if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) - { - ++num_skip_axis; - is_axis = true; - break; - } + int current = axis[axis_idx]; + output_dims->data[current] = 1; } - if (!is_axis) + return context->ResizeTensor(context, op_context->output, output_dims); + } + else + { + TfLiteIntArray *output_dims = TfLiteIntArrayCreate(output_num_dims); + int num_skip_axis = 0; + for (int idx = 0; idx < input_num_dims; ++idx) { - output_dims->data[idx - num_skip_axis] = input_dims->data[idx]; + bool is_axis = false; + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) + { + if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) + { + ++num_skip_axis; + is_axis = true; + break; + } + } + if (!is_axis) + { + output_dims->data[idx - num_skip_axis] = input_dims->data[idx]; + } } + return context->ResizeTensor(context, op_context->output, output_dims); } - return context->ResizeTensor(context, op_context->output, output_dims); } } -- 2.7.4