Fine-Grain Transformation pipeline tuning (#2547)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Fri, 9 Oct 2020 12:33:19 +0000 (15:33 +0300)
committerGitHub <noreply@github.com>
Fri, 9 Oct 2020 12:33:19 +0000 (15:33 +0300)
* Initial version of transformation callback refactoring

* Improved fine-grain tuning for transformation pipeline

* Check disabled matchers in GraphRewrite

* Avoid deprecated classes inside PassConfig

* Enabled DepthToSpace fusion by default

* Removed doulbe search in map

* Moved back pass_config.hpp; Added doxygen documentation for new class and methods

* Added doxygen comment for Manager and GraphRewrite new mthods

20 files changed:
inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp
inference-engine/src/transformations/src/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.cpp
inference-engine/src/transformations/src/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.cpp
inference-engine/src/transformations/src/transformations/convert_pad_to_group_conv.cpp
inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp
ngraph/core/include/ngraph/pass/graph_rewrite.hpp
ngraph/core/include/ngraph/pass/manager.hpp
ngraph/core/include/ngraph/pass/pass.hpp
ngraph/core/include/ngraph/pass/pass_config.hpp
ngraph/core/src/pass/graph_rewrite.cpp
ngraph/core/src/pass/manager.cpp
ngraph/core/src/pass/pass.cpp
ngraph/core/src/pass/pass_config.cpp
ngraph/test/graph_rewrite.cpp
ngraph/test/runtime/CMakeLists.txt
ngraph/test/runtime/backend.cpp
ngraph/test/runtime/backend.hpp

index 57d23d3..25a2140 100644 (file)
 #include <transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp>
 #include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
 #include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp>
+#include <transformations/convert_depth_to_space.hpp>
+#include <transformations/convert_space_to_depth.hpp>
 #include <transformations/init_node_info.hpp>
 #include <transformations/convert_precision.hpp>
+#include <transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp>
+#include <transformations/convert_gelu.hpp>
+#include <transformations/depth_to_space_fusion.hpp>
+#include <transformations/convert_batch_to_space.hpp>
+#include <transformations/convert_extract_image_patches_to_reorg_yolo.hpp>
+#include <transformations/hswish_decomposition.hpp>
+#include <transformations/reduce_l1_decomposition.hpp>
+#include <transformations/reduce_l2_decomposition.hpp>
+#include <transformations/convert_space_to_batch.hpp>
+#include <transformations/softplus_decomposition.hpp>
+#include <transformations/convert_pad_to_group_conv.hpp>
 #include <transformations/rt_info/fused_names_attribute.hpp>
 #include <ngraph/opsets/opset2.hpp>
 #include <ngraph/opsets/opset3.hpp>
@@ -63,31 +76,6 @@ Engine::~Engine() {
 static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
     OV_ITT_SCOPED_TASK(MKLDNNPlugin::itt::domains::MKLDNNPlugin, "Transformation");
 
-    const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
-        // DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
-        if (auto dtsOp = std::dynamic_pointer_cast<const ::ngraph::opset3::DepthToSpace>(node)) {
-            return dtsOp->input_value(0).get_shape().size() <= 5lu && dtsOp->input_value(0).get_shape().size() == dtsOp->get_output_shape(0).size();
-        }
-
-        // SpaceToDepth node implementation supports only equal input/output tensors with rank <= 5
-        if (auto stdOp = std::dynamic_pointer_cast<const ::ngraph::opset3::SpaceToDepth>(node)) {
-            return stdOp->input_value(0).get_shape().size() <= 5lu && stdOp->input_value(0).get_shape().size() == stdOp->get_output_shape(0).size();
-        }
-
-        if (auto fc_op = std::dynamic_pointer_cast<const ngraph::op::FullyConnected>(node)) {
-            return fc_op->input_value(0).get_shape().size() == 3ul;
-        }
-
-        return std::dynamic_pointer_cast<const ngraph::opset2::Gelu>(node) ||
-               std::dynamic_pointer_cast<const ngraph::opset2::BatchToSpace>(node) ||
-               std::dynamic_pointer_cast<const ngraph::opset2::SpaceToBatch>(node) ||
-               std::dynamic_pointer_cast<const ngraph::opset3::ExtractImagePatches>(node) ||
-               std::dynamic_pointer_cast<const ngraph::opset4::HSwish>(node) ||
-               std::dynamic_pointer_cast<const ngraph::opset4::ReduceL1>(node) ||
-               std::dynamic_pointer_cast<const ngraph::opset4::ReduceL2>(node) ||
-               std::dynamic_pointer_cast<const ngraph::opset4::SoftPlus>(node) ||
-               std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
-    };
     auto nGraphFunc = clonedNetwork->getFunction();
     // Disable shape inference (WA for generic operations)
     ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
@@ -116,7 +104,41 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
     manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
     manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
 
-    manager.set_callback(transformations_callback);
+    auto pass_config = manager.get_pass_config();
+
+    using const_node_ptr = const std::shared_ptr<const ngraph::Node>;
+
+    // SpaceToDepth/ DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
+    pass_config->set_callback<ngraph::pass::ConvertSpaceToDepth,
+                              ngraph::pass::ConvertDepthToSpace>(
+            [](const_node_ptr &node) -> bool {
+                return node->input_value(0).get_shape().size() <= 5lu &&
+                       node->input_value(0).get_shape().size() == node->get_output_shape(0).size();
+            });
+
+    // Disable FC reshaping for 3D case
+    pass_config->set_callback<ngraph::pass::ReshapeFullyConnected>(
+            [](const_node_ptr &node) -> bool {
+                return node->input_value(0).get_shape().size() == 3ul;
+            });
+
+    pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
+                              ngraph::pass::ConvertSpaceToBatch>(
+            [](const_node_ptr &node) -> bool {
+                const auto & rank = node->input(0).get_partial_shape().rank().get_length();
+                return rank == 4lu || rank == 5lu;
+            });
+
+    // List of enabled/disabled transformations
+    pass_config->disable<ngraph::pass::ConvertGELU>();
+    pass_config->disable<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
+    pass_config->disable<ngraph::pass::HSwishDecomposition>();
+    pass_config->disable<ngraph::pass::ReduceL1Decomposition>();
+    pass_config->disable<ngraph::pass::ReduceL2Decomposition>();
+    pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
+
+    pass_config->enable<ngraph::pass::ConvertPadToGroupConvolution>();
+
     manager.run_passes(nGraphFunc);
 
     clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
index fae8de5..4989409 100644 (file)
@@ -68,7 +68,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
     manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
     manager.register_pass<ngraph::pass::SwishFusion>();
     manager.register_pass<ngraph::pass::HSwishFusion>();
-    manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
+    manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
     manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
     manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
     manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
@@ -111,7 +111,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
     fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
     fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
 
-    manager.set_callback(m_transformation_callback);
+    // Propagate local PassConfig to internal pass::Manager
+    manager.set_pass_config(get_pass_config());
     manager.run_passes(f);
     return true;
 }
