using namespace std;
using namespace ngraph;
-constexpr NodeTypeInfo op::CTCLoss::type_info;
-
-op::CTCLoss::CTCLoss(const Output<Node>& logits,
- const Output<Node>& logit_length,
- const Output<Node>& labels,
- const Output<Node>& label_length,
- const Output<Node>& blank_index,
- const bool preprocess_collapse_repeated,
- const bool ctc_merge_repeated,
- const bool unique)
+constexpr NodeTypeInfo op::v4::CTCLoss::type_info;
+
+op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
+ const Output<Node>& logit_length,
+ const Output<Node>& labels,
+ const Output<Node>& label_length,
+ const bool preprocess_collapse_repeated,
+ const bool ctc_merge_repeated,
+ const bool unique)
+ : Op({logits, logit_length, labels, label_length})
+ , preprocess_collapse_repeated_(preprocess_collapse_repeated)
+ , ctc_merge_repeated_(ctc_merge_repeated)
+ , unique_(unique)
+{
+ constructor_validate_and_infer_types();
+}
+
+op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
+ const Output<Node>& logit_length,
+ const Output<Node>& labels,
+ const Output<Node>& label_length,
+ const Output<Node>& blank_index,
+ const bool preprocess_collapse_repeated,
+ const bool ctc_merge_repeated,
+ const bool unique)
: Op({logits, logit_length, labels, label_length, blank_index})
, preprocess_collapse_repeated_(preprocess_collapse_repeated)
, ctc_merge_repeated_(ctc_merge_repeated)
constructor_validate_and_infer_types();
}
-void op::CTCLoss::validate_and_infer_types()
+void op::v4::CTCLoss::validate_and_infer_types()
{
// check types of input tensors
const auto& logits_type = get_input_element_type(0);
const auto& logit_length_type = get_input_element_type(1);
const auto& labels_type = get_input_element_type(2);
const auto& label_length_type = get_input_element_type(3);
- const auto& blank_index_type = get_input_element_type(4);
NODE_VALIDATION_CHECK(this,
logits_type.is_real(),
"The label length type is expected to be an integer type. Got: ",
label_length_type);
- NODE_VALIDATION_CHECK(this,
- blank_index_type.is_integral_number(),
- "The blank index type is expected to be an integer type. Got: ",
- blank_index_type);
+ // check optional input type: blank index
+ if (get_input_size() == 5)
+ {
+ const auto& blank_index_type = get_input_element_type(4);
+ NODE_VALIDATION_CHECK(this,
+ blank_index_type.is_integral_number(),
+ "The blank index type is expected to be an integer type. Got: ",
+ blank_index_type);
+ }
// check ranks of input tensors
const auto& logits_pshape = get_input_partial_shape(0);
const auto& logit_length_pshape = get_input_partial_shape(1);
const auto& labels_pshape = get_input_partial_shape(2);
const auto& label_length_pshape = get_input_partial_shape(3);
- const auto& blank_index_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
logits_pshape.rank().compatible(3),
"Expected a 1D tensor for label length. Got: ",
label_length_pshape);
- NODE_VALIDATION_CHECK(this,
- blank_index_pshape.rank().compatible(0),
- "Expected a scalar for blank index. Got: ",
- blank_index_pshape);
+ // check optional input shape: blank index
+ if (get_input_size() == 5)
+ {
+ const auto& blank_index_pshape = get_input_partial_shape(4);
+ NODE_VALIDATION_CHECK(this,
+ blank_index_pshape.rank().compatible(0),
+ "Expected a scalar for blank index. Got: ",
+ blank_index_pshape);
+ }
// check shapes of input tensors
size_t batch_size = 1;
}
}
-bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
+bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_);
visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_);
return true;
}
-shared_ptr<Node> op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
+shared_ptr<Node> op::v4::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
- return make_shared<CTCLoss>(new_args.at(0),
- new_args.at(1),
- new_args.at(2),
- new_args.at(3),
- new_args.at(4),
- preprocess_collapse_repeated_,
- ctc_merge_repeated_,
- unique_);
+ if (new_args.size() == 4)
+ {
+ return make_shared<CTCLoss>(new_args.at(0),
+ new_args.at(1),
+ new_args.at(2),
+ new_args.at(3),
+ preprocess_collapse_repeated_,
+ ctc_merge_repeated_,
+ unique_);
+ }
+ else if (new_args.size() == 5)
+ {
+ return make_shared<CTCLoss>(new_args.at(0),
+ new_args.at(1),
+ new_args.at(2),
+ new_args.at(3),
+ new_args.at(4),
+ preprocess_collapse_repeated_,
+ ctc_merge_repeated_,
+ unique_);
+ }
+ else
+ {
+ throw ngraph_error("Incorrect number of arguments");
+ }
}