[CPU] Interpolate operation improvements (#2366)
authorChenhu Wang <chenhu.wang@intel.com>
Tue, 17 Nov 2020 07:42:34 +0000 (15:42 +0800)
committerGitHub <noreply@github.com>
Tue, 17 Nov 2020 07:42:34 +0000 (10:42 +0300)
* interpolate improvement

* JITTED cubic mode

* fix 'code is too big' when JIT

* extend test to cover tail code path

* transformation of interpolate1 to interpolate4

* add low precision transformation for interpolate4

14 files changed:
inference-engine/src/low_precision_transformations/src/common/interpolate.cpp
inference-engine/src/low_precision_transformations/src/common/transformer.cpp
inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp
inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.h
inference-engine/src/transformations/include/transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/lp_transformations/interpolate_transformation.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp [new file with mode: 0644]
inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/interpolate.cpp
inference-engine/tests/functional/plugin/cpu/single_layer_tests/interpolate.cpp
inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/interpolate_function.hpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/interpolate_function.cpp
ngraph/core/reference/include/ngraph/runtime/reference/interpolate.hpp

index f0bef07..c28eab9 100644 (file)
@@ -20,6 +20,16 @@ void InterpolateTransformation::registerMatcherIn(GraphRewrite& pass, Transforma
         pass,
         context,
         make_op_pattern<opset1::Interpolate>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Constant>() }));
+    addPattern(
+        pass,
+        context,
+        make_op_pattern<opset4::Interpolate>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Constant>(),
+            make_op_label<opset1::Constant>(), make_op_label<opset1::Constant>() }));
+    addPattern(
+        pass,
+        context,
+        make_op_pattern<opset4::Interpolate>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Constant>(),
+            make_op_label<opset1::Constant>() }));
 }
 
 bool InterpolateTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
@@ -33,9 +43,19 @@ bool InterpolateTransformation::transform(TransformationContext &context, ngraph
 }
 
 bool InterpolateTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
-    std::shared_ptr<opset1::Interpolate> interpolate = as_type_ptr<opset1::Interpolate>(layer);
-    const auto attrs = interpolate->get_attrs();
-    return attrs.mode == "nearest";
+    std::shared_ptr<opset1::Interpolate> interpolate1 = as_type_ptr<opset1::Interpolate>(layer);
+    if (interpolate1) {
+        const auto attrs = interpolate1->get_attrs();
+        return attrs.mode == "nearest";
+    }
+
+    std::shared_ptr<opset4::Interpolate> interpolate4 = as_type_ptr<opset4::Interpolate>(layer);
+    if (interpolate4) {
+        const auto attrs = interpolate4->get_attrs();
+        return attrs.mode == op::v4::Interpolate::InterpolateMode::nearest;
+    }
+
+    return false;
 }
 
 bool InterpolateTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
@@ -49,19 +69,46 @@ bool InterpolateTransformation::canBeTransformed(const TransformationContext& co
     if (dequantization.empty()) {
         return false;
     }
-    const auto interpolate = as_type_ptr<opset1::Interpolate>(layer);
-    const auto interpAttrs = interpolate->get_attrs();
 
-    if (interpAttrs.axes.count(0) || interpAttrs.axes.count(1)) {
-        return false;
+    const auto interpolate1 = as_type_ptr<opset1::Interpolate>(layer);
+    if (interpolate1) {
+        const auto interpAttrs = interpolate1->get_attrs();
+        if (interpAttrs.axes.count(0) || interpAttrs.axes.count(1)) {
+            return false;
+        }
+        if (interpAttrs.mode != "nearest") {
+            return false;
+        }
+        if (interpAttrs.pads_begin[0] != 0 || interpAttrs.pads_end[0] != 0 || interpAttrs.align_corners) {
+            return false;
+        }
     }
 
-    if (interpAttrs.mode != "nearest") {
-        return false;
-    }
+    const auto interpolate4 = as_type_ptr<opset4::Interpolate>(layer);
+    if (interpolate4) {
+        const auto interpAttrs = interpolate4->get_attrs();
 
-    if (interpAttrs.pads_begin[0] != 0 || interpAttrs.pads_end[0] != 0 || interpAttrs.align_corners) {
-        return false;
+        if (interpAttrs.mode != op::v4::Interpolate::InterpolateMode::nearest) {
+            return false;
+        }
+
+        auto pads_begin = interpAttrs.pads_begin;
+        for (int i = 0; i < pads_begin.size(); ++i) {
+            if (pads_begin[i] != 0) {
+                return false;
+            }
+        }
+
+        auto pads_end = interpAttrs.pads_end;
+        for (int i = 0; i < pads_end.size(); ++i) {
+            if (pads_end[i] != 0) {
+                return false;
+            }
+        }
+
+        if (interpAttrs.coordinate_transformation_mode == op::v4::Interpolate::CoordinateTransformMode::align_corners) {
+            return false;
+        }
     }
 
     return true;
index e152cfe..66ae78c 100644 (file)
@@ -242,6 +242,7 @@ LowPrecisionTransformations LowPrecisionTransformer::getAllTransformations(const
         add<SqueezeTransformation, opset1::Squeeze>(params).
         add<TransposeTransformation, opset1::Transpose>(params).
         add<UnsqueezeTransformation, opset1::Unsqueeze>(params).
+        add<InterpolateTransformation, opset4::Interpolate>(params).
 
         addCleanup<FuseConvertTransformation, opset1::Multiply>(params).
 
@@ -341,6 +342,7 @@ TypeRelaxedReplacer::TypeRelaxedReplacer() {
     make_matcher_type_relaxed<opset1::Multiply>(this);
     make_matcher_type_relaxed<op::MVN>(this);
     make_matcher_type_relaxed<opset1::NormalizeL2>(this);
+    make_matcher_type_relaxed<opset4::Interpolate>(this);
 }
 
 LowPrecisionTransformer::LowPrecisionTransformer(const LowPrecisionTransformations& transformations)
index 799f2da..a2be1f5 100644 (file)
@@ -61,8 +61,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi
 
         this->preamble();
 
-        mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
-        mov(reg_dst, ptr[reg_params + GET_OFF(dst)]);
         if (attr_.post_ops_.len_ != 0)
             mov(reg_oc_off, ptr[reg_params + GET_OFF(oc_off)]);
         if (isa == cpu::avx512_common)
@@ -70,8 +68,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi
 
         switch (jcp_.mode) {
             case InterpolateMode::nearest: {
-                mov(reg_src, ptr[reg_params + GET_OFF(src)]);
+                mov(reg_dst, ptr[reg_params + GET_OFF(dst)]);
+                mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]);
                 mov(reg_index, ptr[reg_params + GET_OFF(index)]);
+                mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
 
                 switch (jcp_.layout) {
                     case InterpolateLayoutType::planar: {
@@ -94,28 +94,11 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi
             case InterpolateMode::linear_onnx: {
                 switch (jcp_.layout) {
                     case InterpolateLayoutType::planar: {
-                        mov(reg_src, ptr[reg_params + GET_OFF(src)]);
-                        mov(reg_index, ptr[reg_params + GET_OFF(index)]);
-                        mov(reg_src_aux, ptr[reg_params + GET_OFF(weight)]);
-
                         linear_onnx_planar();
                         break;
                     }
                     case InterpolateLayoutType::block:
                     case InterpolateLayoutType::by_channel: {
-                        mov(reg_src, ptr[reg_params + GET_OFF(weight)]);
-                        mov(reg_src_aux, ptr[reg_params + GET_OFF(weightR)]);
-                        mov(reg_src_aux1, ptr[reg_params + GET_OFF(weightT)]);
-                        mov(reg_src_aux2, ptr[reg_params + GET_OFF(weightB)]);
-                        uni_vbroadcastss(vmm_weightL, ptr[reg_src]);
-                        uni_vbroadcastss(vmm_weightR, ptr[reg_src_aux]);
-                        uni_vbroadcastss(vmm_weightT, ptr[reg_src_aux1]);
-                        uni_vbroadcastss(vmm_weightB, ptr[reg_src_aux2]);
-                        mov(reg_src, ptr[reg_params + GET_OFF(src)]);
-                        mov(reg_src_aux, ptr[reg_params + GET_OFF(srcTR)]);
-                        mov(reg_src_aux1, ptr[reg_params + GET_OFF(srcBL)]);
-                        mov(reg_src_aux2, ptr[reg_params + GET_OFF(srcBR)]);
-
                         linear_onnx_c_gathered();
                         break;
                     }
@@ -124,8 +107,23 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi
                 }
                 break;
             }
-            case InterpolateMode::linear:
             case InterpolateMode::cubic: {
+                switch (jcp_.layout) {
+                    case InterpolateLayoutType::planar: {
+                        cubic_planar();
+                        break;
+                    }
+                    case InterpolateLayoutType::block:
+                    case InterpolateLayoutType::by_channel: {
+                        cubic_c_gathered();
+                        break;
+                    }
+                    default:
+                        assert(!"unsupported memory layout for interpolate layer with cubic mode.");
+                }
+                break;
+            }
+            case InterpolateMode::linear: {
                 assert(!"unsupported mode for interpolate layer with JITTED implimentation.");
                 break;
             }
@@ -138,6 +136,9 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi
 
         for (auto& inj : eltwise_injectors)
             inj->prepare_table();
+        if ((jcp_.mode == InterpolateMode::cubic) && (jcp_.layout == InterpolateLayoutType::planar)) {
+            prepare_cubic_planar_table();
+        }
 
         ker_ = (decltype(ker_)) this->getCode();
     }
@@ -164,7 +165,12 @@ private:
     Xbyak::Reg64 reg_oc_off = rax;
     Xbyak::Reg64 reg_d_weights = rbx;
     Xbyak::Reg64 reg_d_bias = rcx;
-    Xbyak::Reg32 reg_index_oc = edx;
+    Xbyak::Reg32 reg_index_offset = edx;
+
+    // for cubic planar
+    Xbyak::Reg64 reg_tbl_y = rsi;
+    Xbyak::Reg64 reg_tbl_x = rbp;
+    Xbyak::Reg64 reg_table = rdx;   // do not need reg_index_offset in this mode, so use rdx
 
     Vmm vmm_val = Vmm(0);
     Xmm xmm_val = Xmm(0);
@@ -174,6 +180,7 @@ private:
     Vmm vmm_d_weights = Vmm(4);
     Vmm vmm_d_bias = Vmm(5);
 
+    // for linear
     Vmm vmm_weightT = Vmm(15);
     Xmm xmm_weightT = Xmm(15);
     Vmm vmm_weightB = Vmm(14);
@@ -191,6 +198,32 @@ private:
     Vmm vmm_valBR = Vmm(8);
     Xmm xmm_valBR = Xmm(8);
 
+    // for cubic
+    Vmm vmm_src = Vmm(6);
+    Xmm xmm_src = Xmm(6);
+    Vmm vmm_dstX = Vmm(7);
+
+    Vmm vmm_weightX0 = vmm_weightT;
+    Vmm vmm_weightX1 = vmm_weightB;
+    Vmm vmm_weightX2 = vmm_weightL;
+    Vmm vmm_weightX3 = vmm_weightR;
+    Vmm vmm_weightY0 = vmm_valTL;
+    Vmm vmm_weightY1 = Vmm(10);  // vmm_valTR is vmm_val, need reserved
+    Vmm vmm_weightY2 = vmm_valBL;
+    Vmm vmm_weightY3 = vmm_valBR;
+    // cubic planar
+    Vmm vmm_one = vmm_index;
+    Vmm vmm_weightY = vmm_weightY0;
+    Vmm vmm_index_y_itr = vmm_weightY1;
+    Vmm vmm_index_x_itr = vmm_weightY2;
+    Vmm vmm_tbl_y = vmm_weightY3;
+    // temporally used. when post ops, value in vmm_d_weights and vmm_d_bias is re-loaded(init) each time.
+    Vmm vmm_index_in_y = vmm_d_weights;
+    Vmm vmm_index_in_x = vmm_d_bias;
+
+    Xbyak::Label l_table_constant;
+    Opmask k_mask = Xbyak::Opmask(1);
+
     std::vector<std::shared_ptr<jit_uni_eltwise_injector_f32<isa>>> eltwise_injectors;
     std::vector<std::shared_ptr<jit_uni_depthwise_injector_f32<isa>>> depthwise_injectors;
     std::vector<std::shared_ptr<jit_uni_quantization_injector_f32<isa>>> quantization_injectors;
@@ -221,8 +254,8 @@ private:
             Xbyak::Reg64 reg_src_h = rsi;
             mov(reg_src_h, reg_src);
             // index_h * IW * dataSize done when built to avoid redundent compute
-            mov(reg_index_oc, dword[reg_index_h]);
-            add(reg_src_h, reg_index_oc);  // reg_src_h now point to begin of row
+            mov(reg_index_offset, dword[reg_index_h]);
+            add(reg_src_h, reg_index_offset);  // reg_src_h now point to begin of row
 
             // reset index_w, index_w * dataSize done when built to avoid redundent compute
             mov(reg_index, reg_index_w);
@@ -260,8 +293,8 @@ private:
                 jl(nn_tail_loop_end_label, T_NEAR);
 
                 mov(reg_src_aux, reg_src_h);
-                mov(reg_index_oc, dword[reg_index]);
-                add(reg_src_aux, reg_index_oc);
+                mov(reg_index_offset, dword[reg_index]);
+                add(reg_src_aux, reg_index_offset);
 
                 load_scalar(xmm_val, ptr[reg_src_aux], jcp_.src_dt);
                 if (attr_.post_ops_.len_ != 0)
@@ -298,8 +331,8 @@ private:
             jle(nn_loop_end_label, T_NEAR);
 
             mov(reg_src_aux, reg_src);
-            mov(reg_index_oc, dword[reg_index]);
-            add(reg_src_aux, reg_index_oc);
+            mov(reg_index_offset, dword[reg_index]);
+            add(reg_src_aux, reg_index_offset);
 
             load_vector(vmm_val, ptr[reg_src_aux], jcp_.src_dt);
             if (attr_.post_ops_.len_ != 0)
@@ -353,8 +386,8 @@ private:
             // dst and index address is continous, advanced each interator.
             mov(reg_src_aux, reg_src);
             // index*C*dataSize done when built to avoid redundent compute
-            mov(reg_index_oc, dword[reg_index]);
-            add(reg_src_aux, reg_index_oc);
+            mov(reg_index_offset, dword[reg_index]);
+            add(reg_src_aux, reg_index_offset);
 
             mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
             if (attr_.post_ops_.len_ != 0)
@@ -409,11 +442,30 @@ private:
     }
 
     void linear_onnx_c_gathered() {
+        mov(reg_dst, ptr[reg_params + GET_OFF(dst)]);
+
+        mov(reg_src, ptr[reg_params + GET_OFF(weight_ptr[0])]);
+        mov(reg_src_aux, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]);
+        mov(reg_src_aux1, ptr[reg_params + GET_OFF(weight_ptr[0]) + 2 * sizeof(size_t)]);
+        mov(reg_src_aux2, ptr[reg_params + GET_OFF(weight_ptr[0]) + 3 * sizeof(size_t)]);
+        uni_vbroadcastss(vmm_weightL, ptr[reg_src]);
+        uni_vbroadcastss(vmm_weightR, ptr[reg_src_aux]);
+        uni_vbroadcastss(vmm_weightT, ptr[reg_src_aux1]);
+        uni_vbroadcastss(vmm_weightB, ptr[reg_src_aux2]);
+
+        mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]);
+        mov(reg_src_aux, ptr[reg_params + GET_OFF(src_ptr[0]) + sizeof(size_t)]);
+        mov(reg_src_aux1, ptr[reg_params + GET_OFF(src_ptr[0]) + 2 * sizeof(size_t)]);
+        mov(reg_src_aux2, ptr[reg_params + GET_OFF(src_ptr[0]) + 3 * sizeof(size_t)]);
+        mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
+
         int step = vlen / sizeof(float);
         int blk = (isa == cpu::sse42) ? (2 * step) : step;
 
         Xbyak::Label main_loop_label;
         Xbyak::Label main_loop_end_label;
+        Xbyak::Label blk_tail_loop_label;
+        Xbyak::Label blk_tail_loop_end_label;
         Xbyak::Label tail_loop_label;
         Xbyak::Label tail_loop_end_label;
         L(main_loop_label);
@@ -431,14 +483,7 @@ private:
             load_vector(vmm_valBL, ptr[reg_src_aux1], jcp_.src_dt);
             load_vector(vmm_valBR, ptr[reg_src_aux2], jcp_.src_dt);
 
-            // weightT * (srcTL * weightL + srcTR * weightR) +
-            // weightB * (srcBL * weightL + srcBR * weightR);
-            uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR);
-            uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR);
-            uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL);
-            uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL);
-            uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT);
-            uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB);
+            linear_onnx_worker();
 
             if (attr_.post_ops_.len_ != 0) {
                 apply_post_ops(jcp_.dst_dt, false);  // vmm_val is vmm_valTR
@@ -453,12 +498,7 @@ private:
                 load_vector(vmm_valBL, ptr[reg_src_aux1 + sse42_offset * jcp_.src_data_size], jcp_.src_dt);
                 load_vector(vmm_valBR, ptr[reg_src_aux2 + sse42_offset * jcp_.src_data_size], jcp_.src_dt);
 
-                uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR);
-                uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR);
-                uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL);
-                uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL);
-                uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT);
-                uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB);
+                linear_onnx_worker();
 
                 if (attr_.post_ops_.len_ != 0) {
                     apply_post_ops(jcp_.dst_dt, false);
@@ -477,7 +517,7 @@ private:
                 sub(reg_work_amount, step);    // work_amount is c
             } else {
                 int dst_stride = blk * jcp_.OW * jcp_.OH * jcp_.dst_data_size;
-                int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size;;
+                int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size;
                 add(reg_dst, dst_stride);
                 add(reg_src, src_stride);
                 add(reg_src_aux, src_stride);
@@ -490,6 +530,37 @@ private:
         }
         L(main_loop_end_label);
 
