#include <transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp>
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp>
+#include <transformations/convert_depth_to_space.hpp>
+#include <transformations/convert_space_to_depth.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/convert_precision.hpp>
+#include <transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp>
+#include <transformations/convert_gelu.hpp>
+#include <transformations/depth_to_space_fusion.hpp>
+#include <transformations/convert_batch_to_space.hpp>
+#include <transformations/convert_extract_image_patches_to_reorg_yolo.hpp>
+#include <transformations/hswish_decomposition.hpp>
+#include <transformations/reduce_l1_decomposition.hpp>
+#include <transformations/reduce_l2_decomposition.hpp>
+#include <transformations/convert_space_to_batch.hpp>
+#include <transformations/softplus_decomposition.hpp>
+#include <transformations/convert_pad_to_group_conv.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
OV_ITT_SCOPED_TASK(MKLDNNPlugin::itt::domains::MKLDNNPlugin, "Transformation");
- const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
- // DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
- if (auto dtsOp = std::dynamic_pointer_cast<const ::ngraph::opset3::DepthToSpace>(node)) {
- return dtsOp->input_value(0).get_shape().size() <= 5lu && dtsOp->input_value(0).get_shape().size() == dtsOp->get_output_shape(0).size();
- }
-
- // SpaceToDepth node implementation supports only equal input/output tensors with rank <= 5
- if (auto stdOp = std::dynamic_pointer_cast<const ::ngraph::opset3::SpaceToDepth>(node)) {
- return stdOp->input_value(0).get_shape().size() <= 5lu && stdOp->input_value(0).get_shape().size() == stdOp->get_output_shape(0).size();
- }
-
- if (auto fc_op = std::dynamic_pointer_cast<const ngraph::op::FullyConnected>(node)) {
- return fc_op->input_value(0).get_shape().size() == 3ul;
- }
-
- return std::dynamic_pointer_cast<const ngraph::opset2::Gelu>(node) ||
- std::dynamic_pointer_cast<const ngraph::opset2::BatchToSpace>(node) ||
- std::dynamic_pointer_cast<const ngraph::opset2::SpaceToBatch>(node) ||
- std::dynamic_pointer_cast<const ngraph::opset3::ExtractImagePatches>(node) ||
- std::dynamic_pointer_cast<const ngraph::opset4::HSwish>(node) ||
- std::dynamic_pointer_cast<const ngraph::opset4::ReduceL1>(node) ||
- std::dynamic_pointer_cast<const ngraph::opset4::ReduceL2>(node) ||
- std::dynamic_pointer_cast<const ngraph::opset4::SoftPlus>(node) ||
- std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
- };
auto nGraphFunc = clonedNetwork->getFunction();
// Disable shape inference (WA for generic operations)
ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
- manager.set_callback(transformations_callback);
+ auto pass_config = manager.get_pass_config();
+
+ using const_node_ptr = const std::shared_ptr<const ngraph::Node>;
+
+ // SpaceToDepth/ DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
+ pass_config->set_callback<ngraph::pass::ConvertSpaceToDepth,
+ ngraph::pass::ConvertDepthToSpace>(
+ [](const_node_ptr &node) -> bool {
+ return node->input_value(0).get_shape().size() <= 5lu &&
+ node->input_value(0).get_shape().size() == node->get_output_shape(0).size();
+ });
+
+ // Disable FC reshaping for 3D case
+ pass_config->set_callback<ngraph::pass::ReshapeFullyConnected>(
+ [](const_node_ptr &node) -> bool {
+ return node->input_value(0).get_shape().size() == 3ul;
+ });
+
+ pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
+ ngraph::pass::ConvertSpaceToBatch>(
+ [](const_node_ptr &node) -> bool {
+ const auto & rank = node->input(0).get_partial_shape().rank().get_length();
+ return rank == 4lu || rank == 5lu;
+ });
+
+ // List of enabled/disabled transformations
+ pass_config->disable<ngraph::pass::ConvertGELU>();
+ pass_config->disable<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
+ pass_config->disable<ngraph::pass::HSwishDecomposition>();
+ pass_config->disable<ngraph::pass::ReduceL1Decomposition>();
+ pass_config->disable<ngraph::pass::ReduceL2Decomposition>();
+ pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
+
+ pass_config->enable<ngraph::pass::ConvertPadToGroupConvolution>();
+
manager.run_passes(nGraphFunc);
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.register_pass<ngraph::pass::HSwishFusion>();
- manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
+ manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
- manager.set_callback(m_transformation_callback);
+ // Propagate local PassConfig to internal pass::Manager
+ manager.set_pass_config(get_pass_config());
manager.run_passes(f);
return true;
}
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
- if (!dts_node || m_transformation_callback(dts_node)) {
+ if (!dts_node || transformation_callback(dts_node)) {
return false;
}
manager.register_pass<ngraph::pass::ConstantFolding>();
- manager.set_callback(m_transformation_callback);
+ manager.set_pass_config(get_pass_config());
manager.run_passes(f);
return true;
}
manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
- manager.set_callback(m_transformation_callback);
+ manager.set_pass_config(get_pass_config());
manager.run_passes(f);
return true;
}
manager.register_pass<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
- manager.set_callback(m_transformation_callback);
+ manager.set_pass_config(get_pass_config());
manager.run_passes(f);
return true;
}
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto pad = std::dynamic_pointer_cast<ngraph::opset4::Pad> (m.get_match_root());
- if (!pad || !m_transformation_callback(pad) /* disabled by default */) {
+ if (!pad) {
return false;
}
std::make_shared<ngraph::opset3::DepthToSpace>(reshape_before->input_value(0), mode, block_size);
depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
-
- if (!m_transformation_callback(depth_to_space)) {
- return false;
- }
-
ngraph::replace_node(reshape_after, depth_to_space);
return true;
};
m_matchers.push_back(pass);
}
- template <typename T, class... Args>
+ /// \brief Register given transformation class type to GraphRewrite execution list
+ /// All registered transformations will be executed in a single graph traversal.
+ /// Example below show the basic usage of pass::GraphRewrite
+ ///
+ /// pass::Manager manager;
+ /// auto anchor = manager.register_pass<GraphRewrite>();
+ /// anchor->add_matcher<MatcherPassA>();
+ /// anchor->add_matcher<MatcherPassB>();
+ /// anchor->set_name("CommonMathcers");
+ /// manager.run_passes(f);
+ ///
+ /// For some purposes transformation can be registered and disabled by default.
+ ///
+ /// anchor->add_matcher<MatcherPassB, false>();
+ ///
+ /// \return shared_ptr to the transformation instance
+ template <typename T, bool Enabled = true, class... Args>
std::shared_ptr<T> add_matcher(Args&&... args)
{
static_assert(std::is_base_of<pass::MatcherPass, T>::value,
"pass not derived from MatcherPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
+ auto pass_config = get_pass_config();
+ pass->set_pass_config(pass_config);
+ if (!Enabled)
+ {
+ pass_config->disable<T>();
+ }
m_matchers.push_back(pass);
return pass;
}
#include <vector>
#include "ngraph/pass/pass.hpp"
-#include "ngraph/pass/pass_config.hpp"
#include "ngraph/pass/validate.hpp"
namespace ngraph
Manager();
~Manager();
- template <typename T, class... Args>
+ /// \brief Register given transformation class type to execution list
+ /// Example below show the basic usage of pass::Manager
+ ///
+ /// pass::Manager manager;
+ /// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
+ /// manager.run_passes(f);
+ ///
+ /// For some purposes transformation can be registered and disabled by default.
+ ///
+ /// manager.register_pass<MyTransformation, false>();
+ ///
+ /// \return shared_ptr to the transformation instance
+ template <typename T, bool Enable = true, class... Args>
std::shared_ptr<T> register_pass(Args&&... args)
{
auto rc = push_pass<T>(std::forward<Args>(args)...);
+ rc->set_pass_config(m_pass_config);
if (m_per_pass_validation)
{
push_pass<Validate>();
}
+ if (!Enable)
+ {
+ m_pass_config->disable<T>();
+ }
return rc;
}
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
/// \brief Callback is a lambda function that can be used by registered transformations.
/// The main purpose of this callback is to provide a way for plugins to disable/enable
- /// transformations. In some cases plugins may want not to execute some transformations.
- /// For example plugin can disable unpleasant decompositions because of performance reasons.
+ /// transformations based on some conditions. In some cases plugins may want not to execute some
+ /// transformations.
+ /// For example plugin can disable unpleasant decompositions because of performance reasons for
+ /// some cases.
/// Callback example:
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
/// decomposition pass will check is this decomposition needed or plugin can execute this
/// operation directly. And of course on transformation side we need to have a response for this
/// callback.
- /// if (m_transformation_callback(batch_to_space)) {
+ /// if (transformation_callback(batch_to_space)) {
/// return false;
/// }
/// \param callback lamda function that returns true in case if node is supported by plugin and
/// transformation is not needed
- void set_callback(param_callback callback)
+ NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
+ void set_callback(const param_callback& callback) { m_pass_config->set_callback(callback); }
+ /// \return PassConfig shared object. This object is used for transformations pipeline
+ /// configuration.
+ /// This object allows to disable/enable transformations execution, set callback to particular
+ /// transformation. For mo details see PassConfig class.
+ std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
+ /// \brief Set external PassConfig object.
+ void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
{
- m_transformation_callback = callback;
- m_has_default_callback = false;
+ *m_pass_config = *pass_config;
}
protected:
return pass;
}
- param_callback m_transformation_callback = [](const std::shared_ptr<const Node>&) -> bool {
- return false;
- };
- bool m_has_default_callback = true;
-
+ std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
bool m_visualize = false;
bool m_per_pass_validation = true;
#include "ngraph/deprecated.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
+#include "ngraph/pass/pass_config.hpp"
#include "ngraph/util.hpp"
namespace ngraph
typedef EnumMask<PassProperty> PassPropertyMask;
const PassPropertyMask all_pass_property_off;
- using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
class NGRAPH_API PassBase
{
void set_name(const std::string& name) { m_name = name; }
std::string get_name() const;
+ /// \brief Set callback for particular transformation type.
+ /// This method set global callback. For more details see PassConfig class
+ /// documentation.
+ /// \param callback lambda function that takes node and returns bool
void set_callback(const param_callback& callback);
+ /// \brief Set PassConfig for particular transformation instance
+ /// \param pass_config is a PassConfig shared_ptr
+ void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
+ {
+ m_pass_config = pass_config;
+ }
+
+ /// \brief Allows to access PassConfig shared instance
+ /// \return Shared instance of PassConfig class
+ std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
+ /// \brief Applies callback for given node. By default callback returns false.
+ /// This method remains here only for backward compatibility and will be removed
+ /// after all transformations are moved to transformation_callback() method.
+ /// \return result of callback execution for given node
+ NGRAPH_DEPRECATED("Please use transformation_callback method instead")
+ bool m_transformation_callback(const std::shared_ptr<const Node>& node)
+ {
+ return m_pass_config->get_callback(get_type_info())(node);
+ }
+
+ /// \brief Applies callback for given node. By default callback returns false.
+ /// \param node which will be used inside callback
+ /// \return result of callback execution for given node
+ bool transformation_callback(const std::shared_ptr<const Node>& node)
+ {
+ return m_pass_config->get_callback(get_type_info())(node);
+ }
+
using type_info_t = DiscreteTypeInfo;
virtual const type_info_t& get_type_info() const = 0;
protected:
void set_property(const PassPropertyMask& prop, bool value);
- param_callback m_transformation_callback =
- [](const std::shared_ptr<const Node>&) -> bool { return false; };
- bool m_has_default_callback = true;
-
private:
PassPropertyMask m_property;
+
std::string m_name;
+ std::shared_ptr<PassConfig> m_pass_config;
};
class NGRAPH_API FunctionPass : public PassBase
#pragma once
-#include <map>
-#include <string>
+#include <list>
+#include <memory>
+#include <vector>
-#include <ngraph/ngraph_visibility.hpp>
+#include "ngraph/deprecated.hpp"
+#include "ngraph/function.hpp"
+#include "ngraph/node.hpp"
+#include "ngraph/util.hpp"
namespace ngraph
{
namespace pass
{
- class PassConfig;
- }
-}
+ using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
+ using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>;
-class NGRAPH_API ngraph::pass::PassConfig
-{
-public:
- PassConfig();
- const std::map<std::string, bool>& get_enables() const { return m_pass_enables; }
- void set_pass_enable(const std::string& name, bool enable);
- bool get_pass_enable(const std::string& name) const;
- const std::map<std::string, bool>& get_pass_attributes() const { return m_pass_attributes; }
- void set_pass_attribute(const std::string& name, bool enable);
- bool get_pass_attribute(const std::string& name) const;
-
-private:
- std::map<std::string, bool> m_pass_enables;
- std::map<std::string, bool> m_pass_attributes;
-};
+ /// \brief Class representing a transformations config that is used for disabling/enabling
+ /// transformations registered inside pass::Manager and also allows to set callback for all
+ /// transformations or for particular transformation.
+ ///
+ /// When pass::Manager is created all passes registered inside this manager including nested
+ /// passes will share the same instance of PassConfig class.
+ /// To work with this class first you need to get shared instance of this class by calling
+ /// manager.get_pass_config() method. Then you will be able to disable/enable passes based
+ /// on transformations type_info. For example:
+ ///
+ /// pass::Manager manager;
+ /// manager.register_pass<CommonOptimizations>();
+ /// auto pass_config = manager.get_pass_config();
+ /// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
+ /// // CommonOptimizations pipeline
+ /// manager.run_passes(f);
+ ///
+ /// Sometimes it is needed to call transformation inside other transformation manually. And
+ /// for that case before running transformation you need manually check that this pass is
+ /// not disabled and then you need to set current PassConfig instance to this
+ /// transformation. For example:
+ ///
+ /// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
+ /// // you need to call get_pass_config() method to get shared instance of PassConfig
+ /// auto pass_config = get_pass_config();
+ ///
+ /// // Before running nested transformation you need to check is it disabled or not
+ /// if (!pass_config->is_disabled<ConvertGELU>()) {
+ /// auto pass = ConvertGELU();
+ /// pass->set_pass_config(pass_config);
+ /// pass.apply(node);
+ /// }
+ ///
+ /// Following this logic inside your transformations you will guaranty that transformations
+ /// will be executed in a right way.
+ class NGRAPH_API PassConfig
+ {
+ public:
+ /// \brief Disable transformation by its type_info
+ /// \param type_info Transformation type_info
+ void disable(const DiscreteTypeInfo& type_info) { m_disabled.insert(type_info); }
+ /// \brief Disable transformation by its class type (based on type_info)
+ template <typename T>
+ void disable()
+ {
+ disable(T::type_info);
+ }
+
+ /// \brief Enable transformation by its type_info
+ /// \param type_info Transformation type_info
+ void enable(const DiscreteTypeInfo& type_info) { m_disabled.erase(type_info); }
+ /// \brief Enable transformation by its class type (based on type_info)
+ template <typename T>
+ void enable()
+ {
+ enable(T::type_info);
+ }
+
+ /// \brief Set callback for all kind of transformations
+ void set_callback(const param_callback& callback) { m_callback = callback; }
+ template <typename... Args>
+ typename std::enable_if<sizeof...(Args) == 0>::type
+ set_callback(const param_callback& callback)
+ {
+ }
+
+ /// \brief Set callback for particular transformation class types
+ ///
+ /// Example below show how to set callback for one or multiple passes using this method.
+ ///
+ /// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
+ /// ngraph::pass::ConvertSpaceToBatch>(
+ /// [](const_node_ptr &node) -> bool {
+ /// // Disable transformations for cases when input shape rank is not
+ /// equal to 4
+ /// const auto input_shape_rank =
+ /// node->get_output_partial_shape(0).rank().get_length();
+ /// if (input_shape_rank != 4) {
+ /// return false;
+ /// }
+ /// return true;
+ /// });
+ ///
+ /// Note that inside transformations you must provide code that work with this callback.
+ /// See example below:
+ ///
+ /// if (transformation_callback(node)) {
+ /// return false; // exit from transformation
+ /// }
+ ///
+ template <typename T, class... Args>
+ void set_callback(const param_callback& callback)
+ {
+ m_callback_map[T::type_info] = callback;
+ set_callback<Args...>(callback);
+ }
+
+ /// \brief Get callback for given transformation type_info
+ /// \param type_info Transformation type_info
+ ///
+ /// In case if callback wasn't set for given transformation type then global callback
+ /// will be returned. But if even global callback wasn't set then default callback will
+ /// be returned.
+ param_callback get_callback(const DiscreteTypeInfo& type_info) const;
+
+ /// \brief Get callback for given transformation class type
+ /// \return callback lambda function
+ template <typename T>
+ param_callback get_callback() const
+ {
+ return get_callback(T::type_info);
+ }
+
+ /// \brief Check either transformation type is disabled or not
+ /// \param type_info Transformation type_info
+ /// \return true if transformation type was disabled and false otherwise
+ bool is_disabled(const DiscreteTypeInfo& type_info) const
+ {
+ return m_disabled.count(type_info);
+ }
+
+ /// \brief Check either transformation class type is disabled or not
+ /// \return true if transformation type was disabled and false otherwise
+ template <typename T>
+ bool is_disabled() const
+ {
+ return is_disabled(T::type_info);
+ }
+
+ private:
+ param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
+ return false;
+ };
+ param_callback_map m_callback_map;
+ std::unordered_set<DiscreteTypeInfo> m_disabled;
+ };
+ }
+}
\ No newline at end of file
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
bool rewritten = false;
+ const auto& pass_config = get_pass_config();
// Initialize execution queue with nodes in topological order
deque<std::shared_ptr<Node>> nodes_to_run;
std::unordered_map<NodeTypeInfo, std::vector<size_t>> type_to_matcher;
for (size_t matcher_index = 0; matcher_index < m_matchers.size(); ++matcher_index)
{
+ // Skip passes that are disabled
+ if (pass_config->is_disabled(m_matchers[matcher_index]->get_type_info()))
+ continue;
+
auto matcher = m_matchers[matcher_index]->get_matcher();
if (!matcher)
{
return false;
}
- if (!m_has_default_callback)
- {
- m_pass->set_callback(m_transformation_callback);
- }
-
// Apply MatcherPass. In case if it returns true no other MatcherPasses will apply
// to this node
bool status = m_pass->apply(node);
{
for (auto& m_pass : m_matchers)
{
+ // Skip passes that are disabled
+ if (pass_config->is_disabled(m_pass->get_type_info()))
+ continue;
+
if (run_matcher_pass(m_pass, node))
{
rewritten = true;
pass::Manager::Manager()
: m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING"))
+ , m_pass_config(std::make_shared<PassConfig>())
{
}
bool function_changed = false;
for (auto& pass : m_pass_list)
{
- pass_timer.start();
- if (!m_has_default_callback)
+ if (m_pass_config->is_disabled(pass->get_type_info()))
{
- pass->set_callback(m_transformation_callback);
+ NGRAPH_DEBUG << "Pass " << pass->get_name() << " is disabled";
+ continue;
}
+ pass_timer.start();
+
NGRAPH_SUPPRESS_DEPRECATED_START
if (auto matcher_pass = dynamic_pointer_cast<MatcherPass>(pass))
{
pass::PassBase::PassBase()
: m_property{all_pass_property_off}
+ , m_pass_config(std::make_shared<PassConfig>())
{
}
void pass::PassBase::set_callback(const param_callback& callback)
{
- m_transformation_callback = callback;
- m_has_default_callback = false;
+ m_pass_config->set_callback(callback);
}
// The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
//*****************************************************************************
#include "ngraph/pass/pass_config.hpp"
-#include "ngraph/env_util.hpp"
-#include "ngraph/except.hpp"
-#include "ngraph/log.hpp"
-#include "ngraph/util.hpp"
-using namespace std;
using namespace ngraph;
-// TODO: Add file-based configuration support
-pass::PassConfig::PassConfig()
+pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const
{
- //
- // Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES
- // and returns the pass names and whether they should be enabled or disabled in the
- // provided unordered_map. Implementation of pass selection is up to the backend
- // E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would
- // set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims
- //
- string pass_enables = getenv_string("NGRAPH_PASS_ENABLES");
- if (!pass_enables.empty())
- {
- stringstream ss;
- ss << pass_enables;
- while (ss.good())
- {
- string substr;
- getline(ss, substr, ';');
- auto split_str = split(substr, ':', false);
- switch (split_str.size())
- {
- case 1: m_pass_enables.emplace(split_str[0], true); break;
- case 2: m_pass_enables.emplace(split_str[0], parse_string<bool>(split_str[1])); break;
- default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ENABLES: " + substr);
- }
- }
- }
- //
- // Parses the semi-colon separated environment string passed through NGRAPH_PASS_ATTRIBUTES
- // and returns the pass attributes and whether they should be enabled or disabled in the
- // provided unordered_map. Naming of pass attributes is up to the backends.
- //
- // For example:
- // NGRAPH_PASS_ATTRIBUTES="OptimizeForMemory=0;MemoryAssignment::ReuseMemory=1;UseDefaultLayouts"
- // would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on
- // "UseDefaultLayouts"
- //
- static const string pass_attributes = getenv_string("NGRAPH_PASS_ATTRIBUTES");
- if (!pass_attributes.empty())
- {
- stringstream ss;
- ss << pass_attributes;
- while (ss.good())
- {
- string substr;
- getline(ss, substr, ';');
- auto split_str = split(substr, '=', false);
- switch (split_str.size())
- {
- case 1: m_pass_attributes.emplace(split_str[0], true); break;
- case 2:
- m_pass_attributes.emplace(split_str[0], parse_string<bool>(split_str[1]));
- break;
- default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ATTRIBUTES: " + substr);
- }
- }
- }
-}
-
-void pass::PassConfig::set_pass_enable(const string& name, bool enable)
-{
- m_pass_enables[name] = enable;
-}
-
-bool pass::PassConfig::get_pass_enable(const string& name) const
-{
- auto it = m_pass_enables.find(name);
- if (it != m_pass_enables.end())
+ const auto& it = m_callback_map.find(type_info);
+ if (it != m_callback_map.end())
{
return it->second;
}
- return false;
-}
-
-void pass::PassConfig::set_pass_attribute(const string& name, bool enable)
-{
- m_pass_attributes[name] = enable;
-}
-
-bool pass::PassConfig::get_pass_attribute(const string& name) const
-{
- auto it = m_pass_attributes.find(name);
- if (it != m_pass_attributes.end())
+ else
{
- return it->second;
+ return m_callback;
}
- return false;
}
class TestPass : public ngraph::pass::MatcherPass
{
public:
+ NGRAPH_RTTI_DECLARATION;
TestPass()
: MatcherPass()
{
class Anchor : public ngraph::pass::GraphRewrite
{
public:
+ NGRAPH_RTTI_DECLARATION;
Anchor()
: GraphRewrite()
{
}
};
+NGRAPH_RTTI_DEFINITION(TestPass, "TestPass", 0);
+NGRAPH_RTTI_DEFINITION(Anchor, "Anchor", 0);
+
std::shared_ptr<Function> get_function()
{
auto data =
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
-TEST(GraphRewriteTest, ManagerCallback)
+TEST(GraphRewriteTest, ManagerCallbackDeprecated)
{
auto f = get_function();
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
+TEST(GraphRewriteTest, ManagerCallback)
+{
+ auto f = get_function();
+
+ pass::Manager manager;
+ auto anchor = manager.register_pass<Anchor>();
+ anchor->add_matcher<TestPass>();
+ auto pass_config = manager.get_pass_config();
+ pass_config->set_callback(get_callback());
+ manager.run_passes(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
+
TEST(GraphRewriteTest, ManagerCallback2)
{
auto f = get_function();
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Tanh>(f), 1);
+}
+
+TEST(PassConfigTest, Test1)
+{
+ {
+ auto f = get_function();
+
+ pass::Manager manager;
+ manager.register_pass<TestPass>();
+
+ auto pass_config = manager.get_pass_config();
+ pass_config->set_callback(get_callback());
+
+ manager.run_passes(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+ }
+
+ {
+ auto f = get_function();
+
+ pass::Manager manager;
+ manager.register_pass<TestPass>();
+
+ auto pass_config = manager.get_pass_config();
+ pass_config->set_callback<TestPass>(get_callback());
+
+ manager.run_passes(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+ }
+
+ {
+ auto f = get_function();
+
+ pass::Manager manager;
+ manager.register_pass<TestPass>();
+
+ auto pass_config = std::make_shared<ngraph::pass::PassConfig>();
+ pass_config->set_callback<TestPass>(get_callback());
+
+ manager.set_pass_config(pass_config);
+ manager.run_passes(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+ }
+
+ {
+ auto f = get_function();
+
+ pass::Manager manager;
+ auto anchor = manager.register_pass<Anchor>();
+ anchor->add_matcher<TestPass>();
+
+ auto pass_config = anchor->get_pass_config();
+ pass_config->set_callback(get_callback());
+
+ manager.run_passes(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+ }
+
+ {
+ auto f = get_function();
+
+ pass::Manager manager;
+ auto anchor = manager.register_pass<Anchor>();
+ anchor->add_matcher<TestPass>();
+
+ auto pass_config = anchor->get_pass_config();
+ pass_config->set_callback<TestPass>(get_callback());
+
+ manager.run_passes(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+ }
+
+ {
+ auto pass_config = std::make_shared<pass::PassConfig>();
+
+ pass::Manager manager1;
+ pass::Manager manager2;
+ manager1.set_pass_config(pass_config);
+ manager2.set_pass_config(pass_config);
+ ASSERT_EQ(pass_config.use_count(), 1);
+ }
+
+ {
+ auto f = get_function();
+
+ pass::Manager manager;
+ manager.register_pass<TestPass>();
+
+ auto pass_config = manager.get_pass_config();
+ pass_config->set_callback<TestPass>(get_callback());
+
+ pass_config->disable<TestPass>();
+ manager.run_passes(f);
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
+
+ pass_config->enable<TestPass>();
+ manager.run_passes(f);
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+ }
+
+ {
+ auto f = get_function();
+
+ pass::Manager manager;
+ auto anchor = manager.register_pass<Anchor>();
+ anchor->add_matcher<TestPass>();
+
+ auto pass_config = manager.get_pass_config();
+ pass_config->set_callback<TestPass>(get_callback());
+
+ pass_config->disable<TestPass>();
+ manager.run_passes(f);
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
+
+ pass_config->enable<TestPass>();
+ manager.run_passes(f);
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+ }
}
\ No newline at end of file
pass/shape_relevance.hpp
)
+disable_deprecated_warnings()
+
add_library(ngraph_backend SHARED ${SRC})
target_compile_definitions(ngraph_backend
PRIVATE
throw std::invalid_argument("This backend does not support dynamic tensors");
}
-std::shared_ptr<runtime::Executable>
- runtime::Backend::compile(std::shared_ptr<Function> func,
- ngraph::pass::PassConfig& /* pass_config */,
- bool enable_performance_data)
-{
- return compile(func, enable_performance_data);
-}
-
bool runtime::Backend::is_supported(const Node& /* node */) const
{
// The default behavior is that a backend does not support any ops. If this is not the case
#include "backend_visibility.hpp"
#include "executable.hpp"
#include "ngraph/function.hpp"
-#include "ngraph/pass/pass_config.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
bool enable_performance_data = false) = 0;
- /// \brief Compiles a Function.
- /// \param func The function to compile
- /// \param pass_config Configuration object for defining compilation options
- /// \returns compiled function or nullptr on failure
- virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
- ngraph::pass::PassConfig& pass_config,
- bool enable_performance_data = false);
-
/// \brief Loads a previously saved Executable object from a stream.
/// \param input_stream the opened input stream containing the saved Executable
/// \returns A compiled function or throws an exception on error