index 8543019..e2d9de3 100644 (file)
@@ -18,7 +18,7 @@ ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
 
     ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
-        if (!dts_node || m_transformation_callback(dts_node)) {
+        if (!dts_node || transformation_callback(dts_node)) {
             return false;
         }
 
index b678426..77eec04 100644 (file)
@@ -154,7 +154,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
 
     manager.register_pass<ngraph::pass::ConstantFolding>();
 
-    manager.set_callback(m_transformation_callback);
+    manager.set_pass_config(get_pass_config());
     manager.run_passes(f);
     return true;
 }
index 80b66b3..ded5783 100644 (file)
@@ -24,7 +24,7 @@ bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph
     manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
     manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
 
-    manager.set_callback(m_transformation_callback);
+    manager.set_pass_config(get_pass_config());
     manager.run_passes(f);
     return true;
 }
index 72db0fb..9689256 100644 (file)
@@ -33,7 +33,7 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph
     manager.register_pass<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
     manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
 
-    manager.set_callback(m_transformation_callback);
+    manager.set_pass_config(get_pass_config());
     manager.run_passes(f);
     return true;
 }
index 9576aa7..b987e96 100644 (file)
@@ -19,7 +19,7 @@ ngraph::pass::ConvertPadToGroupConvolution::ConvertPadToGroupConvolution() {
 
     ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
         auto pad = std::dynamic_pointer_cast<ngraph::opset4::Pad> (m.get_match_root());
-        if (!pad || !m_transformation_callback(pad) /* disabled by default */) {
+        if (!pad) {
             return false;
         }
 
index 75f38a1..411bda6 100644 (file)
@@ -155,11 +155,6 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
                 std::make_shared<ngraph::opset3::DepthToSpace>(reshape_before->input_value(0), mode, block_size);
         depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
         ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
-
-        if (!m_transformation_callback(depth_to_space)) {
-            return false;
-        }
-
         ngraph::replace_node(reshape_after, depth_to_space);
         return true;
     };
index dbfb470..4b19bb5 100644 (file)
@@ -126,12 +126,34 @@ public:
         m_matchers.push_back(pass);
     }
 
