From 004f414b89445c4b3c33e163fdcba9d0ce3edfa4 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 1 Jun 2020 13:48:06 +0500 Subject: [PATCH] Fix SparseWeightedSum transform for Wide and Deep (#698) WhereDecomposition transform is applied to Where operation in for-garbage sub-graph remained after SparseWeightedSum transform. Signed-off-by: Roman Kazantsev --- model-optimizer/extensions/front/tf/sparse_weighted_sum.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/model-optimizer/extensions/front/tf/sparse_weighted_sum.py b/model-optimizer/extensions/front/tf/sparse_weighted_sum.py index 2732212..7ea0bb0 100644 --- a/model-optimizer/extensions/front/tf/sparse_weighted_sum.py +++ b/model-optimizer/extensions/front/tf/sparse_weighted_sum.py @@ -76,9 +76,10 @@ class ExperimentalSparseWeightedSumFrontReplacer(FrontReplacementSubgraph): gather0_2 = match['gather0_2'] greaterequal0 = match['greaterequal0'] sparse_fill_empty_rows = match['sparse_fill_empty_rows'] - + where0 = match['where0'] gather = match['gather'] select = match['select'] + log.debug('Found ExperimentalSparseWeightedSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op, sparse_fill_empty_rows.name)) sparse_weighted_sum = ExperimentalSparseWeightedSum(graph, {'name': sparse_fill_empty_rows.name + '/ExperimentalSparseWeightedSum_'}).create_node() @@ -96,7 +97,7 @@ class ExperimentalSparseWeightedSumFrontReplacer(FrontReplacementSubgraph): gather.in_port(0).disconnect() select.out_port(0).get_connection().set_source(sparse_weighted_sum.out_port(0)) - graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id]) + graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id]) class ExperimentalSparseWeightedSumFrontReplacer2(FrontReplacementSubgraph): @@ -160,6 +161,8 @@ class ExperimentalSparseWeightedSumFrontReplacer2(FrontReplacementSubgraph): sparse_fill_empty_rows = match['sparse_fill_empty_rows'] gather = match['gather'] select = match['select'] + where0 = match['where0'] + log.debug('Found ExperimentalSparseWeightedSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op, sparse_fill_empty_rows.name)) sparse_weighted_sum = ExperimentalSparseWeightedSum(graph, {'name': sparse_fill_empty_rows.name + '/ExperimentalSparseWeightedSum_'}).create_node() @@ -177,4 +180,4 @@ class ExperimentalSparseWeightedSumFrontReplacer2(FrontReplacementSubgraph): gather.in_port(0).disconnect() select.out_port(0).get_connection().set_source(sparse_weighted_sum.out_port(0)) - graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id]) + graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id]) -- 2.7.4