('variance', dict(op='ReduceMean')),
('squeeze_mean', dict(op='Squeeze')),
('squeeze_variance', dict(op='Squeeze')),
- ('fbn', dict(op='FusedBatchNorm')),
+ ('fbn', dict(op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3'])),
],
edges=[
('mean', 'stop_grad', {'in': 0}),
def pattern(self):
return dict(
nodes=[
- ('op', dict(kind='op', op='FusedBatchNorm'))],
+ ('op', dict(kind='op', op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2',
+ 'FusedBatchNormV3']))],
edges=[]
)
def pattern(self):
return dict(
nodes=[
- ('op', dict(kind='op', op='FusedBatchNorm', is_training=True))],
+ ('op', dict(kind='op', op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3'],
+ is_training=True))],
edges=[]
)
import unittest
import numpy as np
+from generator import generator, generate
from extensions.middle.FusedBatchNormTraining import FusedBatchNormTraining
from mo.front.common.partial_infer.utils import int64_array
}
+@generator
class FusedBatchNormTrainingTest(unittest.TestCase):
- def test_transformation(self):
+ @generate(*[
+ 'FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
+ ])
+ def test_transformation(self, op: str):
graph = build_graph(nodes_attributes,
[('placeholder', 'placeholder_data', {}),
('scale', 'scale_data'),
('batchnorm_data', 'result'),
],
{}, nodes_with_edges_only=True)
-
+ graph.nodes['batchnorm']['op'] = op
graph_ref = build_graph(nodes_attributes,
[('placeholder', 'placeholder_data', {}),
('scale', 'scale_data'),
FusedBatchNormTraining().find_and_replace_pattern(graph)
shape_inference(graph)
+ graph_ref.nodes['batchnorm']['op'] = op
+
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)