ExtractImagePatches MO implementation and nGraph transformation (#739)
authorAnton Chetverikov <Anton.Chetverikov@intel.com>
Wed, 19 Aug 2020 10:23:41 +0000 (13:23 +0300)
committerGitHub <noreply@github.com>
Wed, 19 Aug 2020 10:23:41 +0000 (13:23 +0300)
inference-engine/src/cldnn_engine/cldnn_engine.cpp
inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
inference-engine/src/transformations/include/transformations/convert_extract_image_patches_to_reorg_yolo.hpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/convert_extract_image_patches_to_reorg_yolo.cpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_extract_image_patches_to_reorg_yolo_test.cpp [new file with mode: 0644]
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/front/tf/extract_image_patches_ext.py
model-optimizer/extensions/ops/ExtractImagePatches.py [new file with mode: 0644]
model-optimizer/extensions/ops/ExtractImagePatches_test.py [new file with mode: 0644]
model-optimizer/mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py [new file with mode: 0644]

index 1d3dcf5..ecde611 100644 (file)
@@ -91,6 +91,7 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const In
                    std::dynamic_pointer_cast<const ::ngraph::opset3::ShuffleChannels>(node) ||
                    std::dynamic_pointer_cast<const ::ngraph::opset2::BatchToSpace>(node) ||
                    std::dynamic_pointer_cast<const ::ngraph::opset2::SpaceToBatch>(node) ||
+                   std::dynamic_pointer_cast<const ::ngraph::opset3::ExtractImagePatches>(node) ||
                    std::dynamic_pointer_cast<const ::ngraph::opset4::HSwish>(node) ||
                    std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL1>(node) ||
                    std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL2>(node);
index dea6792..707058d 100644 (file)
@@ -80,6 +80,7 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
         return std::dynamic_pointer_cast<const ngraph::opset2::Gelu>(node) ||
                std::dynamic_pointer_cast<const ngraph::opset2::BatchToSpace>(node) ||
                std::dynamic_pointer_cast<const ngraph::opset2::SpaceToBatch>(node) ||
+               std::dynamic_pointer_cast<const ngraph::opset3::ExtractImagePatches>(node) ||
                std::dynamic_pointer_cast<const ngraph::opset4::ReduceL1>(node) ||
                std::dynamic_pointer_cast<const ngraph::opset4::ReduceL2>(node) ||
                std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
diff --git a/inference-engine/src/transformations/include/transformations/convert_extract_image_patches_to_reorg_yolo.hpp b/inference-engine/src/transformations/include/transformations/convert_extract_image_patches_to_reorg_yolo.hpp
new file mode 100644 (file)
index 0000000..5a857db
--- /dev/null
@@ -0,0 +1,28 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <vector>
+#include <memory>
+
+#include <transformations_visibility.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API ConvertExtractImagePatchesToReorgYolo;
+
+}  // namespace pass
+}  // namespace ngraph
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief ConvertExtractImagePatchesToReorgYolo transformation replaces ExtractImagePatches with a ReorgYolo op.
+ */
+class ngraph::pass::ConvertExtractImagePatchesToReorgYolo : public ngraph::pass::MatcherPass {
+public:
+    ConvertExtractImagePatchesToReorgYolo();
+};
diff --git a/inference-engine/src/transformations/src/transformations/convert_extract_image_patches_to_reorg_yolo.cpp b/inference-engine/src/transformations/src/transformations/convert_extract_image_patches_to_reorg_yolo.cpp
new file mode 100644 (file)
index 0000000..4b5f112
--- /dev/null
@@ -0,0 +1,87 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/convert_extract_image_patches_to_reorg_yolo.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+ngraph::pass::ConvertExtractImagePatchesToReorgYolo::ConvertExtractImagePatchesToReorgYolo() {
+    auto image = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{1, 1, 1, 1});
+    auto eip = std::make_shared<ngraph::opset3::ExtractImagePatches>(image, ngraph::Shape{1, 1}, ngraph::Strides{1, 1}, ngraph::Shape{1, 1},
+            ngraph::op::PadType::VALID);
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+        auto &pattern_to_output = m.get_pattern_value_map();
+        auto extract_image_patches =  std::dynamic_pointer_cast<ngraph::opset3::ExtractImagePatches>(m.get_match_root());
+
+        /*
+         * In this transformation we raplace ExtractImagePatches operation to ReorgYolo operation
+         * if ExtractImagePatches operation attributes obey the following conditions:
+         *
+         * EIP.sizes = EIP.strides
+         * EIP.rates = {1, 1}
+         * EIP.PadType = VALID
+         * Spatial dimensions of input tensor must be divisible by EIP.strides
+         *
+         */
+
+        if (!extract_image_patches || m_transformation_callback(extract_image_patches)) {
+            return false;
+        }
+
+        if (extract_image_patches->get_auto_pad() != ngraph::op::PadType::VALID) {
+            return false;
+        }
+
+        if (extract_image_patches->get_strides() != extract_image_patches->get_sizes()) {
+            return false;
+        }
+
+        auto p_shape_input = extract_image_patches->get_input_partial_shape(0);
+        auto sizes = extract_image_patches->get_sizes();
+        auto strides = extract_image_patches->get_strides();
+        auto rates = extract_image_patches->get_rates();
+
+        // Check that ExtractImagePatches input have static shape and rank == 4
+        if (!p_shape_input.rank().is_static() || p_shape_input.rank().get_length() != 4) {
+            return false;
+        }
+
+        // Check that ExtractImagePatches input spatial dimensions are not dynamic
+        if (p_shape_input[2].is_dynamic() || p_shape_input[3].is_dynamic()) {
+            return false;
+        }
+
+        // Check that ExtractImagePatches input spatial dimensions are divisible by EIP.strides
+        if (p_shape_input[2].get_length() % strides[0] != 0 || p_shape_input[3].get_length() % strides[1] != 0) {
+            return false;
+        }
+
+        // Check that EIP.sizes = EIP.strides
+        if (sizes[0] != strides[0] || sizes[1] != strides[1]) {
+            return false;
+        }
+
+        // Check that EIP.rates = {1, 1}
+        if (rates[0] != 1 || rates[1] != 1) {
+            return false;
+        }
+
+        auto reorg_yolo = std::make_shared<ngraph::opset3::ReorgYolo>(extract_image_patches->input(0).get_source_output(),
+            ngraph::Strides{extract_image_patches->get_strides()});
+
+        reorg_yolo->set_friendly_name(extract_image_patches->get_friendly_name());
+        ngraph::copy_runtime_info(extract_image_patches, reorg_yolo);
+        ngraph::replace_node(extract_image_patches, reorg_yolo);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(eip, "ConvertExtractImagePatchesToReorgYolo");
+    register_matcher(m, callback);
+}
index 9d957cc..8b8873d 100644 (file)
@@ -9,6 +9,7 @@
 #include "transformations/convert_opset3_to_opset2/convert_shapeof3.hpp"
 #include "transformations/convert_opset3_to_opset2/convert_shuffle_channels3.hpp"
 #include "transformations/convert_opset3_to_opset2/convert_topk3.hpp"
