Add CTCLoss op to nGraph Python API (#1642)
authorRoman Kazantsev <roman.kazantsev@intel.com>
Thu, 6 Aug 2020 12:03:39 +0000 (15:03 +0300)
committerGitHub <noreply@github.com>
Thu, 6 Aug 2020 12:03:39 +0000 (15:03 +0300)
ngraph/python/src/ngraph/__init__.py
ngraph/python/src/ngraph/opset4/__init__.py
ngraph/python/src/ngraph/opset4/ops.py
ngraph/python/tests/test_ngraph/test_ctc_loss.py [new file with mode: 0644]

index ebb8554..0945ff9 100644 (file)
@@ -51,6 +51,7 @@ from ngraph.opset4 import convolution_backprop_data
 from ngraph.opset4 import cos
 from ngraph.opset4 import cosh
 from ngraph.opset4 import ctc_greedy_decoder
+from ngraph.opset4 import ctc_loss
 from ngraph.opset4 import cum_sum
 from ngraph.opset4 import cum_sum as cumsum
 from ngraph.opset4 import deformable_convolution
index 31c89ba..c7f5d2b 100644 (file)
@@ -40,6 +40,7 @@ from ngraph.opset1.ops import convolution_backprop_data
 from ngraph.opset1.ops import cos
 from ngraph.opset1.ops import cosh
 from ngraph.opset1.ops import ctc_greedy_decoder
+from ngraph.opset4.ops import ctc_loss
 from ngraph.opset3.ops import cum_sum
 from ngraph.opset3.ops import cum_sum as cumsum
 from ngraph.opset1.ops import deformable_convolution
index 69c3808..f988822 100644 (file)
@@ -59,6 +59,44 @@ _get_node_factory_opset4 = partial(_get_node_factory, "opset4")
 
 
 @nameable_op
+def ctc_loss(
+    logits: NodeInput,
+    logit_length: NodeInput,
+    labels: NodeInput,
+    label_length: NodeInput,
+    blank_index: Optional[NodeInput] = None,
+    preprocess_collapse_repeated: bool = False,
+    ctc_merge_repeated: bool = True,
+    unique: bool = False,
+    name: Optional[str] = None,
+) -> Node:
+    """Return a node which performs CTCLoss.
+
+    :param logits:                        3-D tensor of logits.
+    :param logit_length:                  1-D tensor of lengths for each object from a batch.
+    :param labels:                        2-D tensor of labels for which likelihood is estimated using logits.
+    :param label_length:                  1-D tensor of length for each label sequence.
+    :param blank_index:                   Scalar used to mark a blank index.
+    :param preprocess_collapse_repeated:  Flag for preprocessing labels before loss calculation.
+    :param ctc_merge_repeated:            Flag for merging repeated characters in a potential alignment.
+    :param unique:                        Flag to find unique elements in a target.
+    :return: The new node which performs CTCLoss
+    """
+    if blank_index is not None:
+        inputs = as_nodes(logits, logit_length, labels, label_length, blank_index)
+    else:
+        inputs = as_nodes(logits, logit_length, labels, label_length)
+
+    attributes = {
+        "preprocess_collapse_repeated": preprocess_collapse_repeated,
+        "ctc_merge_repeated": ctc_merge_repeated,
+        "unique": unique,
+    }
+
+    return _get_node_factory_opset4().create("CTCLoss", inputs, attributes)
+
+
+@nameable_op
 def non_max_suppression(
     boxes: NodeInput,
     scores: NodeInput,
diff --git a/ngraph/python/tests/test_ngraph/test_ctc_loss.py b/ngraph/python/tests/test_ngraph/test_ctc_loss.py
new file mode 100644 (file)
index 0000000..3b53f25
--- /dev/null
@@ -0,0 +1,39 @@
+# ******************************************************************************
+# Copyright 2017-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+
+import ngraph as ng
+from ngraph.impl import Type
+
+
+def test_ctc_loss_props():
+    ind_dtype = np.int32
+    float_dtype = np.float32
+    logits = ng.parameter([2, 100, 80], dtype=float_dtype, name="logits")
+    logit_length = ng.parameter([2], dtype=ind_dtype, name="logit_length")
+    labels = ng.parameter([2, 100], dtype=ind_dtype, name="labels")
+    label_length = ng.parameter([2], dtype=ind_dtype, name="label_length")
+    blank_index = ng.parameter([], dtype=ind_dtype, name="blank_index")
+    preprocess_collapse_repeated = False
+    ctc_merge_repeated = True
+    unique = False
+
+    node = ng.ctc_loss(logits, logit_length, labels, label_length, blank_index,
+                       preprocess_collapse_repeated, ctc_merge_repeated, unique)
+    assert node.get_type_name() == "CTCLoss"
+    assert node.get_output_size() == 1
+    assert list(node.get_output_shape(0)) == [2]
+    assert node.get_output_element_type(0) == Type.f32