+        step = 4;
+        L(blk_tail_loop_label);
+        {
+            cmp(reg_work_amount, step);
+            jl(blk_tail_loop_end_label, T_NEAR);
+
+            // use xmm for 4s in tails
+            load_xmm(xmm_valTL, ptr[reg_src], jcp_.src_dt);
+            load_xmm(xmm_valTR, ptr[reg_src_aux], jcp_.src_dt);
+            load_xmm(xmm_valBL, ptr[reg_src_aux1], jcp_.src_dt);
+            load_xmm(xmm_valBR, ptr[reg_src_aux2], jcp_.src_dt);
+
+            linear_onnx_worker();
+
+            if (attr_.post_ops_.len_ != 0) {
+                apply_post_ops(jcp_.dst_dt, false);  // vmm_val is vmm_valTR
+                add(reg_oc_off, step * sizeof(float));
+            }
+            store_xmm(ptr[reg_dst], xmm_valTR, jcp_.dst_dt);
+
+            add(reg_dst, step * jcp_.dst_data_size);
+            add(reg_src, step * jcp_.src_data_size);
+            add(reg_src_aux, step * jcp_.src_data_size);
+            add(reg_src_aux1, step * jcp_.src_data_size);
+            add(reg_src_aux2, step * jcp_.src_data_size);
+            sub(reg_work_amount, step);
+
+            jmp(blk_tail_loop_label, T_NEAR);
+        }
+        L(blk_tail_loop_end_label);
+
         step = 1;
         L(tail_loop_label);
         {
@@ -502,12 +573,7 @@ private:
             load_scalar(xmm_valBL, ptr[reg_src_aux1], jcp_.src_dt);
             load_scalar(xmm_valBR, ptr[reg_src_aux2], jcp_.src_dt);
 
-            uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR);
-            uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR);
-            uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL);
-            uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL);
-            uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT);
-            uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB);
+            linear_onnx_worker();
 
             if (attr_.post_ops_.len_ != 0) {
                 apply_post_ops(jcp_.dst_dt, false);  // vmm_val is vmm_valTR
@@ -528,6 +594,12 @@ private:
     }
 
     void linear_onnx_planar() {
+        mov(reg_dst, ptr[reg_params + GET_OFF(dst)]);
+        mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]);
+        mov(reg_index, ptr[reg_params + GET_OFF(index)]);
+        mov(reg_src_aux, ptr[reg_params + GET_OFF(weight_ptr[0])]);
+        mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
+
         int step = vlen / sizeof(float);
         int index_stride = jcp_.OW * jcp_.OH * jcp_.indices_size;
         int weight_stride = jcp_.OW * jcp_.OH * sizeof(float);
@@ -563,14 +635,7 @@ private:
             load_vector(vmm_weightT, ptr[reg_src_aux + 2 * weight_stride], memory::f32);
             load_vector(vmm_weightB, ptr[reg_src_aux + 3 * weight_stride], memory::f32);
 
