Fix for CTCLoss in NGraph (#1563)
authorRoman Kazantsev <roman.kazantsev@intel.com>
Fri, 31 Jul 2020 08:57:29 +0000 (11:57 +0300)
committerGitHub <noreply@github.com>
Fri, 31 Jul 2020 08:57:29 +0000 (11:57 +0300)
Blank index is optional input and must be handled appropriately

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
inference-engine/src/readers/ir_reader/ie_ir_parser.cpp
ngraph/src/ngraph/op/ctc_loss.cpp
ngraph/src/ngraph/op/ctc_loss.hpp
ngraph/test/type_prop/ctc_loss.cpp

index e3ddad0..6e5c856 100644 (file)
@@ -369,7 +369,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
 
     // Check that operation in default opsets
     auto isDefaultOpSet = [](const std::string& version) -> bool {
-        for (size_t i = 1; i <= 3; i++) {
+        for (size_t i = 1; i <= 4; i++) {
             std::string opset_name = "opset" + std::to_string(i);
             if (version == opset_name)
                 return true;
index 05d6b90..d003663 100644 (file)
 using namespace std;
 using namespace ngraph;
 
-constexpr NodeTypeInfo op::CTCLoss::type_info;
-
-op::CTCLoss::CTCLoss(const Output<Node>& logits,
-                     const Output<Node>& logit_length,
-                     const Output<Node>& labels,
-                     const Output<Node>& label_length,
-                     const Output<Node>& blank_index,
-                     const bool preprocess_collapse_repeated,
-                     const bool ctc_merge_repeated,
-                     const bool unique)
+constexpr NodeTypeInfo op::v4::CTCLoss::type_info;
+
+op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
+                         const Output<Node>& logit_length,
+                         const Output<Node>& labels,
+                         const Output<Node>& label_length,
+                         const bool preprocess_collapse_repeated,
+                         const bool ctc_merge_repeated,
+                         const bool unique)
+    : Op({logits, logit_length, labels, label_length})
+    , preprocess_collapse_repeated_(preprocess_collapse_repeated)
+    , ctc_merge_repeated_(ctc_merge_repeated)
+    , unique_(unique)
+{
+    constructor_validate_and_infer_types();
+}
+
+op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
+                         const Output<Node>& logit_length,
+                         const Output<Node>& labels,
+                         const Output<Node>& label_length,
+                         const Output<Node>& blank_index,
+                         const bool preprocess_collapse_repeated,
+                         const bool ctc_merge_repeated,
+                         const bool unique)
     : Op({logits, logit_length, labels, label_length, blank_index})
     , preprocess_collapse_repeated_(preprocess_collapse_repeated)
     , ctc_merge_repeated_(ctc_merge_repeated)
@@ -37,14 +52,13 @@ op::CTCLoss::CTCLoss(const Output<Node>& logits,
     constructor_validate_and_infer_types();
 }
 
-void op::CTCLoss::validate_and_infer_types()
+void op::v4::CTCLoss::validate_and_infer_types()
 {
     // check types of input tensors
     const auto& logits_type = get_input_element_type(0);
     const auto& logit_length_type = get_input_element_type(1);
     const auto& labels_type = get_input_element_type(2);
     const auto& label_length_type = get_input_element_type(3);
-    const auto& blank_index_type = get_input_element_type(4);
 
     NODE_VALIDATION_CHECK(this,
                           logits_type.is_real(),
@@ -66,17 +80,21 @@ void op::CTCLoss::validate_and_infer_types()
                           "The label length type is expected to be an integer type. Got: ",
                           label_length_type);
 
-    NODE_VALIDATION_CHECK(this,
-                          blank_index_type.is_integral_number(),
-                          "The blank index type is expected to be an integer type. Got: ",
-                          blank_index_type);
+    // check optional input type: blank index
+    if (get_input_size() == 5)
+    {
+        const auto& blank_index_type = get_input_element_type(4);
+        NODE_VALIDATION_CHECK(this,
+                              blank_index_type.is_integral_number(),
+                              "The blank index type is expected to be an integer type. Got: ",
+                              blank_index_type);
+    }
 
     // check ranks of input tensors
     const auto& logits_pshape = get_input_partial_shape(0);
     const auto& logit_length_pshape = get_input_partial_shape(1);
     const auto& labels_pshape = get_input_partial_shape(2);
     const auto& label_length_pshape = get_input_partial_shape(3);
-    const auto& blank_index_pshape = get_input_partial_shape(4);
 
     NODE_VALIDATION_CHECK(this,
                           logits_pshape.rank().compatible(3),
@@ -98,10 +116,15 @@ void op::CTCLoss::validate_and_infer_types()
                           "Expected a 1D tensor for label length. Got: ",
                           label_length_pshape);
 
-    NODE_VALIDATION_CHECK(this,
-                          blank_index_pshape.rank().compatible(0),
-                          "Expected a scalar for blank index. Got: ",
-                          blank_index_pshape);
+    // check optional input shape: blank index
+    if (get_input_size() == 5)
+    {
+        const auto& blank_index_pshape = get_input_partial_shape(4);
+        NODE_VALIDATION_CHECK(this,
+                              blank_index_pshape.rank().compatible(0),
+                              "Expected a scalar for blank index. Got: ",
+                              blank_index_pshape);
+    }
 
     // check shapes of input tensors
     size_t batch_size = 1;
@@ -204,7 +227,7 @@ void op::CTCLoss::validate_and_infer_types()
     }
 }
 
-bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
+bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor)
 {
     visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_);
     visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_);