+#include "transformations/convert_extract_image_patches_to_reorg_yolo.hpp"
 #include "transformations/itt.hpp"
 
 #include <memory>
@@ -26,6 +27,7 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph
     manager.register_pass<ngraph::pass::ConvertShapeOf3>();
     manager.register_pass<ngraph::pass::ConvertShuffleChannels3>();
     manager.register_pass<ngraph::pass::ConvertTopK3>();
+    manager.register_pass<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
 
     manager.set_callback(m_transformation_callback);
     manager.run_passes(f);
diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_extract_image_patches_to_reorg_yolo_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_extract_image_patches_to_reorg_yolo_test.cpp
new file mode 100644 (file)
index 0000000..5cccbd8
--- /dev/null
@@ -0,0 +1,168 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/convert_extract_image_patches_to_reorg_yolo.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+
+TEST(TransformationTests, ConvertExtractImagePatchesToReorgYoloTests1) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 10, 10});
+
+        auto sizes = ngraph::Shape{5, 5};
+        auto strides = ngraph::Strides{5, 5};
+        auto rates = ngraph::Shape{1, 1};
+        ngraph::op::PadType auto_pad = ngraph::op::PadType::VALID;
+
+        auto eip = std::make_shared<ngraph::opset3::ExtractImagePatches>(input, sizes, strides, rates, auto_pad);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{eip}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 10, 10});
+        auto strides = ngraph::Strides{5, 5};
+        auto reorg_yolo = std::make_shared<ngraph::opset3::ReorgYolo>(input, strides);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reorg_yolo}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertExtractImagePatchesToReorgYoloTestsNegative1) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
+            ngraph::PartialShape{1, 3, ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic()});
+
+        auto sizes = ngraph::Shape{5, 5};
+        auto strides = ngraph::Strides{5, 5};
+        auto rates = ngraph::Shape{1, 1};
+        ngraph::op::PadType auto_pad = ngraph::op::PadType::VALID;
+
+        auto eip = std::make_shared<ngraph::opset3::ExtractImagePatches>(input, sizes, strides, rates, auto_pad);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{eip}, ngraph::ParameterVector{input});
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
+            ngraph::PartialShape{1, 3, ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic()});
+
+        auto sizes = ngraph::Shape{5, 5};
+        auto strides = ngraph::Strides{5, 5};
+        auto rates = ngraph::Shape{1, 1};
+        ngraph::op::PadType auto_pad = ngraph::op::PadType::VALID;
+
+        auto eip = std::make_shared<ngraph::opset3::ExtractImagePatches>(input, sizes, strides, rates, auto_pad);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{eip}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertExtractImagePatchesToReorgYoloTestsNegative2) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 10, 10});
+
+        auto sizes = ngraph::Shape{5, 5};
+        auto strides = ngraph::Strides{5, 5};
+        auto rates = ngraph::Shape{1, 1};
+        ngraph::op::PadType auto_pad = ngraph::op::PadType::SAME_LOWER;
+
+        auto eip = std::make_shared<ngraph::opset3::ExtractImagePatches>(input, sizes, strides, rates, auto_pad);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{eip}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 10, 10});
+
+        auto sizes = ngraph::Shape{5, 5};
+        auto strides = ngraph::Strides{5, 5};
+        auto rates = ngraph::Shape{1, 1};
+        ngraph::op::PadType auto_pad = ngraph::op::PadType::SAME_LOWER;
+
+        auto eip = std::make_shared<ngraph::opset3::ExtractImagePatches>(input, sizes, strides, rates, auto_pad);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{eip}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertExtractImagePatchesToReorgYoloTestsNegative3) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 10, 10});
+
+        auto sizes = ngraph::Shape{3, 3};
+        auto strides = ngraph::Strides{5, 5};
+        auto rates = ngraph::Shape{1, 1};
+        ngraph::op::PadType auto_pad = ngraph::op::PadType::VALID;
+
+        auto eip = std::make_shared<ngraph::opset3::ExtractImagePatches>(input, sizes, strides, rates, auto_pad);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{eip}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 10, 10});
+
+        auto sizes = ngraph::Shape{3, 3};
+        auto strides = ngraph::Strides{5, 5};
+        auto rates = ngraph::Shape{1, 1};
+        ngraph::op::PadType auto_pad = ngraph::op::PadType::VALID;
+
+        auto eip = std::make_shared<ngraph::opset3::ExtractImagePatches>(input, sizes, strides, rates, auto_pad);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{eip}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
index 7c98a0e..8e9a907 100644 (file)
@@ -606,6 +606,7 @@ extensions/ops/elementwise.py
 extensions/ops/embedding_bag.py
 extensions/ops/Enter.py
 extensions/ops/Exit.py