-            // weightT * (srcTL * weightL + srcTR * weightR) +
-            // weightB * (srcBL * weightL + srcBR * weightR);
-            uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR);
-            uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR);
-            uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL);
-            uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL);
-            uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT);
-            uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB);
+            linear_onnx_worker();
 
             if (attr_.post_ops_.len_ != 0) {
                 apply_post_ops(jcp_.dst_dt, true);  // vmm_val is vmm_valTR, broadcase is true
@@ -594,23 +659,23 @@ private:
 
             // still use xmm on avx2/avx512
             mov(reg_src_aux1, reg_src);
-            mov(reg_index_oc, dword[reg_index]);
-            add(reg_src_aux1, reg_index_oc);
+            mov(reg_index_offset, dword[reg_index]);
+            add(reg_src_aux1, reg_index_offset);
             load_scalar(xmm_valTL, ptr[reg_src_aux1], jcp_.src_dt);
 
             mov(reg_src_aux1, reg_src);
-            mov(reg_index_oc, dword[reg_index + index_stride]);
-            add(reg_src_aux1, reg_index_oc);
+            mov(reg_index_offset, dword[reg_index + index_stride]);
+            add(reg_src_aux1, reg_index_offset);
             load_scalar(xmm_valTR, ptr[reg_src_aux1], jcp_.src_dt);
 
             mov(reg_src_aux1, reg_src);
-            mov(reg_index_oc, dword[reg_index + 2 * index_stride]);
-            add(reg_src_aux1, reg_index_oc);
+            mov(reg_index_offset, dword[reg_index + 2 * index_stride]);
+            add(reg_src_aux1, reg_index_offset);
             load_scalar(xmm_valBL, ptr[reg_src_aux1], jcp_.src_dt);
 
             mov(reg_src_aux1, reg_src);
-            mov(reg_index_oc, dword[reg_index + 3 * index_stride]);
-            add(reg_src_aux1, reg_index_oc);
+            mov(reg_index_offset, dword[reg_index + 3 * index_stride]);
+            add(reg_src_aux1, reg_index_offset);
             load_scalar(xmm_valBR, ptr[reg_src_aux1], jcp_.src_dt);
 
             load_scalar(xmm_weightL, ptr[reg_src_aux], memory::f32);
@@ -618,12 +683,7 @@ private:
             load_scalar(xmm_weightT, ptr[reg_src_aux + 2 * weight_stride], memory::f32);
             load_scalar(xmm_weightB, ptr[reg_src_aux + 3 * weight_stride], memory::f32);
 
-            uni_vmulps(xmm_valTR, xmm_valTR, xmm_weightR);
-            uni_vmulps(xmm_valBR, xmm_valBR, xmm_weightR);
-            uni_vfmadd231ps(xmm_valTR, xmm_valTL, xmm_weightL);
-            uni_vfmadd231ps(xmm_valBR, xmm_valBL, xmm_weightL);
-            uni_vmulps(xmm_valTR, xmm_valTR, xmm_weightT);
-            uni_vfmadd231ps(xmm_valTR, xmm_valBR, xmm_weightB);
+            linear_onnx_worker();
 
             if (attr_.post_ops_.len_ != 0) {
                 apply_post_ops(jcp_.dst_dt, true);  // process on vmm_val, vmm_val is vmm_valTR, and bc
@@ -640,6 +700,473 @@ private:
         L(tail_loop_end_label);
     }
 
+    // weightT * (srcTL * weightL + srcTR * weightR) +
+    // weightB * (srcBL * weightL + srcBR * weightR)
+    inline void linear_onnx_worker() {
+        uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightR);
+        uni_vmulps(vmm_valBR, vmm_valBR, vmm_weightR);
+        uni_vfmadd231ps(vmm_valTR, vmm_valTL, vmm_weightL);
+        uni_vfmadd231ps(vmm_valBR, vmm_valBL, vmm_weightL);
+        uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightT);
+        uni_vfmadd231ps(vmm_valTR, vmm_valBR, vmm_weightB);
+    }
+
+    void cubic_c_gathered() {
+        mov(reg_dst, ptr[reg_params + GET_OFF(dst)]);
+        mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]);
+        mov(reg_index, ptr[reg_params + GET_OFF(index)]);
+        mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
+
+        // weight_ptr[0] point to weightX
+        mov(reg_src_aux1, ptr[reg_params + GET_OFF(weight_ptr[0])]);
+        uni_vbroadcastss(vmm_weightX0, ptr[reg_src_aux1]);
+        uni_vbroadcastss(vmm_weightX1, ptr[reg_src_aux1 + 1 * sizeof(float)]);
+        uni_vbroadcastss(vmm_weightX2, ptr[reg_src_aux1 + 2 * sizeof(float)]);
+        uni_vbroadcastss(vmm_weightX3, ptr[reg_src_aux1 + 3 * sizeof(float)]);
+
+        // weight_ptr[1] point to weightY
+        mov(reg_src_aux1, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]);
+        uni_vbroadcastss(vmm_weightY0, ptr[reg_src_aux1]);
+        uni_vbroadcastss(vmm_weightY1, ptr[reg_src_aux1 + 1 * sizeof(float)]);
+        uni_vbroadcastss(vmm_weightY2, ptr[reg_src_aux1 + 2 * sizeof(float)]);
+        uni_vbroadcastss(vmm_weightY3, ptr[reg_src_aux1 + 3 * sizeof(float)]);
+
+        int step = vlen / sizeof(float);
+        int blk = (isa == cpu::sse42) ? (2 * step) : step;
+
+        Xbyak::Label main_loop_label;
+        Xbyak::Label main_loop_end_label;
+        Xbyak::Label tail_loop_label;
+        Xbyak::Label tail_loop_end_label;
+        L(main_loop_label);
+        {
+            if (jcp_.layout == InterpolateLayoutType::by_channel) {
+                cmp(reg_work_amount, step);
+                jl(main_loop_end_label, T_NEAR);
+            } else {
+                cmp(reg_work_amount, 1);
+                jl(tail_loop_end_label, T_NEAR);
+            }
+
+            uni_vpxor(vmm_val, vmm_val, vmm_val);
+
+            cubic_c_gathered_matrix(false);
+
+            if (attr_.post_ops_.len_ != 0) {
+                apply_post_ops(jcp_.dst_dt, false);     // vmm_val is default dst value to post_ops and store
+                add(reg_oc_off, step * sizeof(float));
+            }
+            store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt);
+
+            if ((isa == cpu::sse42) && (jcp_.layout == InterpolateLayoutType::block)) {
+                int sse42_offset = 4;  // vmm is xmm here
+                add(reg_src, sse42_offset * jcp_.src_data_size);
+                add(reg_dst, sse42_offset * jcp_.dst_data_size);
+
+                uni_vpxor(vmm_val, vmm_val, vmm_val);
+
+                cubic_c_gathered_matrix(false);
+
+                if (attr_.post_ops_.len_ != 0) {
+                    apply_post_ops(jcp_.dst_dt, false);
+                    add(reg_oc_off, step * sizeof(float));  // second step for one blk
+                }
+                store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt);
+
+                sub(reg_src, sse42_offset * jcp_.src_data_size);
+                sub(reg_dst, sse42_offset * jcp_.dst_data_size);
+            }
+            if (jcp_.layout == InterpolateLayoutType::by_channel) {
+                int dst_stride = step * jcp_.dst_data_size;
+                int src_stride = step * jcp_.src_data_size;
+                add(reg_dst, dst_stride);
+                add(reg_src, src_stride);
+                sub(reg_work_amount, step);    // work_amount is c
+            } else {
+                int dst_stride = blk * jcp_.OW * jcp_.OH * jcp_.dst_data_size;
+                int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size;
+                add(reg_dst, dst_stride);
+                add(reg_src, src_stride);
+                sub(reg_work_amount, 1);  // work_amount = div_up(c, blk), no tails
+            }
+
+            jmp(main_loop_label, T_NEAR);
+        }
+        L(main_loop_end_label);
+
+        // only for by_channel layout for tails.
+        step = 1;
+        L(tail_loop_label);
+        {
+            cmp(reg_work_amount, 1);
+            jl(tail_loop_end_label, T_NEAR);
+
+            // store final computed value
+            uni_vpxor(vmm_val, vmm_val, vmm_val);
+
+            cubic_c_gathered_matrix(true);
+
+            if (attr_.post_ops_.len_ != 0) {
+                apply_post_ops(jcp_.dst_dt, false);     // vmm_val is default dst value
+                add(reg_oc_off, step * sizeof(float));
+            }
+            store_scalar(ptr[reg_dst], xmm_val, jcp_.dst_dt);
+
+            int dst_stride = step * jcp_.dst_data_size;
+            int src_stride = step * jcp_.src_data_size;
+            add(reg_dst, dst_stride);
+            add(reg_src, src_stride);
+            sub(reg_work_amount, step);    // work_amount is c
+
+            jmp(tail_loop_label, T_NEAR);
+        }
+        L(tail_loop_end_label);
+    }
+
+    inline void cubic_c_gathered_matrix(bool is_scalar) {
+        // y0:  (x0 * weightX0 + x1 * weightX1 + x2 * weightX2 + x3 * weightX3) * weightY0
+        cubic_c_gathered_line(0, vmm_weightY0, is_scalar);
+        // y1
+        cubic_c_gathered_line(4, vmm_weightY1, is_scalar);
+        // y2
+        cubic_c_gathered_line(8, vmm_weightY2, is_scalar);
+        // y3
+        cubic_c_gathered_line(12, vmm_weightY3, is_scalar);
+    }
+
+    inline void cubic_c_gathered_line(int index_start, Vmm vmm_weight, bool is_scalar) {
+        uni_vpxor(vmm_dstX, vmm_dstX, vmm_dstX);
+        cubic_c_gathered_pixel(index_start, vmm_weightX0, is_scalar);
+        cubic_c_gathered_pixel(index_start + 1, vmm_weightX1, is_scalar);
+        cubic_c_gathered_pixel(index_start + 2, vmm_weightX2, is_scalar);
+        cubic_c_gathered_pixel(index_start + 3, vmm_weightX3, is_scalar);
+        uni_vfmadd231ps(vmm_val, vmm_dstX, vmm_weight);
+    }
+
+    inline void cubic_c_gathered_pixel(int i, Vmm vmm_weight, bool is_scalar) {
+        mov(reg_src_aux, reg_src);
+        mov(reg_index_offset, dword[reg_index + i * jcp_.indices_size]);
+        add(reg_src_aux, reg_index_offset);
+        if (!is_scalar) {
+            load_vector(vmm_src, ptr[reg_src_aux], jcp_.src_dt);
+        } else {
+            load_scalar(xmm_src, ptr[reg_src_aux], jcp_.src_dt);
+        }
+        uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weight);
+    }
+
+    void cubic_planar() {
+        mov(reg_table, l_table_constant);
+        // src_ptr[2] for oh sequence, src_ptr[3] for ow sequence
+        mov(reg_tbl_y, ptr[reg_params + GET_OFF(src_ptr[0]) + 2 * sizeof(size_t)]);
+        mov(reg_tbl_x, ptr[reg_params + GET_OFF(src_ptr[0]) + 3 * sizeof(size_t)]);
+        uni_vmovdqu(vmm_one, cubic_planar_table_val(0));
+        uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
+
+        mov(reg_dst, ptr[reg_params + GET_OFF(dst)]);
+        mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]);
+        // index_OW
+        mov(reg_index, ptr[reg_params + GET_OFF(index)]);
+        // index_OH from src_ptr[1]
+        Xbyak::Reg64 reg_index_y = reg_src_aux;
+        mov(reg_index_y, ptr[reg_params + GET_OFF(src_ptr[0]) + sizeof(size_t)]);
+        // weight_OW
+        Xbyak::Reg64 reg_weight_x = reg_src_aux1;
+        mov(reg_weight_x, ptr[reg_params + GET_OFF(weight_ptr[0])]);
+        // weight_OH
+        Xbyak::Reg64 reg_weight_y = reg_src_aux2;
+        mov(reg_weight_y, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]);
+        mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
+
+        int step = vlen / sizeof(float);
+        int grid_len = 4;
+
+        // 0   1   2   3   4   5   6   7   8   9   10   11   12   13   14   15   16   17   18   19
+        // 20  21  22  23  24  25  26  27  28  29  30   31   32   33   34   35   36   37   38   39
+        // for 3th step(8): 16  17  18  19  20  21  22  23
+        //               y: 0   0   0   0   1   1   1   1
+        //               x: 16  17  18  19  0   1   2   3
+
+        Xbyak::Label main_loop_label;
+        Xbyak::Label main_loop_end_label;
+        Xbyak::Label tail_loop_label;
+        Xbyak::Label tail_loop_end_label;
+        L(main_loop_label);
+        {
+            cmp(reg_work_amount, step);
+            jl(main_loop_end_label, T_NEAR);
+
+            // vmm_tbl_y: (0 0 0 0 1 1 1 1 * index_size) --> (0 0 0 0 4 4 4 4)
+            uni_vmovdqu(vmm_tbl_y, ptr[reg_tbl_y]);
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            // vmm_index_in_y: 0 0 0 0 2 2 2 2
+            vpgatherdd(vmm_index_in_y, ptr[reg_index_y + vmm_tbl_y], vmm_mask);
+
+            // use vmm_val temporally for value in reg_tbl_x: 16  17  18  19  0   1   2   3
+            uni_vmovdqu(vmm_val, ptr[reg_tbl_x]);
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            // e.g. vmm_index_in_x: 32 34 36 38 0 2 4 6, now save src index.
+            vpgatherdd(vmm_index_in_x, ptr[reg_index + vmm_val], vmm_mask);
+
+            // build weightX used in y0-y3
+            // weight format: w0_0 w1_0 w2_0 w3_0 w0_1 w1_1 w2_1 w3_1 ...
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            vgatherdps(vmm_weightX0, ptr[reg_weight_x + vmm_val * grid_len], vmm_mask);  // 4 in vmm_val for weight_size, another 4 for grid_len
+
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            // shift weight_size then gather second weight
+            vgatherdps(vmm_weightX1, ptr[reg_weight_x + sizeof(float) + (vmm_val * grid_len)], vmm_mask);
+
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            vgatherdps(vmm_weightX2, ptr[reg_weight_x + 2 * sizeof(float) + (vmm_val * grid_len)], vmm_mask);
+
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            vgatherdps(vmm_weightX3, ptr[reg_weight_x + 3 * sizeof(float) + (vmm_val * grid_len)], vmm_mask);
+            // vmm_val is now relieved and used for dst_value
+
+            uni_vpxor(vmm_val, vmm_val, vmm_val);
+            // y0
+            vpsubd(vmm_index_y_itr, vmm_index_in_y, vmm_one);
+            // crop to [0, IH - 1]
+            vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1));
+            vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero);
+
+            // weight y0
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            vgatherdps(vmm_weightY, ptr[reg_weight_y + (vmm_tbl_y * grid_len)], vmm_mask);
+            cubic_planar_line(false);
+
+            // y1
+            // crop to [0, IH - 1]
+            vpminsd(vmm_index_y_itr, vmm_index_in_y, cubic_planar_table_val(1));
+            vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero);
+            // weight y1: shift weight_size
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            vgatherdps(vmm_weightY, ptr[reg_weight_y + sizeof(float) + (vmm_tbl_y * grid_len)], vmm_mask);
+            cubic_planar_line(false);
+
+            // y2
+            vpaddd(vmm_index_y_itr, vmm_index_in_y, vmm_one);
+            // crop to [0, IH - 1]
+            vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1));
+            vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero);
+            // weight y2
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            vgatherdps(vmm_weightY, ptr[reg_weight_y + 2 * sizeof(float) + (vmm_tbl_y * grid_len)], vmm_mask);
+            cubic_planar_line(false);
+
+            // y3
+            vpaddd(vmm_index_y_itr, vmm_index_in_y, vmm_one);
+            vpaddd(vmm_index_y_itr, vmm_index_y_itr, vmm_one);
+            // crop to [0, IH - 1]
+            vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1));
+            vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero);
+            // weight y3
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            vgatherdps(vmm_weightY, ptr[reg_weight_y + 3 * sizeof(float) + (vmm_tbl_y * grid_len)], vmm_mask);
+            cubic_planar_line(false);
+
+            if (attr_.post_ops_.len_ != 0) {
+                apply_post_ops(jcp_.dst_dt, true);  // oc_off is broadcast and always the same value for this channel
+            }
+            store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt);
+
+            add(reg_tbl_y, step * sizeof(int));  // sizeof(int): sequence by dd()
+            add(reg_tbl_x, step * sizeof(int));
+            add(reg_dst, step * jcp_.dst_data_size);
+
+            sub(reg_work_amount, step);
+
+            jmp(main_loop_label, T_NEAR);
+        }
+        L(main_loop_end_label);
+
+        step = 1;
+        L(tail_loop_label);
+        {
+            cmp(reg_work_amount, 1);
+            jl(tail_loop_end_label, T_NEAR);
+
+            // get idx for input
+            movss(Xmm(vmm_tbl_y.getIdx()), ptr[reg_tbl_y]);
+            gather_i32_indices(vmm_index_in_y, reg_index_y, 0, vmm_tbl_y, 1, memory::s32, true);
+
+            movss(Xmm(vmm_val.getIdx()), ptr[reg_tbl_x]);
+            gather_i32_indices(vmm_index_in_x, reg_index, 0, vmm_val, 1, memory::s32, true);
+            // gather weightX by input idx, used in y0-y3
+            gather_i32_indices(vmm_weightX0, reg_weight_x, 0, vmm_val, grid_len, memory::f32, true);
+            gather_i32_indices(vmm_weightX1, reg_weight_x, sizeof(float), vmm_val, grid_len, memory::f32, true);
+            gather_i32_indices(vmm_weightX2, reg_weight_x, 2 * sizeof(float), vmm_val, grid_len, memory::f32, true);
+            gather_i32_indices(vmm_weightX3, reg_weight_x, 3 * sizeof(float), vmm_val, grid_len, memory::f32, true);
+            // vmm_val is now relieved and used for dst_value
+
+            uni_vpxor(vmm_val, vmm_val, vmm_val);
+            // y0
+            vpsubd(vmm_index_y_itr, vmm_index_in_y, vmm_one);
+            // crop to [0, IH - 1]
+            vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1));
+            vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero);
+
+            gather_i32_indices(vmm_weightY, reg_weight_y, 0, vmm_tbl_y, grid_len, memory::f32, true);
+            cubic_planar_line(true);
+
+            // y1
+            // crop to [0, IH - 1]
+            vpminsd(vmm_index_y_itr, vmm_index_in_y, cubic_planar_table_val(1));
+            vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero);
+            // weight y1: shift weight_size
+            gather_i32_indices(vmm_weightY, reg_weight_y, sizeof(float), vmm_tbl_y, grid_len, memory::f32, true);
+            cubic_planar_line(true);
+
+            // y2
+            vpaddd(vmm_index_y_itr, vmm_index_in_y, vmm_one);
+            // crop to [0, IH - 1]
+            vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1));
+            vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero);
+            // weight y2
+            gather_i32_indices(vmm_weightY, reg_weight_y, 2 * sizeof(float), vmm_tbl_y, grid_len, memory::f32, true);
+            cubic_planar_line(true);
+
+            // y3
+            vpaddd(vmm_index_y_itr, vmm_index_in_y, vmm_one);
+            vpaddd(vmm_index_y_itr, vmm_index_y_itr, vmm_one);
+            // crop to [0, IH - 1]
+            vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1));
+            vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero);
+            // weight y3
+            gather_i32_indices(vmm_weightY, reg_weight_y, 3 * sizeof(float), vmm_tbl_y, grid_len, memory::f32, true);
+            cubic_planar_line(true);
+
+            if (attr_.post_ops_.len_ != 0) {
+                apply_post_ops(jcp_.dst_dt, true);  // oc_off is broadcast and always the same value for this channel
+            }
+            store_scalar(ptr[reg_dst], Xmm(vmm_val.getIdx()), jcp_.dst_dt);
+
+            add(reg_tbl_y, step * sizeof(int));  // sizeof(int): sequence with dd()
+            add(reg_tbl_x, step * sizeof(int));
+            add(reg_dst, step * jcp_.dst_data_size);
+
+            sub(reg_work_amount, step);
+
+            jmp(tail_loop_label, T_NEAR);
+        }
+        L(tail_loop_end_label);
+    }
+
+    inline void cubic_planar_line(bool is_scalar) {
+        uni_vpxor(vmm_dstX, vmm_dstX, vmm_dstX);
+        cubic_planar_pixel(0, is_scalar);
+        cubic_planar_pixel(1, is_scalar);
+        cubic_planar_pixel(2, is_scalar);
+        cubic_planar_pixel(3, is_scalar);
+        uni_vfmadd231ps(vmm_val, vmm_dstX, vmm_weightY);
+    }
+
+    inline void cubic_planar_pixel(int itr, bool is_scalar) {
+        // vmm_index_in_x have index for src
+        if (itr == 0) {
+            vpsubd(vmm_index_x_itr, vmm_index_in_x, vmm_one);
+        } else if (itr == 1) {
+            vpaddd(vmm_index_x_itr, vmm_index_in_x, vmm_zero);
+        } else if (itr == 2) {
+            vpaddd(vmm_index_x_itr, vmm_index_in_x, vmm_one);
+        } else if (itr == 3) {
+            vpaddd(vmm_index_x_itr, vmm_index_in_x, vmm_one);
+            vpaddd(vmm_index_x_itr, vmm_index_x_itr, vmm_one);
+        }
+
+        // crop to [0, IW - 1]
+        vpminsd(vmm_index_x_itr, vmm_index_x_itr, cubic_planar_table_val(2));
+        vpmaxsd(vmm_index_x_itr, vmm_index_x_itr, vmm_zero);
+
+        // value
+        // index is: ptr[reg_src + (vmm_index_y_itr * jcp_.IW + vmm_index_x_itr) * jcp_.src_data_size]
+        uni_vmovdqu(vmm_mask, cubic_planar_table_val(2));
+        vpaddd(vmm_mask, vmm_mask, vmm_one);  // (IW - 1) + 1 = IW
+        uni_vpmulld(vmm_mask, vmm_mask, vmm_index_y_itr);
+        uni_vpaddd(vmm_index_x_itr, vmm_index_x_itr, vmm_mask);
+        gather_i32_indices(vmm_src, reg_src, 0, vmm_index_x_itr, jcp_.src_data_size, memory::f32, is_scalar);
+
+        if (itr == 0) {
+            uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weightX0);
+        } else if (itr == 1) {
+            uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weightX1);
+        } else if (itr == 2) {
+            uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weightX2);
+        } else if (itr == 3) {
+            uni_vfmadd231ps(vmm_dstX, vmm_src, vmm_weightX3);
+        }
+    }
+
+    inline void prepare_cubic_planar_table() {
+        auto broadcast_int = [&](int val) {
+            for (size_t d = 0; d < vlen / sizeof(int); ++d) {
+                dd(val);
+            }
+        };
+
+        align(64);
+        L(l_table_constant);
+        broadcast_int(vals_for_cubic_planar.int_one);
+        broadcast_int(jcp_.IH - 1);
+        broadcast_int(jcp_.IW - 1);
+        dd(vals_for_cubic_planar.mask_gather_avx512);
+    }
+
+    struct vals_for_cubic_planar_type {
+        int int_one = 0x00000001;
+        int mask_gather_avx512 = 0x0000ffff;  // 00000000000000001111111111111111
+    } vals_for_cubic_planar;
+
+    inline Xbyak::Address cubic_planar_table_val(int index) {
+        return ptr[reg_table + index * vlen];
+    }
+
+    // always gather to Vmm, compute with Vmm, store with Xmm if scalar
+    inline void gather_i32_indices(Vmm vmm_src, const Xbyak::Reg64 &base, int offset, Vmm vmm_indices, int scale,
+                                memory::data_type src_dt, bool is_scalar) {
+        Xbyak::Address table_idx = ptr[base + offset + vmm_indices * scale];
+        if ((isa == cpu::avx512_common) && !is_scalar) {
+            // [0-15] bit of int to mask
+            kmovw(k_mask, cubic_planar_table_val(3));
+            if (src_dt == memory::f32) {
+                vgatherdps(vmm_src | k_mask, table_idx);  // dword index, packed single data
+            } else if (src_dt == memory::s32) {
+                vpgatherdd(vmm_src | k_mask, table_idx);  // dword index, dword data
+            }
+        } else if ((isa == cpu::avx2) && !is_scalar) {
+            uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
+            if (src_dt == memory::f32) {
+                vgatherdps(vmm_src, table_idx, vmm_mask);
+            } else if (src_dt == memory::s32) {
+                vpgatherdd(vmm_src, table_idx, vmm_mask);
+            }
+        } else {
+            const int gpr_size = 8;
+            sub(rsp, gpr_size);
+            // move content in register to content in address(ptr[])
+            mov(ptr[rsp], reg_tmp_64);
+
+            // replace index with value in stack
+            sub(rsp, vlen);
+            uni_vmovdqu(ptr[rsp], vmm_indices);
+
+            int repeats = is_scalar ? 1 : vlen / sizeof(float);
+            for (size_t i = 0; i < repeats; ++i) {
+                mov(reg_tmp_64.cvt32(), ptr[rsp + i * sizeof(int)]);       // sizeof(int)  index_size
+                table_idx = ptr[base + offset + reg_tmp_64 * scale];       // scale: sizeof(float)   value_size
+                mov(reg_tmp_64.cvt32(), table_idx);
+                mov(ptr[rsp + i * sizeof(int)], reg_tmp_64.cvt32());
+            }
+
+            uni_vmovups(vmm_src, ptr[rsp]);
+            add(rsp, vlen);
+            // restore GPR state
+            mov(reg_tmp_64, ptr[rsp]);
+            add(rsp, gpr_size);
+        }
+    }
+
     inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, memory::data_type src_dt) {
         switch (src_dt) {
             case memory::f32:
@@ -653,10 +1180,9 @@ private:
                 uni_vpmovzxbd(vmm_src, op);
                 break;
             case memory::bf16:
-                if (isa != cpu::sse42) {
-                    vpmovzxwd(vmm_src, op);
-                }
+                uni_vpmovzxwd(vmm_src, op);
                 uni_vpslld(vmm_src, vmm_src, 16);
+                break;
             default:
                 assert(!"unknown dst_dt");
         }
@@ -665,6 +1191,30 @@ private:
             uni_vcvtdq2ps(vmm_src, vmm_src);
     }
 