@@ -212,15 +235,32 @@ bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
     return true;
 }
 
-shared_ptr<Node> op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
+shared_ptr<Node> op::v4::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
 {
     check_new_args_count(this, new_args);
-    return make_shared<CTCLoss>(new_args.at(0),
-                                new_args.at(1),
-                                new_args.at(2),
-                                new_args.at(3),
-                                new_args.at(4),
-                                preprocess_collapse_repeated_,
-                                ctc_merge_repeated_,
-                                unique_);
+    if (new_args.size() == 4)
+    {
+        return make_shared<CTCLoss>(new_args.at(0),
+                                    new_args.at(1),
+                                    new_args.at(2),
+                                    new_args.at(3),
+                                    preprocess_collapse_repeated_,
+                                    ctc_merge_repeated_,
+                                    unique_);
+    }
+    else if (new_args.size() == 5)
+    {
+        return make_shared<CTCLoss>(new_args.at(0),
+                                    new_args.at(1),
+                                    new_args.at(2),
+                                    new_args.at(3),
+                                    new_args.at(4),
+                                    preprocess_collapse_repeated_,
+                                    ctc_merge_repeated_,
+                                    unique_);
+    }
+    else
+    {
+        throw ngraph_error("Incorrect number of arguments");
+    }
 }
index 82518e7..0fa2f5c 100644 (file)
@@ -50,6 +50,14 @@ namespace ngraph
                         const Output<Node>& logit_length,
                         const Output<Node>& labels,
                         const Output<Node>& label_length,
+                        const bool preprocess_collapse_repeated = false,
+                        const bool ctc_merge_repeated = true,
+                        const bool unique = false);
+
+                CTCLoss(const Output<Node>& logits,
+                        const Output<Node>& logit_length,
+                        const Output<Node>& labels,
+                        const Output<Node>& label_length,
                         const Output<Node>& blank_index,
                         const bool preprocess_collapse_repeated = false,
                         const bool ctc_merge_repeated = true,
@@ -72,6 +80,5 @@ namespace ngraph
                 bool unique_;
             };
         }
-        using v4::CTCLoss;
     }
 }
index 07333a5..2b2cc6f 100644 (file)
@@ -39,6 +39,22 @@ TEST(type_prop, ctc_loss)
     EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
 }
 
+TEST(type_prop, ctc_loss_no_blank_index)
+{
+    // create inputs
+    auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
+    auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
+    auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
+    auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
+
+    // create CTCLoss node
+    auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length);
+
+    // check type and shape infer
+    EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
+    EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
+}
+
 TEST(type_prop, ctc_loss_output_type)
 {
     // create inputs