+extensions/ops/ExtractImagePatches.py
 extensions/ops/fake_output.py
 extensions/ops/fakequantize.py
 extensions/ops/gather.py
@@ -938,6 +939,7 @@ mo/utils/ir_reader/extenders/convert_extender.py
 mo/utils/ir_reader/extenders/deconvolution_extender.py
 mo/utils/ir_reader/extenders/deformable_convolution_extender.py
 mo/utils/ir_reader/extenders/experimental_extender.py
+mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py
 mo/utils/ir_reader/extenders/fakequantize_extender.py
 mo/utils/ir_reader/extenders/GRUCell_extender.py
 mo/utils/ir_reader/extenders/interpolate_extender.py
index 8967092..14ec62b 100644 (file)
  See the License for the specific language governing permissions and
  limitations under the License.
 """
-import numpy as np
 
-from extensions.ops.reorgyolo import ReorgYoloOp
+from extensions.ops.ExtractImagePatches import ExtractImagePatches
+from mo.front.common.partial_infer.utils import convert_tf_padding_to_str
+from mo.front.common.partial_infer.utils import int64_array
 from mo.front.extractor import FrontExtractorOp
-
+from mo.front.tf.extractors.utils import tf_int_list
 
 class ExtractImagePatchesExtractor(FrontExtractorOp):
     op = 'ExtractImagePatches'
@@ -25,8 +26,13 @@ class ExtractImagePatchesExtractor(FrontExtractorOp):
 
     @classmethod
     def extract(cls, node):
-        node['batch_dims'] = 0
-        node['channel_dims'] = 3
-        node['spatial_dims'] = [1, 2]
-        ReorgYoloOp.update_node_stat(node, {'stride': np.array(node.pb.attr['strides'].list.i[1])})
-        return cls.enabled
+
+        attrs = {
+            'spatial_dims': int64_array([1, 2]),
+            'sizes': tf_int_list(node.pb.attr['ksizes'].list),
+            'strides': tf_int_list(node.pb.attr['strides'].list),
+            'rates': tf_int_list(node.pb.attr['rates'].list),
+            'auto_pad': convert_tf_padding_to_str(node.pb.attr['padding'].s.decode()),
+        }
+
+        ExtractImagePatches.update_node_stat(node, attrs)
diff --git a/model-optimizer/extensions/ops/ExtractImagePatches.py b/model-optimizer/extensions/ops/ExtractImagePatches.py
new file mode 100644 (file)
index 0000000..69cd325
--- /dev/null
@@ -0,0 +1,82 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import numpy as np
+
+from mo.front.common.layout import shape_for_layout, get_batch_dim, get_features_dim
+from mo.front.common.partial_infer.utils import int64_array, tf_window_op_pad_infer
+from mo.graph.graph import Node, Graph
+from mo.ops.op import Op
+
+
+class ExtractImagePatches(Op):
+    op = "ExtractImagePatches"
+
+    def __init__(self, graph: Graph, attrs: dict):
+        assert 'spatial_dims' in attrs, \
+            'ExtractImagePatches operation should have `spatial_dims` parameter set during creation'
+
+        super().__init__(graph, {
+            'type': self.op,
+            'op': self.op,
+            'version': 'opset3',
+            'infer': self.infer,
+            'in_ports_count': 1,
+            'out_ports_count': 1,
+        }, attrs)
+
+    def backend_attrs(self):
+        return [
+            ('sizes', lambda node: ','.join(map(str, node['sizes'][node.spatial_dims]))),
+            ('strides', lambda node: ','.join(map(str, node['strides'][node.spatial_dims]))),
+            ('rates', lambda node: ','.join(map(str, node['rates'][node.spatial_dims]))),
+            'auto_pad',
+        ]
+
+    @staticmethod
+    def infer(node: Node):
+        assert [port.idx for port in node.in_ports().values() if not port.disconnected()] == [0], \
+            'Wrong input nodes number for node {} with type ExtractImagePatches'.format(node.soft_get('name', node.id))
+        input_shape = node.in_port(0).data.get_shape()
+        name = node.soft_get('name', node.id)
+        assert input_shape is not None, 'Input shape is not set for node {} with type ExtractImagePatches'.format(name)
+
+        assert len(input_shape) == 4, 'ExtractImagePatches operation supports only 4D tensors'
+
+        layout = node.graph.graph['layout']
+        N = input_shape[get_batch_dim(layout, 4)]
+        C = input_shape[get_features_dim(layout, 4)]
+
+        size_spatial = int64_array(node.sizes)[node.spatial_dims]
+
+        input_spatial_shape = input_shape[node.spatial_dims]
+        stride_spatial_shape = node.strides[node.spatial_dims]
+
+        size_extent = node.rates[node.spatial_dims] * (size_spatial - 1) + 1
+
+        pad_spatial_shape, output_spatial_shape = tf_window_op_pad_infer(input_spatial_shape,
+                                                                         size_extent,
+                                                                         stride_spatial_shape,
+                                                                         node.auto_pad,
+                                                                         False)
+
+        out_shape = shape_for_layout(layout,
+                                     batch=N,
+                                     features=C * np.prod(size_spatial),
+                                     height=output_spatial_shape[0],
+                                     width=output_spatial_shape[1])
+
+        node.out_port(0).data.set_shape(int64_array(out_shape))
diff --git a/model-optimizer/extensions/ops/ExtractImagePatches_test.py b/model-optimizer/extensions/ops/ExtractImagePatches_test.py
new file mode 100644 (file)
index 0000000..2d23c16
--- /dev/null
@@ -0,0 +1,81 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+import numpy as np
+from generator import generator, generate
+
+from extensions.ops.ExtractImagePatches import ExtractImagePatches
+from mo.front.common.partial_infer.utils import int64_array
+from mo.graph.graph import Node
+from mo.utils.unittest.graph import build_graph
+
+nodes = {
+    'input': {'op': 'Parameter', 'kind': 'op', 'shape': None},
+    'input_data': {'value': None, 'kind': 'data', 'shape': None},
+    'EIP': {'op': 'ExtractImagePatches', 'kind': 'op', 'sizes': None, 'strides': None, 'rates': None, 'auto_pad': None},
+    'EIP_data': {'value': None, 'kind': 'data', 'shape': None},
+    'output': {'op': 'Result', 'kind': 'op', 'shape': None},
+}
+
+edges = [
+    ('input', 'input_data'),
+    ('input_data', 'EIP'),
+    ('EIP', 'EIP_data'),
+    ('EIP_data', 'output'),
+]
+
+@generator
+class TestExtractImagePatchesPartialInfer(unittest.TestCase):
+    @generate(*[
+        ([1, 10, 10, 3], [1, 3, 3, 1], [1, 5, 5, 1], [1, 1, 1, 1], 'valid', 'NHWC', [1, 2, 2, 27]),
+        ([1, 10, 10, 3], [1, 3, 3, 1], [1, 5, 5, 1], [1, 2, 2, 1], 'valid', 'NHWC', [1, 2, 2, 27]),
+        ([1, 10, 10, 3], [1, 4, 4, 1], [1, 8, 8, 1], [1, 1, 1, 1], 'valid', 'NHWC', [1, 1, 1, 48]),
+        ([1, 10, 10, 3], [1, 4, 4, 1], [1, 8, 8, 1], [1, 1, 1, 1], 'same_upper', 'NHWC', [1, 2, 2, 48]),
+        ([1, 10, 10, 3], [1, 4, 4, 1], [1, 9, 9, 1], [1, 1, 1, 1], 'same_upper', 'NHWC', [1, 2, 2, 48]),
+        ([1, 10, 10, 3], [1, 4, 4, 1], [1, 9, 9, 1], [1, 1, 1, 1], 'same_lower', 'NHWC', [1, 2, 2, 48]),
+        ([1, 64, 64, 3], [1, 3, 3, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'valid', 'NHWC', [1, 62, 62, 27]),
+        ([1, 64, 64, 3], [1, 3, 3, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'same_upper', 'NHWC', [1, 64, 64, 27]),
+
+        ([1, 3, 10, 10], [1, 1, 3, 3], [1, 1, 5, 5], [1, 1, 1, 1], 'valid', 'NCHW', [1, 27, 2, 2]),
+        ([1, 3, 10, 10], [1, 1, 4, 4], [1, 1, 8, 8], [1, 1, 1, 1], 'valid', 'NCHW', [1, 48, 1, 1]),
+
+        ([1, 3, 10, 10], [1, 1, 4, 4], [1, 1, 9, 9], [1, 1, 1, 1], 'same_upper', 'NCHW', [1, 48, 2, 2]),
+        ([1, 3, 10, 10], [1, 1, 4, 4], [1, 1, 9, 9], [1, 1, 1, 1], 'same_lower', 'NCHW', [1, 48, 2, 2]),
+
+    ])
+
+
+    def test_eip_infer(self, input_shape, sizes, strides, rates, auto_pad, layout, output_shape):
+        graph = build_graph(
+            nodes_attrs=nodes,
+            edges=edges,
+            update_attributes={
+                'input': {'shape': int64_array(input_shape)},
+                'input_data': {'shape': int64_array(input_shape)},
+                'EIP': {'spatial_dims': int64_array([1, 2]) if layout == 'NHWC' else int64_array([2, 3]),
+                        'sizes': int64_array(sizes), 'strides': int64_array(strides), 'rates': int64_array(rates),
+                        'auto_pad': auto_pad},
+            }
+        )
+
+        graph.graph['layout'] = layout
+
+        eip_node = Node(graph, 'EIP')
+        ExtractImagePatches.infer(eip_node)
+
+        self.assertTrue(np.array_equal(eip_node.out_port(0).data.get_shape(), output_shape))
diff --git a/model-optimizer/mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py b/model-optimizer/mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py
new file mode 100644 (file)
index 0000000..73f5909
--- /dev/null
@@ -0,0 +1,31 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from mo.front.common.partial_infer.utils import int64_array
+from mo.utils.graph import Node
+from mo.utils.ir_reader.extender import Extender
+
+
+class ExtractImagePatches(Extender):
+    op = 'ExtractImagePatches'
+
+    @staticmethod
+    def extend(op: Node):
+        op['sizes'] = int64_array([1, 1] + op.sizes)
+        op['strides'] = int64_array([1, 1] + op.strides)
+        op['rates'] = int64_array([1, 1] + op.rates)
+
+        op['spatial_dims'] = int64_array([2, 3])