-    template <typename T, class... Args>
+    /// \brief Register given transformation class type to GraphRewrite execution list
+    /// All registered transformations will be executed in a single graph traversal.
+    /// Example below show the basic usage of pass::GraphRewrite
+    ///
+    ///     pass::Manager manager;
+    ///     auto anchor = manager.register_pass<GraphRewrite>();
+    ///     anchor->add_matcher<MatcherPassA>();
+    ///     anchor->add_matcher<MatcherPassB>();
+    ///     anchor->set_name("CommonMathcers");
+    ///     manager.run_passes(f);
+    ///
+    /// For some purposes transformation can be registered and disabled by default.
+    ///
+    ///     anchor->add_matcher<MatcherPassB, false>();
+    ///
+    /// \return shared_ptr to the transformation instance
+    template <typename T, bool Enabled = true, class... Args>
     std::shared_ptr<T> add_matcher(Args&&... args)
     {
         static_assert(std::is_base_of<pass::MatcherPass, T>::value,
                       "pass not derived from MatcherPass");
         auto pass = std::make_shared<T>(std::forward<Args>(args)...);
+        auto pass_config = get_pass_config();
+        pass->set_pass_config(pass_config);
+        if (!Enabled)
+        {
+            pass_config->disable<T>();
+        }
         m_matchers.push_back(pass);
         return pass;
     }
index 5438a6c..6a0060f 100644 (file)
@@ -22,7 +22,6 @@
 #include <vector>
 
 #include "ngraph/pass/pass.hpp"
-#include "ngraph/pass/pass_config.hpp"
 #include "ngraph/pass/validate.hpp"
 
 namespace ngraph
@@ -39,14 +38,31 @@ public:
     Manager();
     ~Manager();
 
-    template <typename T, class... Args>
+    /// \brief Register given transformation class type to execution list
+    /// Example below show the basic usage of pass::Manager
+    ///
+    ///     pass::Manager manager;
+    ///     manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
+    ///     manager.run_passes(f);
+    ///
+    /// For some purposes transformation can be registered and disabled by default.
+    ///
+    ///     manager.register_pass<MyTransformation, false>();
+    ///
+    /// \return shared_ptr to the transformation instance
+    template <typename T, bool Enable = true, class... Args>
     std::shared_ptr<T> register_pass(Args&&... args)
     {
         auto rc = push_pass<T>(std::forward<Args>(args)...);
+        rc->set_pass_config(m_pass_config);
         if (m_per_pass_validation)
         {
             push_pass<Validate>();
         }
+        if (!Enable)
+        {
+            m_pass_config->disable<T>();
+        }
         return rc;
     }
 
@@ -59,8 +75,10 @@ public:
     void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
     /// \brief Callback is a lambda function that can be used by registered transformations.
     /// The main purpose of this callback is to provide a way for plugins to disable/enable
-    /// transformations. In some cases plugins may want not to execute some transformations.
-    /// For example plugin can disable unpleasant decompositions because of performance reasons.
+    /// transformations based on some conditions. In some cases plugins may want not to execute some
+    /// transformations.
+    /// For example plugin can disable unpleasant decompositions because of performance reasons for
+    /// some cases.
     /// Callback example:
     /// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
     ///     return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
@@ -69,15 +87,22 @@ public:
     /// decomposition pass will check is this decomposition needed or plugin can execute this
     /// operation directly. And of course on transformation side we need to have a response for this
     /// callback.
-    /// if (m_transformation_callback(batch_to_space)) {
+    /// if (transformation_callback(batch_to_space)) {
     ///     return false;
     /// }
     /// \param callback lamda function that returns true in case if node is supported by plugin and
     /// transformation is not needed
-    void set_callback(param_callback callback)
+    NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
+    void set_callback(const param_callback& callback) { m_pass_config->set_callback(callback); }
+    /// \return PassConfig shared object. This object is used for transformations pipeline
+    /// configuration.
+    /// This object allows to disable/enable transformations execution, set callback to particular
+    /// transformation. For mo details see PassConfig class.
+    std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
+    /// \brief Set external PassConfig object.
+    void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
     {
-        m_transformation_callback = callback;
-        m_has_default_callback = false;
+        *m_pass_config = *pass_config;
     }
 
 protected:
@@ -91,11 +116,7 @@ protected:
         return pass;
     }
 
