Implement support for opset3 EmbeddingBag ops (#546)
authorMaxim Vafin <maxim.vafin@intel.com>
Mon, 8 Jun 2020 15:06:40 +0000 (18:06 +0300)
committerGitHub <noreply@github.com>
Mon, 8 Jun 2020 15:06:40 +0000 (18:06 +0300)
commitf1811ad0602fc283d7750a4910d84f14824ee449
tree38e951490fd62d7ac78acae64fd816cf52feb8ca
parentd15548357380f43c9f33dd825f18b20b65202426
Implement support for opset3 EmbeddingBag ops (#546)

* [MO] Implement EmbeddingBag_3

* Transform dynamic sub-graph of Wide and Deep into EmbeddingSegmentsSum

- Expressed SparseWeightedSum sub-graph through EmbeddingSegmentsSum
- Removed experimental SparseWeightedSum layer
- Implemented tests for the transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix EmbeddingBag shape infer

* Fix EmbeddingSegmentsSum transformation for Wide and Deep

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix EmbeddingSegmentSum replacer after ports swap

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Update package_BOM.txt

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Add unit tests for EmbeddingXXX shape infer

* Fix ATen resolver

* Remove deleted files from BOM

* Add opset version to embedding_bag

* Use base class for EmbeddingBag

* Fix per_sample_weights case

* Fix EmbeddingSegmentsSum transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix EmbeddingBag checks

* Fix ATen front transformation and merge conflicts

* Fix BOM

* Work around limitation for I64 input of W&D model

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Cleanup where operation to fix affect of WhereDecomposition transform

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix BOM

* Correct EmbeddingSegmentSum transform for Wide and Deep

Add casting segment ids to i32 and remove ConstToResult sub-graph.

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Update BOM with RemoveConstToResult transform

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Add more comments for RemoveConstToResult transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Remove useless logging in EmbeddingSegmentsSum transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Small fixes

* Move EmbeddingBag resolving back to front phase

* Improve error messages

* Fix typo in unittests

* Reimplement sparse_reshape middle transform

Avoid deprecated API.

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Clean-up graph after sparse_reshape and ConstToResult transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix clean-up for transformations

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix clean-up for transformation #2

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
19 files changed:
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/back/SpecialNodesFinalization.py
model-optimizer/extensions/front/ATenToEmbeddingBag.py
model-optimizer/extensions/front/ATenToEmbeddingBag_test.py
model-optimizer/extensions/front/onnx/aten_ext.py
model-optimizer/extensions/front/tf/WhereDecomposition.py
model-optimizer/extensions/front/tf/embedding_segments_sum.py [new file with mode: 0644]
model-optimizer/extensions/front/tf/embedding_segments_sum_test.py [new file with mode: 0644]
model-optimizer/extensions/front/tf/sparse_weighted_sum.py [deleted file]
model-optimizer/extensions/front/tf/sparse_weighted_sum_test.py [deleted file]
model-optimizer/extensions/middle/EmbeddingBagResolver.py [deleted file]
model-optimizer/extensions/middle/EmbeddingBagResolver_test.py [deleted file]
model-optimizer/extensions/middle/sparse_reshape.py
model-optimizer/extensions/ops/aten.py
model-optimizer/extensions/ops/embedding_bag.py
model-optimizer/extensions/ops/embedding_bag_test.py [new file with mode: 0644]
model-optimizer/extensions/ops/sparse_reshape.py
model-optimizer/extensions/ops/sparse_weighted_sum.py [deleted file]
model-optimizer/extensions/ops/sparse_weighted_sum_test.py [deleted file]