+    inline void load_xmm(Xmm xmm_src, const Xbyak::Address &op, memory::data_type src_dt) {
+        switch (src_dt) {
+            case memory::f32:
+            case memory::s32:
+                uni_vmovups(xmm_src, op);
+                break;
+            case memory::s8:
+                uni_vpmovsxbd(xmm_src, op);
+                break;
+            case memory::u8:
+                uni_vpmovzxbd(xmm_src, op);
+                break;
+            case memory::bf16:
+                uni_vpmovzxwd(xmm_src, op);
+                uni_vpslld(xmm_src, xmm_src, 16);
+                break;
+            default:
+                assert(!"unknown dst_dt");
+        }
+
+        if (src_dt != memory::f32 && src_dt != data_type::bf16)
+            uni_vcvtdq2ps(xmm_src, xmm_src);
+    }
+
     inline void load_scalar(Xmm xmm_src, const Xbyak::Address &op, memory::data_type src_dt) {
         switch (src_dt) {
             case memory::f32:
@@ -682,6 +1232,7 @@ private:
             case memory::bf16:
                 pinsrw(xmm_src, op, 0x0);
                 uni_vpslld(xmm_src, xmm_src, 16);
+                break;
             default:
                 assert(!"unknown dst_dt");
         }
@@ -736,6 +1287,37 @@ private:
         }
     }
 
