Added transformation config to support automl efficientdet models (#2894)
authorEvgeny Lazarev <evgeny.lazarev@intel.com>
Mon, 2 Nov 2020 16:21:05 +0000 (19:21 +0300)
committerGitHub <noreply@github.com>
Mon, 2 Nov 2020 16:21:05 +0000 (19:21 +0300)
* Added transformation config to support automl efficientdet-4 model

* Added configuration file to convert Automl EfficientDet model

* Updated unit test for Pack

* Added instruction on how to convert EfficientDet Tensorflow model

* Updated documentation on how to convert EfficientDet model

* Updated a documentation with instruction on how to convert Automl EfficientDet.

docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_EfficientDet_Models.md [new file with mode: 0644]
docs/doxygen/ie_docs.xml
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/front/Pack.py
model-optimizer/extensions/front/Pack_test.py
model-optimizer/extensions/front/tf/AutomlEfficientDet.py [new file with mode: 0644]
model-optimizer/extensions/front/tf/automl_efficientdet.json [new file with mode: 0644]
model-optimizer/mo/ops/unsqueeze.py

diff --git a/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_EfficientDet_Models.md b/docs/MO_DG/prepare_model/convert_model/tf_specific/Convert_EfficientDet_Models.md
new file mode 100644 (file)
index 0000000..c58de18
--- /dev/null
@@ -0,0 +1,96 @@
+# Converting EfficientDet Models from TensorFlow {#openvino_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_EfficientDet_Models}
+
+This tutorial explains how to convert detection EfficientDet\* public models to the Intermediate Representation (IR). 
+
+## <a name="efficientdet-to-ir"></a>Convert EfficientDet Model to IR
+
+On GitHub*, you can find several public versions of EfficientDet model implementation. This tutorial explains how to 
+convert models from the [https://github.com/google/automl/tree/master/efficientdet](https://github.com/google/automl/tree/master/efficientdet) 
+repository (commit 96e1fee) to IR.
+
+### Get Frozen TensorFlow\* Model
+
+Follow the instructions below to get frozen TensorFlow EfficientDet model. We use EfficientDet-D4 model as an example:
+
+1. Clone the repository:<br>
+```sh
+git clone https://github.com/google/automl
+cd automl/efficientdet
+```
+2. (Optional) Checkout to the commit that the conversion was tested on:<br>
+```sh
+git checkout 96e1fee
+```
+3. Install required dependencies:<br>
+```sh
+python3 -m pip install --upgrade pip
+python3 -m pip install -r automl/efficientdet/requirements.txt
+```
+4. Download and extract the model checkpoint [efficientdet-d4.tar.gz](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.tar.gz)
+referenced in the "Pretrained EfficientDet Checkpoints" section of the model repository:<br>
+```sh
+wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.tar.gz
+tar zxvf efficientdet-d4.tar.gz
+```
+5. Freeze the model:<br>
+```sh
+python3 model_inspect.py --runmode=saved_model --model_name=efficientdet-d4  --ckpt_path=efficientdet-d4 --saved_model_dir=savedmodeldir
+```
+As a result the frozen model file `savedmodeldir/efficientdet-d4_frozen.pb` will be generated.
+
+> **NOTE:** If you see an error `AttributeError: module 'tensorflow_core.python.keras.api._v2.keras.initializers' has no attribute 'variance_scaling'` apply the fix from the [patch](https://github.com/google/automl/pull/846).
+
+### Convert EfficientDet TensorFlow Model to the IR
+
+To generate the IR of the EfficientDet TensorFlow model, run:<br>
+```sh
+python3 $MO_ROOT/mo.py \
+--input_model savedmodeldir/efficientdet-d4_frozen.pb \
+--tensorflow_use_custom_operations_config $MO_ROOT/extensions/front/tf/automl_efficientdet.json \
+--input_shape [1,$IMAGE_SIZE,$IMAGE_SIZE,3] \
+--reverse_input_channels
+```
+
+Where `$IMAGE_SIZE` is the size that the input image of the original TensorFlow model will be resized to. Different
+EfficientDet models were trained with different input image sizes. To determine the right one refer to the `efficientdet_model_param_dict`
+dictionary in the [hparams_config.py](https://github.com/google/automl/blob/96e1fee/efficientdet/hparams_config.py#L304) file.
+The attribute `image_size` specifies the shape to be specified for the model conversion.
+
+The `tensorflow_use_custom_operations_config` command line parameter specifies the configuration json file containing hints
+to the Model Optimizer on how to convert the model and trigger transformations implemented in the 
+`$MO_ROOT/extensions/front/tf/AutomlEfficientDet.py`. The json file contains some parameters which must be changed if you
+train the model yourself and modified the `hparams_config` file or the parameters are different from the ones used for EfficientDet-D4.
+The attribute names are self-explanatory or match the name in the `hparams_config` file.
+
+> **NOTE:** The color channel order (RGB or BGR) of an input data should match the channel order of the model training dataset. If they are different, perform the `RGB<->BGR` conversion specifying the command-line parameter: `--reverse_input_channels`. Otherwise, inference results may be incorrect. For more information about the parameter, refer to **When to Reverse Input Channels** section of [Converting a Model Using General Conversion Parameters](../Converting_Model_General.md).
+
+OpenVINO&trade; toolkit provides samples that can be used to infer EfficientDet model. For more information, refer to 
+[Object Detection for SSD C++ Sample](@ref openvino_inference_engine_samples_object_detection_sample_ssd_README) and 
+[Object Detection for SSD Python Sample](@ref openvino_inference_engine_ie_bridges_python_sample_object_detection_sample_ssd_README).
+
+## <a name="efficientdet-ir-results-interpretation"></a>Interpreting Results of the TensorFlow Model and the IR
+
+The TensorFlow model produces as output a list of 7-element tuples: `[image_id, y_min, x_min, y_max, x_max, confidence, class_id]`, where:
+* `image_id` -- image batch index.
+* `y_min` -- absolute `y` coordinate of the lower left corner of the detected object.
+* `x_min` -- absolute `x` coordinate of the lower left corner of the detected object.
+* `y_max` -- absolute `y` coordinate of the upper right corner of the detected object.
+* `x_max` -- absolute `x` coordinate of the upper right corner of the detected object.
+* `confidence` -- is the confidence of the detected object.
+* `class_id` -- is the id of the detected object class counted from 1.
+
+The output of the IR is a list of 7-element tuples: `[image_id, class_id, confidence, x_min, y_min, x_max, y_max]`, where:
+* `image_id` -- image batch index.
+* `class_id` -- is the id of the detected object class counted from 0.
+* `confidence` -- is the confidence of the detected object.
+* `x_min` -- normalized `x` coordinate of the lower left corner of the detected object.
+* `y_min` -- normalized `y` coordinate of the lower left corner of the detected object.
+* `x_max` -- normalized `x` coordinate of the upper right corner of the detected object.
+* `y_max` -- normalized `y` coordinate of the upper right corner of the detected object.
+
+The first element with `image_id = -1` means end of data.
+
+---
+## See Also
+
+* [Sub-Graph Replacement in Model Optimizer](../../customize_model_optimizer/Subgraph_Replacement_Model_Optimizer.md)
index 40fd30d..996a43c 100644 (file)
@@ -22,6 +22,7 @@
                             <tab type="user" title="Converting BERT from TensorFlow" url="@ref openvino_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_BERT_From_Tensorflow"/>
                             <tab type="user" title="Convert TensorFlow* XLNet Model to the Intermediate Representation" url="@ref openvino_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_XLNet_From_Tensorflow"/>
                             <tab type="user" title="Converting TensorFlow* Wide and Deep Models from TensorFlow" url="@ref openvino_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_WideAndDeep_Family_Models"/>
+                            <tab type="user" title="Converting EfficientDet Models from TensorFlow" url="@ref openvino_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_EfficientDet_Models"/>
                         </tab>
                         <tab type="usergroup" title="Converting a MXNet* Model" url="@ref openvino_docs_MO_DG_prepare_model_convert_model_Convert_Model_From_MxNet">
                             <tab type="user" title="Converting a Style Transfer Model from MXNet" url="@ref openvino_docs_MO_DG_prepare_model_convert_model_mxnet_specific_Convert_Style_Transfer_From_MXNet"/>
index a58cf4d..847f3bd 100644 (file)
@@ -343,6 +343,8 @@ extensions/front/tf/__init__.py
 extensions/front/tf/activation_ext.py
 extensions/front/tf/argmax_ext.py
 extensions/front/tf/assign_elimination.py
+extensions/front/tf/automl_efficientdet.json
+extensions/front/tf/AutomlEfficientDet.py
 extensions/front/tf/basic_lstm_cell.py
 extensions/front/tf/batch_to_space_ext.py
 extensions/front/tf/BatchMatMul_ext.py
index 9c285b4..4070ec4 100644 (file)
 """
 from mo.front.common.partial_infer.utils import int64_array
 from mo.front.common.replacement import FrontReplacementOp
-from mo.graph.graph import Node, Graph
+from mo.front.tf.graph_utils import create_op_with_const_inputs
+from mo.graph.graph import Node, Graph, rename_nodes
 from mo.ops.concat import Concat
-from mo.ops.expand_dims import ExpandDims
+from mo.ops.unsqueeze import Unsqueeze
 
 
 class Pack(FrontReplacementOp):
@@ -25,15 +26,15 @@ class Pack(FrontReplacementOp):
     enabled = True
 
     def replace_op(self, graph: Graph, node: Node):
-        out_node = Concat(graph, {'axis': node.axis, 'in_ports_count': len(node.in_ports()),
-                                  'name': node.name + '/Concat_', }).create_node()
+        out_node = Concat(graph, {'axis': node.axis, 'in_ports_count': len(node.in_ports())}).create_node()
+        pack_name = node.soft_get('name', node.id)
 
         for ind in node.in_ports():
-            expand_dims_node = ExpandDims(graph, {'expand_axis': int64_array([node.axis]),
-                                                  'name': node.name + '/ExpandDims_'}).create_node()
-            node.in_port(ind).get_connection().set_destination(expand_dims_node.in_port(0))
-            expand_dims_node.out_port(0).connect(out_node.in_port(ind))
-        # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0.
-        # The "explicit" version of the return value is: [(out_node.id, 0)])
+            unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array([node.axis])},
+                                                         {'name': node.soft_get('name', node.id) + '/Unsqueeze'})
+            node.in_port(ind).get_connection().set_destination(unsqueeze_node.in_port(0))
+            unsqueeze_node.out_port(0).connect(out_node.in_port(ind))
+
+        rename_nodes([(node, pack_name + '/TBR'), (out_node, pack_name)])
         return [out_node.id]
 
index 663d1ce..7471b8e 100644 (file)
@@ -20,6 +20,7 @@ import numpy as np
 from generator import generator, generate
 
 from extensions.front.Pack import Pack
+from mo.front.common.partial_infer.utils import int64_array
 from mo.utils.ir_engine.compare_graphs import compare_graphs
 from mo.utils.unittest.graph import build_graph
 
@@ -32,12 +33,16 @@ nodes_attributes = {
     'pack': {'axis': None, 'type': None, 'kind': 'op', 'op': 'Pack'},
     # Test operation
     'last': {'type': None, 'value': None, 'kind': 'op', 'op': None},
-    # ExpandDims, Concat and Const operations
+    # Unsqueeze, Concat and Const operations
     'const_1': {'value': None, 'type': None, 'kind': 'op', 'op': 'Const'},
-    'ExpandDims_0': {'expand_axis': None, 'type': None, 'kind': 'op', 'op': 'ExpandDims'},
-    'ExpandDims_1': {'expand_axis': None, 'type': None, 'kind': 'op', 'op': 'ExpandDims'},
-    'ExpandDims_2': {'expand_axis': None, 'type': None, 'kind': 'op', 'op': 'ExpandDims'},
-    'ExpandDims_3': {'expand_axis': None, 'type': None, 'kind': 'op', 'op': 'ExpandDims'},
+    'Unsqueeze_0': {'type': 'Unsqueeze', 'kind': 'op', 'op': 'Unsqueeze'},
+    'Unsqueeze_1': {'type': 'Unsqueeze', 'kind': 'op', 'op': 'Unsqueeze'},
+    'Unsqueeze_2': {'type': 'Unsqueeze', 'kind': 'op', 'op': 'Unsqueeze'},
+    'Unsqueeze_3': {'type': 'Unsqueeze', 'kind': 'op', 'op': 'Unsqueeze'},
+    'Unsqueeze_0_axis': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': None, 'value': None},
+    'Unsqueeze_1_axis': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': None, 'value': None},
+    'Unsqueeze_2_axis': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': None, 'value': None},
+    'Unsqueeze_3_axis': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': None, 'value': None},
     'concat_1': {'axis': None, 'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
 }
 
@@ -65,15 +70,17 @@ class PackTest(unittest.TestCase):
         graph_ref_edges = []
         for i in range(num_inputs - num_placeholders + 1):
             for j in range(num_placeholders):
-                graph_ref_edges.append(('placeholder_{}'.format(j), 'ExpandDims_{}'.format(i + j)))
-                graph_ref_edges.append(('ExpandDims_{}'.format(i + j), 'concat_1'))
+                graph_ref_edges.append(('placeholder_{}'.format(j), 'Unsqueeze_{}'.format(i + j)))
+                graph_ref_edges.append(('Unsqueeze_{}'.format(i + j), 'concat_1'))
         graph_ref_edges.append(('concat_1', 'last'))
 
         update_graph_ref_attributes = {}
         for i in range(num_placeholders):
             update_graph_ref_attributes['placeholder_{}'.format(i)] = {'shape': np.array([1, 227, 227, 3])}
         for i in range(num_inputs):
-            update_graph_ref_attributes['ExpandDims_{}'.format(i)] = {'expand_axis': np.array([axis])}
+            graph_ref_edges.append(('Unsqueeze_{}_axis'.format(i), 'Unsqueeze_{}'.format(i)))
+            update_graph_ref_attributes['Unsqueeze_{}_axis'.format(i)] = {'shape': int64_array([1]),
+                                                                          'value': int64_array([axis])}
         update_graph_ref_attributes['concat_1'] = {'axis': axis}
 
         graph_ref = build_graph(nodes_attributes, graph_ref_edges, update_graph_ref_attributes,
diff --git a/model-optimizer/extensions/front/tf/AutomlEfficientDet.py b/model-optimizer/extensions/front/tf/AutomlEfficientDet.py
new file mode 100644 (file)
index 0000000..f9af872
--- /dev/null
@@ -0,0 +1,140 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import numpy as np
+
+from extensions.front.Pack import Pack
+from extensions.front.TransposeOrderNormalizer import TransposeOrderNormalizer
+from extensions.front.eltwise_n import EltwiseNReplacement
+from extensions.front.tf.pad_tf_to_pad import PadTFToPad
+from extensions.ops.DetectionOutput import DetectionOutput
+from extensions.ops.activation_ops import Sigmoid
+from extensions.ops.priorbox_clustered import PriorBoxClusteredOp
+from mo.front.common.partial_infer.utils import int64_array
+from mo.front.tf.replacement import FrontReplacementFromConfigFileGeneral
+from mo.graph.graph import Graph, Node
+from mo.middle.passes.convert_data_type import data_type_str_to_np
+from mo.ops.concat import Concat
+from mo.ops.const import Const
+from mo.ops.reshape import Reshape
+from mo.ops.result import Result
+
+
+class EfficientDet(FrontReplacementFromConfigFileGeneral):
+    replacement_id = 'AutomlEfficientDet'
+
+    def run_before(self):
+        from extensions.front.ExpandDimsToUnsqueeze import ExpandDimsToUnsqueeze
+        return [ExpandDimsToUnsqueeze, Pack, TransposeOrderNormalizer, PadTFToPad, EltwiseNReplacement]
+
+    class AnchorGenerator:
+        def __init__(self, min_level, aspect_ratios, num_scales, anchor_scale):
+            self.min_level = min_level
+            self.aspect_ratios = aspect_ratios
+            self.anchor_scale = anchor_scale
+            self.scales = [2 ** (float(s) / num_scales) for s in range(num_scales)]
+
+        def get(self, layer_id):
+            widths = []
+            heights = []
+            for s in self.scales:
+                for a in self.aspect_ratios:
+                    base_anchor_size = 2 ** (self.min_level + layer_id) * self.anchor_scale
+                    heights.append(base_anchor_size * s * a[1])
+                    widths.append(base_anchor_size * s * a[0])
+            return widths, heights
+
+    def transform_graph(self, graph: Graph, replacement_descriptions: dict):
+        parameter_node = graph.get_op_nodes(op='Parameter')[0]
+        parameter_node['data_type'] = data_type_str_to_np(parameter_node.graph.graph['cmd_params'].data_type)
+        parameter_node.out_port(0).disconnect()
+
+        # remove existing Result operations to remove unsupported sub-graph
+        graph.remove_nodes_from([node.id for node in graph.get_op_nodes(op='Result')] + ['detections'])
+
+        # determine if the op which is a input/final result of mean value and scale applying to the input tensor
+        # then connect it to the input of the first convolution of the model, so we remove the image pre-processing
+        # which includes padding and resizing from the model
+        preprocessing_input_node_id = replacement_descriptions['preprocessing_input_node']
+        assert preprocessing_input_node_id in graph.nodes, 'The node with name "{}" is not found in the graph. This ' \
+                                                           'node should provide scaled image output and is specified' \
+                                                           ' in the json file.'.format(preprocessing_input_node_id)
+        preprocessing_input_node = Node(graph, preprocessing_input_node_id)
+        preprocessing_input_node.in_port(0).get_connection().set_source(parameter_node.out_port(0))
+
+        preprocessing_output_node_id = replacement_descriptions['preprocessing_output_node']
+        assert preprocessing_output_node_id in graph.nodes, 'The node with name "{}" is not found in the graph. This ' \
+                                                            'node should provide scaled image output and is specified' \
+                                                            ' in the json file.'.format(preprocessing_output_node_id)
+        preprocessing_output_node = Node(graph, preprocessing_output_node_id)
+        preprocessing_output_node.out_port(0).disconnect()
+
+        convolution_nodes = [n for n in graph.pseudo_topological_sort() if n.soft_get('type') == 'Convolution']
+        convolution_nodes[0].in_port(0).get_connection().set_source(preprocessing_output_node.out_port(0))
+
+        # create prior boxes (anchors) generator
+        aspect_ratios = replacement_descriptions['aspect_ratios']
+        assert len(aspect_ratios) % 2 == 0
+        aspect_ratios = list(zip(aspect_ratios[::2], aspect_ratios[1::2]))
+        priors_generator = self.AnchorGenerator(min_level=int(replacement_descriptions['min_level']),
+                                                aspect_ratios=aspect_ratios,
+                                                num_scales=int(replacement_descriptions['num_scales']),
+                                                anchor_scale=replacement_descriptions['anchor_scale'])
+
+        prior_boxes = []
+        for i in range(100):
+            inp_name = 'box_net/box-predict{}/BiasAdd'.format('_%d' % i if i else '')
+            if inp_name not in graph:
+                break
+            widths, heights = priors_generator.get(i)
+            prior_box_op = PriorBoxClusteredOp(graph, {'width': np.array(widths),
+                                                       'height': np.array(heights),
+                                                       'clip': 0, 'flip': 0,
+                                                       'variance': replacement_descriptions['variance'],
+                                                       'offset': 0.5})
+            prior_boxes.append(prior_box_op.create_node([Node(graph, inp_name), parameter_node]))
+
+        # concatenate prior box operations
+        concat_prior_boxes = Concat(graph, {'axis': -1}).create_node()
+        for idx, node in enumerate(prior_boxes):
+            concat_prior_boxes.add_input_port(idx)
+            concat_prior_boxes.in_port(idx).connect(node.out_port(0))
+
+        conf = Sigmoid(graph, dict(name='concat/sigmoid')).create_node([Node(graph, 'concat')])
+        reshape_size_node = Const(graph, {'value': int64_array([0, -1])}).create_node([])
+        logits = Reshape(graph, dict(name=conf.name + '/Flatten')).create_node([conf, reshape_size_node])
+        deltas = Reshape(graph, dict(name='concat_1/Flatten')).create_node([Node(graph, 'concat_1'), reshape_size_node])
+
+        # revert convolution boxes prediction weights from yxYX to xyXY (convolutions share weights and bias)
+        weights = Node(graph, 'box_net/box-predict/pointwise_kernel')
+        weights.value = weights.value.reshape(-1, 4)[:, [1, 0, 3, 2]].reshape(weights.shape)
+        bias = Node(graph, 'box_net/box-predict/bias')
+        bias.value = bias.value.reshape(-1, 4)[:, [1, 0, 3, 2]].reshape(bias.shape)
+
+        detection_output_node = DetectionOutput(graph, dict(
+            name='detections',
+            num_classes=int(replacement_descriptions['num_classes']),
+            share_location=1,
+            background_label_id=int(replacement_descriptions['num_classes']) + 1,
+            nms_threshold=replacement_descriptions['nms_threshold'],
+            confidence_threshold=replacement_descriptions['confidence_threshold'],
+            top_k=100,
+            keep_top_k=100,
+            code_type='caffe.PriorBoxParameter.CENTER_SIZE',
+        )).create_node([deltas, logits, concat_prior_boxes])
+
+        output_op = Result(graph, dict(name='output'))
+        output_op.create_node([detection_output_node])
diff --git a/model-optimizer/extensions/front/tf/automl_efficientdet.json b/model-optimizer/extensions/front/tf/automl_efficientdet.json
new file mode 100644 (file)
index 0000000..19eb112
--- /dev/null
@@ -0,0 +1,18 @@
+[
+  {
+    "id": "AutomlEfficientDet",
+    "custom_attributes": {
+      "preprocessing_input_node": "convert_image",
+      "preprocessing_output_node": "truediv",
+      "aspect_ratios": [1.0, 1.0, 1.4, 0.7, 0.7, 1.4],
+      "variance": [1.0, 1.0, 1.0, 1.0],
+      "min_level": 3,
+      "num_scales": 3,
+      "anchor_scale": 4.0,
+      "num_classes": 90,
+      "nms_threshold": 0.6,
+      "confidence_threshold": 0.2
+    },
+    "match_kind": "general"
+  }
+]
index 946dc7b..6ac1ed7 100644 (file)
@@ -32,14 +32,14 @@ class Unsqueeze(Op):
 
     def __init__(self, graph, attrs: dict):
         super().__init__(graph, {
-            'op': __class__.op,
-            'type': __class__.op,
+            'op': self.op,
+            'type': self.op,
             'version': 'opset1',
             'unsqueeze_dims': None,
             'reinterp_shape': True,
             'in_ports_count': 2,
             'out_ports_count': 1,
-            'infer': __class__.infer
+            'infer': self.infer
         }, attrs)
 
     @staticmethod