-    param_callback m_transformation_callback = [](const std::shared_ptr<const Node>&) -> bool {
-        return false;
-    };
-    bool m_has_default_callback = true;
-
+    std::shared_ptr<PassConfig> m_pass_config;
     std::vector<std::shared_ptr<PassBase>> m_pass_list;
     bool m_visualize = false;
     bool m_per_pass_validation = true;
index ae50394..100d454 100644 (file)
@@ -23,6 +23,7 @@
 #include "ngraph/deprecated.hpp"
 #include "ngraph/function.hpp"
 #include "ngraph/node.hpp"
+#include "ngraph/pass/pass_config.hpp"
 #include "ngraph/util.hpp"
 
 namespace ngraph
@@ -39,7 +40,6 @@ namespace ngraph
 
         typedef EnumMask<PassProperty> PassPropertyMask;
         const PassPropertyMask all_pass_property_off;
-        using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
 
         class NGRAPH_API PassBase
         {
@@ -54,8 +54,40 @@ namespace ngraph
             void set_name(const std::string& name) { m_name = name; }
             std::string get_name() const;
 
+            /// \brief Set callback for particular transformation type.
+            /// This method set global callback. For more details see PassConfig class
+            /// documentation.
+            /// \param callback lambda function that takes node and returns bool
             void set_callback(const param_callback& callback);
 
+            /// \brief Set PassConfig for particular transformation instance
+            /// \param pass_config is a PassConfig shared_ptr
+            void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
+            {
+                m_pass_config = pass_config;
+            }
+
+            /// \brief Allows to access PassConfig shared instance
+            /// \return Shared instance of PassConfig class
+            std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
+            /// \brief Applies callback for given node. By default callback returns false.
+            /// This method remains here only for backward compatibility and will be removed
+            /// after all transformations are moved to transformation_callback() method.
+            /// \return result of callback execution for given node
+            NGRAPH_DEPRECATED("Please use transformation_callback method instead")
+            bool m_transformation_callback(const std::shared_ptr<const Node>& node)
+            {
+                return m_pass_config->get_callback(get_type_info())(node);
+            }
+
+            /// \brief Applies callback for given node. By default callback returns false.
+            /// \param node which will be used inside callback
+            /// \return result of callback execution for given node
+            bool transformation_callback(const std::shared_ptr<const Node>& node)
+            {
+                return m_pass_config->get_callback(get_type_info())(node);
+            }
+
             using type_info_t = DiscreteTypeInfo;
 
             virtual const type_info_t& get_type_info() const = 0;
@@ -63,13 +95,11 @@ namespace ngraph
         protected:
             void set_property(const PassPropertyMask& prop, bool value);
 
-            param_callback m_transformation_callback =
-                [](const std::shared_ptr<const Node>&) -> bool { return false; };
-            bool m_has_default_callback = true;
-
         private:
             PassPropertyMask m_property;
+
             std::string m_name;
+            std::shared_ptr<PassConfig> m_pass_config;
         };
 
         class NGRAPH_API FunctionPass : public PassBase
index a592d87..7d85d1a 100644 (file)
 
 #pragma once
 
-#include <map>
-#include <string>
+#include <list>
+#include <memory>
+#include <vector>
 
-#include <ngraph/ngraph_visibility.hpp>
+#include "ngraph/deprecated.hpp"
+#include "ngraph/function.hpp"
+#include "ngraph/node.hpp"
+#include "ngraph/util.hpp"
 
 namespace ngraph
 {
     namespace pass
     {
-        class PassConfig;
-    }
-}
+        using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
+        using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>;
 
