bool fuse_type_to_topk(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
+bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_fuse {
{opset4::Parameter::type_info, fuse_type_to_parameter},
{opset4::TopK::type_info, fuse_type_to_topk},
{opset4::NonZero::type_info, fuse_type_to_nonzero},
{opset4::Bucketize::type_info, fuse_type_to_bucketize},
+ {NodeTypeInfo("GenericIE", 1), fuse_type_to_generic_ie},
};
bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f) {
std::map<std::shared_ptr<Node>, std::vector<Input<Node>>> const_to_internal_output;
std::function<void(const std::shared_ptr<Function> &)> register_constants =
- [&const_to_internal_output](const std::shared_ptr<Function> & f) {
+ [&const_to_internal_output, ®ister_constants](const std::shared_ptr<Function> & f) {
for (auto & node : f->get_ordered_ops()) {
+ // Recursively run for TensorIterator body function
+ if (auto ti = std::dynamic_pointer_cast<opset4::TensorIterator>(node)) {
+ register_constants(ti->get_body()->to_function());
+ }
for (auto & input : node->inputs()) {
if (auto const_node = std::dynamic_pointer_cast<opset4::Constant>(input.get_source_output().get_node_shared_ptr())) {
const_to_internal_output[const_node].emplace_back(input);
};
std::function<void(const std::shared_ptr<Function> &)> convert_function_precision =
- [this, &const_to_internal_output, &convert_node_precision](const std::shared_ptr<Function> & f) {
+ [this, &const_to_internal_output, &convert_node_precision, &convert_function_precision](const std::shared_ptr<Function> & f) {
// Iterate over all nodes in topological order and then iterate over node outputs.
// If output type mismatch given type we try to fuse type into this operation
// otherwise we insert Convert operation.
for (auto &node : f->get_ordered_ops()) {
+ // Recursively run for TensorIterator body function
+ if (auto ti = std::dynamic_pointer_cast<opset4::TensorIterator>(node)) {
+ convert_function_precision(ti->get_body()->to_function());
+ }
convert_node_precision(node);
}
};
return false;
}
+bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+ node->set_output_type(idx, to, node->output(idx).get_partial_shape());
+ return true;
+}
+
template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant> & constant) {
using src_type = typename element_type_traits<PREC_FROM>::value_type;
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
ASSERT_FALSE(has_type<ngraph::element::Type_t::i64>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_TIBody) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto X = std::make_shared<opset4::Parameter>(element::f16, Shape{2, 1, 16});
+ auto Y = std::make_shared<opset4::Parameter>(element::f16, Shape{1, 128});
+
+ auto Xi = std::make_shared<opset4::Parameter>(element::f16, Shape{1, 1, 16});
+ auto Yi = std::make_shared<opset4::Parameter>(element::f16, Shape{1, 128});
+
+ // Body
+ auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
+ auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
+
+ auto w_val = std::vector<float>(384 * 16, 0);
+ auto r_val = std::vector<float>(384 * 128, 0);
+ auto b_val = std::vector<float>(384, 0);
+ auto W = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{384, 16}, w_val);
+ auto R = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{384, 128}, r_val);
+ auto B = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{384}, b_val);
+
+ auto gru_cell = std::make_shared<opset4::GRUCell>(squeeze, Yi, W, R, B, 128);
+ auto res_1 = std::make_shared<opset4::Result>(gru_cell);
+ auto unsqueeze = std::make_shared<opset4::Unsqueeze>(gru_cell, axis);
+ auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
+ auto body = std::make_shared<opset4::TensorIterator::BodyLambda>(OutputVector{res_1, res_2},
+ ParameterVector{Xi, Yi});
+
+ auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
+ tensor_iterator->set_body(body);
+
+ tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 0);
+ tensor_iterator->set_merged_input(Yi, Y, res_1);
+
+ auto out0 = tensor_iterator->get_iter_value(res_1, -1);
+ auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0);
+
+ auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
+ //auto res_ti_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
+ ngraph::ParameterVector{X, Y});
+
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+ manager.run_passes(f);
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::i64>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(tensor_iterator->get_body()->to_function()));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::i64>(tensor_iterator->get_body()->to_function()));
+ }
}
\ No newline at end of file