From be044a70035fef3204ec00c34f308c89caacd621 Mon Sep 17 00:00:00 2001 From: Chenhu Wang Date: Tue, 17 Nov 2020 15:42:34 +0800 Subject: [PATCH] [CPU] Interpolate operation improvements (#2366) * interpolate improvement * JITTED cubic mode * fix 'code is too big' when JIT * extend test to cover tail code path * transformation of interpolate1 to interpolate4 * add low precision transformation for interpolate4 --- .../src/common/interpolate.cpp | 71 +- .../src/common/transformer.cpp | 2 + .../nodes/mkldnn_interpolate_node.cpp | 1056 +++++++++++++++++--- .../mkldnn_plugin/nodes/mkldnn_interpolate_node.h | 26 +- .../convert_interpolate1_to_interpolate4.hpp | 31 + .../common_optimizations/common_optimizations.cpp | 2 + .../convert_interpolate1_to_interpolate4.cpp | 69 ++ .../interpolate_transformation.cpp | 256 ++++- .../convert_interpolate1_to_interpolate4_test.cpp | 112 +++ .../single_layer_tests/interpolate.cpp | 60 +- .../plugin/cpu/single_layer_tests/interpolate.cpp | 36 +- .../interpolate_function.hpp | 26 + .../interpolate_function.cpp | 66 ++ .../ngraph/runtime/reference/interpolate.hpp | 4 + 14 files changed, 1590 insertions(+), 227 deletions(-) create mode 100644 inference-engine/src/transformations/include/transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp create mode 100644 inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp create mode 100644 inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp diff --git a/inference-engine/src/low_precision_transformations/src/common/interpolate.cpp b/inference-engine/src/low_precision_transformations/src/common/interpolate.cpp index f0bef07..c28eab9 100644 --- a/inference-engine/src/low_precision_transformations/src/common/interpolate.cpp +++ b/inference-engine/src/low_precision_transformations/src/common/interpolate.cpp @@ -20,6 +20,16 @@ void InterpolateTransformation::registerMatcherIn(GraphRewrite& pass, Transforma pass, context, make_op_pattern({ make_op_label(), make_op_label() })); + addPattern( + pass, + context, + make_op_pattern({ make_op_label(), make_op_label(), + make_op_label(), make_op_label() })); + addPattern( + pass, + context, + make_op_pattern({ make_op_label(), make_op_label(), + make_op_label() })); } bool InterpolateTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const { @@ -33,9 +43,19 @@ bool InterpolateTransformation::transform(TransformationContext &context, ngraph } bool InterpolateTransformation::isPrecisionPreserved(std::shared_ptr layer) const noexcept { - std::shared_ptr interpolate = as_type_ptr(layer); - const auto attrs = interpolate->get_attrs(); - return attrs.mode == "nearest"; + std::shared_ptr interpolate1 = as_type_ptr(layer); + if (interpolate1) { + const auto attrs = interpolate1->get_attrs(); + return attrs.mode == "nearest"; + } + + std::shared_ptr interpolate4 = as_type_ptr(layer); + if (interpolate4) { + const auto attrs = interpolate4->get_attrs(); + return attrs.mode == op::v4::Interpolate::InterpolateMode::nearest; + } + + return false; } bool InterpolateTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { @@ -49,19 +69,46 @@ bool InterpolateTransformation::canBeTransformed(const TransformationContext& co if (dequantization.empty()) { return false; } - const auto interpolate = as_type_ptr(layer); - const auto interpAttrs = interpolate->get_attrs(); - if (interpAttrs.axes.count(0) || interpAttrs.axes.count(1)) { - return false; + const auto interpolate1 = as_type_ptr(layer); + if (interpolate1) { + const auto interpAttrs = interpolate1->get_attrs(); + if (interpAttrs.axes.count(0) || interpAttrs.axes.count(1)) { + return false; + } + if (interpAttrs.mode != "nearest") { + return false; + } + if (interpAttrs.pads_begin[0] != 0 || interpAttrs.pads_end[0] != 0 || interpAttrs.align_corners) { + return false; + } } - if (interpAttrs.mode != "nearest") { - return false; - } + const auto interpolate4 = as_type_ptr(layer); + if (interpolate4) { + const auto interpAttrs = interpolate4->get_attrs(); - if (interpAttrs.pads_begin[0] != 0 || interpAttrs.pads_end[0] != 0 || interpAttrs.align_corners) { - return false; + if (interpAttrs.mode != op::v4::Interpolate::InterpolateMode::nearest) { + return false; + } + + auto pads_begin = interpAttrs.pads_begin; + for (int i = 0; i < pads_begin.size(); ++i) { + if (pads_begin[i] != 0) { + return false; + } + } + + auto pads_end = interpAttrs.pads_end; + for (int i = 0; i < pads_end.size(); ++i) { + if (pads_end[i] != 0) { + return false; + } + } + + if (interpAttrs.coordinate_transformation_mode == op::v4::Interpolate::CoordinateTransformMode::align_corners) { + return false; + } } return true; diff --git a/inference-engine/src/low_precision_transformations/src/common/transformer.cpp b/inference-engine/src/low_precision_transformations/src/common/transformer.cpp index e152cfe..66ae78c 100644 --- a/inference-engine/src/low_precision_transformations/src/common/transformer.cpp +++ b/inference-engine/src/low_precision_transformations/src/common/transformer.cpp @@ -242,6 +242,7 @@ LowPrecisionTransformations LowPrecisionTransformer::getAllTransformations(const add(params). add(params). add(params). + add(params). addCleanup(params). @@ -341,6 +342,7 @@ TypeRelaxedReplacer::TypeRelaxedReplacer() { make_matcher_type_relaxed(this); make_matcher_type_relaxed(this); make_matcher_type_relaxed(this); + make_matcher_type_relaxed(this); } LowPrecisionTransformer::LowPrecisionTransformer(const LowPrecisionTransformations& transformations) diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp index 799f2da..a2be1f5 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp @@ -61,8 +61,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi this->preamble(); - mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); - mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); if (attr_.post_ops_.len_ != 0) mov(reg_oc_off, ptr[reg_params + GET_OFF(oc_off)]); if (isa == cpu::avx512_common) @@ -70,8 +68,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi switch (jcp_.mode) { case InterpolateMode::nearest: { - mov(reg_src, ptr[reg_params + GET_OFF(src)]); + mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); + mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]); mov(reg_index, ptr[reg_params + GET_OFF(index)]); + mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); switch (jcp_.layout) { case InterpolateLayoutType::planar: { @@ -94,28 +94,11 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi case InterpolateMode::linear_onnx: { switch (jcp_.layout) { case InterpolateLayoutType::planar: { - mov(reg_src, ptr[reg_params + GET_OFF(src)]); - mov(reg_index, ptr[reg_params + GET_OFF(index)]); - mov(reg_src_aux, ptr[reg_params + GET_OFF(weight)]); - linear_onnx_planar(); break; } case InterpolateLayoutType::block: case InterpolateLayoutType::by_channel: { - mov(reg_src, ptr[reg_params + GET_OFF(weight)]); - mov(reg_src_aux, ptr[reg_params + GET_OFF(weightR)]); - mov(reg_src_aux1, ptr[reg_params + GET_OFF(weightT)]); - mov(reg_src_aux2, ptr[reg_params + GET_OFF(weightB)]); - uni_vbroadcastss(vmm_weightL, ptr[reg_src]); - uni_vbroadcastss(vmm_weightR, ptr[reg_src_aux]); - uni_vbroadcastss(vmm_weightT, ptr[reg_src_aux1]); - uni_vbroadcastss(vmm_weightB, ptr[reg_src_aux2]); - mov(reg_src, ptr[reg_params + GET_OFF(src)]); - mov(reg_src_aux, ptr[reg_params + GET_OFF(srcTR)]); - mov(reg_src_aux1, ptr[reg_params + GET_OFF(srcBL)]); - mov(reg_src_aux2, ptr[reg_params + GET_OFF(srcBR)]); - linear_onnx_c_gathered(); break; } @@ -124,8 +107,23 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } break; } - case InterpolateMode::linear: case InterpolateMode::cubic: { + switch (jcp_.layout) { + case InterpolateLayoutType::planar: { + cubic_planar(); + break; + } + case InterpolateLayoutType::block: + case InterpolateLayoutType::by_channel: { + cubic_c_gathered(); + break; + } + default: + assert(!"unsupported memory layout for interpolate layer with cubic mode."); + } + break; + } + case InterpolateMode::linear: { assert(!"unsupported mode for interpolate layer with JITTED implimentation."); break; } @@ -138,6 +136,9 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi for (auto& inj : eltwise_injectors) inj->prepare_table(); + if ((jcp_.mode == InterpolateMode::cubic) && (jcp_.layout == InterpolateLayoutType::planar)) { + prepare_cubic_planar_table(); + } ker_ = (decltype(ker_)) this->getCode(); } @@ -164,7 +165,12 @@ private: Xbyak::Reg64 reg_oc_off = rax; Xbyak::Reg64 reg_d_weights = rbx; Xbyak::Reg64 reg_d_bias = rcx; - Xbyak::Reg32 reg_index_oc = edx; + Xbyak::Reg32 reg_index_offset = edx; + + // for cubic planar + Xbyak::Reg64 reg_tbl_y = rsi; + Xbyak::Reg64 reg_tbl_x = rbp; + Xbyak::Reg64 reg_table = rdx; // do not need reg_index_offset in this mode, so use rdx Vmm vmm_val = Vmm(0); Xmm xmm_val = Xmm(0); @@ -174,6 +180,7 @@ private: Vmm vmm_d_weights = Vmm(4); Vmm vmm_d_bias = Vmm(5); + // for linear Vmm vmm_weightT = Vmm(15); Xmm xmm_weightT = Xmm(15); Vmm vmm_weightB = Vmm(14); @@ -191,6 +198,32 @@ private: Vmm vmm_valBR = Vmm(8); Xmm xmm_valBR = Xmm(8); + // for cubic + Vmm vmm_src = Vmm(6); + Xmm xmm_src = Xmm(6); + Vmm vmm_dstX = Vmm(7); + + Vmm vmm_weightX0 = vmm_weightT; + Vmm vmm_weightX1 = vmm_weightB; + Vmm vmm_weightX2 = vmm_weightL; + Vmm vmm_weightX3 = vmm_weightR; + Vmm vmm_weightY0 = vmm_valTL; + Vmm vmm_weightY1 = Vmm(10); // vmm_valTR is vmm_val, need reserved + Vmm vmm_weightY2 = vmm_valBL; + Vmm vmm_weightY3 = vmm_valBR; + // cubic planar + Vmm vmm_one = vmm_index; + Vmm vmm_weightY = vmm_weightY0; + Vmm vmm_index_y_itr = vmm_weightY1; + Vmm vmm_index_x_itr = vmm_weightY2; + Vmm vmm_tbl_y = vmm_weightY3; + // temporally used. when post ops, value in vmm_d_weights and vmm_d_bias is re-loaded(init) each time. + Vmm vmm_index_in_y = vmm_d_weights; + Vmm vmm_index_in_x = vmm_d_bias; + + Xbyak::Label l_table_constant; + Opmask k_mask = Xbyak::Opmask(1); + std::vector>> eltwise_injectors; std::vector>> depthwise_injectors; std::vector>> quantization_injectors; @@ -221,8 +254,8 @@ private: Xbyak::Reg64 reg_src_h = rsi; mov(reg_src_h, reg_src); // index_h * IW * dataSize done when built to avoid redundent compute - mov(reg_index_oc, dword[reg_index_h]); - add(reg_src_h, reg_index_oc); // reg_src_h now point to begin of row + mov(reg_index_offset, dword[reg_index_h]); + add(reg_src_h, reg_index_offset); // reg_src_h now point to begin of row // reset index_w, index_w * dataSize done when built to avoid redundent compute mov(reg_index, reg_index_w); @@ -260,8 +293,8 @@ private: jl(nn_tail_loop_end_label, T_NEAR); mov(reg_src_aux, reg_src_h); - mov(reg_index_oc, dword[reg_index]); - add(reg_src_aux, reg_index_oc); + mov(reg_index_offset, dword[reg_index]); + add(reg_src_aux, reg_index_offset); load_scalar(xmm_val, ptr[reg_src_aux], jcp_.src_dt); if (attr_.post_ops_.len_ != 0) @@ -298,8 +331,8 @@ private: jle(nn_loop_end_label, T_NEAR); mov(reg_src_aux, reg_src); - mov(reg_index_oc, dword[reg_index]); - add(reg_src_aux, reg_index_oc); + mov(reg_index_offset, dword[reg_index]); + add(reg_src_aux, reg_index_offset); load_vector(vmm_val, ptr[reg_src_aux], jcp_.src_dt); if (attr_.post_ops_.len_ != 0) @@ -353,8 +386,8 @@ private: // dst and index address is continous, advanced each interator. mov(reg_src_aux, reg_src); // index*C*dataSize done when built to avoid redundent compute - mov(reg_index_oc, dword[reg_index]); - add(reg_src_aux, reg_index_oc); + mov(reg_index_offset, dword[reg_index]); + add(reg_src_aux, reg_index_offset); mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); if (attr_.post_ops_.len_ != 0) @@ -409,11 +442,30 @@ private: } void linear_onnx_c_gathered() { + mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); + + mov(reg_src, ptr[reg_params + GET_OFF(weight_ptr[0])]); + mov(reg_src_aux, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]); + mov(reg_src_aux1, ptr[reg_params + GET_OFF(weight_ptr[0]) + 2 * sizeof(size_t)]); + mov(reg_src_aux2, ptr[reg_params + GET_OFF(weight_ptr[0]) + 3 * sizeof(size_t)]); + uni_vbroadcastss(vmm_weightL, ptr[reg_src]); + uni_vbroadcastss(vmm_weightR, ptr[reg_src_aux]); + uni_vbroadcastss(vmm_weightT, ptr[reg_src_aux1]); + uni_vbroadcastss(vmm_weightB, ptr[reg_src_aux2]); + + mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]); + mov(reg_src_aux, ptr[reg_params + GET_OFF(src_ptr[0]) + sizeof(size_t)]); + mov(reg_src_aux1, ptr[reg_params + GET_OFF(src_ptr[0]) + 2 * sizeof(size_t)]); + mov(reg_src_aux2, ptr[reg_params + GET_OFF(src_ptr[0]) + 3 * sizeof(size_t)]); + mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); + int step = vlen / sizeof(float); int blk = (isa == cpu::sse42) ? (2 * step) : step; Xbyak::Label main_loop_label; Xbyak::Label main_loop_end_label; + Xbyak::Label blk_tail_loop_label; + Xbyak::Label blk_tail_loop_end_label; Xbyak::Label tail_loop_label; Xbyak::Label tail_loop_end_label; L(main_loop_label); @@ -431,14 +483,7 @@ private: load_vector(vmm_valBL, ptr[reg_src_aux1], jcp_.src_dt); load_vector(vmm_valBR, ptr[reg_src_aux2], jcp_.src_dt); - // weightT * (srcTL * weightL + srcTR * weightR) + - // weightB * (srcBL * weightL + srcBR * weightR); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR); - uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR); - uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL); - uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT); - uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB); + linear_onnx_worker(); if (attr_.post_ops_.len_ != 0) { apply_post_ops(jcp_.dst_dt, false); // vmm_val is vmm_valTR @@ -453,12 +498,7 @@ private: load_vector(vmm_valBL, ptr[reg_src_aux1 + sse42_offset * jcp_.src_data_size], jcp_.src_dt); load_vector(vmm_valBR, ptr[reg_src_aux2 + sse42_offset * jcp_.src_data_size], jcp_.src_dt); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR); - uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR); - uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL); - uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT); - uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB); + linear_onnx_worker(); if (attr_.post_ops_.len_ != 0) { apply_post_ops(jcp_.dst_dt, false); @@ -477,7 +517,7 @@ private: sub(reg_work_amount, step); // work_amount is c } else { int dst_stride = blk * jcp_.OW * jcp_.OH * jcp_.dst_data_size; - int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size;; + int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size; add(reg_dst, dst_stride); add(reg_src, src_stride); add(reg_src_aux, src_stride); @@ -490,6 +530,37 @@ private: } L(main_loop_end_label); + step = 4; + L(blk_tail_loop_label); + { + cmp(reg_work_amount, step); + jl(blk_tail_loop_end_label, T_NEAR); + + // use xmm for 4s in tails + load_xmm(xmm_valTL, ptr[reg_src], jcp_.src_dt); + load_xmm(xmm_valTR, ptr[reg_src_aux], jcp_.src_dt); + load_xmm(xmm_valBL, ptr[reg_src_aux1], jcp_.src_dt); + load_xmm(xmm_valBR, ptr[reg_src_aux2], jcp_.src_dt); + + linear_onnx_worker(); + + if (attr_.post_ops_.len_ != 0) { + apply_post_ops(jcp_.dst_dt, false); // vmm_val is vmm_valTR + add(reg_oc_off, step * sizeof(float)); + } + store_xmm(ptr[reg_dst], xmm_valTR, jcp_.dst_dt); + + add(reg_dst, step * jcp_.dst_data_size); + add(reg_src, step * jcp_.src_data_size); + add(reg_src_aux, step * jcp_.src_data_size); + add(reg_src_aux1, step * jcp_.src_data_size); + add(reg_src_aux2, step * jcp_.src_data_size); + sub(reg_work_amount, step); + + jmp(blk_tail_loop_label, T_NEAR); + } + L(blk_tail_loop_end_label); + step = 1; L(tail_loop_label); { @@ -502,12 +573,7 @@ private: load_scalar(xmm_valBL, ptr[reg_src_aux1], jcp_.src_dt); load_scalar(xmm_valBR, ptr[reg_src_aux2], jcp_.src_dt); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR); - uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR); - uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL); - uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT); - uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB); + linear_onnx_worker(); if (attr_.post_ops_.len_ != 0) { apply_post_ops(jcp_.dst_dt, false); // vmm_val is vmm_valTR @@ -528,6 +594,12 @@ private: } void linear_onnx_planar() { + mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); + mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]); + mov(reg_index, ptr[reg_params + GET_OFF(index)]); + mov(reg_src_aux, ptr[reg_params + GET_OFF(weight_ptr[0])]); + mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); + int step = vlen / sizeof(float); int index_stride = jcp_.OW * jcp_.OH * jcp_.indices_size; int weight_stride = jcp_.OW * jcp_.OH * sizeof(float); @@ -563,14 +635,7 @@ private: load_vector(vmm_weightT, ptr[reg_src_aux + 2 * weight_stride], memory::f32); load_vector(vmm_weightB, ptr[reg_src_aux + 3 * weight_stride], memory::f32); - // weightT * (srcTL * weightL + srcTR * weightR) + - // weightB * (srcBL * weightL + srcBR * weightR); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR); - uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR); - uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL); - uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT); - uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB); + linear_onnx_worker(); if (attr_.post_ops_.len_ != 0) { apply_post_ops(jcp_.dst_dt, true); // vmm_val is vmm_valTR, broadcase is true @@ -594,23 +659,23 @@ private: // still use xmm on avx2/avx512 mov(reg_src_aux1, reg_src); - mov(reg_index_oc, dword[reg_index]); - add(reg_src_aux1, reg_index_oc); + mov(reg_index_offset, dword[reg_index]); + add(reg_src_aux1, reg_index_offset); load_scalar(xmm_valTL, ptr[reg_src_aux1], jcp_.src_dt); mov(reg_src_aux1, reg_src); - mov(reg_index_oc, dword[reg_index + index_stride]); - add(reg_src_aux1, reg_index_oc); + mov(reg_index_offset, dword[reg_index + index_stride]); + add(reg_src_aux1, reg_index_offset); load_scalar(xmm_valTR, ptr[reg_src_aux1], jcp_.src_dt); mov(reg_src_aux1, reg_src); - mov(reg_index_oc, dword[reg_index + 2 * index_stride]); - add(reg_src_aux1, reg_index_oc); + mov(reg_index_offset, dword[reg_index + 2 * index_stride]); + add(reg_src_aux1, reg_index_offset); load_scalar(xmm_valBL, ptr[reg_src_aux1], jcp_.src_dt); mov(reg_src_aux1, reg_src); - mov(reg_index_oc, dword[reg_index + 3 * index_stride]); - add(reg_src_aux1, reg_index_oc); + mov(reg_index_offset, dword[reg_index + 3 * index_stride]); + add(reg_src_aux1, reg_index_offset); load_scalar(xmm_valBR, ptr[reg_src_aux1], jcp_.src_dt); load_scalar(xmm_weightL, ptr[reg_src_aux], memory::f32); @@ -618,12 +683,7 @@ private: load_scalar(xmm_weightT, ptr[reg_src_aux + 2 * weight_stride], memory::f32); load_scalar(xmm_weightB, ptr[reg_src_aux + 3 * weight_stride], memory::f32); - uni_vmulps(xmm_valTR, xmm_valTR, xmm_weightR); - uni_vmulps(xmm_valBR, xmm_valBR, xmm_weightR); - uni_vfmadd231ps(xmm_valTR, xmm_valTL, xmm_weightL); - uni_vfmadd231ps(xmm_valBR, xmm_valBL, xmm_weightL); - uni_vmulps(xmm_valTR, xmm_valTR, xmm_weightT); - uni_vfmadd231ps(xmm_valTR, xmm_valBR, xmm_weightB); + linear_onnx_worker(); if (attr_.post_ops_.len_ != 0) { apply_post_ops(jcp_.dst_dt, true); // process on vmm_val, vmm_val is vmm_valTR, and bc @@ -640,6 +700,473 @@ private: L(tail_loop_end_label); } + // weightT * (srcTL * weightL + srcTR * weightR) + + // weightB * (srcBL * weightL + srcBR * weightR) + inline void linear_onnx_worker() { + uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR); + uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR); + uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL); + uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL); + uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT); + uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB); + } + + void cubic_c_gathered() { + mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); + mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]); + mov(reg_index, ptr[reg_params + GET_OFF(index)]); + mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); + + // weight_ptr[0] point to weightX + mov(reg_src_aux1, ptr[reg_params + GET_OFF(weight_ptr[0])]); + uni_vbroadcastss(vmm_weightX0, ptr[reg_src_aux1]); + uni_vbroadcastss(vmm_weightX1, ptr[reg_src_aux1 + 1 * sizeof(float)]); + uni_vbroadcastss(vmm_weightX2, ptr[reg_src_aux1 + 2 * sizeof(float)]); + uni_vbroadcastss(vmm_weightX3, ptr[reg_src_aux1 + 3 * sizeof(float)]); + + // weight_ptr[1] point to weightY + mov(reg_src_aux1, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]); + uni_vbroadcastss(vmm_weightY0, ptr[reg_src_aux1]); + uni_vbroadcastss(vmm_weightY1, ptr[reg_src_aux1 + 1 * sizeof(float)]); + uni_vbroadcastss(vmm_weightY2, ptr[reg_src_aux1 + 2 * sizeof(float)]); + uni_vbroadcastss(vmm_weightY3, ptr[reg_src_aux1 + 3 * sizeof(float)]); + + int step = vlen / sizeof(float); + int blk = (isa == cpu::sse42) ? (2 * step) : step; + + Xbyak::Label main_loop_label; + Xbyak::Label main_loop_end_label; + Xbyak::Label tail_loop_label; + Xbyak::Label tail_loop_end_label; + L(main_loop_label); + { + if (jcp_.layout == InterpolateLayoutType::by_channel) { + cmp(reg_work_amount, step); + jl(main_loop_end_label, T_NEAR); + } else { + cmp(reg_work_amount, 1); + jl(tail_loop_end_label, T_NEAR); + } + + uni_vpxor(vmm_val, vmm_val, vmm_val); + + cubic_c_gathered_matrix(false); + + if (attr_.post_ops_.len_ != 0) { + apply_post_ops(jcp_.dst_dt, false); // vmm_val is default dst value to post_ops and store + add(reg_oc_off, step * sizeof(float)); + } + store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt); + + if ((isa == cpu::sse42) && (jcp_.layout == InterpolateLayoutType::block)) { + int sse42_offset = 4; // vmm is xmm here + add(reg_src, sse42_offset * jcp_.src_data_size); + add(reg_dst, sse42_offset * jcp_.dst_data_size); + + uni_vpxor(vmm_val, vmm_val, vmm_val); + + cubic_c_gathered_matrix(false); + + if (attr_.post_ops_.len_ != 0) { + apply_post_ops(jcp_.dst_dt, false); + add(reg_oc_off, step * sizeof(float)); // second step for one blk + } + store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt); + + sub(reg_src, sse42_offset * jcp_.src_data_size); + sub(reg_dst, sse42_offset * jcp_.dst_data_size); + } + if (jcp_.layout == InterpolateLayoutType::by_channel) { + int dst_stride = step * jcp_.dst_data_size; + int src_stride = step * jcp_.src_data_size; + add(reg_dst, dst_stride); + add(reg_src, src_stride); + sub(reg_work_amount, step); // work_amount is c + } else { + int dst_stride = blk * jcp_.OW * jcp_.OH * jcp_.dst_data_size; + int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size; + add(reg_dst, dst_stride); + add(reg_src, src_stride); + sub(reg_work_amount, 1); // work_amount = div_up(c, blk), no tails + } + + jmp(main_loop_label, T_NEAR); + } + L(main_loop_end_label); + + // only for by_channel layout for tails. + step = 1; + L(tail_loop_label); + { + cmp(reg_work_amount, 1); + jl(tail_loop_end_label, T_NEAR); + + // store final computed value + uni_vpxor(vmm_val, vmm_val, vmm_val); + + cubic_c_gathered_matrix(true); + + if (attr_.post_ops_.len_ != 0) { + apply_post_ops(jcp_.dst_dt, false); // vmm_val is default dst value + add(reg_oc_off, step * sizeof(float)); + } + store_scalar(ptr[reg_dst], xmm_val, jcp_.dst_dt); + + int dst_stride = step * jcp_.dst_data_size; + int src_stride = step * jcp_.src_data_size; + add(reg_dst, dst_stride); + add(reg_src, src_stride); + sub(reg_work_amount, step); // work_amount is c + + jmp(tail_loop_label, T_NEAR); + } + L(tail_loop_end_label); + } + + inline void cubic_c_gathered_matrix(bool is_scalar) { + // y0: (x0 * weightX0 + x1 * weightX1 + x2 * weightX2 + x3 * weightX3) * weightY0 + cubic_c_gathered_line(0, vmm_weightY0, is_scalar); + // y1 + cubic_c_gathered_line(4, vmm_weightY1, is_scalar); + // y2 + cubic_c_gathered_line(8, vmm_weightY2, is_scalar); + // y3 + cubic_c_gathered_line(12, vmm_weightY3, is_scalar); + } + + inline void cubic_c_gathered_line(int index_start, Vmm vmm_weight, bool is_scalar) { + uni_vpxor(vmm_dstX, vmm_dstX, vmm_dstX); + cubic_c_gathered_pixel(index_start, vmm_weightX0, is_scalar); + cubic_c_gathered_pixel(index_start + 1, vmm_weightX1, is_scalar); + cubic_c_gathered_pixel(index_start + 2, vmm_weightX2, is_scalar); + cubic_c_gathered_pixel(index_start + 3, vmm_weightX3, is_scalar); + uni_vfmadd231ps(vmm_val, vmm_dstX, vmm_weight); + } + + inline void cubic_c_gathered_pixel(int i, Vmm vmm_weight, bool is_scalar) { + mov(reg_src_aux, reg_src); + mov(reg_index_offset, dword[reg_index + i * jcp_.indices_size]); + add(reg_src_aux, reg_index_offset); + if (!is_scalar) { + load_vector(vmm_src, ptr[reg_src_aux], jcp_.src_dt); + } else { + load_scalar(xmm_src, ptr[reg_src_aux], jcp_.src_dt); + } + uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weight); + } + + void cubic_planar() { + mov(reg_table, l_table_constant); + // src_ptr[2] for oh sequence, src_ptr[3] for ow sequence + mov(reg_tbl_y, ptr[reg_params + GET_OFF(src_ptr[0]) + 2 * sizeof(size_t)]); + mov(reg_tbl_x, ptr[reg_params + GET_OFF(src_ptr[0]) + 3 * sizeof(size_t)]); + uni_vmovdqu(vmm_one, cubic_planar_table_val(0)); + uni_vpxor(vmm_zero, vmm_zero, vmm_zero); + + mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); + mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]); + // index_OW + mov(reg_index, ptr[reg_params + GET_OFF(index)]); + // index_OH from src_ptr[1] + Xbyak::Reg64 reg_index_y = reg_src_aux; + mov(reg_index_y, ptr[reg_params + GET_OFF(src_ptr[0]) + sizeof(size_t)]); + // weight_OW + Xbyak::Reg64 reg_weight_x = reg_src_aux1; + mov(reg_weight_x, ptr[reg_params + GET_OFF(weight_ptr[0])]); + // weight_OH + Xbyak::Reg64 reg_weight_y = reg_src_aux2; + mov(reg_weight_y, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]); + mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); + + int step = vlen / sizeof(float); + int grid_len = 4; + + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 + // 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 + // for 3th step(8): 16 17 18 19 20 21 22 23 + // y: 0 0 0 0 1 1 1 1 + // x: 16 17 18 19 0 1 2 3 + + Xbyak::Label main_loop_label; + Xbyak::Label main_loop_end_label; + Xbyak::Label tail_loop_label; + Xbyak::Label tail_loop_end_label; + L(main_loop_label); + { + cmp(reg_work_amount, step); + jl(main_loop_end_label, T_NEAR); + + // vmm_tbl_y: (0 0 0 0 1 1 1 1 * index_size) --> (0 0 0 0 4 4 4 4) + uni_vmovdqu(vmm_tbl_y, ptr[reg_tbl_y]); + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + // vmm_index_in_y: 0 0 0 0 2 2 2 2 + vpgatherdd(vmm_index_in_y, ptr[reg_index_y + vmm_tbl_y], vmm_mask); + + // use vmm_val temporally for value in reg_tbl_x: 16 17 18 19 0 1 2 3 + uni_vmovdqu(vmm_val, ptr[reg_tbl_x]); + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + // e.g. vmm_index_in_x: 32 34 36 38 0 2 4 6, now save src index. + vpgatherdd(vmm_index_in_x, ptr[reg_index + vmm_val], vmm_mask); + + // build weightX used in y0-y3 + // weight format: w0_0 w1_0 w2_0 w3_0 w0_1 w1_1 w2_1 w3_1 ... + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + vgatherdps(vmm_weightX0, ptr[reg_weight_x + vmm_val * grid_len], vmm_mask); // 4 in vmm_val for weight_size, another 4 for grid_len + + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + // shift weight_size then gather second weight + vgatherdps(vmm_weightX1, ptr[reg_weight_x + sizeof(float) + (vmm_val * grid_len)], vmm_mask); + + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + vgatherdps(vmm_weightX2, ptr[reg_weight_x + 2 * sizeof(float) + (vmm_val * grid_len)], vmm_mask); + + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + vgatherdps(vmm_weightX3, ptr[reg_weight_x + 3 * sizeof(float) + (vmm_val * grid_len)], vmm_mask); + // vmm_val is now relieved and used for dst_value + + uni_vpxor(vmm_val, vmm_val, vmm_val); + // y0 + vpsubd(vmm_index_y_itr, vmm_index_in_y, vmm_one); + // crop to [0, IH - 1] + vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1)); + vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); + + // weight y0 + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + vgatherdps(vmm_weightY, ptr[reg_weight_y + (vmm_tbl_y * grid_len)], vmm_mask); + cubic_planar_line(false); + + // y1 + // crop to [0, IH - 1] + vpminsd(vmm_index_y_itr, vmm_index_in_y, cubic_planar_table_val(1)); + vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); + // weight y1: shift weight_size + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + vgatherdps(vmm_weightY, ptr[reg_weight_y + sizeof(float) + (vmm_tbl_y * grid_len)], vmm_mask); + cubic_planar_line(false); + + // y2 + vpaddd(vmm_index_y_itr, vmm_index_in_y, vmm_one); + // crop to [0, IH - 1] + vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1)); + vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); + // weight y2 + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + vgatherdps(vmm_weightY, ptr[reg_weight_y + 2 * sizeof(float) + (vmm_tbl_y * grid_len)], vmm_mask); + cubic_planar_line(false); + + // y3 + vpaddd(vmm_index_y_itr, vmm_index_in_y, vmm_one); + vpaddd(vmm_index_y_itr, vmm_index_y_itr, vmm_one); + // crop to [0, IH - 1] + vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1)); + vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); + // weight y3 + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + vgatherdps(vmm_weightY, ptr[reg_weight_y + 3 * sizeof(float) + (vmm_tbl_y * grid_len)], vmm_mask); + cubic_planar_line(false); + + if (attr_.post_ops_.len_ != 0) { + apply_post_ops(jcp_.dst_dt, true); // oc_off is broadcast and always the same value for this channel + } + store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt); + + add(reg_tbl_y, step * sizeof(int)); // sizeof(int): sequence by dd() + add(reg_tbl_x, step * sizeof(int)); + add(reg_dst, step * jcp_.dst_data_size); + + sub(reg_work_amount, step); + + jmp(main_loop_label, T_NEAR); + } + L(main_loop_end_label); + + step = 1; + L(tail_loop_label); + { + cmp(reg_work_amount, 1); + jl(tail_loop_end_label, T_NEAR); + + // get idx for input + movss(Xmm(vmm_tbl_y.getIdx()), ptr[reg_tbl_y]); + gather_i32_indices(vmm_index_in_y, reg_index_y, 0, vmm_tbl_y, 1, memory::s32, true); + + movss(Xmm(vmm_val.getIdx()), ptr[reg_tbl_x]); + gather_i32_indices(vmm_index_in_x, reg_index, 0, vmm_val, 1, memory::s32, true); + // gather weightX by input idx, used in y0-y3 + gather_i32_indices(vmm_weightX0, reg_weight_x, 0, vmm_val, grid_len, memory::f32, true); + gather_i32_indices(vmm_weightX1, reg_weight_x, sizeof(float), vmm_val, grid_len, memory::f32, true); + gather_i32_indices(vmm_weightX2, reg_weight_x, 2 * sizeof(float), vmm_val, grid_len, memory::f32, true); + gather_i32_indices(vmm_weightX3, reg_weight_x, 3 * sizeof(float), vmm_val, grid_len, memory::f32, true); + // vmm_val is now relieved and used for dst_value + + uni_vpxor(vmm_val, vmm_val, vmm_val); + // y0 + vpsubd(vmm_index_y_itr, vmm_index_in_y, vmm_one); + // crop to [0, IH - 1] + vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1)); + vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); + + gather_i32_indices(vmm_weightY, reg_weight_y, 0, vmm_tbl_y, grid_len, memory::f32, true); + cubic_planar_line(true); + + // y1 + // crop to [0, IH - 1] + vpminsd(vmm_index_y_itr, vmm_index_in_y, cubic_planar_table_val(1)); + vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); + // weight y1: shift weight_size + gather_i32_indices(vmm_weightY, reg_weight_y, sizeof(float), vmm_tbl_y, grid_len, memory::f32, true); + cubic_planar_line(true); + + // y2 + vpaddd(vmm_index_y_itr, vmm_index_in_y, vmm_one); + // crop to [0, IH - 1] + vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1)); + vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); + // weight y2 + gather_i32_indices(vmm_weightY, reg_weight_y, 2 * sizeof(float), vmm_tbl_y, grid_len, memory::f32, true); + cubic_planar_line(true); + + // y3 + vpaddd(vmm_index_y_itr, vmm_index_in_y, vmm_one); + vpaddd(vmm_index_y_itr, vmm_index_y_itr, vmm_one); + // crop to [0, IH - 1] + vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1)); + vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); + // weight y3 + gather_i32_indices(vmm_weightY, reg_weight_y, 3 * sizeof(float), vmm_tbl_y, grid_len, memory::f32, true); + cubic_planar_line(true); + + if (attr_.post_ops_.len_ != 0) { + apply_post_ops(jcp_.dst_dt, true); // oc_off is broadcast and always the same value for this channel + } + store_scalar(ptr[reg_dst], Xmm(vmm_val.getIdx()), jcp_.dst_dt); + + add(reg_tbl_y, step * sizeof(int)); // sizeof(int): sequence with dd() + add(reg_tbl_x, step * sizeof(int)); + add(reg_dst, step * jcp_.dst_data_size); + + sub(reg_work_amount, step); + + jmp(tail_loop_label, T_NEAR); + } + L(tail_loop_end_label); + } + + inline void cubic_planar_line(bool is_scalar) { + uni_vpxor(vmm_dstX, vmm_dstX, vmm_dstX); + cubic_planar_pixel(0, is_scalar); + cubic_planar_pixel(1, is_scalar); + cubic_planar_pixel(2, is_scalar); + cubic_planar_pixel(3, is_scalar); + uni_vfmadd231ps(vmm_val, vmm_dstX, vmm_weightY); + } + + inline void cubic_planar_pixel(int itr, bool is_scalar) { + // vmm_index_in_x have index for src + if (itr == 0) { + vpsubd(vmm_index_x_itr, vmm_index_in_x, vmm_one); + } else if (itr == 1) { + vpaddd(vmm_index_x_itr, vmm_index_in_x, vmm_zero); + } else if (itr == 2) { + vpaddd(vmm_index_x_itr, vmm_index_in_x, vmm_one); + } else if (itr == 3) { + vpaddd(vmm_index_x_itr, vmm_index_in_x, vmm_one); + vpaddd(vmm_index_x_itr, vmm_index_x_itr, vmm_one); + } + + // crop to [0, IW - 1] + vpminsd(vmm_index_x_itr, vmm_index_x_itr, cubic_planar_table_val(2)); + vpmaxsd(vmm_index_x_itr, vmm_index_x_itr, vmm_zero); + + // value + // index is: ptr[reg_src + (vmm_index_y_itr * jcp_.IW + vmm_index_x_itr) * jcp_.src_data_size] + uni_vmovdqu(vmm_mask, cubic_planar_table_val(2)); + vpaddd(vmm_mask, vmm_mask, vmm_one); // (IW - 1) + 1 = IW + uni_vpmulld(vmm_mask, vmm_mask, vmm_index_y_itr); + uni_vpaddd(vmm_index_x_itr, vmm_index_x_itr, vmm_mask); + gather_i32_indices(vmm_src, reg_src, 0, vmm_index_x_itr, jcp_.src_data_size, memory::f32, is_scalar); + + if (itr == 0) { + uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weightX0); + } else if (itr == 1) { + uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weightX1); + } else if (itr == 2) { + uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weightX2); + } else if (itr == 3) { + uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weightX3); + } + } + + inline void prepare_cubic_planar_table() { + auto broadcast_int = [&](int val) { + for (size_t d = 0; d < vlen / sizeof(int); ++d) { + dd(val); + } + }; + + align(64); + L(l_table_constant); + broadcast_int(vals_for_cubic_planar.int_one); + broadcast_int(jcp_.IH - 1); + broadcast_int(jcp_.IW - 1); + dd(vals_for_cubic_planar.mask_gather_avx512); + } + + struct vals_for_cubic_planar_type { + int int_one = 0x00000001; + int mask_gather_avx512 = 0x0000ffff; // 00000000000000001111111111111111 + } vals_for_cubic_planar; + + inline Xbyak::Address cubic_planar_table_val(int index) { + return ptr[reg_table + index * vlen]; + } + + // always gather to Vmm, compute with Vmm, store with Xmm if scalar + inline void gather_i32_indices(Vmm vmm_src, const Xbyak::Reg64 &base, int offset, Vmm vmm_indices, int scale, + memory::data_type src_dt, bool is_scalar) { + Xbyak::Address table_idx = ptr[base + offset + vmm_indices * scale]; + if ((isa == cpu::avx512_common) && !is_scalar) { + // [0-15] bit of int to mask + kmovw(k_mask, cubic_planar_table_val(3)); + if (src_dt == memory::f32) { + vgatherdps(vmm_src | k_mask, table_idx); // dword index, packed single data + } else if (src_dt == memory::s32) { + vpgatherdd(vmm_src | k_mask, table_idx); // dword index, dword data + } + } else if ((isa == cpu::avx2) && !is_scalar) { + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + if (src_dt == memory::f32) { + vgatherdps(vmm_src, table_idx, vmm_mask); + } else if (src_dt == memory::s32) { + vpgatherdd(vmm_src, table_idx, vmm_mask); + } + } else { + const int gpr_size = 8; + sub(rsp, gpr_size); + // move content in register to content in address(ptr[]) + mov(ptr[rsp], reg_tmp_64); + + // replace index with value in stack + sub(rsp, vlen); + uni_vmovdqu(ptr[rsp], vmm_indices); + + int repeats = is_scalar ? 1 : vlen / sizeof(float); + for (size_t i = 0; i < repeats; ++i) { + mov(reg_tmp_64.cvt32(), ptr[rsp + i * sizeof(int)]); // sizeof(int) index_size + table_idx = ptr[base + offset + reg_tmp_64 * scale]; // scale: sizeof(float) value_size + mov(reg_tmp_64.cvt32(), table_idx); + mov(ptr[rsp + i * sizeof(int)], reg_tmp_64.cvt32()); + } + + uni_vmovups(vmm_src, ptr[rsp]); + add(rsp, vlen); + // restore GPR state + mov(reg_tmp_64, ptr[rsp]); + add(rsp, gpr_size); + } + } + inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, memory::data_type src_dt) { switch (src_dt) { case memory::f32: @@ -653,10 +1180,9 @@ private: uni_vpmovzxbd(vmm_src, op); break; case memory::bf16: - if (isa != cpu::sse42) { - vpmovzxwd(vmm_src, op); - } + uni_vpmovzxwd(vmm_src, op); uni_vpslld(vmm_src, vmm_src, 16); + break; default: assert(!"unknown dst_dt"); } @@ -665,6 +1191,30 @@ private: uni_vcvtdq2ps(vmm_src, vmm_src); } + inline void load_xmm(Xmm xmm_src, const Xbyak::Address &op, memory::data_type src_dt) { + switch (src_dt) { + case memory::f32: + case memory::s32: + uni_vmovups(xmm_src, op); + break; + case memory::s8: + uni_vpmovsxbd(xmm_src, op); + break; + case memory::u8: + uni_vpmovzxbd(xmm_src, op); + break; + case memory::bf16: + uni_vpmovzxwd(xmm_src, op); + uni_vpslld(xmm_src, xmm_src, 16); + break; + default: + assert(!"unknown dst_dt"); + } + + if (src_dt != memory::f32 && src_dt != data_type::bf16) + uni_vcvtdq2ps(xmm_src, xmm_src); + } + inline void load_scalar(Xmm xmm_src, const Xbyak::Address &op, memory::data_type src_dt) { switch (src_dt) { case memory::f32: @@ -682,6 +1232,7 @@ private: case memory::bf16: pinsrw(xmm_src, op, 0x0); uni_vpslld(xmm_src, xmm_src, 16); + break; default: assert(!"unknown dst_dt"); } @@ -736,6 +1287,37 @@ private: } } + inline void store_xmm(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt) { + if (dst_dt != memory::f32 && dst_dt != memory::bf16) { + uni_vcvtps2dq(xmm_dst, xmm_dst); + } + + switch (dst_dt) { + case memory::f32: + case memory::s32: + uni_vmovups(op, xmm_dst); + break; + case memory::s8: + uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst); + uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst); + movd(op, xmm_dst); + break; + case memory::u8: + uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst); + uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst); + movd(op, xmm_dst); + break; + case memory::bf16: + pshuflw(xmm_dst, xmm_dst, 0x0d); // 01 01 01 01 --> 01 01 11 00 imm=0b00001101 + pshufhw(xmm_dst, xmm_dst, 0x0d); // 01 01 11 00 --> 11 00 11 00 + pshufd(xmm_dst, xmm_dst, 0x08); // 11 00 11 00 --> 11 11 00 00 imm=0b00001000 + vmovq(op, xmm_dst); + break; + default: + assert(!"unknown dst_dt"); + } + } + inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt) { if (dst_dt != data_type::f32 && dst_dt != data_type::bf16) { uni_vcvtps2dq(xmm_dst, xmm_dst); @@ -761,6 +1343,7 @@ private: case memory::bf16: uni_vpsrld(xmm_dst, xmm_dst, 16); pextrw(op, xmm_dst, 0x0); + break; default: assert(!"unknown dst_dt"); } @@ -1056,7 +1639,7 @@ void MKLDNNInterpolateNode::initSupportedPrimitiveDescriptors() { supportedPrimitiveDescriptors.push_back({config, implDetail, dataFormat}); }; - if (mode == InterpolateMode::nearest || mode == InterpolateMode::linear_onnx) { + if (mode != InterpolateMode::linear) { // blk and by_channel JIT kernel on sse42 or above machine if (mayiuse(cpu::sse42)) { if (getParentEdgeAt(DATA_ID)->getDims().ndims() == 4) { @@ -1091,7 +1674,7 @@ void MKLDNNInterpolateNode::initSupportedPrimitiveDescriptors() { if (mayiuse(cpu::avx2) && inputPrec == Precision::FP32) { pushDesc(MKLDNNMemory::GetPlainFormat(getParentEdgeAt(DATA_ID)->getDims()), jit_avx2); } - } else if (mode == InterpolateMode::linear || mode == InterpolateMode::cubic) { + } else { pushDesc(MKLDNNMemory::GetPlainFormat(getParentEdgeAt(DATA_ID)->getDims()), ref); } } @@ -1129,8 +1712,8 @@ void MKLDNNInterpolateNode::createPrimitive() { size_t dimSize = dstDim.size(); jcp.OW = dstDim[dimSize - 1]; jcp.OH = dstDim[dimSize - 2]; - jcp.IW = srcDim[dimSize - 1]; - jcp.IH = srcDim[dimSize - 2]; + jcp.IW = srcDimPad[dimSize - 1]; + jcp.IH = srcDimPad[dimSize - 2]; if (MKLDNNMemory::GetPlainLayout(getChildEdgeAt(0)->getDims()) == selected_layout) { jcp.layout = InterpolateLayoutType::planar; @@ -1140,7 +1723,7 @@ void MKLDNNInterpolateNode::createPrimitive() { jcp.layout = InterpolateLayoutType::block; } - if (mode == InterpolateMode::nearest || mode == InterpolateMode::linear_onnx) { + if (mode == InterpolateMode::nearest || mode == InterpolateMode::linear_onnx || mode == InterpolateMode::cubic) { if (jcp.layout != InterpolateLayoutType::planar) { if (mayiuse(cpu::avx512_common)) { interpolateKernel.reset(new jit_uni_interpolate_kernel_f32(jcp, *attr.get())); @@ -1175,11 +1758,11 @@ void MKLDNNInterpolateNode::createPrimitive() { break; } case InterpolateMode::linear: { - buidTblLinear(srcDimPad5d, dstDim5d, dataScales, LINEAR_KERNEL, antialias); + buildTblLinear(srcDimPad5d, dstDim5d, dataScales, LINEAR_KERNEL, antialias); break; } case InterpolateMode::cubic: { - buidTblCubic(srcDimPad5d, dstDim5d, dataScales, cubeCoeff); + buildTblCubic(srcDimPad5d, dstDim5d, dataScales, cubeCoeff, jcp.layout); break; } default: { @@ -1338,7 +1921,7 @@ static inline float triangleCoeff(float x) { // wd .........wd, wh............wh, ww.............ww, id...........id, ih............ih, iw..............iw // | | // wh0.....wh_diameter ih0.....ih_diameter -void MKLDNNInterpolateNode::buidTblLinear(SizeVector& srcDimPad5d, SizeVector& dstDim5d, +void MKLDNNInterpolateNode::buildTblLinear(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector& dataScales, int kernel_width, bool antialias) { int dimSize = srcDim.size(); float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f; @@ -1429,7 +2012,8 @@ std::vector MKLDNNInterpolateNode::getCubicCoeffs(float mantissa, float a // table layout: // OW OW OW OW OW OH OH OH OH OH // x_idx x_weight0 x_weight1 x_weight2 x_weight3 y_idx y_weight0 y_weight1 y_weight2 y_weight3 -void MKLDNNInterpolateNode::buidTblCubic(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector& dataScales, float cubicCoeff) { +void MKLDNNInterpolateNode::buildTblCubic(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector& dataScales, + float cubicCoeff, InterpolateLayoutType layout) { int dimSize = srcDim.size(); float fy = dataScales[dimSize - 2]; float fx = dataScales[dimSize - 1]; @@ -1438,9 +2022,18 @@ void MKLDNNInterpolateNode::buidTblCubic(SizeVector& srcDimPad5d, SizeVector& ds // idxNum for index, CUBIC_GRID_LEN for weight const int idxNum = 1; - indexTable.resize((CUBIC_GRID_LEN + idxNum) * OW + (CUBIC_GRID_LEN + idxNum) * OH); - int *xOrigin = static_cast(&indexTable[0]); - float *xFactor = reinterpret_cast(&indexTable[OW]); + size_t idxWeightSize = (CUBIC_GRID_LEN + idxNum) * OW + (CUBIC_GRID_LEN + idxNum) * OH; + if (layout != InterpolateLayoutType::planar) { + indexTable.resize(idxWeightSize); + } else { + size_t sequenceSize = 2 * OH * OW; + indexTable.resize(idxWeightSize + sequenceSize); + } + + int tblAdvance = 0; + int *xOrigin = static_cast(&indexTable[tblAdvance]); + tblAdvance += OW; + float *xFactor = reinterpret_cast(&indexTable[tblAdvance]); for (int ox = 0; ox < OW; ox++) { float ix = coordTransToInput(ox, fx, IW, OW); int ix_r = static_cast(std::floor(ix)); @@ -1453,8 +2046,10 @@ void MKLDNNInterpolateNode::buidTblCubic(SizeVector& srcDimPad5d, SizeVector& ds xFactor[CUBIC_GRID_LEN * ox + 3] = coffes[3]; } - int *yOrigin = static_cast(&indexTable[(CUBIC_GRID_LEN + idxNum) * OW]); - float *yFactor = reinterpret_cast(&indexTable[(CUBIC_GRID_LEN + idxNum) * OW + OH]); + tblAdvance += CUBIC_GRID_LEN * OW; + int *yOrigin = static_cast(&indexTable[tblAdvance]); + tblAdvance += OH; + float *yFactor = reinterpret_cast(&indexTable[tblAdvance]); for (int oy = 0; oy < OH; oy++) { float iy = coordTransToInput(oy, fy, IH, OH); int iy_r = static_cast(std::floor(iy)); @@ -1466,6 +2061,20 @@ void MKLDNNInterpolateNode::buidTblCubic(SizeVector& srcDimPad5d, SizeVector& ds yFactor[CUBIC_GRID_LEN * oy + 2] = coffes[2]; yFactor[CUBIC_GRID_LEN * oy + 3] = coffes[3]; } + + if (layout == InterpolateLayoutType::planar) { + tblAdvance += CUBIC_GRID_LEN * OH; + int *sequenceOH = static_cast(&indexTable[tblAdvance]); + tblAdvance += OH * OW; + int *sequenceOW = static_cast(&indexTable[tblAdvance]); + for (int h = 0; h < OH; ++h) { + int offset = h * OW; + for (int w = 0; w < OW; ++w) { + sequenceOH[offset + w] = h * sizeof(int); + sequenceOW[offset + w] = w * sizeof(int); + } + } + } } void MKLDNNInterpolateNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) { @@ -1531,6 +2140,16 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) { auto srcDimPad5d = to5Dim(srcDimPad); auto dstDim5d = to5Dim(dstDim); + InterpolateLayoutType layout; + Layout selected_layout = getParentEdgeAt(DATA_ID)->getDesc().getLayout(); + if (MKLDNNMemory::GetPlainLayout(getChildEdgeAt(0)->getDims()) == selected_layout) { + layout = InterpolateLayoutType::planar; + } else if ((selected_layout == NHWC) || (selected_layout == NDHWC)) { + layout = InterpolateLayoutType::by_channel; + } else { + layout = InterpolateLayoutType::block; + } + uint8_t *src_data = nullptr; std::vector srcPadded; if (hasPad) { @@ -1542,16 +2161,53 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) { SizeVector inShapeBlock = getBlockND(srcDim5d); SizeVector inShapePadBlock = getBlockND(srcDimPad5d); - srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0); - uint8_t *src_data_pad = static_cast(&srcPadded[0]); - - parallel_for4d(srcDim5d[0], srcDim5d[1], srcDim5d[2], srcDim5d[3], [&](int n, int c, int d, int h) { - uint8_t *src = src_data_origin + (inShapeBlock[1] * n + inShapeBlock[2] * c + inShapeBlock[3] * d + inShapeBlock[4] * h) * srcDataSize; - uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + inShapePadBlock[2] * (c + padB1) + - inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + padB4) * srcDataSize; - cpu_memcpy(srcPad, src, srcDim5d[4] * srcDataSize); - }); - src_data = src_data_pad; + + if (layout == InterpolateLayoutType::planar) { + srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0); + uint8_t *src_data_pad = static_cast(&srcPadded[0]); + parallel_for4d(srcDim5d[0], srcDim5d[1], srcDim5d[2], srcDim5d[3], [&](int n, int c, int d, int h) { + uint8_t *src = src_data_origin + (inShapeBlock[1] * n + inShapeBlock[2] * c + inShapeBlock[3] * d + inShapeBlock[4] * h) * srcDataSize; + uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + inShapePadBlock[2] * (c + padB1) + + inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + padB4) * srcDataSize; + cpu_memcpy(srcPad, src, srcDim5d[4] * srcDataSize); + }); + src_data = src_data_pad; + } else if (layout == InterpolateLayoutType::by_channel) { + srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0); + uint8_t *src_data_pad = static_cast(&srcPadded[0]); + parallel_for4d(srcDim5d[0], srcDim5d[2], srcDim5d[3], srcDim5d[4], [&](int n, int d, int h, int w) { + uint8_t *src = src_data_origin + (inShapeBlock[1] * n + + (inShapeBlock[3] * d + inShapeBlock[4] * h + inShapeBlock[5] * w) * srcDim5d[1]) * srcDataSize; + uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + (inShapePadBlock[3] * (d + padB2) + + inShapePadBlock[4] * (h + padB3) + inShapePadBlock[5] * (w + padB4)) * srcDimPad5d[1] + padB1) * srcDataSize; + cpu_memcpy(srcPad, src, srcDim5d[1] * srcDataSize); + }); + src_data = src_data_pad; + } else if (layout == InterpolateLayoutType::block) { + size_t blkSize = mayiuse(cpu::avx512_common) ? 16 : 8; + size_t CB = div_up(srcDimPad5d[1], blkSize); + size_t eltsTotal = srcDimPad5d[0] * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize; + srcPadded.resize(eltsTotal * srcDataSize, 0x0); + uint8_t *src_data_pad = static_cast(&srcPadded[0]); + if ((srcDim5d[0] != srcDimPad5d[0]) || (srcDim5d[1] != srcDimPad5d[1])) { + THROW_IE_EXCEPTION << "Interpolate layer with name '" << getName() << + "' does not support padding on batch and channel dimensions"; + } + parallel_for5d(srcDim5d[0], CB, srcDim5d[2], srcDim5d[3], srcDim5d[4], [&](int n, int cb, int d, int h, int w) { + uint8_t *src = src_data_origin + (n * CB * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize + + (cb * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize + + (d * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize + + (h * srcDim5d[4] * blkSize) * srcDataSize + + (w * blkSize) * srcDataSize; + uint8_t *srcPad = src_data_pad + (n * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize + + (cb * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize + + ((d + padB2) * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize + + ((h + padB3) * srcDimPad5d[4] * blkSize) * srcDataSize + + ((w + padB4) * blkSize) * srcDataSize; + cpu_memcpy(srcPad, src, blkSize * srcDataSize); + }); + src_data = src_data_pad; + } } else { src_data = src_data_origin; } @@ -1562,13 +2218,11 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) { if (dimSize > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f)) { THROW_IE_EXCEPTION << "Interpolate layer only supports resize on spatial dimensions(depth, height and width)"; } - Layout layout = getParentEdgeAt(DATA_ID)->getDesc().getLayout(); - bool isPlanar = (layout == NC || layout == NCHW || layout == NCDHW) ? true : false; switch (mode) { case InterpolateMode::nearest: { if (interpolateKernel) { - if (isPlanar) { + if (layout == InterpolateLayoutType::planar) { NNPlanar(src_data, dst_data, N, C, ID, IH, IW, OD, OH, OW); } else { NNCGathered(src_data, dst_data, N, C, ID, IH, IW, OD, OH, OW); @@ -1578,19 +2232,9 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) { } break; } - case InterpolateMode::linear: { - float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f; - float fy = dataScales[dimSize - 2]; - float fx = dataScales[dimSize - 1]; - - bool isDownsample = (fx < 1.f) || (fy < 1.f) || (fz < 1.f); - int kernel_width = 2; - linearInterpolation(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias); - break; - } case InterpolateMode::linear_onnx: { if (interpolateKernel) { - if (isPlanar) { + if (layout == InterpolateLayoutType::planar) { linearOnnxPlanar(src_data, dst_data, N, C, IH, IW, OH, OW); } else { linearOnnxCGathered(src_data, dst_data, N, C, IH, IW, OH, OW); @@ -1601,7 +2245,25 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) { break; } case InterpolateMode::cubic: { - cubic(src_data, dst_data, N, C, IH, IW, OH, OW, cubeCoeff); + if (interpolateKernel) { + if (layout == InterpolateLayoutType::planar) { + cubicPlanar(src_data, dst_data, N, C, IH, IW, OH, OW); + } else { + cubicCGathered(src_data, dst_data, N, C, IH, IW, OH, OW); + } + } else { + cubicRef(src_data, dst_data, N, C, IH, IW, OH, OW); + } + break; + } + case InterpolateMode::linear: { + float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f; + float fy = dataScales[dimSize - 2]; + float fx = dataScales[dimSize - 1]; + + bool isDownsample = (fx < 1.f) || (fy < 1.f) || (fz < 1.f); + int kernel_width = 2; + linearInterpolation(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias); break; } default: { @@ -1634,7 +2296,7 @@ void MKLDNNInterpolateNode::NNCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr const uint8_t *in_ptr_dh = in_ptr + (C * IW * IH * index_d[d] + C * IW * index_h[h]) * srcDataSize; auto arg = jit_interpolate_call_args(); arg.dst = out_ptr_dh; - arg.src = in_ptr_dh; + arg.src_ptr[0] = in_ptr_dh; arg.index = static_cast(&(index_w_kernel[0])); arg.work_amount = C; arg.oc_off = 0; @@ -1655,7 +2317,7 @@ void MKLDNNInterpolateNode::NNCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr auto arg = jit_interpolate_call_args(); for (int h = 0; h < OH; h++) { // kernel for blk_size * OW arg.dst = out_ptr_cbd + blk_size * OW * h * dstDataSize; - arg.src = in_ptr_cbd + blk_size * IW * index_h[h] * srcDataSize; + arg.src_ptr[0] = in_ptr_cbd + blk_size * IW * index_h[h] * srcDataSize; arg.index = static_cast(&(index_w_kernel[0])); arg.work_amount = static_cast(OW); arg.oc_off = cb * blk_size; @@ -1686,7 +2348,7 @@ void MKLDNNInterpolateNode::NNPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, uint8_t *out_ptr = out_ptr_ + (OW * OH * OD * C * b + OW * OH * OD * c + OW * OH * od) * dstDataSize; auto arg = jit_interpolate_call_args(); - arg.src = in_ptr; + arg.src_ptr[0] = in_ptr; arg.dst = out_ptr; arg.index = static_cast(&index_kernel[0]); // need index_h and index_w in kernel, it's in continous memory so one param arg.oc_off = static_cast(c); @@ -1724,9 +2386,9 @@ void MKLDNNInterpolateNode::linearOnnxPlanar(const uint8_t *in_ptr_, uint8_t *ou uint8_t *out_ptr_nc = out_ptr_ + (OH * OW * C * b + OH * OW * c) * dstDataSize; const uint8_t *in_ptr_nc = in_ptr_ + (IH * IW * C * b + IH * IW * c) * srcDataSize; auto arg = jit_interpolate_call_args(); - arg.src = in_ptr_nc; + arg.src_ptr[0] = in_ptr_nc; arg.index = static_cast(&index[0]); - arg.weight = static_cast(&weight[0]); + arg.weight_ptr[0] = static_cast(&weight[0]); arg.dst = out_ptr_nc; arg.work_amount = OW * OH; arg.oc_off = c; @@ -1750,45 +2412,33 @@ void MKLDNNInterpolateNode::linearOnnxCGathered(const uint8_t *in_ptr_, uint8_t Layout layout = getParentEdgeAt(0)->getDesc().getLayout(); bool isByChannel = (layout == NHWC) ? true : false; - if (isByChannel) { - parallel_for3d(B, OH, OW, [&](size_t b, size_t h, size_t w) { - uint8_t *out_ptr_nhw = out_ptr_ + (OH * OW * C * b + OW * C * h + C * w) * dstDataSize; - const uint8_t *in_ptr_n = in_ptr_ + (IH * IW * C * b) * srcDataSize; - auto arg = jit_interpolate_call_args(); - arg.src = in_ptr_n + (indexTop[h] * IW * C + indexLeft[w] * C) * srcDataSize; - arg.srcTR = in_ptr_n + (indexTop[h] * IW * C + indexRight[w] * C) * srcDataSize; - arg.srcBL = in_ptr_n + (indexBottom[h] * IW * C + indexLeft[w] * C) * srcDataSize; - arg.srcBR = in_ptr_n + (indexBottom[h] * IW * C + indexRight[w] * C) * srcDataSize; - arg.weight = static_cast(&weightLeft[w]); - arg.weightR = static_cast(&weightRight[w]); - arg.weightT = static_cast(&weightTop[h]); - arg.weightB = static_cast(&weightBottom[h]); - arg.dst = out_ptr_nhw; - arg.work_amount = C; - arg.oc_off = 0; - (*interpolateKernel)(&arg); - }); - } else { - size_t blkSize = mayiuse(cpu::avx512_common) ? 16 : 8; - size_t CB = div_up(C, blkSize); - parallel_for3d(B, OH, OW, [&](size_t b, size_t h, size_t w) { - uint8_t *out_ptr_nhw = out_ptr_ + (CB * OH * OW * blkSize * b + OW * blkSize * h + blkSize * w) * dstDataSize; - const uint8_t *in_ptr_n = in_ptr_ + (CB * IH * IW * blkSize * b) * srcDataSize; - auto arg = jit_interpolate_call_args(); - arg.src = in_ptr_n + (indexTop[h] * IW * blkSize + indexLeft[w] * blkSize) * srcDataSize; - arg.srcTR = in_ptr_n + (indexTop[h] * IW * blkSize + indexRight[w] * blkSize) * srcDataSize; - arg.srcBL = in_ptr_n + (indexBottom[h] * IW * blkSize + indexLeft[w] * blkSize) * srcDataSize; - arg.srcBR = in_ptr_n + (indexBottom[h] * IW * blkSize + indexRight[w] * blkSize) * srcDataSize; - arg.weight = static_cast(&weightLeft[w]); - arg.weightR = static_cast(&weightRight[w]); - arg.weightT = static_cast(&weightTop[h]); - arg.weightB = static_cast(&weightBottom[h]); + int blkSize = mayiuse(cpu::avx512_common) ? 16 : 8; + int CB = div_up(C, blkSize); + int CSize = isByChannel ? C : blkSize * CB; + int CGatherLen = isByChannel ? C : blkSize; + int workAmount = isByChannel ? C : CB; + parallel_for2d(B, OH, [&](size_t b, size_t h) { + uint8_t *out_ptr_nh = out_ptr_ + (OH * OW * CSize * b + OW * CGatherLen * h) * dstDataSize; + const uint8_t *in_ptr_n = in_ptr_ + (IH * IW * CSize * b) * srcDataSize; + const uint8_t *in_ptr_nh_t = in_ptr_n + (indexTop[h] * IW * CGatherLen) * srcDataSize; + const uint8_t *in_ptr_nh_b = in_ptr_n + (indexBottom[h] * IW * CGatherLen) * srcDataSize; + auto arg = jit_interpolate_call_args(); + for (int w = 0; w < OW; ++w) { + uint8_t *out_ptr_nhw = out_ptr_nh + CGatherLen * w * dstDataSize; + arg.src_ptr[0] = in_ptr_nh_t + (indexLeft[w] * CGatherLen) * srcDataSize; + arg.src_ptr[1] = in_ptr_nh_t + (indexRight[w] * CGatherLen) * srcDataSize; + arg.src_ptr[2] = in_ptr_nh_b + (indexLeft[w] * CGatherLen) * srcDataSize; + arg.src_ptr[3] = in_ptr_nh_b + (indexRight[w] * CGatherLen) * srcDataSize; + arg.weight_ptr[0] = static_cast(&weightLeft[w]); + arg.weight_ptr[1] = static_cast(&weightRight[w]); + arg.weight_ptr[2] = static_cast(&weightTop[h]); + arg.weight_ptr[3] = static_cast(&weightBottom[h]); arg.dst = out_ptr_nhw; - arg.work_amount = CB; + arg.work_amount = workAmount; arg.oc_off = 0; (*interpolateKernel)(&arg); - }); - } + } + }); } void MKLDNNInterpolateNode::linearOnnxRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) { @@ -1938,7 +2588,90 @@ void MKLDNNInterpolateNode::linearInterpolation(const uint8_t *in_ptr_, uint8_t }); } -void MKLDNNInterpolateNode::cubic(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW, float a) { +void MKLDNNInterpolateNode::cubicCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) { + const int idxNum = 1; + int *xOrigin = static_cast(&indexTable[0]); + float *xFactor = reinterpret_cast(&indexTable[OW]); + int *yOrigin = static_cast(&indexTable[(CUBIC_GRID_LEN + idxNum) * OW]); + float *yFactor = reinterpret_cast(&indexTable[(CUBIC_GRID_LEN + idxNum) * OW + OH]); + + Layout layout = getParentEdgeAt(0)->getDesc().getLayout(); + bool isByChannel = (layout == NHWC) ? true : false; + + int blkSize = mayiuse(cpu::avx512_common) ? 16 : 8; + int CB = div_up(C, blkSize); + int CSize = isByChannel ? C : blkSize * CB; + int CGatherLen = isByChannel ? C : blkSize; + int workAmount = isByChannel ? C : CB; + + parallel_for3d(B, OH, OW, [&](size_t b, size_t h, size_t w) { + uint8_t *out_ptr_nhw = out_ptr_ + (OH * OW * CSize * b + OW * CGatherLen * h + CGatherLen * w) * dstDataSize; + const uint8_t *in_ptr_n = in_ptr_ + (IH * IW * CSize * b) * srcDataSize; + + std::vector kernelIndex(CUBIC_GRID_LEN * CUBIC_GRID_LEN); // 16 address offset to src(batch) or src(CB) + int iy = yOrigin[h]; + int ix = xOrigin[w]; + for (int y = iy - 1, i = 0; y <= iy + 2; y++, i++) { + int yInRange = std::max(0, std::min(y, IH - 1)); + yInRange = yInRange * CGatherLen * IW * srcDataSize; + for (int x = ix - 1, j = 0; x <= ix + 2; x++, j++) { + int xInRange = std::max(0, std::min(x, IW - 1)); + xInRange = yInRange + xInRange * CGatherLen * srcDataSize; + kernelIndex[i * CUBIC_GRID_LEN + j] = xInRange; + } + } + auto arg = jit_interpolate_call_args(); + arg.dst = out_ptr_nhw; + arg.src_ptr[0] = in_ptr_n; + arg.index = static_cast(&kernelIndex[0]); + // 0 for weight_W, 1 for weight_H + arg.weight_ptr[0] = static_cast(&xFactor[w * CUBIC_GRID_LEN]); + arg.weight_ptr[1] = static_cast(&yFactor[h * CUBIC_GRID_LEN]); + + // for by channel, src + step, dst + step, process next step on continuous memory + // for blk, src + IW*IH*blkSize, dst + OW*OH*blkSize, process the blkSize on next CB + arg.work_amount = workAmount; + arg.oc_off = 0; + (*interpolateKernel)(&arg); + }); +} + +void MKLDNNInterpolateNode::cubicPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) { + const int idxNum = 1; + int tblAdvance = 0; + int *xOrigin = static_cast(&indexTable[tblAdvance]); + tblAdvance += OW; + float *xFactor = reinterpret_cast(&indexTable[tblAdvance]); + tblAdvance += CUBIC_GRID_LEN * OW; + int *yOrigin = static_cast(&indexTable[tblAdvance]); + tblAdvance += OH; + float *yFactor = reinterpret_cast(&indexTable[tblAdvance]); + + tblAdvance += CUBIC_GRID_LEN * OH; + int *sequenceOH = static_cast(&indexTable[tblAdvance]); + tblAdvance += OW * OH; + int *sequenceOW = static_cast(&indexTable[tblAdvance]); + + parallel_for2d(B, C, [&](size_t n, size_t c) { + const uint8_t *in_ptr_nc = in_ptr_ + (IW * IH * C * n + IW * IH * c) * srcDataSize; + uint8_t *out_ptr_nc = out_ptr_ + (OW * OH * C * n + OW * OH * c) * dstDataSize; + + auto arg = jit_interpolate_call_args(); + arg.dst = out_ptr_nc; + arg.src_ptr[0] = in_ptr_nc; + arg.index = xOrigin; + arg.src_ptr[1] = yOrigin; + arg.src_ptr[2] = static_cast(&sequenceOH[0]); + arg.src_ptr[3] = static_cast(&sequenceOW[0]); + arg.weight_ptr[0] = xFactor; + arg.weight_ptr[1] = yFactor; + arg.work_amount = static_cast(OW * OH); + arg.oc_off = static_cast(C); + (*interpolateKernel)(&arg); + }); +} + +void MKLDNNInterpolateNode::cubicRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) { const int idxNum = 1; int *xOrigin = static_cast(&indexTable[0]); float *xFactor = reinterpret_cast(&indexTable[OW]); @@ -2027,9 +2760,12 @@ void MKLDNNInterpolateNode::setValue(uint8_t *base, size_t offset, float value, } // scale is float(outShape) / float(inShape) -// strictly consistent with onnx calc manner(div scale, not multiply inverse), +// strictly consistent with onnx calc manner(div scale, not multiply inverse), given this is done offline // the slight precison diff can produce obvious wrong value due to "nearest round" behavior for NN mode inline float MKLDNNInterpolateNode::coordTransToInput(int outCoord, float scale, int inShape, int outShape) { + if (scale == 1.0f || (inShape == outShape)) { + return outCoord; + } switch (coordTransMode) { case InterpolateCoordTransMode::half_pixel: { return (outCoord + 0.5f) / scale - 0.5f; @@ -2123,7 +2859,9 @@ bool MKLDNNInterpolateNode::canFuse(const MKLDNNNodePtr& node) const { if (eltwiseNode == nullptr) THROW_IE_EXCEPTION << "Cannot get eltwise node " << node->getName(); return isOneOf(eltwiseNode->getOpType(), {MulAdd, Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp, - Tanh, Swish, Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt}); + Tanh, Swish, Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt}) || + ((eltwiseNode->getOpType() == MulAdd && eltwiseNode->getCnnLayer()->blobs.size() == 2) || + (eltwiseNode->getOpType() == Prelu)); } return false; diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.h index 2e2e1fd..a526eab 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.h @@ -10,6 +10,8 @@ #include #include +#define MAX_INPUT_INTERPOLATE 4 + using namespace InferenceEngine; namespace MKLDNNPlugin { @@ -54,14 +56,8 @@ struct jit_interpolate_config_params { }; struct jit_interpolate_call_args { - const void *src; - const void *srcTR; - const void *srcBL; - const void *srcBR; - const float *weight; - const float *weightR; - const float *weightT; - const float *weightB; + const void *src_ptr[MAX_INPUT_INTERPOLATE]; + const void *weight_ptr[MAX_INPUT_INTERPOLATE]; const int *index; void *dst; size_t work_amount; @@ -110,18 +106,20 @@ private: void linearOnnxCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW); void linearOnnxRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW); + // cubic + std::vector getCubicCoeffs(float mantissa, float a); + void cubicPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW); + void cubicCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW); + void cubicRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW); + // linear void linearInterpolation(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int ID, int IH, int IW, float fx, float fy, float fz, int OD, int OH, int OW, int kernel_width, bool antialias); - // cubic - std::vector getCubicCoeffs(float mantissa, float a); - void cubic(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW, float a); - void buildTblNN(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector& dataScales, InterpolateLayoutType layout); void buildTblLinearOnnx(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector& dataScales, InterpolateLayoutType layout); - void buidTblLinear(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector& dataScales, int kernel_width, bool antialias); - void buidTblCubic(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector& dataScales, float cubicCoeff); + void buildTblLinear(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector& dataScales, int kernel_width, bool antialias); + void buildTblCubic(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector& dataScales, float cubicCoeff, InterpolateLayoutType layout); void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false); diff --git a/inference-engine/src/transformations/include/transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp b/inference-engine/src/transformations/include/transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp new file mode 100644 index 0000000..35d802e --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API ConvertInterpolate1ToInterpolate4; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief ConvertInterpolate1ToInterpolate4 covert v0:interpolate into v4::Interpolate. + */ +class ngraph::pass::ConvertInterpolate1ToInterpolate4: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + ConvertInterpolate1ToInterpolate4(); +}; \ No newline at end of file diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index b3b5801..50b3b45 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -38,6 +38,7 @@ #include "transformations/op_conversions/convert_space_to_depth.hpp" #include "transformations/op_conversions/convert_broadcast_to_tiles.hpp" #include "transformations/op_conversions/convert_gelu.hpp" +#include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp" #include "transformations/op_conversions/batch_norm_decomposition.hpp" #include "transformations/op_conversions/reduce_l1_decomposition.hpp" #include "transformations/op_conversions/reduce_l2_decomposition.hpp" @@ -110,6 +111,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp new file mode 100644 index 0000000..52e2fa5 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp @@ -0,0 +1,69 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp" + +#include +#include + +#include +#include +#include +#include + +NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertInterpolate1ToInterpolate4, "ConvertInterpolate1ToInterpolate4", 0); + +ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolate4() { + auto interpolate1 = ngraph::pattern::wrap_type({pattern::any_input(pattern::has_static_rank()), pattern::any_input()}); + ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) { + auto interpolate1 = std::dynamic_pointer_cast(m.get_match_root()); + if (!interpolate1) + return false; + + auto interpolate_attrs = interpolate1->get_attrs(); + auto input_shape_rank = interpolate1->input(0).get_partial_shape().rank().get_length(); + + // attrs + auto mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode(); + if (interpolate_attrs.mode == "nearest") { + mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::nearest; + } else if (interpolate_attrs.mode == "cubic") { + mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::cubic; + } else if (interpolate_attrs.mode == "linear") { + if (input_shape_rank < 5) { + mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx; + } else if (input_shape_rank == 5) { + mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::linear; + } else { + return false; + } + } else { + return false; + } + auto nearest_mode_v4 = ngraph::op::v4::Interpolate::NearestMode::floor; + auto shape_calculation_mode_v4 = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes; + auto coordinate_transformation_mode_v4 = interpolate_attrs.align_corners ? ngraph::op::v4::Interpolate::CoordinateTransformMode::align_corners : + ngraph::op::v4::Interpolate::CoordinateTransformMode::asymmetric; + auto interpolate4_attr = ngraph::op::v4::Interpolate::InterpolateAttrs(mode_v4, shape_calculation_mode_v4, + interpolate_attrs.pads_begin, interpolate_attrs.pads_end, + coordinate_transformation_mode_v4, nearest_mode_v4, interpolate_attrs.antialias, -0.75); + + // input + auto axes = interpolate_attrs.axes.to_vector(); + auto axes_node = ngraph::opset4::Constant::create(element::i64, {axes.size()}, axes); + auto default_scales = std::vector(axes.size(), 1.f); + auto scales_node = ngraph::opset4::Constant::create(element::f32, {axes.size()}, default_scales); + + auto interpolate4 = std::make_shared(interpolate1->input_value(0), interpolate1->input_value(1), + scales_node, axes_node, interpolate4_attr); + + interpolate4->set_friendly_name(interpolate1->get_friendly_name()); + ngraph::copy_runtime_info(interpolate1, interpolate4); + ngraph::replace_node(interpolate1, interpolate4); + return true; + }; + + auto m = std::make_shared(interpolate1, "ConvertInterpolate1ToInterpolate4"); + this->register_matcher(m, callback); +} \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/interpolate_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/interpolate_transformation.cpp index 56f5f72..6b71f2a 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/interpolate_transformation.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/interpolate_transformation.cpp @@ -21,6 +21,7 @@ using namespace testing; using namespace ngraph::pass; +using namespace ngraph; using namespace ngraph::builder::subgraph; class interpAttributes { @@ -32,6 +33,8 @@ public: std::vector pads_begin; std::vector pads_end; + interpAttributes() = default; + interpAttributes(const ngraph::AxisSet& axes, const std::string& mode, const bool& align_corners, @@ -42,6 +45,23 @@ public: antialias(antialias), pads_begin(pads_begin), pads_end(pads_end) {} }; +class interp4Attributes { +public: + op::v4::Interpolate::InterpolateMode mode; + op::v4::Interpolate::CoordinateTransformMode coordinate_transformation_mode; + std::vector pads_begin; + std::vector pads_end; + + interp4Attributes() = default; + + interp4Attributes(const op::v4::Interpolate::InterpolateMode mode, + const op::v4::Interpolate::CoordinateTransformMode coordinate_transformation_mode, + const std::vector& pads_begin, + const std::vector& pads_end) : + mode(mode), coordinate_transformation_mode(coordinate_transformation_mode), + pads_begin(pads_begin), pads_end(pads_end) {} +}; + class InterpolateTransformationTestValues { public: class Actual { @@ -60,9 +80,11 @@ public: ngraph::Shape inputShape; ngraph::Shape outputShape; + ngraph::Shape scalesShape; ngraph::pass::low_precision::LayerTransformation::Params params; - //ngraph::op::InterpolateAttrs interpAttrs; interpAttributes interpAttrs; + interp4Attributes interp4Attrs; + int opset_version; Actual actual; Expected expected; }; @@ -85,60 +107,107 @@ public: void SetUp() override { const InterpolateTransformationTestValues testValues = GetParam(); - ngraph::op::InterpolateAttrs interpAttrs; - interpAttrs.axes = testValues.interpAttrs.axes; - interpAttrs.mode = testValues.interpAttrs.mode; - interpAttrs.align_corners = testValues.interpAttrs.align_corners; - interpAttrs.antialias = testValues.interpAttrs.antialias; - interpAttrs.pads_begin = testValues.interpAttrs.pads_begin; - interpAttrs.pads_end = testValues.interpAttrs.pads_end; - - actualFunction = ngraph::builder::subgraph::InterpolateFunction::getOriginal( - testValues.inputShape, - testValues.outputShape, - interpAttrs, - testValues.actual.precisionBeforeDequantization, - testValues.actual.dequantization); - - SimpleLowPrecisionTransformer transformer; - transformer.add(testValues.params); - transformer.transform(actualFunction); - - referenceFunction = ngraph::builder::subgraph::InterpolateFunction::getReference( - testValues.inputShape, - testValues.outputShape, - interpAttrs, - testValues.expected.precisionBeforeDequantization, - testValues.expected.dequantizationBefore, - testValues.expected.precisionAfterOperation, - testValues.expected.dequantizationAfter); + if (testValues.opset_version == 1) { + ngraph::op::InterpolateAttrs interpAttrs; + interpAttrs.axes = testValues.interpAttrs.axes; + interpAttrs.mode = testValues.interpAttrs.mode; + interpAttrs.align_corners = testValues.interpAttrs.align_corners; + interpAttrs.antialias = testValues.interpAttrs.antialias; + interpAttrs.pads_begin = testValues.interpAttrs.pads_begin; + interpAttrs.pads_end = testValues.interpAttrs.pads_end; + + actualFunction = ngraph::builder::subgraph::InterpolateFunction::getOriginal( + testValues.inputShape, + testValues.outputShape, + interpAttrs, + testValues.actual.precisionBeforeDequantization, + testValues.actual.dequantization); + + SimpleLowPrecisionTransformer transformer; + transformer.add(testValues.params); + transformer.transform(actualFunction); + + referenceFunction = ngraph::builder::subgraph::InterpolateFunction::getReference( + testValues.inputShape, + testValues.outputShape, + interpAttrs, + testValues.expected.precisionBeforeDequantization, + testValues.expected.dequantizationBefore, + testValues.expected.precisionAfterOperation, + testValues.expected.dequantizationAfter); + } else if (testValues.opset_version == 4) { + ngraph::op::v4::Interpolate::InterpolateAttrs interp4Attrs; + interp4Attrs.mode = testValues.interp4Attrs.mode; + interp4Attrs.coordinate_transformation_mode = testValues.interp4Attrs.coordinate_transformation_mode; + interp4Attrs.pads_begin = testValues.interp4Attrs.pads_begin; + interp4Attrs.pads_end = testValues.interp4Attrs.pads_end; + + actualFunction = ngraph::builder::subgraph::InterpolateFunction::getOriginal( + testValues.inputShape, + testValues.outputShape, + testValues.scalesShape, + interp4Attrs, + testValues.actual.precisionBeforeDequantization, + testValues.actual.dequantization); + + SimpleLowPrecisionTransformer transformer; + transformer.add(testValues.params); + transformer.transform(actualFunction); + + referenceFunction = ngraph::builder::subgraph::InterpolateFunction::getReference( + testValues.inputShape, + testValues.outputShape, + testValues.scalesShape, + interp4Attrs, + testValues.expected.precisionBeforeDequantization, + testValues.expected.dequantizationBefore, + testValues.expected.precisionAfterOperation, + testValues.expected.dequantizationAfter); + } } static std::string getTestCaseName(testing::TestParamInfo obj) { const InterpolateTransformationTestValues testValues = obj.param; std::ostringstream result; - result << + if (testValues.opset_version == 1) { + result << testValues.inputShape << "_" << testValues.outputShape << "_" << - testValues.interpAttrs.align_corners << - testValues.interpAttrs.antialias << - testValues.interpAttrs.axes << - testValues.interpAttrs.mode << - testValues.interpAttrs.pads_begin << - testValues.interpAttrs.pads_end << + testValues.opset_version << "_" << + testValues.interpAttrs.align_corners << "_" << + testValues.interpAttrs.antialias << "_" << + testValues.interpAttrs.axes << "_" << + testValues.interpAttrs.mode << "_" << + testValues.interpAttrs.pads_begin << "_" << + testValues.interpAttrs.pads_end << "_" << testValues.actual.precisionBeforeDequantization << "_" << testValues.actual.dequantization << "_" << testValues.expected.dequantizationBefore; + } else if (testValues.opset_version == 4) { + result << + testValues.inputShape << "_" << + testValues.outputShape << "_" << + testValues.opset_version << "_" << + testValues.interp4Attrs.mode << "_" << + testValues.interp4Attrs.coordinate_transformation_mode << "_" << + testValues.interp4Attrs.pads_begin << "_" << + testValues.interp4Attrs.pads_end << "_" << + testValues.actual.precisionBeforeDequantization << "_" << + testValues.actual.dequantization << "_" << + testValues.expected.dequantizationBefore; + } return result.str(); } }; const std::vector testValues { + // opset1 // nearest mode - move dequantization { ngraph::Shape{ 1, 4, 16, 16 }, ngraph::Shape{ 1, 4, 32, 32 }, + ngraph::Shape{}, LayerTransformation::createParamsU8I8(), interpAttributes( ngraph::AxisSet{2, 3}, @@ -147,6 +216,8 @@ const std::vector testValues { false, {0}, {0}), + interp4Attributes(), + 1, { ngraph::element::u8, {{ngraph::element::f32}, {-0.32f}, {0.1f}} @@ -163,6 +234,7 @@ const std::vector testValues { { ngraph::Shape{ 1, 4, 16, 16 }, ngraph::Shape{ 1, 4, 32, 32 }, + ngraph::Shape{}, LayerTransformation::createParamsU8I8(), interpAttributes( ngraph::AxisSet{2, 3}, @@ -171,6 +243,8 @@ const std::vector testValues { false, {0}, {0}), + interp4Attributes(), + 1, { ngraph::element::u8, {{ngraph::element::f32}, {-0.32f}, {0.1f}} @@ -187,6 +261,7 @@ const std::vector testValues { { ngraph::Shape{ 1, 4, 16, 16 }, ngraph::Shape{ 1, 8, 32, 32 }, + ngraph::Shape{}, LayerTransformation::createParamsU8I8(), interpAttributes( ngraph::AxisSet{1, 2, 3}, @@ -195,6 +270,8 @@ const std::vector testValues { false, {0}, {0}), + interp4Attributes(), + 1, { ngraph::element::u8, {{ngraph::element::f32}, {-0.32f}, {0.1f}} @@ -211,6 +288,7 @@ const std::vector testValues { { ngraph::Shape{ 1, 4, 16, 16 }, ngraph::Shape{ 1, 4, 32, 32 }, + ngraph::Shape{}, LayerTransformation::createParamsU8I8(), interpAttributes( ngraph::AxisSet{2, 3}, @@ -219,6 +297,8 @@ const std::vector testValues { false, {0}, {0}), + interp4Attributes(), + 1, { ngraph::element::u8, {{ngraph::element::f32}, {-0.32f}, {0.1f}} @@ -235,6 +315,7 @@ const std::vector testValues { { ngraph::Shape{ 1, 4, 16, 16 }, ngraph::Shape{ 1, 4, 32, 32 }, + ngraph::Shape{}, LayerTransformation::createParamsU8I8(), interpAttributes( ngraph::AxisSet{2, 3}, @@ -243,6 +324,8 @@ const std::vector testValues { false, {1}, {1}), + interp4Attributes(), + 1, { ngraph::element::u8, {{ngraph::element::f32}, {-0.32f}, {0.1f}} @@ -253,7 +336,108 @@ const std::vector testValues { ngraph::element::u8, {{}, {}, {}} } - } + }, + + // v4::Interpolate + // nearest mode - move dequantization + { + ngraph::Shape{ 1, 4, 16, 16 }, + ngraph::Shape{ 1, 4, 32, 32 }, + ngraph::Shape{ 1, 1, 2, 2 }, + LayerTransformation::createParamsU8I8(), + interpAttributes(), + interp4Attributes( + ngraph::op::v4::Interpolate::InterpolateMode::nearest, + ngraph::op::v4::Interpolate::CoordinateTransformMode::half_pixel, + {0, 0, 0, 0}, + {0, 0, 0, 0}), + 4, + { + ngraph::element::i8, + {{ngraph::element::f32}, {-0.32f}, {0.1f}} + }, + { + ngraph::element::i8, + {{}, {}, {}}, + ngraph::element::i8, + {{ngraph::element::f32}, {-0.32f}, {0.1f}} + } + }, + + // mode is not nearest - not transformed + { + ngraph::Shape{ 1, 4, 16, 16 }, + ngraph::Shape{ 1, 4, 32, 32 }, + ngraph::Shape{ 1, 1, 2, 2 }, + LayerTransformation::createParamsU8I8(), + interpAttributes(), + interp4Attributes( + ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx, + ngraph::op::v4::Interpolate::CoordinateTransformMode::half_pixel, + {0, 0, 0, 0}, + {0, 0, 0, 0}), + 4, + { + ngraph::element::i8, + {{ngraph::element::f32}, {-0.32f}, {0.1f}} + }, + { + ngraph::element::i8, + {{ngraph::element::f32}, {-0.32f}, {0.1f}}, + ngraph::element::i8, + {{}, {}, {}} + } + }, + + // align_corners set to true - not transformed + { + ngraph::Shape{ 1, 4, 16, 16 }, + ngraph::Shape{ 1, 4, 32, 32 }, + ngraph::Shape{ 1, 1, 2, 2 }, + LayerTransformation::createParamsU8I8(), + interpAttributes(), + interp4Attributes( + ngraph::op::v4::Interpolate::InterpolateMode::nearest, + ngraph::op::v4::Interpolate::CoordinateTransformMode::align_corners, + {0, 0, 0, 0}, + {0, 0, 0, 0}), + 4, + { + ngraph::element::i8, + {{ngraph::element::f32}, {-0.32f}, {0.1f}} + }, + { + ngraph::element::i8, + {{ngraph::element::f32}, {-0.32f}, {0.1f}}, + ngraph::element::i8, + {{}, {}, {}} + } + }, + + // have pads - not transformed + { + ngraph::Shape{ 1, 4, 16, 16 }, + ngraph::Shape{ 1, 4, 32, 32 }, + ngraph::Shape{ 1, 1, 2, 2 }, + LayerTransformation::createParamsU8I8(), + interpAttributes(), + interp4Attributes( + ngraph::op::v4::Interpolate::InterpolateMode::nearest, + ngraph::op::v4::Interpolate::CoordinateTransformMode::half_pixel, + {0, 0, 0, 1}, + {0, 0, 1, 0}), + 4, + { + ngraph::element::i8, + {{ngraph::element::f32}, {-0.32f}, {0.1f}} + }, + { + ngraph::element::i8, + {{ngraph::element::f32}, {-0.32f}, {0.1f}}, + ngraph::element::i8, + {{}, {}, {}} + } + }, }; TEST_P(InterpolateTransformation, CompareFunctions) { diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp new file mode 100644 index 0000000..1ecd6ca --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp @@ -0,0 +1,112 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; +using namespace ngraph; + +TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_node = std::make_shared(element::f32, Shape{2, 4, 30, 30}); + auto out_shape_node = opset1::Constant::create(element::i32, Shape{4}, {2, 4, 40, 40}); + + auto interpolate1_attr = op::v0::InterpolateAttrs(); + interpolate1_attr.axes = AxisSet(std::vector{0, 1, 2, 3}); + interpolate1_attr.mode = "nearest"; + interpolate1_attr.align_corners = false; + interpolate1_attr.antialias = false; + interpolate1_attr.pads_begin = std::vector{0, 0, 0, 0}; + interpolate1_attr.pads_end = std::vector{0, 0, 0, 0}; + + auto interpolate1 = std::make_shared(data_node, out_shape_node, interpolate1_attr); + + f = std::make_shared(NodeVector{interpolate1}, ParameterVector{data_node}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data_node = std::make_shared(element::f32, Shape{2, 4, 30, 30}); + auto out_shape_node = opset1::Constant::create(element::i32, Shape{4}, {2, 4, 40, 40}); + auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{4}, {1.f, 1.f, 1.f, 1.f}); + auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{4}, {0, 1, 2, 3}); + + auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest, + opset4::Interpolate::ShapeCalcMode::sizes, std::vector{0, 0, 0, 0}, std::vector{0, 0, 0, 0}, + opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::floor, + false, -0.75); + + auto interpolate4 = std::make_shared(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr); + + f_ref = std::make_shared(NodeVector{interpolate4}, ParameterVector{data_node}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertInterpolate1ToInterpolate4_1) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data_node = std::make_shared(element::f32, Shape{2, 4, 30, 30}); + auto out_shape_node = opset1::Constant::create(element::i32, Shape{2}, {40, 40}); + + auto interpolate1_attr = op::v0::InterpolateAttrs(); + interpolate1_attr.axes = AxisSet(std::vector{2, 3}); + interpolate1_attr.mode = "linear"; + interpolate1_attr.align_corners = false; + interpolate1_attr.antialias = true; + interpolate1_attr.pads_begin = std::vector{0, 0, 0, 0}; + interpolate1_attr.pads_end = std::vector{0, 0, 0, 0}; + + auto interpolate1 = std::make_shared(data_node, out_shape_node, interpolate1_attr); + + f = std::make_shared(NodeVector{interpolate1}, ParameterVector{data_node}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data_node = std::make_shared(element::f32, Shape{2, 4, 30, 30}); + auto out_shape_node = opset1::Constant::create(element::i32, Shape{2}, {40, 40}); + auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {1.f, 1.f}); + auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3}); + + auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear_onnx, + opset4::Interpolate::ShapeCalcMode::sizes, std::vector{0, 0, 0, 0}, std::vector{0, 0, 0, 0}, + opset4::Interpolate::CoordinateTransformMode::align_corners, opset4::Interpolate::NearestMode::floor, + false, -0.75); + + auto interpolate4 = std::make_shared(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr); + + f_ref = std::make_shared(NodeVector{interpolate4}, ParameterVector{data_node}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/interpolate.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/interpolate.cpp index 75b79f1..a2fceb3 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/interpolate.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/interpolate.cpp @@ -17,11 +17,11 @@ const std::vector netPrecisions = { }; const std::vector> inShapes = { - {1, 1, 30, 30}, + {1, 4, 30, 30}, }; const std::vector> targetShapes = { - {1, 1, 40, 40}, + {1, 4, 40, 40}, }; const std::vector modesWithoutNearest = { @@ -130,4 +130,60 @@ INSTANTIATE_TEST_CASE_P(smoke_Interpolate_Nearest, InterpolateLayerTest, ::testi ::testing::Values(CommonTestUtils::DEVICE_CPU)), InterpolateLayerTest::getTestCaseName); +const std::vector> targetShapesTailTest = { + {1, 4, 10, 41}, // 10 * 41 is not multipler of 4, cover tail process code path +}; + +const std::vector> defaultScalesTailTest = { + {0.33333f, 1.36666f} +}; + +const auto interpolateCasesWithoutNearestTail = ::testing::Combine( + ::testing::ValuesIn(modesWithoutNearest), + ::testing::ValuesIn(shapeCalculationMode), + ::testing::ValuesIn(coordinateTransformModes), + ::testing::ValuesIn(defaultNearestMode), + ::testing::ValuesIn(antialias), + ::testing::ValuesIn(pads), + ::testing::ValuesIn(pads), + ::testing::ValuesIn(cubeCoefs), + ::testing::ValuesIn(defaultAxes), + ::testing::ValuesIn(defaultScalesTailTest)); + +const auto interpolateCasesTail = ::testing::Combine( + ::testing::ValuesIn(nearestMode), + ::testing::ValuesIn(shapeCalculationMode), + ::testing::ValuesIn(coordinateTransformModes), + ::testing::ValuesIn(nearestModes), + ::testing::ValuesIn(antialias), + ::testing::ValuesIn(pads), + ::testing::ValuesIn(pads), + ::testing::ValuesIn(cubeCoefs), + ::testing::ValuesIn(defaultAxes), + ::testing::ValuesIn(defaultScalesTailTest)); + +INSTANTIATE_TEST_CASE_P(smoke_Interpolate_Basic_2, InterpolateLayerTest, ::testing::Combine( + interpolateCasesWithoutNearestTail, + ::testing::ValuesIn(netPrecisions), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::ValuesIn(inShapes), + ::testing::ValuesIn(targetShapesTailTest), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + InterpolateLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_Interpolate_Nearest_2, InterpolateLayerTest, ::testing::Combine( + interpolateCasesTail, + ::testing::ValuesIn(netPrecisions), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::ValuesIn(inShapes), + ::testing::ValuesIn(targetShapesTailTest), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + InterpolateLayerTest::getTestCaseName); + } // namespace diff --git a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/interpolate.cpp b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/interpolate.cpp index 16ac595..3c3891c 100644 --- a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/interpolate.cpp +++ b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/interpolate.cpp @@ -149,6 +149,7 @@ const std::vector defNearestModes = { const std::vector> pads = { {0, 0, 0, 0}, + {0, 0, 1, 1}, }; const std::vector antialias = { @@ -191,6 +192,18 @@ const auto interpolateCasesLinearOnnx = ::testing::Combine( ::testing::ValuesIn(defaultAxes), ::testing::ValuesIn(defaultScales)); +const auto interpolateCasesCubic = ::testing::Combine( + ::testing::Values(ngraph::op::v4::Interpolate::InterpolateMode::cubic), + ::testing::ValuesIn(shapeCalculationMode), + ::testing::ValuesIn(coordinateTransformModes), + ::testing::ValuesIn(defNearestModes), + ::testing::ValuesIn(antialias), + ::testing::ValuesIn(pads), + ::testing::ValuesIn(pads), + ::testing::ValuesIn(cubeCoefs), + ::testing::ValuesIn(defaultAxes), + ::testing::ValuesIn(defaultScales)); + INSTANTIATE_TEST_CASE_P(smoke_InterpolateNN_Layout_Test, InterpolateLayerCPUTest, ::testing::Combine( ::testing::Combine( @@ -200,8 +213,8 @@ INSTANTIATE_TEST_CASE_P(smoke_InterpolateNN_Layout_Test, InterpolateLayerCPUTest ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), ::testing::Values(InferenceEngine::Layout::ANY), ::testing::Values(InferenceEngine::Layout::ANY), - ::testing::Values(std::vector({1, 1, 40, 40})), - ::testing::Values(std::vector({1, 1, 50, 60})), + ::testing::Values(std::vector({1, 21, 40, 40})), + ::testing::Values(std::vector({1, 21, 50, 60})), ::testing::Values(CommonTestUtils::DEVICE_CPU)), ::testing::ValuesIn(filterCPUInfoForDevice())), InterpolateLayerCPUTest::getTestCaseName); @@ -215,8 +228,23 @@ INSTANTIATE_TEST_CASE_P(smoke_InterpolateLinearOnnx_Layout_Test, InterpolateLaye ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), ::testing::Values(InferenceEngine::Layout::ANY), ::testing::Values(InferenceEngine::Layout::ANY), - ::testing::Values(std::vector({1, 1, 40, 40})), - ::testing::Values(std::vector({1, 1, 50, 60})), + ::testing::Values(std::vector({1, 21, 40, 40})), + ::testing::Values(std::vector({1, 21, 50, 60})), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ::testing::ValuesIn(filterCPUInfoForDevice())), + InterpolateLayerCPUTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_InterpolateCubic_Layout_Test, InterpolateLayerCPUTest, + ::testing::Combine( + ::testing::Combine( + interpolateCasesCubic, + ::testing::ValuesIn(netPrecisions), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(std::vector({1, 21, 40, 40})), + ::testing::Values(std::vector({1, 21, 50, 60})), ::testing::Values(CommonTestUtils::DEVICE_CPU)), ::testing::ValuesIn(filterCPUInfoForDevice())), InterpolateLayerCPUTest::getTestCaseName); diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/interpolate_function.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/interpolate_function.hpp index d3aa6cc..6b45a18 100644 --- a/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/interpolate_function.hpp +++ b/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/interpolate_function.hpp @@ -35,6 +35,32 @@ public: const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore, const ngraph::element::Type precisionAfterOperation, const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter); + + // v4::Interpolate + static std::shared_ptr getOriginal( + const ngraph::Shape& inputShape, + const ngraph::Shape& outputShape, + const ngraph::Shape& scalesShape, + const ngraph::op::v4::Interpolate::InterpolateAttrs& interp4Attrs, + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantization); + + static std::shared_ptr getOriginal( + const ngraph::element::Type precision, + const ngraph::Shape& inputShape, + const ngraph::Shape& outputShape, + const ngraph::Shape& scalesShape, + const ngraph::op::v4::Interpolate::InterpolateAttrs& interp4Attrs); + + static std::shared_ptr getReference( + const ngraph::Shape& inputShape, + const ngraph::Shape& outputShape, + const ngraph::Shape& scalesShape, + const ngraph::op::v4::Interpolate::InterpolateAttrs& interp4Attrs, + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore, + const ngraph::element::Type precisionAfterOperation, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter); }; } // namespace subgraph diff --git a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/interpolate_function.cpp b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/interpolate_function.cpp index a71bb90..44cccef 100644 --- a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/interpolate_function.cpp +++ b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/interpolate_function.cpp @@ -70,6 +70,72 @@ std::shared_ptr InterpolateFunction::getReference( return std::make_shared(results, ngraph::ParameterVector{ input }, "InterpolateFunction"); } +// v4:interpolate +std::shared_ptr InterpolateFunction::getOriginal( + const ngraph::Shape& inputShape, + const ngraph::Shape& outputShape, + const ngraph::Shape& scalesShape, + const ngraph::op::v4::Interpolate::InterpolateAttrs& interpAttrs, + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantization) { + const std::shared_ptr input = std::make_shared( + precisionBeforeDequantization, + ngraph::Shape(inputShape)); + + const auto dequantizationOp = makeDequantization(input, dequantization); + const auto outShape = std::make_shared(ngraph::element::i64, ngraph::Shape{ outputShape.size() }, outputShape); + const auto scales = std::make_shared(ngraph::element::f32, ngraph::Shape{ scalesShape.size() }, scalesShape); + const auto interpolate = std::make_shared(dequantizationOp, outShape, scales, interpAttrs); + interpolate->set_friendly_name("output"); + + ngraph::ResultVector results{ std::make_shared(interpolate) }; + return std::make_shared(results, ngraph::ParameterVector{ input }, "InterpolateFunction"); +} + +std::shared_ptr InterpolateFunction::getOriginal( + const ngraph::element::Type precision, + const ngraph::Shape& inputShape, + const ngraph::Shape& outputShape, + const ngraph::Shape& scalesShape, + const ngraph::op::v4::Interpolate::InterpolateAttrs& interpAttrs) { + float k = 50.f; + + const auto input = std::make_shared(precision, inputShape); + const auto fakeQuantizeOnActivations = ngraph::builder::makeFakeQuantize( + input, precision, 256ul, { 1ul }, + { 0.f }, { 255.f / k }, { 10.f }, { 255.f / k }); + const auto outShape = std::make_shared(ngraph::element::i64, ngraph::Shape{ outputShape.size() }, outputShape); + const auto scales = std::make_shared(ngraph::element::f32, ngraph::Shape{ scalesShape.size() }, scalesShape); + const auto interpolate = std::make_shared(fakeQuantizeOnActivations, outShape, scales, interpAttrs); + + ngraph::ResultVector results{ std::make_shared(interpolate) }; + return std::make_shared(results, ngraph::ParameterVector{ input }, "InterpolateFunction"); +} + +std::shared_ptr InterpolateFunction::getReference( + const ngraph::Shape& inputShape, + const ngraph::Shape& outputShape, + const ngraph::Shape& scalesShape, + const ngraph::op::v4::Interpolate::InterpolateAttrs& interpAttrs, + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore, + const ngraph::element::Type precisionAfterOperation, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) { + const std::shared_ptr input = std::make_shared( + precisionBeforeDequantization, + ngraph::Shape(inputShape)); + + const std::shared_ptr quantizationOpBefore = makeDequantization(input, dequantizationBefore); + const auto outShape = std::make_shared(ngraph::element::i64, ngraph::Shape{ outputShape.size() }, outputShape); + const auto scales = std::make_shared(ngraph::element::f32, ngraph::Shape{ scalesShape.size() }, scalesShape); + const auto interpolate = std::make_shared(quantizationOpBefore, outShape, scales, interpAttrs); + const std::shared_ptr quantizationOpAfter = makeDequantization(interpolate, dequantizationAfter); + quantizationOpAfter->set_friendly_name("output"); + + ngraph::ResultVector results{ std::make_shared(quantizationOpAfter) }; + return std::make_shared(results, ngraph::ParameterVector{ input }, "InterpolateFunction"); +} + } // namespace subgraph } // namespace builder } // namespace ngraph diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/interpolate.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/interpolate.hpp index 703ffea..738b3fb 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/interpolate.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/interpolate.hpp @@ -152,6 +152,10 @@ namespace ngraph float length_resized, float length_original) const { + if (x_scale == 1.0f || (length_resized == length_original)) + { + return x_resized; + } return m_func(x_resized, x_scale, length_resized, length_original); } -- 2.7.4