Fixed transformations looking for FusedBatchNorm operation to look for FBNV2 and...
authorEvgeny Lazarev <evgeny.lazarev@intel.com>
Thu, 12 Nov 2020 04:33:39 +0000 (07:33 +0300)
committerGitHub <noreply@github.com>
Thu, 12 Nov 2020 04:33:39 +0000 (07:33 +0300)
* Fixed transformations looking for FusedBatchNorm operation to consider FusedBatchNormV2 and FusedBatchNormV3 also.

* Updated unit test for FusedBatchNormTraining

* Fixed unit test

model-optimizer/extensions/front/tf/mvn.py
model-optimizer/extensions/middle/FusedBatchNormNonConstant.py
model-optimizer/extensions/middle/FusedBatchNormTraining.py
model-optimizer/extensions/middle/FusedBatchNormTraining_test.py

index adb94c8..edd401f 100644 (file)
@@ -35,7 +35,7 @@ class MVNReplacer(FrontReplacementSubgraph):
                 ('variance', dict(op='ReduceMean')),
                 ('squeeze_mean', dict(op='Squeeze')),
                 ('squeeze_variance', dict(op='Squeeze')),
-                ('fbn', dict(op='FusedBatchNorm')),
+                ('fbn', dict(op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3'])),
             ],
             edges=[
                 ('mean', 'stop_grad', {'in': 0}),
index 543578b..3993d7a 100644 (file)
@@ -40,7 +40,8 @@ class FusedBatchNormNonConstant(MiddleReplacementPattern):
     def pattern(self):
         return dict(
             nodes=[
-                ('op', dict(kind='op', op='FusedBatchNorm'))],
+                ('op', dict(kind='op', op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2',
+                                                            'FusedBatchNormV3']))],
             edges=[]
         )
 
index e4acee3..3640a59 100644 (file)
@@ -44,7 +44,8 @@ class FusedBatchNormTraining(MiddleReplacementPattern):
     def pattern(self):
         return dict(
             nodes=[
-                ('op', dict(kind='op', op='FusedBatchNorm', is_training=True))],
+                ('op', dict(kind='op', op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3'],
+                 is_training=True))],
             edges=[]
         )
 
index c010eab..6f8e399 100644 (file)
@@ -17,6 +17,7 @@
 import unittest
 
 import numpy as np
+from generator import generator, generate
 
 from extensions.middle.FusedBatchNormTraining import FusedBatchNormTraining
 from mo.front.common.partial_infer.utils import int64_array
@@ -74,8 +75,12 @@ nodes_attributes = {
 }
 
 
+@generator
 class FusedBatchNormTrainingTest(unittest.TestCase):
-    def test_transformation(self):
+    @generate(*[
+        'FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
+    ])
+    def test_transformation(self, op: str):
         graph = build_graph(nodes_attributes,
                             [('placeholder', 'placeholder_data', {}),
                              ('scale', 'scale_data'),
@@ -91,7 +96,7 @@ class FusedBatchNormTrainingTest(unittest.TestCase):
                              ('batchnorm_data', 'result'),
                              ],
                             {}, nodes_with_edges_only=True)
-
+        graph.nodes['batchnorm']['op'] = op
         graph_ref = build_graph(nodes_attributes,
                                 [('placeholder', 'placeholder_data', {}),
                                  ('scale', 'scale_data'),
@@ -125,6 +130,8 @@ class FusedBatchNormTrainingTest(unittest.TestCase):
         FusedBatchNormTraining().find_and_replace_pattern(graph)
         shape_inference(graph)
 
+        graph_ref.nodes['batchnorm']['op'] = op
+
         (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
         self.assertTrue(flag, resp)