+    inline void store_xmm(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt) {
+        if (dst_dt != memory::f32 && dst_dt != memory::bf16) {
+            uni_vcvtps2dq(xmm_dst, xmm_dst);
+        }
+
+        switch (dst_dt) {
+            case memory::f32:
+            case memory::s32:
+                uni_vmovups(op, xmm_dst);
+                break;
+            case memory::s8:
+                uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
+                uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
+                movd(op, xmm_dst);
+                break;
+            case memory::u8:
+                uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
+                uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
+                movd(op, xmm_dst);
+                break;
+            case memory::bf16:
+                pshuflw(xmm_dst, xmm_dst, 0x0d);  // 01 01 01 01 --> 01 01 11 00  imm=0b00001101
+                pshufhw(xmm_dst, xmm_dst, 0x0d);  // 01 01 11 00 --> 11 00 11 00
+                pshufd(xmm_dst, xmm_dst, 0x08);   // 11 00 11 00 --> 11 11 00 00  imm=0b00001000
+                vmovq(op, xmm_dst);
+                break;
+            default:
+                assert(!"unknown dst_dt");
+        }
+    }
+
     inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt) {
         if (dst_dt != data_type::f32 && dst_dt != data_type::bf16) {
             uni_vcvtps2dq(xmm_dst, xmm_dst);
@@ -761,6 +1343,7 @@ private:
             case memory::bf16:
                 uni_vpsrld(xmm_dst, xmm_dst, 16);
                 pextrw(op, xmm_dst, 0x0);
+                break;
             default:
                 assert(!"unknown dst_dt");
         }
@@ -1056,7 +1639,7 @@ void MKLDNNInterpolateNode::initSupportedPrimitiveDescriptors() {
         supportedPrimitiveDescriptors.push_back({config, implDetail, dataFormat});
     };
 
-    if (mode == InterpolateMode::nearest || mode == InterpolateMode::linear_onnx) {
+    if (mode != InterpolateMode::linear) {
         // blk and by_channel JIT kernel on sse42 or above machine
         if (mayiuse(cpu::sse42)) {
             if (getParentEdgeAt(DATA_ID)->getDims().ndims() == 4) {
@@ -1091,7 +1674,7 @@ void MKLDNNInterpolateNode::initSupportedPrimitiveDescriptors() {
         if (mayiuse(cpu::avx2) && inputPrec == Precision::FP32) {
             pushDesc(MKLDNNMemory::GetPlainFormat(getParentEdgeAt(DATA_ID)->getDims()), jit_avx2);
         }
-    } else if (mode == InterpolateMode::linear || mode == InterpolateMode::cubic) {
+    } else {
         pushDesc(MKLDNNMemory::GetPlainFormat(getParentEdgeAt(DATA_ID)->getDims()), ref);
     }
 }
@@ -1129,8 +1712,8 @@ void MKLDNNInterpolateNode::createPrimitive() {
     size_t dimSize = dstDim.size();
     jcp.OW = dstDim[dimSize - 1];
     jcp.OH = dstDim[dimSize - 2];
-    jcp.IW = srcDim[dimSize - 1];
-    jcp.IH = srcDim[dimSize - 2];
+    jcp.IW = srcDimPad[dimSize - 1];
+    jcp.IH = srcDimPad[dimSize - 2];
 
     if (MKLDNNMemory::GetPlainLayout(getChildEdgeAt(0)->getDims()) == selected_layout) {
         jcp.layout = InterpolateLayoutType::planar;
@@ -1140,7 +1723,7 @@ void MKLDNNInterpolateNode::createPrimitive() {
         jcp.layout = InterpolateLayoutType::block;
     }
 
-    if (mode == InterpolateMode::nearest || mode == InterpolateMode::linear_onnx) {
+    if (mode == InterpolateMode::nearest || mode == InterpolateMode::linear_onnx || mode == InterpolateMode::cubic) {
         if (jcp.layout != InterpolateLayoutType::planar) {
             if (mayiuse(cpu::avx512_common)) {
                 interpolateKernel.reset(new jit_uni_interpolate_kernel_f32<cpu::avx512_common>(jcp, *attr.get()));
@@ -1175,11 +1758,11 @@ void MKLDNNInterpolateNode::createPrimitive() {
             break;
         }
         case InterpolateMode::linear: {
-            buidTblLinear(srcDimPad5d, dstDim5d, dataScales, LINEAR_KERNEL, antialias);
+            buildTblLinear(srcDimPad5d, dstDim5d, dataScales, LINEAR_KERNEL, antialias);
             break;
         }
         case InterpolateMode::cubic: {
-            buidTblCubic(srcDimPad5d, dstDim5d, dataScales, cubeCoeff);
+            buildTblCubic(srcDimPad5d, dstDim5d, dataScales, cubeCoeff, jcp.layout);
             break;
         }
         default: {
@@ -1338,7 +1921,7 @@ static inline float triangleCoeff(float x) {
 // wd .........wd, wh............wh, ww.............ww, id...........id, ih............ih, iw..............iw
 //                        |                                                      |
 //                   wh0.....wh_diameter                                    ih0.....ih_diameter
-void MKLDNNInterpolateNode::buidTblLinear(SizeVector& srcDimPad5d, SizeVector& dstDim5d,
+void MKLDNNInterpolateNode::buildTblLinear(SizeVector& srcDimPad5d, SizeVector& dstDim5d,
                                             std::vector<float>& dataScales, int kernel_width, bool antialias) {
     int dimSize = srcDim.size();
     float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f;
@@ -1429,7 +2012,8 @@ std::vector<float> MKLDNNInterpolateNode::getCubicCoeffs(float mantissa, float a
 // table layout:
 // OW      OW         OW         OW         OW          OH       OH           OH           OH           OH
 // x_idx   x_weight0  x_weight1  x_weight2  x_weight3   y_idx    y_weight0    y_weight1    y_weight2    y_weight3
-void MKLDNNInterpolateNode::buidTblCubic(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector<float>& dataScales, float cubicCoeff) {
+void MKLDNNInterpolateNode::buildTblCubic(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector<float>& dataScales,
+                                        float cubicCoeff, InterpolateLayoutType layout) {
     int dimSize = srcDim.size();
     float fy = dataScales[dimSize - 2];
     float fx = dataScales[dimSize - 1];
@@ -1438,9 +2022,18 @@ void MKLDNNInterpolateNode::buidTblCubic(SizeVector& srcDimPad5d, SizeVector& ds
 
     // idxNum for index, CUBIC_GRID_LEN for weight
     const int idxNum = 1;
-    indexTable.resize((CUBIC_GRID_LEN + idxNum) * OW + (CUBIC_GRID_LEN + idxNum) * OH);
-    int *xOrigin = static_cast<int*>(&indexTable[0]);
-    float *xFactor = reinterpret_cast<float*>(&indexTable[OW]);
+    size_t idxWeightSize = (CUBIC_GRID_LEN + idxNum) * OW + (CUBIC_GRID_LEN + idxNum) * OH;
+    if (layout != InterpolateLayoutType::planar) {
+        indexTable.resize(idxWeightSize);
+    } else {
+        size_t sequenceSize = 2 * OH * OW;
+        indexTable.resize(idxWeightSize + sequenceSize);
+    }
+
+    int tblAdvance = 0;
+    int *xOrigin = static_cast<int*>(&indexTable[tblAdvance]);
+    tblAdvance += OW;
+    float *xFactor = reinterpret_cast<float*>(&indexTable[tblAdvance]);
     for (int ox = 0; ox < OW; ox++) {
         float ix = coordTransToInput(ox, fx, IW, OW);
         int ix_r = static_cast<int>(std::floor(ix));
@@ -1453,8 +2046,10 @@ void MKLDNNInterpolateNode::buidTblCubic(SizeVector& srcDimPad5d, SizeVector& ds
         xFactor[CUBIC_GRID_LEN * ox + 3] = coffes[3];
     }
 
-    int *yOrigin = static_cast<int*>(&indexTable[(CUBIC_GRID_LEN + idxNum) * OW]);
-    float *yFactor = reinterpret_cast<float*>(&indexTable[(CUBIC_GRID_LEN + idxNum) * OW + OH]);
+    tblAdvance += CUBIC_GRID_LEN * OW;
+    int *yOrigin = static_cast<int*>(&indexTable[tblAdvance]);
+    tblAdvance += OH;
+    float *yFactor = reinterpret_cast<float*>(&indexTable[tblAdvance]);
     for (int oy = 0; oy < OH; oy++) {
         float iy = coordTransToInput(oy, fy, IH, OH);
         int iy_r = static_cast<int>(std::floor(iy));
@@ -1466,6 +2061,20 @@ void MKLDNNInterpolateNode::buidTblCubic(SizeVector& srcDimPad5d, SizeVector& ds
         yFactor[CUBIC_GRID_LEN * oy + 2] = coffes[2];
         yFactor[CUBIC_GRID_LEN * oy + 3] = coffes[3];
     }
+
+    if (layout == InterpolateLayoutType::planar) {
+        tblAdvance += CUBIC_GRID_LEN * OH;
+        int *sequenceOH = static_cast<int*>(&indexTable[tblAdvance]);
+        tblAdvance += OH * OW;
+        int *sequenceOW = static_cast<int*>(&indexTable[tblAdvance]);
+        for (int h = 0; h < OH; ++h) {
+            int offset = h * OW;
+            for (int w = 0; w < OW; ++w) {
+                sequenceOH[offset + w] = h * sizeof(int);
+                sequenceOW[offset + w] = w * sizeof(int);
+            }
+        }
+    }
 }
 
 void MKLDNNInterpolateNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) {
@@ -1531,6 +2140,16 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) {
     auto srcDimPad5d = to5Dim(srcDimPad);
     auto dstDim5d = to5Dim(dstDim);
 
+    InterpolateLayoutType layout;
+    Layout selected_layout = getParentEdgeAt(DATA_ID)->getDesc().getLayout();
+    if (MKLDNNMemory::GetPlainLayout(getChildEdgeAt(0)->getDims()) == selected_layout) {
+        layout = InterpolateLayoutType::planar;
+    } else if ((selected_layout == NHWC) || (selected_layout == NDHWC)) {
+        layout = InterpolateLayoutType::by_channel;
+    } else {
+        layout = InterpolateLayoutType::block;
+    }
+
     uint8_t *src_data = nullptr;
     std::vector<uint8_t> srcPadded;
     if (hasPad) {
@@ -1542,16 +2161,53 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) {
 
         SizeVector inShapeBlock = getBlockND(srcDim5d);
         SizeVector inShapePadBlock = getBlockND(srcDimPad5d);
-        srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0);
-        uint8_t *src_data_pad = static_cast<uint8_t *>(&srcPadded[0]);
-
-        parallel_for4d(srcDim5d[0], srcDim5d[1], srcDim5d[2], srcDim5d[3], [&](int n, int c, int d, int h) {
-            uint8_t *src = src_data_origin + (inShapeBlock[1] * n + inShapeBlock[2] * c + inShapeBlock[3] * d + inShapeBlock[4] * h) * srcDataSize;
-            uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + inShapePadBlock[2] * (c + padB1) +
-                           inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + padB4) * srcDataSize;
-            cpu_memcpy(srcPad, src, srcDim5d[4] * srcDataSize);
-        });
-        src_data = src_data_pad;
+
+        if (layout == InterpolateLayoutType::planar) {
+            srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0);
+            uint8_t *src_data_pad = static_cast<uint8_t *>(&srcPadded[0]);
+            parallel_for4d(srcDim5d[0], srcDim5d[1], srcDim5d[2], srcDim5d[3], [&](int n, int c, int d, int h) {
+                uint8_t *src = src_data_origin + (inShapeBlock[1] * n + inShapeBlock[2] * c + inShapeBlock[3] * d + inShapeBlock[4] * h) * srcDataSize;
+                uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + inShapePadBlock[2] * (c + padB1) +
+                               inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + padB4) * srcDataSize;
+                cpu_memcpy(srcPad, src, srcDim5d[4] * srcDataSize);
+            });
+            src_data = src_data_pad;
+        } else if (layout == InterpolateLayoutType::by_channel) {
+            srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0);
+            uint8_t *src_data_pad = static_cast<uint8_t *>(&srcPadded[0]);
+            parallel_for4d(srcDim5d[0], srcDim5d[2], srcDim5d[3], srcDim5d[4], [&](int n, int d, int h, int w) {
+                uint8_t *src = src_data_origin + (inShapeBlock[1] * n +
+                                (inShapeBlock[3] * d + inShapeBlock[4] * h + inShapeBlock[5] * w) * srcDim5d[1]) * srcDataSize;
+                uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + (inShapePadBlock[3] * (d + padB2) +
+                                inShapePadBlock[4] * (h + padB3) + inShapePadBlock[5] * (w + padB4)) * srcDimPad5d[1] + padB1) * srcDataSize;
+                cpu_memcpy(srcPad, src, srcDim5d[1] * srcDataSize);
+            });
+            src_data = src_data_pad;
+        } else if (layout == InterpolateLayoutType::block) {
+            size_t blkSize = mayiuse(cpu::avx512_common) ? 16 : 8;
+            size_t CB = div_up(srcDimPad5d[1], blkSize);
+            size_t eltsTotal = srcDimPad5d[0] * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize;
+            srcPadded.resize(eltsTotal * srcDataSize, 0x0);
+            uint8_t *src_data_pad = static_cast<uint8_t *>(&srcPadded[0]);
+            if ((srcDim5d[0] != srcDimPad5d[0]) || (srcDim5d[1] != srcDimPad5d[1])) {
+                THROW_IE_EXCEPTION << "Interpolate layer with name '" << getName() <<
+                "' does not support padding on batch and channel dimensions";
+            }
+            parallel_for5d(srcDim5d[0], CB, srcDim5d[2], srcDim5d[3], srcDim5d[4], [&](int n, int cb, int d, int h, int w) {
+                uint8_t *src = src_data_origin + (n * CB * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize
+                                               + (cb * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize
+                                               + (d * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize
+                                               + (h * srcDim5d[4] * blkSize) * srcDataSize
+                                               + (w * blkSize) * srcDataSize;
+                uint8_t *srcPad = src_data_pad + (n * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize
+                                               + (cb * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize
+                                               + ((d + padB2) * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize
+                                               + ((h + padB3) * srcDimPad5d[4] * blkSize) * srcDataSize
+                                               + ((w + padB4) * blkSize) * srcDataSize;
+                cpu_memcpy(srcPad, src, blkSize * srcDataSize);
+            });
+            src_data = src_data_pad;
+        }
     } else {
         src_data = src_data_origin;
     }
@@ -1562,13 +2218,11 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) {
     if (dimSize > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f)) {
         THROW_IE_EXCEPTION << "Interpolate layer only supports resize on spatial dimensions(depth, height and width)";
     }
-    Layout layout = getParentEdgeAt(DATA_ID)->getDesc().getLayout();
-    bool isPlanar = (layout == NC || layout == NCHW || layout == NCDHW) ? true : false;
 
     switch (mode) {
         case InterpolateMode::nearest: {
             if (interpolateKernel) {
-                if (isPlanar) {
+                if (layout == InterpolateLayoutType::planar) {
                     NNPlanar(src_data, dst_data, N, C, ID, IH, IW, OD, OH, OW);
                 } else {
                     NNCGathered(src_data, dst_data, N, C, ID, IH, IW, OD, OH, OW);
@@ -1578,19 +2232,9 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) {
             }
             break;
         }
-        case InterpolateMode::linear: {
-            float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f;
-            float fy = dataScales[dimSize - 2];
-            float fx = dataScales[dimSize - 1];
-
-            bool isDownsample = (fx < 1.f) || (fy < 1.f) || (fz < 1.f);
-            int kernel_width = 2;
-            linearInterpolation(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias);
-            break;
-        }
         case InterpolateMode::linear_onnx: {
             if (interpolateKernel) {
-                if (isPlanar) {
+                if (layout == InterpolateLayoutType::planar) {
                     linearOnnxPlanar(src_data, dst_data, N, C, IH, IW, OH, OW);
                 } else {
                     linearOnnxCGathered(src_data, dst_data, N, C, IH, IW, OH, OW);
@@ -1601,7 +2245,25 @@ void MKLDNNInterpolateNode::execute(mkldnn::stream strm) {
             break;
         }
         case InterpolateMode::cubic: {
-            cubic(src_data, dst_data, N, C, IH, IW, OH, OW, cubeCoeff);
+            if (interpolateKernel) {
+                if (layout == InterpolateLayoutType::planar) {
+                    cubicPlanar(src_data, dst_data, N, C, IH, IW, OH, OW);
+                } else {
+                    cubicCGathered(src_data, dst_data, N, C, IH, IW, OH, OW);
+                }
+            } else {
+                cubicRef(src_data, dst_data, N, C, IH, IW, OH, OW);
+            }
+            break;
+        }
+        case InterpolateMode::linear: {
+            float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f;
+            float fy = dataScales[dimSize - 2];
+            float fx = dataScales[dimSize - 1];
+
+            bool isDownsample = (fx < 1.f) || (fy < 1.f) || (fz < 1.f);
+            int kernel_width = 2;
+            linearInterpolation(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias);
             break;
         }
         default: {
@@ -1634,7 +2296,7 @@ void MKLDNNInterpolateNode::NNCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr
                 const uint8_t *in_ptr_dh = in_ptr + (C * IW * IH * index_d[d] + C * IW * index_h[h]) * srcDataSize;
                 auto arg = jit_interpolate_call_args();
                 arg.dst = out_ptr_dh;
-                arg.src = in_ptr_dh;
+                arg.src_ptr[0] = in_ptr_dh;
                 arg.index = static_cast<int*>(&(index_w_kernel[0]));
                 arg.work_amount = C;
                 arg.oc_off = 0;
@@ -1655,7 +2317,7 @@ void MKLDNNInterpolateNode::NNCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr
                 auto arg = jit_interpolate_call_args();
                 for (int h = 0; h < OH; h++) {  // kernel for blk_size * OW
                     arg.dst = out_ptr_cbd + blk_size * OW * h * dstDataSize;
-                    arg.src = in_ptr_cbd + blk_size * IW * index_h[h] * srcDataSize;
+                    arg.src_ptr[0] = in_ptr_cbd + blk_size * IW * index_h[h] * srcDataSize;
                     arg.index = static_cast<int*>(&(index_w_kernel[0]));
                     arg.work_amount = static_cast<size_t>(OW);
                     arg.oc_off = cb * blk_size;
@@ -1686,7 +2348,7 @@ void MKLDNNInterpolateNode::NNPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_,
         uint8_t *out_ptr = out_ptr_ + (OW * OH * OD * C * b + OW * OH * OD * c + OW * OH * od) * dstDataSize;
 
         auto arg = jit_interpolate_call_args();
-        arg.src = in_ptr;
+        arg.src_ptr[0] = in_ptr;
         arg.dst = out_ptr;
         arg.index = static_cast<int*>(&index_kernel[0]);  // need index_h and index_w in kernel, it's in continous memory so one param
         arg.oc_off = static_cast<size_t>(c);
@@ -1724,9 +2386,9 @@ void MKLDNNInterpolateNode::linearOnnxPlanar(const uint8_t *in_ptr_, uint8_t *ou
         uint8_t *out_ptr_nc = out_ptr_ + (OH * OW * C * b + OH * OW * c) * dstDataSize;
         const uint8_t *in_ptr_nc = in_ptr_ + (IH * IW * C * b + IH * IW * c) * srcDataSize;
         auto arg = jit_interpolate_call_args();
-        arg.src = in_ptr_nc;
+        arg.src_ptr[0] = in_ptr_nc;
         arg.index = static_cast<int*>(&index[0]);
-        arg.weight = static_cast<float*>(&weight[0]);
+        arg.weight_ptr[0] = static_cast<float*>(&weight[0]);
         arg.dst = out_ptr_nc;
         arg.work_amount = OW * OH;
         arg.oc_off = c;
@@ -1750,45 +2412,33 @@ void MKLDNNInterpolateNode::linearOnnxCGathered(const uint8_t *in_ptr_, uint8_t
     Layout layout = getParentEdgeAt(0)->getDesc().getLayout();
     bool isByChannel = (layout == NHWC) ? true : false;
 
-    if (isByChannel) {
-        parallel_for3d(B, OH, OW, [&](size_t b, size_t h, size_t w) {
-            uint8_t *out_ptr_nhw = out_ptr_ + (OH * OW * C * b + OW * C * h + C * w) * dstDataSize;
-            const uint8_t *in_ptr_n = in_ptr_ + (IH * IW * C * b) * srcDataSize;
-            auto arg = jit_interpolate_call_args();
-            arg.src = in_ptr_n + (indexTop[h] * IW * C + indexLeft[w] * C) * srcDataSize;
-            arg.srcTR = in_ptr_n + (indexTop[h] * IW * C + indexRight[w] * C) * srcDataSize;
-            arg.srcBL = in_ptr_n + (indexBottom[h] * IW * C + indexLeft[w] * C) * srcDataSize;
-            arg.srcBR = in_ptr_n + (indexBottom[h] * IW * C + indexRight[w] * C) * srcDataSize;
-            arg.weight = static_cast<float*>(&weightLeft[w]);
-            arg.weightR = static_cast<float*>(&weightRight[w]);
-            arg.weightT = static_cast<float*>(&weightTop[h]);
-            arg.weightB = static_cast<float*>(&weightBottom[h]);
-            arg.dst = out_ptr_nhw;
-            arg.work_amount = C;
-            arg.oc_off = 0;
-            (*interpolateKernel)(&arg);
-        });
-    } else {
-        size_t blkSize = mayiuse(cpu::avx512_common) ? 16 : 8;
-        size_t CB = div_up(C, blkSize);
-        parallel_for3d(B, OH, OW, [&](size_t b, size_t h, size_t w) {
-            uint8_t *out_ptr_nhw = out_ptr_ + (CB * OH * OW * blkSize * b + OW * blkSize * h + blkSize * w) * dstDataSize;
-            const uint8_t *in_ptr_n = in_ptr_ + (CB * IH * IW * blkSize * b) * srcDataSize;
-            auto arg = jit_interpolate_call_args();
-            arg.src = in_ptr_n + (indexTop[h] * IW * blkSize + indexLeft[w] * blkSize) * srcDataSize;
-            arg.srcTR = in_ptr_n + (indexTop[h] * IW * blkSize + indexRight[w] * blkSize) * srcDataSize;
-            arg.srcBL = in_ptr_n + (indexBottom[h] * IW * blkSize + indexLeft[w] * blkSize) * srcDataSize;
-            arg.srcBR = in_ptr_n + (indexBottom[h] * IW * blkSize + indexRight[w] * blkSize) * srcDataSize;
-            arg.weight = static_cast<float*>(&weightLeft[w]);
-            arg.weightR = static_cast<float*>(&weightRight[w]);
-            arg.weightT = static_cast<float*>(&weightTop[h]);
-            arg.weightB = static_cast<float*>(&weightBottom[h]);
+    int blkSize = mayiuse(cpu::avx512_common) ? 16 : 8;
+    int CB = div_up(C, blkSize);
+    int CSize = isByChannel ? C : blkSize * CB;
+    int CGatherLen = isByChannel ? C : blkSize;
+    int workAmount = isByChannel ? C : CB;
+    parallel_for2d(B, OH, [&](size_t b, size_t h) {
+        uint8_t *out_ptr_nh = out_ptr_ + (OH * OW * CSize * b + OW * CGatherLen * h) * dstDataSize;
+        const uint8_t *in_ptr_n = in_ptr_ + (IH * IW * CSize * b) * srcDataSize;
+        const uint8_t *in_ptr_nh_t = in_ptr_n + (indexTop[h] * IW * CGatherLen) * srcDataSize;
+        const uint8_t *in_ptr_nh_b = in_ptr_n + (indexBottom[h] * IW * CGatherLen) * srcDataSize;
+        auto arg = jit_interpolate_call_args();
+        for (int w = 0; w < OW; ++w) {
+            uint8_t *out_ptr_nhw = out_ptr_nh + CGatherLen * w * dstDataSize;
+            arg.src_ptr[0] = in_ptr_nh_t + (indexLeft[w] * CGatherLen) * srcDataSize;
+            arg.src_ptr[1] = in_ptr_nh_t + (indexRight[w] * CGatherLen) * srcDataSize;
+            arg.src_ptr[2] = in_ptr_nh_b + (indexLeft[w] * CGatherLen) * srcDataSize;
+            arg.src_ptr[3] = in_ptr_nh_b + (indexRight[w] * CGatherLen) * srcDataSize;
+            arg.weight_ptr[0] = static_cast<float*>(&weightLeft[w]);
+            arg.weight_ptr[1] = static_cast<float*>(&weightRight[w]);
+            arg.weight_ptr[2] = static_cast<float*>(&weightTop[h]);
+            arg.weight_ptr[3] = static_cast<float*>(&weightBottom[h]);
             arg.dst = out_ptr_nhw;
-            arg.work_amount = CB;
+            arg.work_amount = workAmount;
             arg.oc_off = 0;
             (*interpolateKernel)(&arg);
-        });
-    }
+        }
+    });
 }
 
 void MKLDNNInterpolateNode::linearOnnxRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) {
@@ -1938,7 +2588,90 @@ void MKLDNNInterpolateNode::linearInterpolation(const uint8_t *in_ptr_, uint8_t
     });
 }
 
-void MKLDNNInterpolateNode::cubic(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW, float a) {
+void MKLDNNInterpolateNode::cubicCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) {
+    const int idxNum = 1;
+    int *xOrigin = static_cast<int*>(&indexTable[0]);
+    float *xFactor = reinterpret_cast<float*>(&indexTable[OW]);
+    int *yOrigin = static_cast<int*>(&indexTable[(CUBIC_GRID_LEN + idxNum) * OW]);
+    float *yFactor = reinterpret_cast<float*>(&indexTable[(CUBIC_GRID_LEN + idxNum) * OW + OH]);
+
+    Layout layout = getParentEdgeAt(0)->getDesc().getLayout();
+    bool isByChannel = (layout == NHWC) ? true : false;
+
+    int blkSize = mayiuse(cpu::avx512_common) ? 16 : 8;
+    int CB = div_up(C, blkSize);
+    int CSize = isByChannel ? C : blkSize * CB;
+    int CGatherLen = isByChannel ? C : blkSize;
+    int workAmount = isByChannel ? C : CB;
+
+    parallel_for3d(B, OH, OW, [&](size_t b, size_t h, size_t w) {
+        uint8_t *out_ptr_nhw = out_ptr_ + (OH * OW * CSize * b + OW * CGatherLen * h + CGatherLen * w) * dstDataSize;
+        const uint8_t *in_ptr_n = in_ptr_ + (IH * IW * CSize * b) * srcDataSize;
+
+        std::vector<int> kernelIndex(CUBIC_GRID_LEN * CUBIC_GRID_LEN);  // 16 address offset to src(batch) or src(CB)
+        int iy = yOrigin[h];
+        int ix = xOrigin[w];
+        for (int y = iy - 1, i = 0; y <= iy + 2; y++, i++) {
+            int yInRange = std::max(0, std::min(y, IH - 1));
+            yInRange = yInRange * CGatherLen * IW * srcDataSize;
+            for (int x = ix - 1, j = 0; x <= ix + 2; x++, j++) {
+                int xInRange = std::max(0, std::min(x, IW - 1));
+                xInRange = yInRange + xInRange * CGatherLen * srcDataSize;
+                kernelIndex[i * CUBIC_GRID_LEN + j] = xInRange;
+            }
+        }
+        auto arg = jit_interpolate_call_args();
+            arg.dst = out_ptr_nhw;
+            arg.src_ptr[0] = in_ptr_n;
+            arg.index = static_cast<int*>(&kernelIndex[0]);
+            // 0 for weight_W, 1 for weight_H
+            arg.weight_ptr[0] = static_cast<float*>(&xFactor[w * CUBIC_GRID_LEN]);
+            arg.weight_ptr[1] = static_cast<float*>(&yFactor[h * CUBIC_GRID_LEN]);
+
+            // for by channel, src + step, dst + step, process next step on continuous memory
+            // for blk, src + IW*IH*blkSize, dst + OW*OH*blkSize, process the blkSize on next CB
+            arg.work_amount = workAmount;
+            arg.oc_off = 0;
+            (*interpolateKernel)(&arg);
+    });
+}
+
+void MKLDNNInterpolateNode::cubicPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) {
+    const int idxNum = 1;
+    int tblAdvance = 0;
+    int *xOrigin = static_cast<int*>(&indexTable[tblAdvance]);
+    tblAdvance += OW;
+    float *xFactor = reinterpret_cast<float*>(&indexTable[tblAdvance]);
+    tblAdvance += CUBIC_GRID_LEN * OW;
+    int *yOrigin = static_cast<int*>(&indexTable[tblAdvance]);
+    tblAdvance += OH;
+    float *yFactor = reinterpret_cast<float*>(&indexTable[tblAdvance]);
+
+    tblAdvance += CUBIC_GRID_LEN * OH;
+    int *sequenceOH = static_cast<int*>(&indexTable[tblAdvance]);
+    tblAdvance += OW * OH;
+    int *sequenceOW = static_cast<int*>(&indexTable[tblAdvance]);
+
+    parallel_for2d(B, C, [&](size_t n, size_t c) {
+        const uint8_t *in_ptr_nc = in_ptr_ + (IW * IH * C * n + IW * IH * c) * srcDataSize;
+        uint8_t *out_ptr_nc = out_ptr_ + (OW * OH * C * n + OW * OH * c) * dstDataSize;
+
+        auto arg = jit_interpolate_call_args();
+        arg.dst = out_ptr_nc;
+        arg.src_ptr[0] = in_ptr_nc;
+        arg.index = xOrigin;
+        arg.src_ptr[1] = yOrigin;
+        arg.src_ptr[2] = static_cast<int*>(&sequenceOH[0]);
+        arg.src_ptr[3] = static_cast<int*>(&sequenceOW[0]);
+        arg.weight_ptr[0] = xFactor;
+        arg.weight_ptr[1] = yFactor;
+        arg.work_amount = static_cast<size_t>(OW * OH);
+        arg.oc_off = static_cast<size_t>(C);
+        (*interpolateKernel)(&arg);
+    });
+}
+
+void MKLDNNInterpolateNode::cubicRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) {
     const int idxNum = 1;
     int *xOrigin = static_cast<int*>(&indexTable[0]);
     float *xFactor = reinterpret_cast<float*>(&indexTable[OW]);
@@ -2027,9 +2760,12 @@ void MKLDNNInterpolateNode::setValue(uint8_t *base, size_t offset, float value,
 }
 
 // scale is float(outShape) / float(inShape)
-// strictly consistent with onnx calc manner(div scale, not multiply inverse),
+// strictly consistent with onnx calc manner(div scale, not multiply inverse), given this is done offline
 // the slight precison diff can produce obvious wrong value due to "nearest round" behavior for NN mode
 inline float MKLDNNInterpolateNode::coordTransToInput(int outCoord, float scale, int inShape, int outShape) {
+    if (scale == 1.0f || (inShape == outShape)) {
+        return outCoord;
+    }
     switch (coordTransMode) {
         case InterpolateCoordTransMode::half_pixel: {
             return (outCoord + 0.5f) / scale - 0.5f;
@@ -2123,7 +2859,9 @@ bool MKLDNNInterpolateNode::canFuse(const MKLDNNNodePtr& node) const {
         if (eltwiseNode == nullptr)
             THROW_IE_EXCEPTION << "Cannot get eltwise node " << node->getName();
         return isOneOf(eltwiseNode->getOpType(), {MulAdd, Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp,
-                                                  Tanh, Swish, Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt});
+                                                  Tanh, Swish, Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt}) ||
+                ((eltwiseNode->getOpType() == MulAdd && eltwiseNode->getCnnLayer()->blobs.size() == 2) ||
+                 (eltwiseNode->getOpType() == Prelu));
     }
 
     return false;
index 2e2e1fd..a526eab 100644 (file)
@@ -10,6 +10,8 @@
 #include <memory>
 #include <vector>
 
+#define MAX_INPUT_INTERPOLATE 4
+
 using namespace InferenceEngine;
 
 namespace MKLDNNPlugin {
@@ -54,14 +56,8 @@ struct jit_interpolate_config_params {
 };
 
 struct jit_interpolate_call_args {
-    const void *src;
-    const void *srcTR;
-    const void *srcBL;
-    const void *srcBR;
-    const float *weight;
-    const float *weightR;
-    const float *weightT;
-    const float *weightB;
+    const void *src_ptr[MAX_INPUT_INTERPOLATE];
+    const void *weight_ptr[MAX_INPUT_INTERPOLATE];
     const int *index;
     void *dst;
     size_t work_amount;
@@ -110,18 +106,20 @@ private:
     void linearOnnxCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW);
     void linearOnnxRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW);
 
+    // cubic
+    std::vector<float> getCubicCoeffs(float mantissa, float a);
+    void cubicPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW);
+    void cubicCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW);
+    void cubicRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW);
+
     // linear
     void linearInterpolation(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int ID, int IH, int IW,
                                           float fx, float fy, float fz, int OD, int OH, int OW, int kernel_width, bool antialias);
 
-    // cubic
-    std::vector<float> getCubicCoeffs(float mantissa, float a);
-    void cubic(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW, float a);
-
     void buildTblNN(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector<float>& dataScales, InterpolateLayoutType layout);
     void buildTblLinearOnnx(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector<float>& dataScales, InterpolateLayoutType layout);
-    void buidTblLinear(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector<float>& dataScales, int kernel_width, bool antialias);
-    void buidTblCubic(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector<float>& dataScales, float cubicCoeff);
+    void buildTblLinear(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector<float>& dataScales, int kernel_width, bool antialias);
+    void buildTblCubic(SizeVector& srcDimPad5d, SizeVector& dstDim5d, std::vector<float>& dataScales, float cubicCoeff, InterpolateLayoutType layout);
 
     void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false);
 
diff --git a/inference-engine/src/transformations/include/transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp b/inference-engine/src/transformations/include/transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp
new file mode 100644 (file)
index 0000000..35d802e
--- /dev/null
@@ -0,0 +1,31 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <vector>
+#include <utility>
+#include <memory>
+
+#include <transformations_visibility.hpp>
+
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API ConvertInterpolate1ToInterpolate4;
+
+}  // namespace pass
+}  // namespace ngraph
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief ConvertInterpolate1ToInterpolate4 covert v0:interpolate into v4::Interpolate.
+ */
+class ngraph::pass::ConvertInterpolate1ToInterpolate4: public ngraph::pass::MatcherPass {
+public:
+    NGRAPH_RTTI_DECLARATION;
+    ConvertInterpolate1ToInterpolate4();
+};
\ No newline at end of file
index b3b5801..50b3b45 100644 (file)
@@ -38,6 +38,7 @@
 #include "transformations/op_conversions/convert_space_to_depth.hpp"
 #include "transformations/op_conversions/convert_broadcast_to_tiles.hpp"
 #include "transformations/op_conversions/convert_gelu.hpp"
+#include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp"
 #include "transformations/op_conversions/batch_norm_decomposition.hpp"
 #include "transformations/op_conversions/reduce_l1_decomposition.hpp"
 #include "transformations/op_conversions/reduce_l2_decomposition.hpp"
@@ -110,6 +111,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
     manager.register_pass<ngraph::pass::ConvolutionBackpropDataMultiplyFusion>();
     manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
     manager.register_pass<ngraph::pass::ConstantFolding>();
+    manager.register_pass<ngraph::pass::ConvertInterpolate1ToInterpolate4, false>();
 
     manager.register_pass<ngraph::pass::ConvertPreviousNMSToNMS5>();
 
diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp
new file mode 100644 (file)
index 0000000..52e2fa5
--- /dev/null
@@ -0,0 +1,69 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertInterpolate1ToInterpolate4, "ConvertInterpolate1ToInterpolate4", 0);
+
+ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolate4() {
+    auto interpolate1 = ngraph::pattern::wrap_type<ngraph::opset1::Interpolate>({pattern::any_input(pattern::has_static_rank()), pattern::any_input()});
+    ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
+        auto interpolate1 = std::dynamic_pointer_cast<ngraph::opset1::Interpolate>(m.get_match_root());
+        if (!interpolate1)
+            return false;
+
+        auto interpolate_attrs = interpolate1->get_attrs();
+        auto input_shape_rank = interpolate1->input(0).get_partial_shape().rank().get_length();
+
+        // attrs
+        auto mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode();
+        if (interpolate_attrs.mode == "nearest") {
+            mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::nearest;
+        } else if (interpolate_attrs.mode == "cubic") {
+            mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::cubic;
+        } else if (interpolate_attrs.mode == "linear") {
+            if (input_shape_rank < 5) {
+                mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx;
+            } else if (input_shape_rank == 5) {
+                mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::linear;
+            } else {
+                return false;
+            }
+        } else {
+            return false;
+        }
+        auto nearest_mode_v4 = ngraph::op::v4::Interpolate::NearestMode::floor;
+        auto shape_calculation_mode_v4 = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes;
+        auto coordinate_transformation_mode_v4 = interpolate_attrs.align_corners ? ngraph::op::v4::Interpolate::CoordinateTransformMode::align_corners :
+                                                ngraph::op::v4::Interpolate::CoordinateTransformMode::asymmetric;
+        auto interpolate4_attr = ngraph::op::v4::Interpolate::InterpolateAttrs(mode_v4, shape_calculation_mode_v4,
+            interpolate_attrs.pads_begin, interpolate_attrs.pads_end,
+            coordinate_transformation_mode_v4, nearest_mode_v4, interpolate_attrs.antialias, -0.75);
+
+        // input
+        auto axes = interpolate_attrs.axes.to_vector();
+        auto axes_node = ngraph::opset4::Constant::create(element::i64, {axes.size()}, axes);
+        auto default_scales = std::vector<float>(axes.size(), 1.f);
+        auto scales_node = ngraph::opset4::Constant::create(element::f32, {axes.size()}, default_scales);
+
+        auto interpolate4 = std::make_shared<ngraph::opset4::Interpolate>(interpolate1->input_value(0), interpolate1->input_value(1),
+            scales_node, axes_node, interpolate4_attr);
+
+        interpolate4->set_friendly_name(interpolate1->get_friendly_name());
+        ngraph::copy_runtime_info(interpolate1, interpolate4);
+        ngraph::replace_node(interpolate1, interpolate4);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(interpolate1, "ConvertInterpolate1ToInterpolate4");
+    this->register_matcher(m, callback);
+}
\ No newline at end of file
index 56f5f72..6b71f2a 100644 (file)
@@ -21,6 +21,7 @@
 
 using namespace testing;
 using namespace ngraph::pass;
+using namespace ngraph;
 using namespace ngraph::builder::subgraph;
 
 class interpAttributes {
@@ -32,6 +33,8 @@ public:
     std::vector<size_t> pads_begin;
     std::vector<size_t> pads_end;
 
+    interpAttributes() = default;
+
     interpAttributes(const ngraph::AxisSet& axes,
         const std::string& mode,
         const bool& align_corners,
@@ -42,6 +45,23 @@ public:
         antialias(antialias), pads_begin(pads_begin), pads_end(pads_end) {}
 };
 
+class interp4Attributes {
+public:
+    op::v4::Interpolate::InterpolateMode mode;
+    op::v4::Interpolate::CoordinateTransformMode coordinate_transformation_mode;
+    std::vector<size_t> pads_begin;
+    std::vector<size_t> pads_end;
+
+    interp4Attributes() = default;
+
+    interp4Attributes(const op::v4::Interpolate::InterpolateMode mode,
+        const op::v4::Interpolate::CoordinateTransformMode coordinate_transformation_mode,
+        const std::vector<size_t>& pads_begin,
+        const std::vector<size_t>& pads_end) :
+        mode(mode), coordinate_transformation_mode(coordinate_transformation_mode),
+        pads_begin(pads_begin), pads_end(pads_end) {}
+};
+
 class InterpolateTransformationTestValues {
 public:
     class Actual {
@@ -60,9 +80,11 @@ public:
 
     ngraph::Shape inputShape;
     ngraph::Shape outputShape;
+    ngraph::Shape scalesShape;
     ngraph::pass::low_precision::LayerTransformation::Params params;
-    //ngraph::op::InterpolateAttrs interpAttrs;
     interpAttributes interpAttrs;
+    interp4Attributes interp4Attrs;
+    int opset_version;
     Actual actual;
     Expected expected;
 };
@@ -85,60 +107,107 @@ public:
     void SetUp() override {
         const InterpolateTransformationTestValues testValues = GetParam();
 
-        ngraph::op::InterpolateAttrs interpAttrs;
-        interpAttrs.axes = testValues.interpAttrs.axes;
-        interpAttrs.mode = testValues.interpAttrs.mode;
-        interpAttrs.align_corners = testValues.interpAttrs.align_corners;
-        interpAttrs.antialias = testValues.interpAttrs.antialias;
-        interpAttrs.pads_begin = testValues.interpAttrs.pads_begin;
-        interpAttrs.pads_end = testValues.interpAttrs.pads_end;
-
-        actualFunction = ngraph::builder::subgraph::InterpolateFunction::getOriginal(
-            testValues.inputShape,
-            testValues.outputShape,
-            interpAttrs,
-            testValues.actual.precisionBeforeDequantization,
-            testValues.actual.dequantization);
-
-        SimpleLowPrecisionTransformer transformer;
-        transformer.add<ngraph::pass::low_precision::InterpolateTransformation, ngraph::opset1::Interpolate>(testValues.params);
-        transformer.transform(actualFunction);
-
-        referenceFunction = ngraph::builder::subgraph::InterpolateFunction::getReference(
-            testValues.inputShape,
-            testValues.outputShape,
-            interpAttrs,
-            testValues.expected.precisionBeforeDequantization,
-            testValues.expected.dequantizationBefore,
-            testValues.expected.precisionAfterOperation,
-            testValues.expected.dequantizationAfter);
+        if (testValues.opset_version == 1) {
+            ngraph::op::InterpolateAttrs interpAttrs;
+            interpAttrs.axes = testValues.interpAttrs.axes;
+            interpAttrs.mode = testValues.interpAttrs.mode;
+            interpAttrs.align_corners = testValues.interpAttrs.align_corners;
+            interpAttrs.antialias = testValues.interpAttrs.antialias;
+            interpAttrs.pads_begin = testValues.interpAttrs.pads_begin;
+            interpAttrs.pads_end = testValues.interpAttrs.pads_end;
+
+            actualFunction = ngraph::builder::subgraph::InterpolateFunction::getOriginal(
+                testValues.inputShape,
+                testValues.outputShape,
+                interpAttrs,
+                testValues.actual.precisionBeforeDequantization,
+                testValues.actual.dequantization);
+
+            SimpleLowPrecisionTransformer transformer;
+            transformer.add<ngraph::pass::low_precision::InterpolateTransformation, ngraph::opset1::Interpolate>(testValues.params);
+            transformer.transform(actualFunction);
+
+            referenceFunction = ngraph::builder::subgraph::InterpolateFunction::getReference(
+                testValues.inputShape,
+                testValues.outputShape,
+                interpAttrs,
+                testValues.expected.precisionBeforeDequantization,
+                testValues.expected.dequantizationBefore,
+                testValues.expected.precisionAfterOperation,
+                testValues.expected.dequantizationAfter);
+        } else if (testValues.opset_version == 4) {
+            ngraph::op::v4::Interpolate::InterpolateAttrs interp4Attrs;
+            interp4Attrs.mode = testValues.interp4Attrs.mode;
+            interp4Attrs.coordinate_transformation_mode = testValues.interp4Attrs.coordinate_transformation_mode;
+            interp4Attrs.pads_begin = testValues.interp4Attrs.pads_begin;
+            interp4Attrs.pads_end = testValues.interp4Attrs.pads_end;
+
+            actualFunction = ngraph::builder::subgraph::InterpolateFunction::getOriginal(
+                testValues.inputShape,
+                testValues.outputShape,
+                testValues.scalesShape,
+                interp4Attrs,
+                testValues.actual.precisionBeforeDequantization,
+                testValues.actual.dequantization);
+
+            SimpleLowPrecisionTransformer transformer;
+            transformer.add<ngraph::pass::low_precision::InterpolateTransformation, ngraph::opset4::Interpolate>(testValues.params);
+            transformer.transform(actualFunction);
+
+            referenceFunction = ngraph::builder::subgraph::InterpolateFunction::getReference(
+                testValues.inputShape,
+                testValues.outputShape,
+                testValues.scalesShape,
+                interp4Attrs,
+                testValues.expected.precisionBeforeDequantization,
+                testValues.expected.dequantizationBefore,
+                testValues.expected.precisionAfterOperation,
+                testValues.expected.dequantizationAfter);
+        }
     }
 
     static std::string getTestCaseName(testing::TestParamInfo<InterpolateTransformationTestValues> obj) {
         const InterpolateTransformationTestValues testValues = obj.param;
 
         std::ostringstream result;
-        result <<
+        if (testValues.opset_version == 1) {
+            result <<
             testValues.inputShape << "_" <<
             testValues.outputShape << "_" <<
-            testValues.interpAttrs.align_corners <<
-            testValues.interpAttrs.antialias <<
-            testValues.interpAttrs.axes <<
-            testValues.interpAttrs.mode <<
-            testValues.interpAttrs.pads_begin <<
-            testValues.interpAttrs.pads_end <<
+            testValues.opset_version << "_" <<
+            testValues.interpAttrs.align_corners << "_" <<
+            testValues.interpAttrs.antialias << "_" <<
+            testValues.interpAttrs.axes << "_" <<
+            testValues.interpAttrs.mode << "_" <<
+            testValues.interpAttrs.pads_begin << "_" <<
+            testValues.interpAttrs.pads_end << "_" <<
             testValues.actual.precisionBeforeDequantization << "_" <<
             testValues.actual.dequantization << "_" <<
             testValues.expected.dequantizationBefore;
+        } else if (testValues.opset_version == 4) {
+            result <<
+            testValues.inputShape << "_" <<
+            testValues.outputShape << "_" <<
+            testValues.opset_version << "_" <<
+            testValues.interp4Attrs.mode << "_" <<
+            testValues.interp4Attrs.coordinate_transformation_mode << "_" <<
+            testValues.interp4Attrs.pads_begin << "_" <<
+            testValues.interp4Attrs.pads_end << "_" <<
+            testValues.actual.precisionBeforeDequantization << "_" <<
+            testValues.actual.dequantization << "_" <<
+            testValues.expected.dequantizationBefore;
+        }
         return result.str();
     }
 };
 
 const std::vector<InterpolateTransformationTestValues> testValues {
+    // opset1
     // nearest mode - move dequantization
     {
         ngraph::Shape{ 1, 4, 16, 16 },
         ngraph::Shape{ 1, 4, 32, 32 },
+        ngraph::Shape{},
         LayerTransformation::createParamsU8I8(),
         interpAttributes(
             ngraph::AxisSet{2, 3},
@@ -147,6 +216,8 @@ const std::vector<InterpolateTransformationTestValues> testValues {
             false,
             {0},
             {0}),
+        interp4Attributes(),
+        1,
         {
             ngraph::element::u8,
             {{ngraph::element::f32}, {-0.32f}, {0.1f}}
@@ -163,6 +234,7 @@ const std::vector<InterpolateTransformationTestValues> testValues {
     {
         ngraph::Shape{ 1, 4, 16, 16 },
         ngraph::Shape{ 1, 4, 32, 32 },
+        ngraph::Shape{},
         LayerTransformation::createParamsU8I8(),
         interpAttributes(
             ngraph::AxisSet{2, 3},
@@ -171,6 +243,8 @@ const std::vector<InterpolateTransformationTestValues> testValues {
             false,
             {0},
             {0}),
+        interp4Attributes(),
+        1,
         {
             ngraph::element::u8,
             {{ngraph::element::f32}, {-0.32f}, {0.1f}}
@@ -187,6 +261,7 @@ const std::vector<InterpolateTransformationTestValues> testValues {
     {
         ngraph::Shape{ 1, 4, 16, 16 },
         ngraph::Shape{ 1, 8, 32, 32 },
+        ngraph::Shape{},
         LayerTransformation::createParamsU8I8(),
         interpAttributes(
             ngraph::AxisSet{1, 2, 3},
@@ -195,6 +270,8 @@ const std::vector<InterpolateTransformationTestValues> testValues {
             false,
             {0},
             {0}),
+        interp4Attributes(),
+        1,
         {
             ngraph::element::u8,
             {{ngraph::element::f32}, {-0.32f}, {0.1f}}
@@ -211,6 +288,7 @@ const std::vector<InterpolateTransformationTestValues> testValues {
     {
         ngraph::Shape{ 1, 4, 16, 16 },
         ngraph::Shape{ 1, 4, 32, 32 },
+        ngraph::Shape{},
         LayerTransformation::createParamsU8I8(),
         interpAttributes(
             ngraph::AxisSet{2, 3},
@@ -219,6 +297,8 @@ const std::vector<InterpolateTransformationTestValues> testValues {
             false,
             {0},
             {0}),
+        interp4Attributes(),
+        1,
         {
             ngraph::element::u8,
             {{ngraph::element::f32}, {-0.32f}, {0.1f}}
@@ -235,6 +315,7 @@ const std::vector<InterpolateTransformationTestValues> testValues {
     {
         ngraph::Shape{ 1, 4, 16, 16 },
         ngraph::Shape{ 1, 4, 32, 32 },
+        ngraph::Shape{},
         LayerTransformation::createParamsU8I8(),
         interpAttributes(
             ngraph::AxisSet{2, 3},
@@ -243,6 +324,8 @@ const std::vector<InterpolateTransformationTestValues> testValues {
             false,
             {1},
             {1}),
+        interp4Attributes(),
+        1,
         {
             ngraph::element::u8,
             {{ngraph::element::f32}, {-0.32f}, {0.1f}}
@@ -253,7 +336,108 @@ const std::vector<InterpolateTransformationTestValues> testValues {
             ngraph::element::u8,
             {{}, {}, {}}
         }
-    }
+    },
+
+    // v4::Interpolate
+    // nearest mode - move dequantization
+    {
+        ngraph::Shape{ 1, 4, 16, 16 },
+        ngraph::Shape{ 1, 4, 32, 32 },
+        ngraph::Shape{ 1, 1, 2, 2 },
+        LayerTransformation::createParamsU8I8(),
+        interpAttributes(),
+        interp4Attributes(
+            ngraph::op::v4::Interpolate::InterpolateMode::nearest,
+            ngraph::op::v4::Interpolate::CoordinateTransformMode::half_pixel,
+            {0, 0, 0, 0},
+            {0, 0, 0, 0}),
+        4,
+        {
+            ngraph::element::i8,
+            {{ngraph::element::f32}, {-0.32f}, {0.1f}}
+        },
+        {
+            ngraph::element::i8,
+            {{}, {}, {}},
+            ngraph::element::i8,
+            {{ngraph::element::f32}, {-0.32f}, {0.1f}}
+        }
+    },
+
+    // mode is not nearest - not transformed
+    {
+        ngraph::Shape{ 1, 4, 16, 16 },
+        ngraph::Shape{ 1, 4, 32, 32 },
+        ngraph::Shape{ 1, 1, 2, 2 },
+        LayerTransformation::createParamsU8I8(),
+        interpAttributes(),
+        interp4Attributes(
+            ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx,
+            ngraph::op::v4::Interpolate::CoordinateTransformMode::half_pixel,
+            {0, 0, 0, 0},
+            {0, 0, 0, 0}),
+        4,
+        {
+            ngraph::element::i8,
+            {{ngraph::element::f32}, {-0.32f}, {0.1f}}
+        },
+        {
+            ngraph::element::i8,
+            {{ngraph::element::f32}, {-0.32f}, {0.1f}},
+            ngraph::element::i8,
+            {{}, {}, {}}
+        }
+    },
+
+    // align_corners set to true - not transformed
+    {
+        ngraph::Shape{ 1, 4, 16, 16 },
+        ngraph::Shape{ 1, 4, 32, 32 },
+        ngraph::Shape{ 1, 1, 2, 2 },
+        LayerTransformation::createParamsU8I8(),
+        interpAttributes(),
+        interp4Attributes(
+            ngraph::op::v4::Interpolate::InterpolateMode::nearest,
+            ngraph::op::v4::Interpolate::CoordinateTransformMode::align_corners,
+            {0, 0, 0, 0},
+            {0, 0, 0, 0}),
+        4,
+        {
+            ngraph::element::i8,
+            {{ngraph::element::f32}, {-0.32f}, {0.1f}}
+        },
+        {
+            ngraph::element::i8,
+            {{ngraph::element::f32}, {-0.32f}, {0.1f}},
+            ngraph::element::i8,
+            {{}, {}, {}}
+        }
+    },
+
+    // have pads - not transformed
+    {
+        ngraph::Shape{ 1, 4, 16, 16 },
+        ngraph::Shape{ 1, 4, 32, 32 },
+        ngraph::Shape{ 1, 1, 2, 2 },
+        LayerTransformation::createParamsU8I8(),
+        interpAttributes(),
+        interp4Attributes(
+            ngraph::op::v4::Interpolate::InterpolateMode::nearest,
+            ngraph::op::v4::Interpolate::CoordinateTransformMode::half_pixel,
+            {0, 0, 0, 1},
+            {0, 0, 1, 0}),
+        4,
+        {
+            ngraph::element::i8,
+            {{ngraph::element::f32}, {-0.32f}, {0.1f}}
+        },
+        {
+            ngraph::element::i8,
+            {{ngraph::element::f32}, {-0.32f}, {0.1f}},
+            ngraph::element::i8,
+            {{}, {}, {}}
+        }
+    },
 };
 
 TEST_P(InterpolateTransformation, CompareFunctions) {
diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp
new file mode 100644 (file)
index 0000000..1ecd6ca
--- /dev/null
@@ -0,0 +1,112 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <queue>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph/opsets/opset4.hpp>
+#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+#include <ngraph/pass/manager.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+using namespace ngraph;
+
+TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_node = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 4, 30, 30});
+        auto out_shape_node = opset1::Constant::create(element::i32, Shape{4}, {2, 4, 40, 40});
+
+        auto interpolate1_attr = op::v0::InterpolateAttrs();
+        interpolate1_attr.axes = AxisSet(std::vector<size_t>{0, 1, 2, 3});
+        interpolate1_attr.mode = "nearest";
+        interpolate1_attr.align_corners = false;
+        interpolate1_attr.antialias = false;
+        interpolate1_attr.pads_begin = std::vector<size_t>{0, 0, 0, 0};
+        interpolate1_attr.pads_end = std::vector<size_t>{0, 0, 0, 0};
+
+        auto interpolate1 = std::make_shared<opset1::Interpolate>(data_node, out_shape_node, interpolate1_attr);
+
+        f = std::make_shared<Function>(NodeVector{interpolate1}, ParameterVector{data_node});
+
+        pass::Manager m;
+        m.register_pass<pass::InitNodeInfo>();
+        m.register_pass<pass::ConvertInterpolate1ToInterpolate4>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto data_node = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 4, 30, 30});
+        auto out_shape_node = opset1::Constant::create(element::i32, Shape{4}, {2, 4, 40, 40});
+        auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{4}, {1.f, 1.f, 1.f, 1.f});
+        auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{4}, {0, 1, 2, 3});
+
+        auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest,
+            opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
+            opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::floor,
+            false, -0.75);
+
+        auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);
+
+        f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertInterpolate1ToInterpolate4_1) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data_node = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 4, 30, 30});
+        auto out_shape_node = opset1::Constant::create(element::i32, Shape{2}, {40, 40});
+
+        auto interpolate1_attr = op::v0::InterpolateAttrs();
+        interpolate1_attr.axes = AxisSet(std::vector<size_t>{2, 3});
+        interpolate1_attr.mode = "linear";
+        interpolate1_attr.align_corners = false;
+        interpolate1_attr.antialias = true;
+        interpolate1_attr.pads_begin = std::vector<size_t>{0, 0, 0, 0};
+        interpolate1_attr.pads_end = std::vector<size_t>{0, 0, 0, 0};
+
+        auto interpolate1 = std::make_shared<opset1::Interpolate>(data_node, out_shape_node, interpolate1_attr);
+
+        f = std::make_shared<Function>(NodeVector{interpolate1}, ParameterVector{data_node});
+
+        pass::Manager m;
+        m.register_pass<pass::InitNodeInfo>();
+        m.register_pass<pass::ConvertInterpolate1ToInterpolate4>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto data_node = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 4, 30, 30});
+        auto out_shape_node = opset1::Constant::create(element::i32, Shape{2}, {40, 40});
+        auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {1.f, 1.f});
+        auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3});
+
+        auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear_onnx,
+            opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
+            opset4::Interpolate::CoordinateTransformMode::align_corners, opset4::Interpolate::NearestMode::floor,
+            false, -0.75);
+
+        auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);
+
+        f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
index 75b79f1..a2fceb3 100644 (file)
@@ -17,11 +17,11 @@ const std::vector<InferenceEngine::Precision> netPrecisions = {
 };
 
 const std::vector<std::vector<size_t>> inShapes = {
-        {1, 1, 30, 30},
+        {1, 4, 30, 30},
 };
 
 const std::vector<std::vector<size_t>> targetShapes = {
-        {1, 1, 40, 40},
+        {1, 4, 40, 40},
 };
 
 const  std::vector<ngraph::op::v4::Interpolate::InterpolateMode> modesWithoutNearest = {
@@ -130,4 +130,60 @@ INSTANTIATE_TEST_CASE_P(smoke_Interpolate_Nearest, InterpolateLayerTest, ::testi
         ::testing::Values(CommonTestUtils::DEVICE_CPU)),
     InterpolateLayerTest::getTestCaseName);
 
+const std::vector<std::vector<size_t>> targetShapesTailTest = {
+        {1, 4, 10, 41},  // 10 * 41 is not multipler of 4, cover tail process code path
+};
+
+const std::vector<std::vector<float>> defaultScalesTailTest = {
+    {0.33333f, 1.36666f}
+};
+
+const auto interpolateCasesWithoutNearestTail = ::testing::Combine(
+        ::testing::ValuesIn(modesWithoutNearest),
+        ::testing::ValuesIn(shapeCalculationMode),
+        ::testing::ValuesIn(coordinateTransformModes),
+        ::testing::ValuesIn(defaultNearestMode),
+        ::testing::ValuesIn(antialias),
+        ::testing::ValuesIn(pads),
+        ::testing::ValuesIn(pads),
+        ::testing::ValuesIn(cubeCoefs),
+        ::testing::ValuesIn(defaultAxes),
+        ::testing::ValuesIn(defaultScalesTailTest));
+
+const auto interpolateCasesTail = ::testing::Combine(
+        ::testing::ValuesIn(nearestMode),
+        ::testing::ValuesIn(shapeCalculationMode),
+        ::testing::ValuesIn(coordinateTransformModes),
+        ::testing::ValuesIn(nearestModes),
+        ::testing::ValuesIn(antialias),
+        ::testing::ValuesIn(pads),
+        ::testing::ValuesIn(pads),
+        ::testing::ValuesIn(cubeCoefs),
+        ::testing::ValuesIn(defaultAxes),
+        ::testing::ValuesIn(defaultScalesTailTest));
+
+INSTANTIATE_TEST_CASE_P(smoke_Interpolate_Basic_2, InterpolateLayerTest, ::testing::Combine(
+        interpolateCasesWithoutNearestTail,
+        ::testing::ValuesIn(netPrecisions),
+        ::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
+        ::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
+        ::testing::Values(InferenceEngine::Layout::ANY),
+        ::testing::Values(InferenceEngine::Layout::ANY),
+        ::testing::ValuesIn(inShapes),
+        ::testing::ValuesIn(targetShapesTailTest),
+        ::testing::Values(CommonTestUtils::DEVICE_CPU)),
+    InterpolateLayerTest::getTestCaseName);
+
+INSTANTIATE_TEST_CASE_P(smoke_Interpolate_Nearest_2, InterpolateLayerTest, ::testing::Combine(
+        interpolateCasesTail,
+        ::testing::ValuesIn(netPrecisions),
+        ::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
+        ::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
+        ::testing::Values(InferenceEngine::Layout::ANY),
+        ::testing::Values(InferenceEngine::Layout::ANY),
+        ::testing::ValuesIn(inShapes),
+        ::testing::ValuesIn(targetShapesTailTest),
+        ::testing::Values(CommonTestUtils::DEVICE_CPU)),
+    InterpolateLayerTest::getTestCaseName);
+
 } // namespace
index 16ac595..3c3891c 100644 (file)
@@ -149,6 +149,7 @@ const std::vector<ngraph::op::v4::Interpolate::NearestMode> defNearestModes = {
 
 const std::vector<std::vector<size_t>> pads = {
         {0, 0, 0, 0},
+        {0, 0, 1, 1},
 };
 
 const std::vector<bool> antialias = {
@@ -191,6 +192,18 @@ const auto interpolateCasesLinearOnnx = ::testing::Combine(
         ::testing::ValuesIn(defaultAxes),
         ::testing::ValuesIn(defaultScales));
 
+const auto interpolateCasesCubic = ::testing::Combine(
+        ::testing::Values(ngraph::op::v4::Interpolate::InterpolateMode::cubic),
+        ::testing::ValuesIn(shapeCalculationMode),
+        ::testing::ValuesIn(coordinateTransformModes),
+        ::testing::ValuesIn(defNearestModes),
+        ::testing::ValuesIn(antialias),
+        ::testing::ValuesIn(pads),
+        ::testing::ValuesIn(pads),
+        ::testing::ValuesIn(cubeCoefs),
+        ::testing::ValuesIn(defaultAxes),
+        ::testing::ValuesIn(defaultScales));
+
 INSTANTIATE_TEST_CASE_P(smoke_InterpolateNN_Layout_Test, InterpolateLayerCPUTest,
         ::testing::Combine(
             ::testing::Combine(
@@ -200,8 +213,8 @@ INSTANTIATE_TEST_CASE_P(smoke_InterpolateNN_Layout_Test, InterpolateLayerCPUTest
                 ::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
                 ::testing::Values(InferenceEngine::Layout::ANY),
                 ::testing::Values(InferenceEngine::Layout::ANY),
-                ::testing::Values(std::vector<size_t>({1, 1, 40, 40})),
-                ::testing::Values(std::vector<size_t>({1, 1, 50, 60})),
+                ::testing::Values(std::vector<size_t>({1, 21, 40, 40})),
+                ::testing::Values(std::vector<size_t>({1, 21, 50, 60})),
                 ::testing::Values(CommonTestUtils::DEVICE_CPU)),
             ::testing::ValuesIn(filterCPUInfoForDevice())),
     InterpolateLayerCPUTest::getTestCaseName);
@@ -215,8 +228,23 @@ INSTANTIATE_TEST_CASE_P(smoke_InterpolateLinearOnnx_Layout_Test, InterpolateLaye
                 ::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
                 ::testing::Values(InferenceEngine::Layout::ANY),
                 ::testing::Values(InferenceEngine::Layout::ANY),
-                ::testing::Values(std::vector<size_t>({1, 1, 40, 40})),
-                ::testing::Values(std::vector<size_t>({1, 1, 50, 60})),
+                ::testing::Values(std::vector<size_t>({1, 21, 40, 40})),
+                ::testing::Values(std::vector<size_t>({1, 21, 50, 60})),
+                ::testing::Values(CommonTestUtils::DEVICE_CPU)),
+            ::testing::ValuesIn(filterCPUInfoForDevice())),
+    InterpolateLayerCPUTest::getTestCaseName);
+
+INSTANTIATE_TEST_CASE_P(smoke_InterpolateCubic_Layout_Test, InterpolateLayerCPUTest,
+        ::testing::Combine(
+            ::testing::Combine(
+                interpolateCasesCubic,
+                ::testing::ValuesIn(netPrecisions),
+                ::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
+                ::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
+                ::testing::Values(InferenceEngine::Layout::ANY),
+                ::testing::Values(InferenceEngine::Layout::ANY),
+                ::testing::Values(std::vector<size_t>({1, 21, 40, 40})),
+                ::testing::Values(std::vector<size_t>({1, 21, 50, 60})),
                 ::testing::Values(CommonTestUtils::DEVICE_CPU)),
             ::testing::ValuesIn(filterCPUInfoForDevice())),
     InterpolateLayerCPUTest::getTestCaseName);
index d3aa6cc..6b45a18 100644 (file)
@@ -35,6 +35,32 @@ public:
         const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
         const ngraph::element::Type precisionAfterOperation,
         const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter);
+
+    // v4::Interpolate
+    static std::shared_ptr<ngraph::Function> getOriginal(
+        const ngraph::Shape& inputShape,
+        const ngraph::Shape& outputShape,
+        const ngraph::Shape& scalesShape,
+        const ngraph::op::v4::Interpolate::InterpolateAttrs& interp4Attrs,
+        const ngraph::element::Type precisionBeforeDequantization,
+        const ngraph::builder::subgraph::DequantizationOperations& dequantization);
+
+    static std::shared_ptr<ngraph::Function> getOriginal(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::Shape& outputShape,
+        const ngraph::Shape& scalesShape,
+        const ngraph::op::v4::Interpolate::InterpolateAttrs& interp4Attrs);
+
+    static std::shared_ptr<ngraph::Function> getReference(
+        const ngraph::Shape& inputShape,
+        const ngraph::Shape& outputShape,
+        const ngraph::Shape& scalesShape,
+        const ngraph::op::v4::Interpolate::InterpolateAttrs& interp4Attrs,
+        const ngraph::element::Type precisionBeforeDequantization,
+        const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
+        const ngraph::element::Type precisionAfterOperation,
+        const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter);
 };
 
 }  // namespace subgraph
index a71bb90..44cccef 100644 (file)
@@ -70,6 +70,72 @@ std::shared_ptr<ngraph::Function> InterpolateFunction::getReference(
     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "InterpolateFunction");
 }
 
+// v4:interpolate
+std::shared_ptr<ngraph::Function> InterpolateFunction::getOriginal(
+    const ngraph::Shape& inputShape,
+    const ngraph::Shape& outputShape,
+    const ngraph::Shape& scalesShape,
+    const ngraph::op::v4::Interpolate::InterpolateAttrs& interpAttrs,
+    const ngraph::element::Type precisionBeforeDequantization,
+    const ngraph::builder::subgraph::DequantizationOperations& dequantization) {
+    const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
+        precisionBeforeDequantization,
+        ngraph::Shape(inputShape));
+
+    const auto dequantizationOp = makeDequantization(input, dequantization);
+    const auto outShape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ outputShape.size() }, outputShape);
+    const auto scales = std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32, ngraph::Shape{ scalesShape.size() }, scalesShape);
+    const auto interpolate = std::make_shared<ngraph::op::v4::Interpolate>(dequantizationOp, outShape, scales, interpAttrs);
+    interpolate->set_friendly_name("output");
+
+    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(interpolate) };
+    return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "InterpolateFunction");
+}
+
+std::shared_ptr<ngraph::Function> InterpolateFunction::getOriginal(
+    const ngraph::element::Type precision,
+    const ngraph::Shape& inputShape,
+    const ngraph::Shape& outputShape,
+    const ngraph::Shape& scalesShape,
+    const ngraph::op::v4::Interpolate::InterpolateAttrs& interpAttrs) {
+    float k = 50.f;
+
+    const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
+    const auto fakeQuantizeOnActivations = ngraph::builder::makeFakeQuantize(
+        input, precision, 256ul, { 1ul },
+        { 0.f }, { 255.f / k }, { 10.f }, { 255.f / k });
+    const auto outShape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ outputShape.size() }, outputShape);
+    const auto scales = std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32, ngraph::Shape{ scalesShape.size() }, scalesShape);
+    const auto interpolate = std::make_shared<ngraph::op::v4::Interpolate>(fakeQuantizeOnActivations, outShape, scales, interpAttrs);
+
+    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(interpolate) };
+    return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "InterpolateFunction");
+}
+
+std::shared_ptr<ngraph::Function> InterpolateFunction::getReference(
+    const ngraph::Shape& inputShape,
+    const ngraph::Shape& outputShape,
+    const ngraph::Shape& scalesShape,
+    const ngraph::op::v4::Interpolate::InterpolateAttrs& interpAttrs,
+    const ngraph::element::Type precisionBeforeDequantization,
+    const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
+    const ngraph::element::Type precisionAfterOperation,
+    const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
+    const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
+        precisionBeforeDequantization,
+        ngraph::Shape(inputShape));
+
+    const std::shared_ptr<Node> quantizationOpBefore = makeDequantization(input, dequantizationBefore);
+    const auto outShape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ outputShape.size() }, outputShape);
+    const auto scales = std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32, ngraph::Shape{ scalesShape.size() }, scalesShape);
+    const auto interpolate = std::make_shared<ngraph::op::v4::Interpolate>(quantizationOpBefore, outShape, scales, interpAttrs);
+    const std::shared_ptr<Node> quantizationOpAfter = makeDequantization(interpolate, dequantizationAfter);
+    quantizationOpAfter->set_friendly_name("output");
+
+    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(quantizationOpAfter) };
+    return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "InterpolateFunction");
+}
+
 }  // namespace subgraph
 }  // namespace builder
 }  // namespace ngraph
index 703ffea..738b3fb 100644 (file)
@@ -152,6 +152,10 @@ namespace ngraph
                                  float length_resized,
                                  float length_original) const
                 {
+                    if (x_scale == 1.0f || (length_resized == length_original))
+                    {
+                        return x_resized;
+                    }
                     return m_func(x_resized, x_scale, length_resized, length_original);
                 }