Remove ngraph::Lambda class, replace TensorIterator body with ngraph::Function (...
authorIvan Tikhonov <ivan.tikhonov@intel.com>
Wed, 19 Aug 2020 04:09:32 +0000 (07:09 +0300)
committerGitHub <noreply@github.com>
Wed, 19 Aug 2020 04:09:32 +0000 (07:09 +0300)
* remove Lambda class, replace TensorIterator body with ngraph::Function

* Fix passing parameters from parent graph to subgraph

Co-authored-by: mbencer <mateusz.bencer@intel.com>
24 files changed:
inference-engine/src/inference_engine/generic_ie.cpp
inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp
inference-engine/src/plugin_api/generic_ie.hpp
inference-engine/src/readers/ir_reader/ie_ir_parser.cpp
inference-engine/src/transformations/src/transformations/convert_precision.cpp
inference-engine/src/transformations/src/transformations/tensor_iterator_transformations/apply_transformations_to_ti_body.cpp
inference-engine/src/transformations/src/transformations/tensor_iterator_transformations/unroll_tensor_iterator.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp
inference-engine/tests/functional/inference_engine/transformations/unroll_tensor_iterator_test.cpp
inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp
inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp
ngraph/core/include/ngraph/function.hpp
ngraph/core/include/ngraph/lambda.hpp [deleted file]
ngraph/core/include/ngraph/ngraph.hpp
ngraph/core/include/ngraph/node.hpp
ngraph/core/include/ngraph/op/tensor_iterator.hpp
ngraph/core/src/function.cpp
ngraph/core/src/lambda.cpp [deleted file]
ngraph/core/src/op/tensor_iterator.cpp
ngraph/frontend/onnx_import/include/onnx_import/core/graph.hpp
ngraph/frontend/onnx_import/src/core/graph.cpp
ngraph/frontend/onnx_import/src/op/loop.cpp
ngraph/python/src/pyngraph/tensor_iterator_builder.cpp
ngraph/python/src/pyngraph/tensor_iterator_builder.hpp

index 8240230..9e83e41 100644 (file)
@@ -21,7 +21,7 @@
 
 constexpr ::ngraph::NodeTypeInfo ngraph::op::GenericIE::type_info;
 
-void ngraph::op::GenericIE::addExtension(std::shared_ptr<const ngraph::Lambda> func,
+void ngraph::op::GenericIE::addExtension(std::shared_ptr<const ngraph::Function> func,
                                          const InferenceEngine::IShapeInferExtensionPtr& ext) {
     NodeVector nodes;
 
index 7c8565a..8de9a04 100644 (file)
@@ -153,7 +153,7 @@ CNNLayer::Ptr NodeConverter<ngraph::op::TensorIterator>::createLayer(const std::
     // This map will save information about data nodes
     std::map<std::string, std::vector<TensorDesc>> layer_name_to_tensor_desc;
     {
-        CNNNetwork body_net(tensor_iterator->get_body()->to_function());
+        CNNNetwork body_net(tensor_iterator->get_body());
         CNNNetwork net(InferenceEngine::details::convertFunctionToICNNNetwork(body_net.getFunction(), body_net));
         // Paranoid check for cycles
         bool res = CNNNetForestDFS(
index f74725a..078b38b 100644 (file)
@@ -98,7 +98,7 @@ public:
 
     std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
 
-    static void addExtension(std::shared_ptr<const ngraph::Lambda> func, const InferenceEngine::IShapeInferExtensionPtr& ext);
+    static void addExtension(std::shared_ptr<const ngraph::Function> func, const InferenceEngine::IShapeInferExtensionPtr& ext);
     static std::vector<InferenceEngine::IShapeInferExtensionPtr> getExtensions(std::shared_ptr<const ngraph::Function> func);
 
     const std::string& getType() const {
index 764ca79..6a3b42c 100644 (file)
@@ -560,14 +560,14 @@ std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::TensorIterator
         }
     }
 
-    // Create ngraph::Function, convert it to ngraph::BodyLambda and set it as TensorIterator body
+    // Create ngraph::Function and set it as body of TensorIterator layer
     IRParser parser(10);
     auto ngraph_function = parser.parse(node.child("body"), binStream)->getFunction();
     auto parameter_nodes = ngraph_function->get_parameters();
     auto result_nodes = ngraph_function->get_results();
     // Disabled reshape for generic operations in the TI body
     ::ngraph::op::GenericIE::DisableReshape noReshape(ngraph_function);
-    auto body = std::make_shared<ngraph::op::TensorIterator::BodyLambda>(result_nodes, parameter_nodes);
+    auto body = std::make_shared<ngraph::Function>(result_nodes, parameter_nodes);
     tensor_iterator->set_body(body);
 
     // Parse PortMap: inputs
index 250da61..830085f 100644 (file)
@@ -160,7 +160,7 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
         for (auto &node : f->get_ordered_ops()) {
             // Recursively run for TensorIterator body function
             if (auto ti = std::dynamic_pointer_cast<opset4::TensorIterator>(node)) {
-                convert_function_precision(ti->get_body()->to_function());
+                convert_function_precision(ti->get_body());
             }
             convert_node_input_precision(node);
         }
index 6eb22f1..c571aa7 100644 (file)
@@ -19,7 +19,7 @@ ngraph::pass::ApplyTransformationsToTIBody::ApplyTransformationsToTIBody(ngraph:
             return false;
         }
 
-        manager.run_passes(ti->get_body()->to_function());
+        manager.run_passes(ti->get_body());
         return true;
     };
 
index eb758da..ff53614 100644 (file)
@@ -21,7 +21,7 @@ ngraph::pass::UnrollTensorIterator::UnrollTensorIterator() : MatcherPass() {
             return false;
         }
 
-        const auto function = ti->get_body()->to_function();
+        const auto function = ti->get_body();
         auto num_iter = ti->get_num_iterations();
 
         // negative value means inconsistent TI
index 418f460..6671f34 100644 (file)
@@ -261,7 +261,7 @@ TEST(TransformationTests, ConvertPrecision_TIBody) {
         auto res_1 = std::make_shared<opset4::Result>(gru_cell);
         auto unsqueeze = std::make_shared<opset4::Unsqueeze>(gru_cell, axis);
         auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
-        auto body = std::make_shared<opset4::TensorIterator::BodyLambda>(OutputVector{res_1, res_2},
+        auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
                                                                          ParameterVector{Xi, Yi});
 
         auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
@@ -285,8 +285,8 @@ TEST(TransformationTests, ConvertPrecision_TIBody) {
 
         ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
         ASSERT_FALSE(has_type<ngraph::element::Type_t::i64>(f));
-        ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(tensor_iterator->get_body()->to_function()));
-        ASSERT_FALSE(has_type<ngraph::element::Type_t::i64>(tensor_iterator->get_body()->to_function()));
+        ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(tensor_iterator->get_body()));
+        ASSERT_FALSE(has_type<ngraph::element::Type_t::i64>(tensor_iterator->get_body()));
     }
 }
 
index 7d9d0d4..d19024c 100644 (file)
@@ -46,7 +46,7 @@ TEST(TransformationTests, UnrollTensorIteratorGRUCell) {
         auto res_1 = std::make_shared<opset4::Result>(gru_cell);
         auto unsqueeze = std::make_shared<opset4::Unsqueeze>(gru_cell, axis);
         auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
-        auto body = std::make_shared<opset4::TensorIterator::BodyLambda>(OutputVector{res_1, res_2},
+        auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
                                                                          ParameterVector{Xi, Yi});
 
         auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
@@ -128,7 +128,7 @@ TEST(TransformationTests, UnrollTensorIteratorRNNCell) {
         auto res_1 = std::make_shared<opset4::Result>(rnn_cell);
         auto unsqueeze = std::make_shared<opset4::Unsqueeze>(rnn_cell, axis);
         auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
-        auto body = std::make_shared<opset4::TensorIterator::BodyLambda>(OutputVector{res_1, res_2},
+        auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
                                                                          ParameterVector{Xi, Yi});
 
         auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
@@ -212,7 +212,7 @@ TEST(TransformationTests, UnrollTensorIteratorLSTMCell) {
         auto res_1 = std::make_shared<opset4::Result>(lstm_cell);
         auto unsqueeze = std::make_shared<opset4::Unsqueeze>(lstm_cell, axis);
         auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
-        auto body = std::make_shared<opset4::TensorIterator::BodyLambda>(OutputVector{res_1, res_2},
+        auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
                                                                          ParameterVector{Xi, Yi, Zi});
 
         auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
@@ -296,7 +296,7 @@ TEST(TransformationTests, UnrollTensorIteratorGRUCellSingleIteration) {
         auto res_1 = std::make_shared<opset4::Result>(gru_cell);
         auto unsqueeze = std::make_shared<opset4::Unsqueeze>(gru_cell, axis);
         auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
-        auto body = std::make_shared<opset4::TensorIterator::BodyLambda>(OutputVector{res_1, res_2},
+        auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
                                                                          ParameterVector{Xi, Yi});
 
         auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
@@ -372,7 +372,7 @@ TEST(TransformationTests, UnrollTensorIteratorRNNCellSingleIteration) {
         auto res_1 = std::make_shared<opset4::Result>(rnn_cell);
         auto unsqueeze = std::make_shared<opset4::Unsqueeze>(rnn_cell, axis);
         auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
-        auto body = std::make_shared<opset4::TensorIterator::BodyLambda>(OutputVector{res_1, res_2},
+        auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
                                                                          ParameterVector{Xi, Yi});
 
         auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
@@ -449,7 +449,7 @@ TEST(TransformationTests, UnrollTensorIteratorLSTMCellSingleIterationSingleItera
         auto res_1 = std::make_shared<opset4::Result>(lstm_cell);
         auto unsqueeze = std::make_shared<opset4::Unsqueeze>(lstm_cell, axis);
         auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
-        auto body = std::make_shared<opset4::TensorIterator::BodyLambda>(OutputVector{res_1, res_2},
+        auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
                                                                          ParameterVector{Xi, Yi, Zi});
 
         auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
index 9b0c868..8e1630c 100644 (file)
@@ -80,7 +80,7 @@ void Basic_LSTM_S::SetUp() {
     auto C_o = lstm1->output(1);
 
     //TensorIterator [1, 10, 49] [1, 118], [1, 118] -> [1, 118]
-    auto body = std::make_shared<ngraph::opset1::TensorIterator::BodyLambda>(
+    auto body = std::make_shared<ngraph::Function>(
         ngraph::OutputVector{ H_o, C_o }, ngraph::ParameterVector{ X, H_t, C_t });
 
     auto tensor_iterator = std::make_shared<ngraph::opset1::TensorIterator>();
index 87d28b0..57b90b3 100644 (file)
@@ -140,7 +140,7 @@ static std::shared_ptr<ngraph::Function> makeTIwithLSTMcell(InferenceEngine::Pre
     auto constantHo = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{3}, inShape);
     auto H_o = std::make_shared<ngraph::opset1::Reshape>(LSTM_cell->output(0), constantHo, false);
     auto C_o = std::make_shared<ngraph::opset1::Reshape>(LSTM_cell->output(1), constantHo, false);
-    auto body = std::make_shared<ngraph::op::TensorIterator::BodyLambda>(
+    auto body = std::make_shared<ngraph::Function>(
             ngraph::OutputVector{H_o, C_o}, ngraph::ParameterVector{X, H_t, C_t});
 
     auto tensor_iterator = std::make_shared<ngraph::op::TensorIterator>();
index ff866cc..eaa880a 100644 (file)
@@ -23,7 +23,7 @@
 #include <string>
 #include <vector>
 
-#include "ngraph/lambda.hpp"
+#include "ngraph/ngraph_visibility.hpp"
 #include "ngraph/node.hpp"
 #include "ngraph/op/parameter.hpp"
 #include "ngraph/op/result.hpp"
@@ -31,7 +31,7 @@
 namespace ngraph
 {
     /// A user-defined function.
-    class NGRAPH_API Function : public Lambda
+    class NGRAPH_API Function
     {
     public:
         static constexpr DiscreteTypeInfo type_info{"Function", 0};
@@ -120,6 +120,24 @@ namespace ngraph
             const std::vector<std::shared_ptr<Node>>& root_nodes)>;
         void set_topological_sort(topological_sort_t);
 
+        virtual bool visit_attributes(AttributeVisitor& visitor);
+
+        /// Return the function parameters
+        const ParameterVector& get_parameters() const { return m_parameters; };
+        /// Return a list of function's outputs
+        const ResultVector& get_results() const { return m_results; };
+        /// Index for parameter, or -1
+        int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
+
+        /// Index for value or result referencing it, or -1
+        int64_t get_result_index(const Output<Node>& value) const;
+
+        /// \brief Evaluate the function on inputs, putting results in outputs.
+        /// \param outputs Tensors for the outputs to compute. One for each result
+        /// \param inputs Tensors for the inputs. One for each inputs.
+        bool evaluate(const HostTensorVector& output_tensors,
+                      const HostTensorVector& input_tensors) const;
+
     private:
         Function(const Function&) = delete;
         Function(const Function&&) = delete;
@@ -130,5 +148,22 @@ namespace ngraph
         const std::string m_unique_name;
         size_t m_placement{0};
         topological_sort_t m_topological_sorter;
+
+        ResultVector m_results;
+        ParameterVector m_parameters;
+    };
+
+    template <>
+    class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>> : public VisitorAdapter
+    {
+    public:
+        AttributeAdapter(std::shared_ptr<Function>& ref);
+
+        bool visit_attributes(AttributeVisitor& visitor) override;
+
+        static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<shared_ptr<Function>>", 0};
+        const DiscreteTypeInfo& get_type_info() const override { return type_info; }
+    protected:
+        std::shared_ptr<Function>& m_ref;
     };
 }
diff --git a/ngraph/core/include/ngraph/lambda.hpp b/ngraph/core/include/ngraph/lambda.hpp
deleted file mode 100644 (file)
index 301dbc7..0000000
+++ /dev/null
@@ -1,69 +0,0 @@
-//*****************************************************************************
-// Copyright 2017-2020 Intel Corporation
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//*****************************************************************************
-
-#pragma once
-
-#include "ngraph/ngraph_visibility.hpp"
-#include "ngraph/node.hpp"
-#include "ngraph/op/parameter.hpp"
-#include "ngraph/op/result.hpp"
-
-namespace ngraph
-{
-    class NGRAPH_API Lambda
-    {
-    public:
-        virtual ~Lambda() {}
-        static constexpr DiscreteTypeInfo type_info{"Lamdba", 0};
-        const DiscreteTypeInfo& get_type_info() const { return type_info; }
-        /// Return the function parameters
-        virtual bool visit_attributes(AttributeVisitor& visitor);
-        const ParameterVector& get_parameters() const { return m_parameters; };
-        /// Index for parameter, or -1
-        int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
-        /// Return a list of function's outputs
-        const ResultVector& get_results() const { return m_results; };
-        /// Index for value or result referencing it, or -1
-        int64_t get_result_index(const Output<Node>& value) const;
-        /// \brief Evaluate the lambda on inputs, putting results in outputs.
-        /// \param outputs Tensors for the outputs to compute. One for each result
-        /// \param inputs Tensors for the inputs. One for each inputs.
-        bool evaluate(const HostTensorVector& output_tensors,
-                      const HostTensorVector& input_tensors) const;
-
-    protected:
-        Lambda() = default;
-        Lambda(const ResultVector& results, const ParameterVector& parameters);
-        Lambda(const OutputVector& results, const ParameterVector& parameters);
-
-        ResultVector m_results;
-        ParameterVector m_parameters;
-    };
-
-    template <>
-    class NGRAPH_API AttributeAdapter<std::shared_ptr<Lambda>> : public VisitorAdapter
-    {
-    public:
-        AttributeAdapter(std::shared_ptr<Lambda>& ref);
-
-        bool visit_attributes(AttributeVisitor& visitor) override;
-
-        static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<shared_ptr<Lambda>>", 0};
-        const DiscreteTypeInfo& get_type_info() const override { return type_info; }
-    protected:
-        std::shared_ptr<Lambda>& m_ref;
-    };
-}
index 4d000f6..82972d1 100644 (file)
@@ -78,7 +78,6 @@ namespace ngraph
 #include "ngraph/except.hpp"
 #include "ngraph/factory.hpp"
 #include "ngraph/function.hpp"
-#include "ngraph/lambda.hpp"
 #include "ngraph/node.hpp"
 #include "ngraph/ops.hpp"
 #include "ngraph/partial_shape.hpp"
index c55de14..8a9c92a 100644 (file)
@@ -100,6 +100,7 @@ namespace ngraph
     NGRAPH_API
     NodeVector as_node_vector(const OutputVector& values);
     /// Returns a ResultVector referencing values.
+    NGRAPH_API
     ResultVector as_result_vector(const OutputVector& values);
 
     /// Alias useful for cloning
index 2debfc6..3384cbb 100644 (file)
@@ -20,7 +20,6 @@
 
 #include "ngraph/factory_adapter.hpp"
 #include "ngraph/function.hpp"
-#include "ngraph/lambda.hpp"
 #include "ngraph/op/parameter.hpp"
 #include "ngraph/op/util/fused_op.hpp"
 
@@ -47,28 +46,6 @@ namespace ngraph
                 TensorIterator() = default;
                 TensorIterator(const OutputVector& values);
 
-                class NGRAPH_API BodyLambda : public Lambda
-                {
-                public:
-                    using type_info_t = DiscreteTypeInfo;
-                    static constexpr type_info_t type_info{"BodyLamdba", 0};
-                    const type_info_t& get_type_info() const { return type_info; }
-                    BodyLambda(const OutputVector& outputs, const ParameterVector& parameters)
-                        : Lambda(outputs, parameters)
-                    {
-                    }
-                    BodyLambda(const ResultVector& results, const ParameterVector& parameters)
-                        : Lambda(results, parameters)
-                    {
-                    }
-                    BodyLambda() = default;
-                    virtual bool visit_attributes(AttributeVisitor& visitor);
-                    std::shared_ptr<Function> to_function()
-                    {
-                        return std::make_shared<Function>(get_results(), get_parameters());
-                    }
-                };
-
                 /// \brief Describes a connection between a TensorIterator input and the body.
                 class InputDescription
                 {
@@ -333,9 +310,9 @@ namespace ngraph
                     clone_with_new_inputs(const OutputVector& new_args) const override;
                 OutputVector decompose_op() const override;
                 /// \return the body of the iteration
-                std::shared_ptr<BodyLambda> get_body() const { return m_body; }
+                std::shared_ptr<Function> get_body() const { return m_body; }
                 /// \param body set the body of the iteration
-                void set_body(const std::shared_ptr<BodyLambda>& body) { m_body = body; }
+                void set_body(const std::shared_ptr<Function>& body) { m_body = body; }
                 /// \return a reference to the input descriptions.
                 const std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions() const
                 {
@@ -374,7 +351,7 @@ namespace ngraph
                 // Find an input corresponding to value, adding one if necessary.
                 Input<Node> input_for_value(const Output<Node>& value);
 
-                std::shared_ptr<BodyLambda> m_body;
+                std::shared_ptr<Function> m_body;
                 std::vector<std::shared_ptr<InputDescription>> m_input_descriptions;
                 std::vector<std::shared_ptr<OutputDescription>> m_output_descriptions;
 
index 3c1d0f3..2be907c 100644 (file)
 #include <memory>
 
 #include "itt.hpp"
+#include "ngraph/factory_adapter.hpp"
 #include "ngraph/function.hpp"
 #include "ngraph/graph_util.hpp"
 #include "ngraph/log.hpp"
 #include "ngraph/op/util/op_types.hpp"
 #include "ngraph/util.hpp"
+#include "ngraph/validation_util.hpp"
 
 using namespace std;
 using namespace ngraph;
@@ -35,7 +37,8 @@ atomic<size_t> Function::m_next_instance_id(0);
 Function::Function(const ResultVector& results,
                    const ParameterVector& parameters,
                    const std::string& name)
-    : Lambda(results, parameters)
+    : m_results(results)
+    , m_parameters(parameters)
     , m_name(name)
     , m_unique_name("Function_" + to_string(m_next_instance_id.fetch_add(1)))
     , m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
@@ -46,7 +49,8 @@ Function::Function(const ResultVector& results,
 Function::Function(const OutputVector& results,
                    const ParameterVector& parameters,
                    const std::string& name)
-    : Lambda(results, parameters)
+    : m_results(as_result_vector(results))
+    , m_parameters(parameters)
     , m_name(name)
     , m_unique_name("Function_" + to_string(m_next_instance_id.fetch_add(1)))
     , m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
@@ -57,7 +61,8 @@ Function::Function(const OutputVector& results,
 Function::Function(const NodeVector& results,
                    const ParameterVector& parameters,
                    const std::string& name)
-    : Lambda(as_output_vector(results), parameters)
+    : m_results(as_result_vector(as_output_vector(results)))
+    , m_parameters(parameters)
     , m_name(name)
     , m_unique_name("Function_" + to_string(m_next_instance_id.fetch_add(1)))
     , m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
@@ -271,3 +276,337 @@ void Function::set_topological_sort(topological_sort_t sorter)
 {
     m_topological_sorter = sorter;
 }
+
+int64_t Function::get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const
+{
+    int64_t pos = 0;
+    for (auto p : get_parameters())
+    {
+        if (p == parameter)
+        {
+            return pos;
+        }
+        pos++;
+    }
+    return -1;
+}
+
+int64_t Function::get_result_index(const Output<Node>& value) const
+{
+    int64_t pos = 0;
+    if (is_type<op::Result>(value.get_node_shared_ptr()))
+    {
+        auto result = value.get_node_shared_ptr();
+        for (auto r : get_results())
+        {
+            if (r == result)
+            {
+                return pos;
+            }
+            pos++;
+        }
+    }
+    else
+    {
+        for (auto r : get_results())
+        {
+            if (r->input_value(0) == value)
+            {
+                return pos;
+            }
+            pos++;
+        }
+    }
+    return -1;
+}
+
+bool Function::evaluate(const HostTensorVector& output_tensors,
+                        const HostTensorVector& input_tensors) const
+{
+    std::map<RawNodeOutput, HostTensorPtr> value_map;
+    for (size_t i = 0; i < m_parameters.size(); ++i)
+    {
+        value_map[m_parameters.at(i)->output(0)] = input_tensors.at(i);
+    }
+    OutputVector outputs;
+    std::map<RawNodeOutput, HostTensorPtr> output_tensor_map;
+    for (size_t i = 0; i < m_results.size(); ++i)
+    {
+        auto result = m_results.at(i)->output(0);
+        output_tensor_map[result] = output_tensors.at(i);
+        outputs.push_back(result);
+    }
+    evaluate_nodes(value_map, output_tensor_map, outputs);
+    return true;
+}
+
+bool Function::visit_attributes(AttributeVisitor& visitor)
+{
+    visitor.on_attribute("parameters", m_parameters);
+    visitor.on_attribute("results", m_results);
+    return true;
+}
+
+constexpr DiscreteTypeInfo AttributeAdapter<shared_ptr<Function>>::type_info;
+
+AttributeAdapter<shared_ptr<Function>>::AttributeAdapter(shared_ptr<Function>& ref)
+    : m_ref(ref)
+{
+}
+
+class NodeAttributeAdapter : public FactoryAttributeAdapter<Node>
+{
+public:
+    using FactoryAttributeAdapter::FactoryAttributeAdapter;
+    bool on_start(AttributeVisitor& visitor) override
+    {
+        // Indicate that there is a node following
+        m_id = visitor.get_registered_node_id(m_ref);
+        m_set_id = (m_ref == nullptr);
+        visitor.on_attribute("id", m_id);
+        return m_ref == nullptr || m_id != AttributeVisitor::invalid_node_id;
+    }
+    bool on_finish(AttributeVisitor&) override
+    {
+        if (m_set_id && m_ref)
+        {
+            m_ref->set_friendly_name(m_id);
+        }
+        return true;
+    }
+    void visit(AttributeVisitor& visitor, const std::string& id)
+    {
+        visitor.start_structure(id);
+        visitor.on_adapter(id, *this);
+        visitor.finish_structure();
+    }
+    static constexpr DiscreteTypeInfo type_info{"Lambda.NodeAttributeAdapter", 0};
+    const DiscreteTypeInfo& get_type_info() const override { return type_info; }
+    string m_id;
+    bool m_set_id;
+};
+
+constexpr DiscreteTypeInfo NodeAttributeAdapter::type_info;
+
+bool AttributeAdapter<shared_ptr<Function>>::visit_attributes(AttributeVisitor& visitor)
+{
+    if (m_ref->get_results().size() > 0)
+    {
+        NodeVector serialized_nodes;
+        {
+            // Start with all nodes not already serialized
+            visitor.start_structure("nodes");
+            NodeVector results;
+            for (auto result : m_ref->get_results())
+            {
+                results.push_back(result);
+            }
+
+            int64_t i = 0;
+            ostringstream index;
+            traverse_nodes(
+                results, [&i, &index, &visitor, &serialized_nodes](shared_ptr<Node> node) -> void {
+                    if (AttributeVisitor::invalid_node_id == visitor.get_registered_node_id(node))
+                    {
+                        // This node hasn't been seen before
+                        visitor.register_node(node);
+                        index.str("");
+                        index << i++;
+                        string id = index.str();
+                        NodeAttributeAdapter adapter(node);
+                        adapter.visit(visitor, id);
+                        serialized_nodes.push_back(node);
+                    }
+                });
+            {
+                // Sentinel at end
+                index.str("");
+                index << i++;
+                string id = index.str();
+                shared_ptr<Node> null_node;
+                NodeAttributeAdapter adapter(null_node);
+                adapter.visit(visitor, id);
+            }
+            visitor.finish_structure();
+        }
+        {
+            // Now do all the edges
+            visitor.start_structure("edges");
+            int64_t i = 0;
+            ostringstream index;
+            for (auto node : serialized_nodes)
+            {
+                for (auto input : node->inputs())
+                {
+                    index.str("");
+                    index << i++;
+                    string id = index.str();
+                    visitor.start_structure(id);
+                    string input_node_id = visitor.get_registered_node_id(node);
+                    uint64_t input_index = input.get_index();
+                    visitor.on_attribute("input_node", input_node_id);
+                    visitor.on_attribute("input_index", input_index);
+                    auto output = input.get_source_output();
+                    string output_node_id =
+                        visitor.get_registered_node_id(output.get_node_shared_ptr());
+                    uint64_t output_index = output.get_index();
+                    visitor.on_attribute("output_node", output_node_id);
+                    visitor.on_attribute("output_index", output_index);
+                    visitor.finish_structure();
+                }
+            }
+            {
+                // Add a sentinel
+                index.str("");
+                index << i++;
+                string id = index.str();
+                visitor.start_structure(id);
+                string input_node_id = AttributeVisitor::invalid_node_id;
+                visitor.on_attribute("input_node", input_node_id);
+                visitor.finish_structure();
+            }
+            visitor.finish_structure();
+        }
+        {
+            // Control dependencies
+            visitor.start_structure("control");
+            int64_t i = 0;
+            ostringstream index;
+            for (auto node : serialized_nodes)
+            {
+                for (auto control : node->get_control_dependencies())
+                {
+                    index.str("");
+                    index << i++;
+                    string id = index.str();
+                    visitor.start_structure(id);
+                    string node_id = visitor.get_registered_node_id(node);
+                    string dependency_id = visitor.get_registered_node_id(control);
+                    visitor.on_attribute("node", node_id);
+                    visitor.on_attribute("dependency", dependency_id);
+                    visitor.finish_structure();
+                }
+            }
+            {
+                // Add a sentinel
+                index.str("");
+                index << i++;
+                string id = index.str();
+                visitor.start_structure(id);
+                string node_id = AttributeVisitor::invalid_node_id;
+                visitor.on_attribute("node", node_id);
+                visitor.finish_structure();
+            }
+            visitor.finish_structure();
+        }
+    }
+    else
+    {
+        NodeVector deserialized_nodes;
+        {
+            // Read the graph
+            visitor.start_structure("nodes");
+            int64_t i = 0;
+            ostringstream index;
+            while (true)
+            {
+                index.str("");
+                index << i++;
+                string id = index.str();
+                shared_ptr<Node> node;
+                NodeAttributeAdapter adapter(node);
+                adapter.visit(visitor, id);
+                if (node)
+                {
+                    visitor.register_node(node);
+                    deserialized_nodes.push_back(node);
+                }
+                else
+                {
+                    break;
+                }
+            }
+            visitor.finish_structure();
+        }
+        {
+            visitor.start_structure("edges");
+            // Connect the nodes
+            int64_t i = 0;
+            ostringstream index;
+            bool more_edges = true;
+            while (more_edges)
+            {
+                index.str("");
+                index << i++;
+                string id = index.str();
+                visitor.start_structure(id);
+                string input_node_id;
+                visitor.on_attribute("input_node", input_node_id);
+                if (!input_node_id.empty())
+                {
+                    shared_ptr<Node> input_node = visitor.get_registered_node(input_node_id);
+                    NGRAPH_CHECK(input_node, "input node of edge not known");
+                    uint64_t input_index;
+                    string output_node_id;
+                    uint64_t output_index;
+                    visitor.on_attribute("input_index", input_index);
+                    visitor.on_attribute("output_node", output_node_id);
+                    visitor.on_attribute("output_index", output_index);
+                    shared_ptr<Node> output_node = visitor.get_registered_node(output_node_id);
+                    NGRAPH_CHECK(output_node, "output_node of edge not known");
+                    input_node->set_argument(input_index, output_node->output(output_index));
+                }
+                else
+                {
+                    more_edges = false;
+                }
+                visitor.finish_structure();
+            }
+            visitor.finish_structure();
+        }
+        {
+            // Control dependencies
+            visitor.start_structure("control");
+            int64_t i = 0;
+            ostringstream index;
+            bool more_control = true;
+            while (more_control)
+            {
+                index.str("");
+                index << i++;
+                string id = index.str();
+                visitor.start_structure(id);
+                string node_id;
+                visitor.on_attribute("node", node_id);
+                if (!node_id.empty())
+                {
+                    shared_ptr<Node> node = visitor.get_registered_node(node_id);
+                    NGRAPH_CHECK(node, "node of control edge not known");
+                    string dependency_id;
+                    visitor.on_attribute("dependency", dependency_id);
+                    shared_ptr<Node> dependency = visitor.get_registered_node(dependency_id);
+                    NGRAPH_CHECK(dependency, "dependency of control edge not known");
+                    node->add_control_dependency(dependency);
+                }
+                else
+                {
+                    more_control = false;
+                }
+                visitor.finish_structure();
+            }
+            visitor.finish_structure();
+        }
+        for (auto node : topological_sort(deserialized_nodes))
+        {
+            node->validate_and_infer_types();
+        }
+    }
+
+    {
+        // Finally visit the object attributes
+        visitor.start_structure("value");
+        m_ref->visit_attributes(visitor);
+        visitor.finish_structure();
+    }
+    return true;
+}
diff --git a/ngraph/core/src/lambda.cpp b/ngraph/core/src/lambda.cpp
deleted file mode 100644 (file)
index ae8c9d3..0000000
+++ /dev/null
@@ -1,370 +0,0 @@
-//*****************************************************************************
-// Copyright 2017-2020 Intel Corporation
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//*****************************************************************************
-
-#include "ngraph/lambda.hpp"
-#include "ngraph/factory_adapter.hpp"
-#include "ngraph/graph_util.hpp"
-#include "ngraph/validation_util.hpp"
-
-using namespace std;
-using namespace ngraph;
-
-constexpr DiscreteTypeInfo Lambda::type_info;
-
-Lambda::Lambda(const OutputVector& results, const ParameterVector& parameters)
-    : Lambda(as_result_vector(results), parameters)
-{
-}
-
-Lambda::Lambda(const ResultVector& results, const ParameterVector& parameters)
-    : m_results(results)
-    , m_parameters(parameters)
-{
-}
-
-int64_t Lambda::get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const
-{
-    int64_t pos = 0;
-    for (auto p : get_parameters())
-    {
-        if (p == parameter)
-        {
-            return pos;
-        }
-        pos++;
-    }
-    return -1;
-}
-
-int64_t Lambda::get_result_index(const Output<Node>& value) const
-{
-    int64_t pos = 0;
-    if (is_type<op::Result>(value.get_node_shared_ptr()))
-    {
-        auto result = value.get_node_shared_ptr();
-        for (auto r : get_results())
-        {
-            if (r == result)
-            {
-                return pos;
-            }
-            pos++;
-        }
-    }
-    else
-    {
-        for (auto r : get_results())
-        {
-            if (r->input_value(0) == value)
-            {
-                return pos;
-            }
-            pos++;
-        }
-    }
-    return -1;
-}
-
-bool Lambda::evaluate(const HostTensorVector& output_tensors,
-                      const HostTensorVector& input_tensors) const
-{
-    std::map<RawNodeOutput, HostTensorPtr> value_map;
-    for (size_t i = 0; i < m_parameters.size(); ++i)
-    {
-        value_map[m_parameters.at(i)->output(0)] = input_tensors.at(i);
-    }
-    OutputVector outputs;
-    std::map<RawNodeOutput, HostTensorPtr> output_tensor_map;
-    for (size_t i = 0; i < m_results.size(); ++i)
-    {
-        auto result = m_results.at(i)->output(0);
-        output_tensor_map[result] = output_tensors.at(i);
-        outputs.push_back(result);
-    }
-    evaluate_nodes(value_map, output_tensor_map, outputs);
-    return true;
-}
-
-bool Lambda::visit_attributes(AttributeVisitor& visitor)
-{
-    visitor.on_attribute("parameters", m_parameters);
-    visitor.on_attribute("results", m_results);
-    return true;
-}
-
-constexpr DiscreteTypeInfo AttributeAdapter<shared_ptr<Lambda>>::type_info;
-
-AttributeAdapter<shared_ptr<Lambda>>::AttributeAdapter(shared_ptr<Lambda>& ref)
-    : m_ref(ref)
-{
-}
-
-class NodeAttributeAdapter : public FactoryAttributeAdapter<Node>
-{
-public:
-    using FactoryAttributeAdapter::FactoryAttributeAdapter;
-    bool on_start(AttributeVisitor& visitor) override
-    {
-        // Indicate that there is a node following
-        m_id = visitor.get_registered_node_id(m_ref);
-        m_set_id = (m_ref == nullptr);
-        visitor.on_attribute("id", m_id);
-        return m_ref == nullptr || m_id != AttributeVisitor::invalid_node_id;
-    }
-    bool on_finish(AttributeVisitor&) override
-    {
-        if (m_set_id && m_ref)
-        {
-            m_ref->set_friendly_name(m_id);
-        }
-        return true;
-    }
-    void visit(AttributeVisitor& visitor, const std::string& id)
-    {
-        visitor.start_structure(id);
-        visitor.on_adapter(id, *this);
-        visitor.finish_structure();
-    }
-    static constexpr DiscreteTypeInfo type_info{"Lambda.NodeAttributeAdapter", 0};
-    const DiscreteTypeInfo& get_type_info() const override { return type_info; }
-    string m_id;
-    bool m_set_id;
-};
-
-constexpr DiscreteTypeInfo NodeAttributeAdapter::type_info;
-
-bool AttributeAdapter<shared_ptr<Lambda>>::visit_attributes(AttributeVisitor& visitor)
-{
-    if (m_ref->get_results().size() > 0)
-    {
-        NodeVector serialized_nodes;
-        {
-            // Start with all nodes not already serialized
-            visitor.start_structure("nodes");
-            NodeVector results;
-            for (auto result : m_ref->get_results())
-            {
-                results.push_back(result);
-            }
-
-            int64_t i = 0;
-            ostringstream index;
-            traverse_nodes(
-                results, [&i, &index, &visitor, &serialized_nodes](shared_ptr<Node> node) -> void {
-                    if (AttributeVisitor::invalid_node_id == visitor.get_registered_node_id(node))
-                    {
-                        // This node hasn't been seen before
-                        visitor.register_node(node);
-                        index.str("");
-                        index << i++;
-                        string id = index.str();
-                        NodeAttributeAdapter adapter(node);
-                        adapter.visit(visitor, id);
-                        serialized_nodes.push_back(node);
-                    }
-                });
-            {
-                // Sentinel at end
-                index.str("");
-                index << i++;
-                string id = index.str();
-                shared_ptr<Node> null_node;
-                NodeAttributeAdapter adapter(null_node);
-                adapter.visit(visitor, id);
-            }
-            visitor.finish_structure();
-        }
-        {
-            // Now do all the edges
-            visitor.start_structure("edges");
-            int64_t i = 0;
-            ostringstream index;
-            for (auto node : serialized_nodes)
-            {
-                for (auto input : node->inputs())
-                {
-                    index.str("");
-                    index << i++;
-                    string id = index.str();
-                    visitor.start_structure(id);
-                    string input_node_id = visitor.get_registered_node_id(node);
-                    uint64_t input_index = input.get_index();
-                    visitor.on_attribute("input_node", input_node_id);
-                    visitor.on_attribute("input_index", input_index);
-                    auto output = input.get_source_output();
-                    string output_node_id =
-                        visitor.get_registered_node_id(output.get_node_shared_ptr());
-                    uint64_t output_index = output.get_index();
-                    visitor.on_attribute("output_node", output_node_id);
-                    visitor.on_attribute("output_index", output_index);
-                    visitor.finish_structure();
-                }
-            }
-            {
-                // Add a sentinel
-                index.str("");
-                index << i++;
-                string id = index.str();
-                visitor.start_structure(id);
-                string input_node_id = AttributeVisitor::invalid_node_id;
-                visitor.on_attribute("input_node", input_node_id);
-                visitor.finish_structure();
-            }
-            visitor.finish_structure();
-        }
-        {
-            // Control dependencies
-            visitor.start_structure("control");
-            int64_t i = 0;
-            ostringstream index;
-            for (auto node : serialized_nodes)
-            {
-                for (auto control : node->get_control_dependencies())
-                {
-                    index.str("");
-                    index << i++;
-                    string id = index.str();
-                    visitor.start_structure(id);
-                    string node_id = visitor.get_registered_node_id(node);
-                    string dependency_id = visitor.get_registered_node_id(control);
-                    visitor.on_attribute("node", node_id);
-                    visitor.on_attribute("dependency", dependency_id);
-                    visitor.finish_structure();
-                }
-            }
-            {
-                // Add a sentinel
-                index.str("");
-                index << i++;
-                string id = index.str();
-                visitor.start_structure(id);
-                string node_id = AttributeVisitor::invalid_node_id;
-                visitor.on_attribute("node", node_id);
-                visitor.finish_structure();
-            }
-            visitor.finish_structure();
-        }
-    }
-    else
-    {
-        NodeVector deserialized_nodes;
-        {
-            // Read the graph
-            visitor.start_structure("nodes");
-            int64_t i = 0;
-            ostringstream index;
-            while (true)
-            {
-                index.str("");
-                index << i++;
-                string id = index.str();
-                shared_ptr<Node> node;
-                NodeAttributeAdapter adapter(node);
-                adapter.visit(visitor, id);
-                if (node)
-                {
-                    visitor.register_node(node);
-                    deserialized_nodes.push_back(node);
-                }
-                else
-                {
-                    break;
-                }
-            }
-            visitor.finish_structure();
-        }
-        {
-            visitor.start_structure("edges");
-            // Connect the nodes
-            int64_t i = 0;
-            ostringstream index;
-            bool more_edges = true;
-            while (more_edges)
-            {
-                index.str("");
-                index << i++;
-                string id = index.str();
-                visitor.start_structure(id);
-                string input_node_id;
-                visitor.on_attribute("input_node", input_node_id);
-                if (!input_node_id.empty())
-                {
-                    shared_ptr<Node> input_node = visitor.get_registered_node(input_node_id);
-                    NGRAPH_CHECK(input_node, "input node of edge not known");
-                    uint64_t input_index;
-                    string output_node_id;
-                    uint64_t output_index;
-                    visitor.on_attribute("input_index", input_index);
-                    visitor.on_attribute("output_node", output_node_id);
-                    visitor.on_attribute("output_index", output_index);
-                    shared_ptr<Node> output_node = visitor.get_registered_node(output_node_id);
-                    NGRAPH_CHECK(output_node, "output_node of edge not known");
-                    input_node->set_argument(input_index, output_node->output(output_index));
-                }
-                else
-                {
-                    more_edges = false;
-                }
-                visitor.finish_structure();
-            }
-            visitor.finish_structure();
-        }
-        {
-            // Control dependencies
-            visitor.start_structure("control");
-            int64_t i = 0;
-            ostringstream index;
-            bool more_control = true;
-            while (more_control)
-            {
-                index.str("");
-                index << i++;
-                string id = index.str();
-                visitor.start_structure(id);
-                string node_id;
-                visitor.on_attribute("node", node_id);
-                if (!node_id.empty())
-                {
-                    shared_ptr<Node> node = visitor.get_registered_node(node_id);
-                    NGRAPH_CHECK(node, "node of control edge not known");
-                    string dependency_id;
-                    visitor.on_attribute("dependency", dependency_id);
-                    shared_ptr<Node> dependency = visitor.get_registered_node(dependency_id);
-                    NGRAPH_CHECK(dependency, "dependency of control edge not known");
-                    node->add_control_dependency(dependency);
-                }
-                else
-                {
-                    more_control = false;
-                }
-                visitor.finish_structure();
-            }
-            visitor.finish_structure();
-        }
-        for (auto node : topological_sort(deserialized_nodes))
-        {
-            node->validate_and_infer_types();
-        }
-    }
-
-    {
-        // Finally visit the object attributes
-        visitor.start_structure("value");
-        m_ref->visit_attributes(visitor);
-        visitor.finish_structure();
-    }
-    return true;
-}
index e411df3..4f0168c 100644 (file)
@@ -33,13 +33,6 @@ constexpr DiscreteTypeInfo op::v0::TensorIterator::InvariantInputDescription::ty
 constexpr DiscreteTypeInfo op::v0::TensorIterator::BodyOutputDescription::type_info;
 constexpr DiscreteTypeInfo op::v0::TensorIterator::ConcatOutputDescription::type_info;
 
-constexpr DiscreteTypeInfo op::v0::TensorIterator::BodyLambda::type_info;
-
-bool op::v0::TensorIterator::BodyLambda::visit_attributes(AttributeVisitor& visitor)
-{
-    return true;
-}
-
 op::v0::TensorIterator::TensorIterator(const OutputVector& values)
     : op::util::FusedOp(values)
 {
@@ -310,12 +303,7 @@ namespace ngraph
 
 bool op::v0::TensorIterator::visit_attributes(AttributeVisitor& visitor)
 {
-    if (!m_body)
-    {
-        m_body = make_shared<BodyLambda>();
-    }
-    shared_ptr<Lambda> lambda = m_body;
-    visitor.on_attribute("body", lambda);
+    visitor.on_attribute("body", m_body);
     visitor.on_attribute("input_descriptions", m_input_descriptions);
     visitor.on_attribute("output_descriptions", m_output_descriptions);
 
@@ -663,8 +651,7 @@ std::shared_ptr<Node>
     auto func = std::make_shared<Function>(m_body->get_results(), m_body->get_parameters());
     auto spec_func =
         specialize_function(func, types, new_shapes, std::vector<void*>(new_args.size(), nullptr));
-    op->m_body =
-        std::make_shared<BodyLambda>(spec_func->get_results(), spec_func->get_parameters());
+    op->m_body = std::make_shared<Function>(spec_func->get_results(), spec_func->get_parameters());
 
     for (auto& input_description : m_input_descriptions)
     {
index dd02412..9697323 100644 (file)
@@ -64,13 +64,15 @@ namespace ngraph
             void add_provenance_tags(const Node& onnx_node,
                                      const OutputVector& ng_node_vector) const;
 
+        protected:
+            ParameterVector m_parameters;
+
         private:
             const ONNX_NAMESPACE::GraphProto* m_graph_proto;
             std::unique_ptr<GraphCache> m_cache;
             std::vector<Node> m_nodes;
             std::vector<ValueInfo> m_inputs;
             std::vector<ValueInfo> m_outputs;
-            ParameterVector m_parameters;
             Model* m_model;
         };
 
index fd7cd63..7934398 100644 (file)
@@ -20,6 +20,7 @@
 #include <sstream>
 
 #include "ngraph/log.hpp"
+#include "ngraph/node.hpp"
 #include "ngraph/provenance.hpp"
 #include "onnx_import/core/graph.hpp"
 #include "onnx_import/core/node.hpp"
@@ -291,6 +292,37 @@ namespace ngraph
                   model,
                   std::unique_ptr<SubgraphCache>(new SubgraphCache(parent_graph.get_graph_cache())))
         {
+            std::vector<std::shared_ptr<ngraph::Node>> subgraph_root_nodes;
+            const auto& outputs = as_result_vector(get_ng_outputs());
+            for (auto& out : outputs)
+            {
+                subgraph_root_nodes.push_back(out);
+            }
+            const auto& params = get_ng_parameters();
+            for (auto& param : params)
+            {
+                subgraph_root_nodes.push_back(param);
+            }
+            const auto subgraph_nodes = topological_sort(subgraph_root_nodes);
+
+            const auto& parent_graph_parameters = parent_graph.get_ng_parameters();
+            for (const auto& node : subgraph_nodes)
+            {
+                if (op::is_parameter(node))
+                {
+                    const auto sub_it = std::find(m_parameters.begin(), m_parameters.end(), node);
+                    // not present as subgraph parameter
+                    if (sub_it == m_parameters.end())
+                    {
+                        const auto parent_it = std::find(
+                            parent_graph_parameters.begin(), parent_graph_parameters.end(), node);
+                        if (parent_it != m_parameters.end())
+                        {
+                            m_parameters.push_back(*parent_it);
+                        }
+                    }
+                }
+            }
         }
 
     } // namespace onnx_import
index 8c1ba80..f456307 100644 (file)
@@ -116,15 +116,15 @@ namespace ngraph
                     const auto& graph_outputs = body_graph.get_ng_outputs();
                     const auto& graph_inputs = body_graph.get_ng_parameters();
 
-                    CHECK_VALID_NODE(
-                        node,
-                        graph_inputs.size() == loop_carried_dependencies.size() + 2,
-                        "The provided loop body graph inputs size (",
-                        graph_inputs.size(),
-                        "), is not equal to the sum of loop carried dependencies and two mandatory"
-                        " inputs (",
-                        loop_carried_dependencies.size() + 2,
-                        ")");
+                    CHECK_VALID_NODE(node,
+                                     graph_inputs.size() >= loop_carried_dependencies.size() + 2,
+                                     "The provided loop body graph inputs size (",
+                                     graph_inputs.size(),
+                                     "), is not greater than the sum of loop carried dependencies "
+                                     "and two mandatory"
+                                     " inputs (",
+                                     loop_carried_dependencies.size() + 2,
+                                     ")");
 
                     CHECK_VALID_NODE(node,
                                      graph_outputs.size() >= loop_carried_dependencies.size() + 1,
@@ -144,8 +144,8 @@ namespace ngraph
                         default_opset::Constant::create(element::boolean, Shape{}, {true});
 
                     // create the loop body
-                    const auto body = std::make_shared<ngraph::op::TensorIterator::BodyLambda>(
-                        graph_outputs, graph_inputs);
+                    const auto body =
+                        std::make_shared<ngraph::Function>(graph_outputs, graph_inputs);
                     auto tensor_iterator = std::make_shared<ngraph::op::TensorIterator>();
                     tensor_iterator->set_body(body);
 
index 642659c..728b4ce 100644 (file)
@@ -108,8 +108,7 @@ void util::TensorIteratorBuilder::get_graph_body()
 
     m_body_outputs = as_output_vector(body_attrs["results"].cast<ngraph::NodeVector>());
     m_body_parameters = body_attrs["parameters"].cast<ngraph::ParameterVector>();
-    m_body =
-        std::make_shared<ngraph::op::TensorIterator::BodyLambda>(m_body_outputs, m_body_parameters);
+    m_body = std::make_shared<ngraph::Function>(m_body_outputs, m_body_parameters);
 }
 
 void util::TensorIteratorBuilder::set_tensor_iterator_sliced_inputs(
index dc15a3e..fd6fa45 100644 (file)
@@ -126,7 +126,7 @@ namespace util
         const py::dict& m_attributes;
         ngraph::OutputVector m_body_outputs;
         ngraph::ParameterVector m_body_parameters;
-        std::shared_ptr<ngraph::op::TensorIterator::BodyLambda> m_body;
+        std::shared_ptr<ngraph::Function> m_body;
         py::list m_slice_input_desc;
         py::list m_merged_input_desc;
         py::list m_invariant_input_desc;