-class NGRAPH_API ngraph::pass::PassConfig
-{
-public:
-    PassConfig();
-    const std::map<std::string, bool>& get_enables() const { return m_pass_enables; }
-    void set_pass_enable(const std::string& name, bool enable);
-    bool get_pass_enable(const std::string& name) const;
-    const std::map<std::string, bool>& get_pass_attributes() const { return m_pass_attributes; }
-    void set_pass_attribute(const std::string& name, bool enable);
-    bool get_pass_attribute(const std::string& name) const;
-
-private:
-    std::map<std::string, bool> m_pass_enables;
-    std::map<std::string, bool> m_pass_attributes;
-};
+        /// \brief Class representing a transformations config that is used for disabling/enabling
+        /// transformations registered inside pass::Manager and also allows to set callback for all
+        /// transformations or for particular transformation.
+        ///
+        /// When pass::Manager is created all passes registered inside this manager including nested
+        /// passes will share the same instance of PassConfig class.
+        /// To work with this class first you need to get shared instance of this class by calling
+        /// manager.get_pass_config() method. Then you will be able to disable/enable passes based
+        /// on transformations type_info. For example:
+        ///
+        ///     pass::Manager manager;
+        ///     manager.register_pass<CommonOptimizations>();
+        ///     auto pass_config = manager.get_pass_config();
+        ///     pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
+        ///                                          // CommonOptimizations pipeline
+        ///     manager.run_passes(f);
+        ///
+        /// Sometimes it is needed to call transformation inside other transformation manually. And
+        /// for that case before running transformation you need manually check that this pass is
+        /// not disabled and then you need to set current PassConfig instance to this
+        /// transformation. For example:
+        ///
+        ///     // Inside MatcherPass callback or inside FunctionPass run_on_function() method
+        ///     // you need to call get_pass_config() method to get shared instance of PassConfig
+        ///     auto pass_config = get_pass_config();
+        ///
+        ///     // Before running nested transformation you need to check is it disabled or not
+        ///     if (!pass_config->is_disabled<ConvertGELU>()) {
+        ///         auto pass = ConvertGELU();
+        ///         pass->set_pass_config(pass_config);
+        ///         pass.apply(node);
+        ///     }
+        ///
+        /// Following this logic inside your transformations you will guaranty that transformations
+        /// will be executed in a right way.
+        class NGRAPH_API PassConfig
+        {
+        public:
+            /// \brief Disable transformation by its type_info
+            /// \param type_info Transformation type_info
+            void disable(const DiscreteTypeInfo& type_info) { m_disabled.insert(type_info); }
+            /// \brief Disable transformation by its class type (based on type_info)
+            template <typename T>
+            void disable()
+            {
+                disable(T::type_info);
+            }
+
+            /// \brief Enable transformation by its type_info
+            /// \param type_info Transformation type_info
+            void enable(const DiscreteTypeInfo& type_info) { m_disabled.erase(type_info); }
+            /// \brief Enable transformation by its class type (based on type_info)
+            template <typename T>
+            void enable()
+            {
+                enable(T::type_info);
+            }
+
+            /// \brief Set callback for all kind of transformations
+            void set_callback(const param_callback& callback) { m_callback = callback; }
+            template <typename... Args>
+            typename std::enable_if<sizeof...(Args) == 0>::type
+                set_callback(const param_callback& callback)
+            {
+            }
+
+            /// \brief Set callback for particular transformation class types
+            ///
+            /// Example below show how to set callback for one or multiple passes using this method.
+            ///
+            ///     pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
+            ///                               ngraph::pass::ConvertSpaceToBatch>(
+            ///              [](const_node_ptr &node) -> bool {
+            ///                   // Disable transformations for cases when input shape rank is not
+            ///                   equal to 4
+            ///                   const auto input_shape_rank =
+            ///                   node->get_output_partial_shape(0).rank().get_length();
+            ///                   if (input_shape_rank != 4) {
+            ///                       return false;
+            ///                   }
+            ///                   return true;
+            ///               });
+            ///
+            /// Note that inside transformations you must provide code that work with this callback.
+            /// See example below:
+            ///
+            ///     if (transformation_callback(node)) {
+            ///         return false; // exit from transformation
+            ///     }
+            ///
+            template <typename T, class... Args>
+            void set_callback(const param_callback& callback)
+            {
+                m_callback_map[T::type_info] = callback;
+                set_callback<Args...>(callback);
+            }
+
+            /// \brief Get callback for given transformation type_info
+            /// \param type_info Transformation type_info
+            ///
+            /// In case if callback wasn't set for given transformation type then global callback
+            /// will be returned. But if even global callback wasn't set then default callback will
+            /// be returned.
+            param_callback get_callback(const DiscreteTypeInfo& type_info) const;
+
+            /// \brief Get callback for given transformation class type
+            /// \return callback lambda function
+            template <typename T>
+            param_callback get_callback() const
+            {
+                return get_callback(T::type_info);
+            }
+
+            /// \brief Check either transformation type is disabled or not
+            /// \param type_info Transformation type_info
+            /// \return true if transformation type was disabled and false otherwise
+            bool is_disabled(const DiscreteTypeInfo& type_info) const
+            {
+                return m_disabled.count(type_info);
+            }
+
+            /// \brief Check either transformation class type is disabled or not
+            /// \return true if transformation type was disabled and false otherwise
+            template <typename T>
+            bool is_disabled() const
+            {
+                return is_disabled(T::type_info);
+            }
+
+        private:
+            param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
+                return false;
+            };
+            param_callback_map m_callback_map;
+            std::unordered_set<DiscreteTypeInfo> m_disabled;
+        };
+    }
+}
\ No newline at end of file
index ef65eea..66993eb 100644 (file)
@@ -72,6 +72,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
     OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
 
     bool rewritten = false;
+    const auto& pass_config = get_pass_config();
 
     // Initialize execution queue with nodes in topological order
     deque<std::shared_ptr<Node>> nodes_to_run;
@@ -85,6 +86,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
     std::unordered_map<NodeTypeInfo, std::vector<size_t>> type_to_matcher;
     for (size_t matcher_index = 0; matcher_index < m_matchers.size(); ++matcher_index)
     {
+        // Skip passes that are disabled
+        if (pass_config->is_disabled(m_matchers[matcher_index]->get_type_info()))
+            continue;
+
         auto matcher = m_matchers[matcher_index]->get_matcher();
         if (!matcher)
         {
@@ -139,11 +144,6 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
             return false;
         }
 
-        if (!m_has_default_callback)
-        {
-            m_pass->set_callback(m_transformation_callback);
-        }
-
         // Apply MatcherPass. In case if it returns true no other MatcherPasses will apply
         // to this node
         bool status = m_pass->apply(node);
@@ -224,6 +224,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
         {
             for (auto& m_pass : m_matchers)
             {
+                // Skip passes that are disabled
+                if (pass_config->is_disabled(m_pass->get_type_info()))
+                    continue;
+
                 if (run_matcher_pass(m_pass, node))
                 {
                     rewritten = true;
index 67045a7..e4c7044 100644 (file)
@@ -36,6 +36,7 @@ using namespace ngraph;
 
 pass::Manager::Manager()
     : m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING"))
+    , m_pass_config(std::make_shared<PassConfig>())
 {
 }
 
@@ -56,12 +57,14 @@ void pass::Manager::run_passes(shared_ptr<Function> func)
     bool function_changed = false;
     for (auto& pass : m_pass_list)
     {
-        pass_timer.start();
-        if (!m_has_default_callback)
+        if (m_pass_config->is_disabled(pass->get_type_info()))
         {
-            pass->set_callback(m_transformation_callback);
+            NGRAPH_DEBUG << "Pass " << pass->get_name() << " is disabled";
+            continue;
         }
 
+        pass_timer.start();
+
         NGRAPH_SUPPRESS_DEPRECATED_START
         if (auto matcher_pass = dynamic_pointer_cast<MatcherPass>(pass))
         {
index bd4d3b0..4229e5a 100644 (file)
@@ -33,6 +33,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0);
 
 pass::PassBase::PassBase()
     : m_property{all_pass_property_off}
+    , m_pass_config(std::make_shared<PassConfig>())
 {
 }
 
@@ -73,8 +74,7 @@ std::string pass::PassBase::get_name() const
 
 void pass::PassBase::set_callback(const param_callback& callback)
 {
-    m_transformation_callback = callback;
-    m_has_default_callback = false;
+    m_pass_config->set_callback(callback);
 }
 
 // The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
index 1d2435a..c123d4b 100644 (file)
 //*****************************************************************************
 
 #include "ngraph/pass/pass_config.hpp"
-#include "ngraph/env_util.hpp"
-#include "ngraph/except.hpp"
-#include "ngraph/log.hpp"
-#include "ngraph/util.hpp"
 
-using namespace std;
 using namespace ngraph;
 
-// TODO: Add file-based configuration support
-pass::PassConfig::PassConfig()
+pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const
 {
-    //
-    // Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES
-    // and returns the pass names and whether they should be enabled or disabled in the
-    // provided unordered_map. Implementation of pass selection is up to the backend
-    // E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would
-    //       set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims
-    //
-    string pass_enables = getenv_string("NGRAPH_PASS_ENABLES");
-    if (!pass_enables.empty())
-    {
-        stringstream ss;
-        ss << pass_enables;
-        while (ss.good())
-        {
-            string substr;
-            getline(ss, substr, ';');
-            auto split_str = split(substr, ':', false);
-            switch (split_str.size())
-            {
-            case 1: m_pass_enables.emplace(split_str[0], true); break;
-            case 2: m_pass_enables.emplace(split_str[0], parse_string<bool>(split_str[1])); break;
-            default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ENABLES: " + substr);
-            }
-        }
-    }
-    //
-    //  Parses the semi-colon separated environment string passed through NGRAPH_PASS_ATTRIBUTES
-    //  and returns the pass attributes and whether they should be enabled or disabled in the
-    //  provided unordered_map. Naming of pass attributes is up to the backends.
-    //
-    //  For example:
-    //  NGRAPH_PASS_ATTRIBUTES="OptimizeForMemory=0;MemoryAssignment::ReuseMemory=1;UseDefaultLayouts"
-    //  would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on
-    //  "UseDefaultLayouts"
-    //
-    static const string pass_attributes = getenv_string("NGRAPH_PASS_ATTRIBUTES");
-    if (!pass_attributes.empty())
-    {
-        stringstream ss;
-        ss << pass_attributes;
-        while (ss.good())
-        {
-            string substr;
-            getline(ss, substr, ';');
-            auto split_str = split(substr, '=', false);
-            switch (split_str.size())
-            {
-            case 1: m_pass_attributes.emplace(split_str[0], true); break;
-            case 2:
-                m_pass_attributes.emplace(split_str[0], parse_string<bool>(split_str[1]));
-                break;
-            default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ATTRIBUTES: " + substr);
-            }
-        }
-    }
-}
-
-void pass::PassConfig::set_pass_enable(const string& name, bool enable)
-{
-    m_pass_enables[name] = enable;
-}
-
-bool pass::PassConfig::get_pass_enable(const string& name) const
-{
-    auto it = m_pass_enables.find(name);
-    if (it != m_pass_enables.end())
+    const auto& it = m_callback_map.find(type_info);
+    if (it != m_callback_map.end())
     {
         return it->second;
     }
-    return false;
-}
-
-void pass::PassConfig::set_pass_attribute(const string& name, bool enable)
-{
-    m_pass_attributes[name] = enable;
-}
-
-bool pass::PassConfig::get_pass_attribute(const string& name) const
-{
-    auto it = m_pass_attributes.find(name);
-    if (it != m_pass_attributes.end())
+    else
     {
-        return it->second;
+        return m_callback;
     }
-    return false;
 }
index 5f3b13c..e5c435d 100644 (file)
@@ -15,6 +15,7 @@ using namespace ngraph;
 class TestPass : public ngraph::pass::MatcherPass
 {
 public:
+    NGRAPH_RTTI_DECLARATION;
     TestPass()
         : MatcherPass()
     {
@@ -39,12 +40,16 @@ public:
 class Anchor : public ngraph::pass::GraphRewrite
 {
 public:
+    NGRAPH_RTTI_DECLARATION;
     Anchor()
         : GraphRewrite()
     {
     }
 };
 
+NGRAPH_RTTI_DEFINITION(TestPass, "TestPass", 0);
+NGRAPH_RTTI_DEFINITION(Anchor, "Anchor", 0);
+
 std::shared_ptr<Function> get_function()
 {
     auto data =
@@ -93,7 +98,7 @@ TEST(GraphRewriteTest, GraphRewriteCallback)
     ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
 }
 
-TEST(GraphRewriteTest, ManagerCallback)
+TEST(GraphRewriteTest, ManagerCallbackDeprecated)
 {
     auto f = get_function();
 
@@ -106,6 +111,20 @@ TEST(GraphRewriteTest, ManagerCallback)
     ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
 }
 
+TEST(GraphRewriteTest, ManagerCallback)
+{
+    auto f = get_function();
+
+    pass::Manager manager;
+    auto anchor = manager.register_pass<Anchor>();
+    anchor->add_matcher<TestPass>();
+    auto pass_config = manager.get_pass_config();
+    pass_config->set_callback(get_callback());
+    manager.run_passes(f);
+
+    ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
+
 TEST(GraphRewriteTest, ManagerCallback2)
 {
     auto f = get_function();
@@ -244,4 +263,127 @@ TEST(GraphRewriteTest, TypeBasedMatcherPassOrder2)
     anchor.run_on_function(f);
 
     ASSERT_EQ(count_ops_of_type<opset3::Tanh>(f), 1);
+}
+
+TEST(PassConfigTest, Test1)
+{
+    {
+        auto f = get_function();
+
+        pass::Manager manager;
+        manager.register_pass<TestPass>();
+
+        auto pass_config = manager.get_pass_config();
+        pass_config->set_callback(get_callback());
+
+        manager.run_passes(f);
+
+        ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+    }
+
+    {
+        auto f = get_function();
+
+        pass::Manager manager;
+        manager.register_pass<TestPass>();
+
+        auto pass_config = manager.get_pass_config();
+        pass_config->set_callback<TestPass>(get_callback());
+
+        manager.run_passes(f);
+
+        ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+    }
+
+    {
+        auto f = get_function();
+
+        pass::Manager manager;
+        manager.register_pass<TestPass>();
+
+        auto pass_config = std::make_shared<ngraph::pass::PassConfig>();
+        pass_config->set_callback<TestPass>(get_callback());
+
+        manager.set_pass_config(pass_config);
+        manager.run_passes(f);
+
+        ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+    }
+
+    {
+        auto f = get_function();
+
+        pass::Manager manager;
+        auto anchor = manager.register_pass<Anchor>();
+        anchor->add_matcher<TestPass>();
+
+        auto pass_config = anchor->get_pass_config();
+        pass_config->set_callback(get_callback());
+
+        manager.run_passes(f);
+
+        ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+    }
+
+    {
+        auto f = get_function();
+
+        pass::Manager manager;
+        auto anchor = manager.register_pass<Anchor>();
+        anchor->add_matcher<TestPass>();
+
+        auto pass_config = anchor->get_pass_config();
+        pass_config->set_callback<TestPass>(get_callback());
+
+        manager.run_passes(f);
+
+        ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+    }
+
+    {
+        auto pass_config = std::make_shared<pass::PassConfig>();
+
+        pass::Manager manager1;
+        pass::Manager manager2;
+        manager1.set_pass_config(pass_config);
+        manager2.set_pass_config(pass_config);
+        ASSERT_EQ(pass_config.use_count(), 1);
+    }
+
+    {
+        auto f = get_function();
+
+        pass::Manager manager;
+        manager.register_pass<TestPass>();
+
+        auto pass_config = manager.get_pass_config();
+        pass_config->set_callback<TestPass>(get_callback());
+
+        pass_config->disable<TestPass>();
+        manager.run_passes(f);
+        ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
+
+        pass_config->enable<TestPass>();
+        manager.run_passes(f);
+        ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+    }
+
+    {
+        auto f = get_function();
+
+        pass::Manager manager;
+        auto anchor = manager.register_pass<Anchor>();
+        anchor->add_matcher<TestPass>();
+
+        auto pass_config = manager.get_pass_config();
+        pass_config->set_callback<TestPass>(get_callback());
+
+        pass_config->disable<TestPass>();
+        manager.run_passes(f);
+        ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
+
+        pass_config->enable<TestPass>();
+        manager.run_passes(f);
+        ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+    }
 }
\ No newline at end of file
index e37aba8..27cf6b0 100644 (file)
@@ -51,6 +51,8 @@ set (SRC
     pass/shape_relevance.hpp
     )
 
+disable_deprecated_warnings()
+
 add_library(ngraph_backend SHARED ${SRC})
 target_compile_definitions(ngraph_backend
     PRIVATE
index da5a7ba..2a2444a 100644 (file)
@@ -102,14 +102,6 @@ std::shared_ptr<ngraph::runtime::Tensor>
     throw std::invalid_argument("This backend does not support dynamic tensors");
 }
 
-std::shared_ptr<runtime::Executable>
-    runtime::Backend::compile(std::shared_ptr<Function> func,
-                              ngraph::pass::PassConfig& /* pass_config */,
-                              bool enable_performance_data)
-{
-    return compile(func, enable_performance_data);
-}
-
 bool runtime::Backend::is_supported(const Node& /* node */) const
 {
     // The default behavior is that a backend does not support any ops. If this is not the case
index b5dd257..b875781 100644 (file)
@@ -22,7 +22,6 @@
 #include "backend_visibility.hpp"
 #include "executable.hpp"
 #include "ngraph/function.hpp"
-#include "ngraph/pass/pass_config.hpp"
 #include "ngraph/shape.hpp"
 #include "ngraph/type/element_type.hpp"
 #include "ngraph/util.hpp"
@@ -111,14 +110,6 @@ public:
     virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
                                                 bool enable_performance_data = false) = 0;
 
-    /// \brief Compiles a Function.
-    /// \param func The function to compile
-    /// \param pass_config Configuration object for defining compilation options
-    /// \returns compiled function or nullptr on failure
-    virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
-                                                ngraph::pass::PassConfig& pass_config,
-                                                bool enable_performance_data = false);
-
     /// \brief Loads a previously saved Executable object from a stream.
     /// \param input_stream the opened input stream containing the saved Executable
     /// \returns A compiled function or throws an exception on error