#include <ngraph/graph_util.hpp>
#include <ngraph/op/result.hpp>
#include <ngraph/op/parameter.hpp>
+#include <ngraph/op/util/op_types.hpp>
#include <ngraph/rt_info.hpp>
using namespace InferenceEngine;
auto orderedOps = function->get_ordered_ops();
orderedOps.erase(
std::remove_if(std::begin(orderedOps), std::end(orderedOps), [] (const std::shared_ptr<ngraph::Node>& node) {
- return node->is_constant();
+ return ngraph::op::is_constant(node);
}),
std::end(orderedOps));
bool allEmpty = true;
auto NoConstants = [] (std::vector<ngraph::Input<ngraph::Node>>&& inputs) {
std::vector<ngraph::Input<ngraph::Node>> result;
for (auto&& input : inputs) {
- if (!(input.get_source_output().get_node()->is_constant())) {
+ if (!(ngraph::op::is_constant(input.get_source_output().get_node()))) {
result.emplace_back(std::move(input));
}
}
InputSet subgraphInputs;
// Get all subgraph inputs using just node affinities. Also collect transitive closure
for (auto&& node : orderedOps) {
- if (node->is_parameter()) {
+ if (ngraph::op::is_parameter(node)) {
graphInputNodes.insert(node.get());
subgraphInputs.insert(Input{node.get(), 0});
nodeInputDependencies[node.get()].insert(Input{node.get(), 0});
}
auto& nodeSubgraphCyclicInputDependency = nodeSubgraphCyclicInputDependencies[node.get()];
for (auto&& subgraphInput : allNodeSubgraphInputs) {
- if (!subgraphInput.get_node()->is_parameter() && subgraphIds[node.get()] == subgraphIds[InputNode(subgraphInput)]) {
+ if (!ngraph::op::is_parameter(subgraphInput.get_node()) &&
+ subgraphIds[node.get()] == subgraphIds[InputNode(subgraphInput)]) {
nodeSubgraphCyclicInputDependency.emplace(subgraphInput);
}
}
NodeMap<ngraph::Node*> subgraphParameterToPrevResult;
std::vector<std::shared_ptr<ngraph::op::Result>> results;
for (auto&& input : subgraphInputs) {
- if (!(input.get_node()->is_parameter())) {
+ if (!ngraph::op::is_parameter(input.get_node())) {
auto output = input.get_source_output();
output.remove_target_input(input);
auto result = std::make_shared<ngraph::op::Result>(output);
for (auto&& subgraphIdPtrValue : subgraphIds) {
auto node = subgraphIdPtrValue.first;
auto& subgraph = subgraphs[subgraphIdPtrValue.second];
- if (node->is_output()) {
+ if (ngraph::op::is_output(node)) {
subgraph._results.emplace_back(
std::dynamic_pointer_cast<ngraph::op::v0::Result>(node->shared_from_this()));
- } else if (node->is_parameter()) {
+ } else if (ngraph::op::is_parameter(node)) {
subgraph._parameters.emplace_back(
std::dynamic_pointer_cast<ngraph::op::v0::Parameter>(node->shared_from_this()));
}
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/op/fused/gelu.hpp>
+#include <ngraph/op/util/op_types.hpp>
#include "ngraph_ops/fully_connected.hpp"
#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
if (function != nullptr) {
std::unordered_set<std::string> originalOps;
for (auto&& node : function->get_ops()) {
- if (!node->is_constant() && !node->is_parameter() && !node->is_output()) {
+ if (!ngraph::op::is_constant(node) && !ngraph::op::is_parameter(node) && !ngraph::op::is_output(node)) {
originalOps.emplace(node->get_friendly_name());
}
}
#include <transformations_visibility.hpp>
+#include <ngraph/op/util/op_types.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/validation_util.hpp>
auto input = reduce->input_value(0);
auto axes_node = reduce->input_value(1).get_node_shared_ptr();
- if (!axes_node->is_constant()) {
+ if (!ngraph::op::is_constant(axes_node)) {
return false;
}
#include <graph_tools.hpp>
#include <functional_test_utils/plugin_cache.hpp>
#include <multi-device/multi_device_config.hpp>
+#include <ngraph/op/util/op_types.hpp>
#include "common_test_utils/file_utils.hpp"
#include "common_test_utils/unicode_utils.hpp"
ASSERT_NE(nullptr, function);
std::unordered_set<std::string> expectedLayers;
for (auto &&node : function->get_ops()) {
- if (!node->is_constant() && !node->is_parameter() && !node->is_output()) {
+ if (!ngraph::op::is_constant(node) && !ngraph::op::is_parameter(node) && !ngraph::op::is_output(node)) {
expectedLayers.emplace(node->get_friendly_name());
}
}
ASSERT_NE(nullptr, function);
std::unordered_set<std::string> expectedLayers;
for (auto &&node : function->get_ops()) {
- if (!node->is_constant() && !node->is_parameter() && !node->is_output()) {
+ if (!ngraph::op::is_constant(node) && !ngraph::op::is_parameter(node) && !ngraph::op::is_output(node)) {
expectedLayers.emplace(node->get_friendly_name());
}
}
//
#include "hetero/query_network.hpp"
+#include <ngraph/op/util/op_types.hpp>
#include <ngraph/variant.hpp>
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
ASSERT_NE(nullptr, cnnNetwork.getFunction());
std::set<std::string> expectedLayers;
for (auto&& node : function->get_ops()) {
- if (!node->is_parameter() && !node->is_constant() && !node->is_output()) {
+ if (!ngraph::op::is_parameter(node) &&
+ !ngraph::op::is_constant(node) &&
+ !ngraph::op::is_output(node)) {
expectedLayers.insert(node->get_friendly_name());
}
}
}
ASSERT_EQ(expectedLayers, actualLayers);
}
-} // namespace HeteroTests
\ No newline at end of file
+} // namespace HeteroTests
//
#include "hetero/synthetic.hpp"
+#include <ngraph/op/util/op_types.hpp>
#include <ngraph/variant.hpp>
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
for (auto&& builder : builders) {
auto function = builder();
for (auto&& node : function->get_ordered_ops()) {
- if (!(node->is_constant()) && !(node->is_parameter()) && !(node->is_output())) {
+ if (!ngraph::op::is_constant(node) &&
+ !(ngraph::op::is_parameter(node)) &&
+ !(ngraph::op::is_output(node))) {
result.push_back(FunctionParameter{{node->get_friendly_name()}, function});
}
}
for (std::size_t i = 0; i < ordered_ops.size(); ++i) {
std::unordered_set<std::string> majorPluginNodeIds;
for (auto&& node : ordered_ops) {
- if (!(node->is_constant()) && !(node->is_parameter()) && !(node->is_output()) && d(e)) {
+ if (!(ngraph::op::is_constant(node)) &&
+ !(ngraph::op::is_parameter(node)) &&
+ !(ngraph::op::is_output(node)) && d(e)) {
majorPluginNodeIds.emplace(node->get_friendly_name());
}
}
auto& pluginParameters = std::get<Plugin>(param);
affinities += "\n{\n";
for (auto&& node : std::get<Function>(param)._function->get_ordered_ops()) {
- if (!(node->is_constant()) && !(node->is_parameter()) && !(node->is_output())) {
+ if (!ngraph::op::is_constant(node) &&
+ !(ngraph::op::is_parameter(node)) &&
+ !(ngraph::op::is_output(node))) {
std::string affinity;
if (std::get<Function>(param)._majorPluginNodeIds.end() !=
std::get<Function>(param)._majorPluginNodeIds.find(node->get_friendly_name())) {
ASSERT_NE(nullptr, cnnNetwork.getFunction());
}
-} // namespace HeteroTests
\ No newline at end of file
+} // namespace HeteroTests
#include <assert.h>
#include <ngraph/function.hpp>
+#include <ngraph/op/util/op_types.hpp>
#include <ngraph/pass/visualize_tree.hpp>
std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngraph::Function> & f1, const std::shared_ptr<ngraph::Function> & f2) {
std::ostringstream err_log;
for (auto & op : f->get_ops()) {
- if (op->is_constant()) continue;
+ if (ngraph::op::is_constant(op)) continue;
const auto & rt_info = op->get_rt_info();
for (const auto & attr_name : attrs_to_check) {
void visualize_function(std::shared_ptr<ngraph::Function> f, const std::string & file_name) {
std::vector<std::shared_ptr<ngraph::Function> > g{f};
ngraph::pass::VisualizeTree(file_name).run_on_module(g);
-}
\ No newline at end of file
+}
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/op/util/op_types.hpp>
#include <ngraph/specialize_function.hpp>
#include <ngraph_functions/utils/ngraph_helpers.hpp>
const auto &foldedFunc = specialize_function(function, paramElementTypes, paramShapes, inBuffers, true, true);
for (const auto &op : foldedFunc->get_ops()) {
- NGRAPH_CHECK(op->is_constant() || op->is_output() || op->is_parameter(),
+ NGRAPH_CHECK(op::is_constant(op) || op::is_output(op) || op::is_parameter(op),
"Function was not fully folded to constant state!\n",
"At least one non constant node with type ", op->get_type_name(),
" present in function.");
const auto &output = function->output(i).get_node_shared_ptr();
NGRAPH_CHECK(output->inputs().size() == 1);
auto parrentNode = output->input_value(0).get_node_shared_ptr();
- NGRAPH_CHECK(parrentNode->is_constant(), "Function was not fully folded to constant state!\n",
+ NGRAPH_CHECK(op::is_constant(parrentNode), "Function was not fully folded to constant state!\n",
"Parent node of one of results is not constant and has type ", parrentNode->get_type_name());
const auto data = std::dynamic_pointer_cast<opset1::Constant>(parrentNode)->get_data_ptr<std::uint8_t>();
#include "ngraph/check.hpp"
#include "ngraph/except.hpp"
#include "ngraph/node.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset.hpp"
#include "node_factory.hpp"
#include "tensor_iterator_builder.hpp"
std::shared_ptr<ngraph::Node>(m_opset.create(op_type_name));
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
- NGRAPH_CHECK(!op_node->is_constant(),
+ NGRAPH_CHECK(!ngraph::op::is_constant(op_node),
"Currently NodeFactory doesn't support Constant node: ",
op_type_name);
op/util/binary_elementwise_logical.hpp
op/util/broadcast_base.cpp
op/util/broadcast_base.hpp
+ op/util/elementwise_args.cpp
+ op/util/elementwise_args.hpp
op/util/embeddingbag_packed_base.cpp
op/util/embeddingbag_packed_base.hpp
op/util/embeddingbag_offsets_base.cpp
op/util/unary_elementwise_arithmetic.cpp
op/util/unary_elementwise_arithmetic.hpp
op/util/variable.hpp
+ op/util/op_types.cpp
+ op/util/op_types.hpp
ops.hpp
opsets/opset.cpp
partial_shape.cpp
pattern/op/skip.hpp
pattern/op/true.cpp
pattern/op/true.hpp
- placement.cpp
- placement.hpp
provenance.cpp
provenance.hpp
rank.hpp
}
} // namespace onnx_import
} // namespace ngraph
+
+bool ngraph::op::is_null(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::onnx_import::NullNode*>(node) != nullptr;
+}
+
+bool ngraph::op::is_null(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_null(node.get());
+}
#include <memory>
#include "ngraph/node.hpp"
+#include "utils/onnx_importer_visibility.hpp"
namespace ngraph
{
+ namespace op
+ {
+ ONNX_IMPORTER_API
+ bool is_null(const ngraph::Node* node);
+ ONNX_IMPORTER_API
+ bool is_null(const std::shared_ptr<ngraph::Node>& node);
+ }
namespace onnx_import
{
/// \brief Represents a missing optional input or output of an ONNX node
const NodeTypeInfo& get_type_info() const override { return type_info; }
NullNode() = default;
- bool is_null() const final override { return true; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
#include "clip.hpp"
#include "default_opset.hpp"
#include "ngraph/builder/make_constant.hpp"
+#include "ngraph/frontend/onnx_import/core/null_node.hpp"
namespace ngraph
{
// If second input is provided, assign to min input, otherwise set lowest
// numeric limit of double as min input.
- if (inputs.size() > 1 && !inputs.at(1)->is_null())
+ if (inputs.size() > 1 && !ngraph::op::is_null(inputs.at(1)))
{
min = inputs.at(1);
}
// If third input is provided, assign to max input, otherwise set maximum
// numeric limit of double as max input.
- if (inputs.size() == 3 && !inputs.at(2)->is_null())
+ if (inputs.size() == 3 && !ngraph::op::is_null(inputs.at(2)))
{
max = inputs.at(2);
}
#include "dequantize_linear.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
+#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/shape.hpp"
{
std::shared_ptr<ngraph::Node> get_zero_point(const NodeVector& inputs)
{
- if (inputs.size() == 3 && !inputs[2]->is_null())
+ if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2]))
{
auto zero_point = inputs[2];
#include "default_opset.hpp"
#include "gru.hpp"
#include "ngraph/builder/split.hpp"
+#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/shape.hpp"
#include "utils/recurrent.hpp"
const auto& ng_inputs = node.get_ng_inputs();
const auto el_type = ng_inputs.at(0)->get_output_element_type(0);
- if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
+ if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
{
auto bias = ng_inputs.at(3);
// gates_count * 2 since B is: [Wb, Rb]
#include "core/graph.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
+#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/function.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "utils/reshape.hpp"
namespace ngraph
const std::shared_ptr<ngraph::Node>& body_cond)
{
bool loop_cond_value = false;
- if (loop_cond->is_constant() &&
+ if (ngraph::op::is_constant(loop_cond) &&
loop_cond->get_element_type() == element::boolean)
{
loop_cond_value = as_type_ptr<default_opset::Constant>(loop_cond)
}
// According to ONNX skipped cond input (is_null) means
// that is has true value
- bool is_loop_cond_true = loop_cond->is_null() || loop_cond_value == true;
+ bool is_loop_cond_true =
+ ngraph::op::is_null(loop_cond) || loop_cond_value == true;
if (!is_loop_cond_true)
{
{
const auto second_input =
body_cond->input_value(1).get_node_shared_ptr();
- if (second_input->is_constant() &&
+ if (ngraph::op::is_constant(second_input) &&
second_input->get_element_type() == element::boolean &&
as_type_ptr<default_opset::Constant>(second_input)
->cast_vector<bool>()
// At this moment nGraph TensorIterator doesn't have support for conditional
// termination of iterations.
CHECK_VALID_NODE(node,
- !trip_count->is_null(),
+ !ngraph::op::is_null(trip_count),
"Currently nGraph requires trip count input to be provided.");
const OutputVector loop_carried_dependencies{std::next(ng_inputs.begin(), 2),
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/enum_names.hpp"
+#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
// ------ Optional inputs ------
// The bias tensor for input gate. Shape [num_directions, 4*hidden_size]
- if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
+ if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
{
auto bias = ng_inputs.at(3);
auto split_bias = builder::opset1::split(bias, 2, 1);
0.f));
}
// The lengths of the sequences in a batch. Shape [batch_size]
- if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
+ if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
{
m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = ng_inputs.at(4);
}
}
// The initial value of the hidden.
// Shape [num_directions, batch_size, hidden_size]
- if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
+ if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
{
m_map[LSTMInput::LSTM_INPUT_INIT_H] =
builder::opset1::reorder_axes(ng_inputs.at(5), {1, 0, 2});
}
// The initial value of the cell.
// Shape [num_directions, batch_size, hidden_size]
- if (ng_inputs.size() > 6 && !ng_inputs.at(6)->is_null())
+ if (ng_inputs.size() > 6 && !ngraph::op::is_null(ng_inputs.at(6)))
{
m_map[LSTMInput::LSTM_INPUT_INIT_C] =
builder::opset1::reorder_axes(ng_inputs.at(6), {1, 0, 2});
std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
}
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
- if (ng_inputs.size() > 7 && !ng_inputs.at(7)->is_null())
+ if (ng_inputs.size() > 7 && !ngraph::op::is_null(ng_inputs.at(7)))
{
m_map[LSTMInput::LSTM_INPUT_P] = ng_inputs.at(7);
}
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/pad.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
#include "pad.hpp"
#include "utils/convpool.hpp"
data->get_element_type(), ngraph::Shape{}, {0});
}
- if (pads->is_constant())
+ if (ngraph::op::is_constant(pads))
{
std::vector<std::int64_t> pads_vector =
ngraph::as_type_ptr<default_opset::Constant>(pads)
#include "resize.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
+#include "ngraph/op/util/op_types.hpp"
namespace ngraph
{
attrs.mode = mode;
attrs.align_corners = false;
- if (scales->is_constant() && data_shape.is_static())
+ if (ngraph::op::is_constant(scales) && data_shape.is_static())
{
const auto scales_const =
as_type_ptr<default_opset::Constant>(scales->shared_from_this());
#include "gather.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "utils/common.hpp"
namespace
if (inputs.size() >= 4) // axes input provided
{
axes = inputs.at(3);
- CHECK_VALID_NODE(node, axes->is_constant(), "Axes input must be constant");
+ CHECK_VALID_NODE(
+ node, ngraph::op::is_constant(axes), "Axes input must be constant");
}
else
{
#include "default_opset.hpp"
#include "exceptions.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "upsample.hpp"
namespace ngraph
attrs.axes.insert(ax);
}
- if (scales->is_constant() && data_shape.is_static())
+ if (ngraph::op::is_constant(scales) && data_shape.is_static())
{
const auto scales_const =
as_type_ptr<default_opset::Constant>(scales->shared_from_this());
#include "ngraph/builder/split.hpp"
#include "ngraph/check.hpp"
#include "ngraph/enum_names.hpp"
+#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "recurrent.hpp"
namespace ngraph
const std::size_t batch_size = m_map[OpInput::X]->get_shape().at(1);
const std::size_t num_directions = m_map[OpInput::W]->get_shape().front();
- if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
+ if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
{
auto bias = ng_inputs.at(3);
auto split_bias = builder::opset1::split(bias, 2, 1);
m_map[OpInput::B] = std::make_shared<default_opset::Constant>(
el_type, Shape{num_directions, gates_count * hidden_size}, 0.f);
}
- if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
+ if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
{
m_map[OpInput::SEQ_LENGTHS] = ng_inputs.at(4);
}
element::i32, Shape{batch_size}, m_map[OpInput::X]->get_shape().at(0));
}
// The initial value of the hidden.
- if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
+ if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
{
m_map[OpInput::INIT_H] = ng_inputs.at(5);
}
#include "default_opset.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/reshape.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
#include "reshape.hpp"
node_shape);
// If node is a Constant, recreate as Constant with Shape{}
- if (node->is_constant())
+ if (ngraph::op::is_constant(node))
{
const auto value =
ngraph::as_type_ptr<default_opset::Constant>(node)->get_data_ptr();
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/util.hpp"
using namespace std;
node->revalidate_and_infer_types();
// If we find a parameter make sure it is in the list of parameters of the function
- if (node->is_parameter())
+ if (op::is_parameter(node))
{
auto it = std::find(m_parameters.begin(), m_parameters.end(), node);
if (it == m_parameters.end())
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/provenance.hpp"
std::shared_ptr<Node> replacement,
const std::vector<int64_t>& output_order)
{
- if (target->is_output())
+ if (ngraph::op::is_output(target))
{
throw ngraph_error("Result nodes cannot be replaced.");
}
void ngraph::replace_node(const std::shared_ptr<Node>& target,
const OutputVector& replacement_values)
{
- if (target->is_output())
+ if (ngraph::op::is_output(target))
{
throw ngraph_error("Result nodes cannot be replaced.");
}
{
ngraph::Node* curr = stack.top();
visited.insert(curr);
- if (curr->is_output())
+ if (ngraph::op::is_output(curr))
{
return false;
}
// Make parameter node
shared_ptr<op::Parameter> par_node = make_shared<op::Parameter>(
src_node->get_output_element_type(0), src_node->get_output_shape(0));
- par_node->set_placement(dst_node->get_placement());
// Fix input / output among src, dst and par
std::vector<Input<Node>> dst_inputs = get_inputs_from(*src_node, *dst_node);
// Add res node
// Add [4], [5], [6], [7]
shared_ptr<op::Result> res_node = make_shared<op::Result>(src_node);
- res_node->set_placement(src_node->get_placement());
return make_pair(res_node, par_node);
}
ngraph::Node* n = stack.top();
if (instances_seen.count(n) == 0)
{
- if (n->is_output())
+ if (ngraph::op::is_output(n))
{
return true;
}
{
for (auto& input : output.get_target_inputs())
{
- if (input.get_node()->is_op())
+ if (op::is_op(input.get_node()))
{
auto op = static_cast<ngraph::op::Op*>(input.get_node());
if (auto op_annotations = op->get_op_annotations())
bool ngraph::compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2)
{
- if (!(n1->is_constant() && n2->is_constant()))
+ if (!(op::is_constant(n1) && op::is_constant(n2)))
{
return false;
}
#include "ngraph/check.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
-#include "ngraph/placement.hpp"
namespace ngraph
{
NGRAPH_API
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func);
- // Assert that nodes in the function is colocated and return that placement
- NGRAPH_API
- Placement get_colocated_function_placement(std::shared_ptr<Function> func);
-
NGRAPH_API
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>>
insert_result_parameter_split(const std::shared_ptr<Node>& src_node,
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/pattern/matcher.hpp"
-#include "ngraph/placement.hpp"
using namespace std;
using namespace ngraph;
return m_outputs;
}
-bool Node::is_output() const
-{
- return false;
-}
-
-bool Node::is_constant() const
-{
- return false;
-}
-
const std::string& Node::description() const
{
// Terrible transitional kludge to keep description working while we change
m_friendly_name = name;
}
-Placement Node::get_placement() const
-{
- return m_placement;
-}
-
-void Node::set_placement(Placement placement)
-{
- m_placement = placement;
-}
-
void Node::add_provenance_group_member(const shared_ptr<Node>& node)
{
m_provenance_group.insert(node);
return result;
}
-std::tuple<element::Type, PartialShape>
- Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob)
-{
- element::Type element_type = get_input_element_type(0);
- PartialShape pshape = get_input_partial_shape(0);
-
- if (get_input_size() > 1)
- {
- for (size_t i = 1; i < get_input_size(); ++i)
- {
- NODE_VALIDATION_CHECK(
- this,
- element::Type::merge(element_type, element_type, get_input_element_type(i)),
- "Argument element types are inconsistent.");
-
- if (autob.m_type == op::AutoBroadcastType::NONE)
- {
- NODE_VALIDATION_CHECK(this,
- PartialShape::merge_into(pshape, get_input_partial_shape(i)),
- "Argument shapes are inconsistent.");
- }
- else if (autob.m_type == op::AutoBroadcastType::NUMPY ||
- autob.m_type == op::AutoBroadcastType::PDPD)
- {
- NODE_VALIDATION_CHECK(
- this,
- PartialShape::broadcast_merge_into(pshape, get_input_partial_shape(i), autob),
- "Argument shapes are inconsistent.");
- }
- else
- {
- NODE_VALIDATION_CHECK(this, false, "Unsupported auto broadcast specification");
- }
- }
- }
-
- return std::make_tuple(element_type, pshape);
-}
-
-void Node::validate_and_infer_elementwise_arithmetic(const op::AutoBroadcastSpec& autob)
-{
- auto args_et_pshape = validate_and_infer_elementwise_args(autob);
- element::Type& args_et = std::get<0>(args_et_pshape);
- PartialShape& args_pshape = std::get<1>(args_et_pshape);
-
- NODE_VALIDATION_CHECK(this,
- args_et.is_dynamic() || args_et != element::boolean,
- "Arguments cannot have boolean element type (argument element type: ",
- args_et,
- ").");
-
- set_output_type(0, args_et, args_pshape);
-}
-
-void Node::validate_and_infer_elementwise_logical(const op::AutoBroadcastSpec& autob)
-{
- auto args_et_pshape = validate_and_infer_elementwise_args(autob);
- element::Type& args_et = std::get<0>(args_et_pshape);
- PartialShape& args_pshape = std::get<1>(args_et_pshape);
-
- NODE_VALIDATION_CHECK(
- this,
- args_et.is_dynamic() || args_et == element::boolean,
- "Operands for logical operators must have boolean element type but have element type ",
- args_et,
- ".");
-
- set_output_type(0, element::boolean, args_pshape);
-}
-
bool Node::match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/output_vector.hpp"
-#include "ngraph/placement.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/type.hpp"
using type_info_t = DiscreteTypeInfo;
protected:
- std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
- const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
- void validate_and_infer_elementwise_arithmetic(
- const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
- void validate_and_infer_elementwise_logical(
- const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
-
/// \brief Construct an unitialized Node
Node() {}
/// \brief Construct an unitialized Node
void safe_delete(NodeVector& nodes, bool recurse);
public:
+ virtual bool is_parameter() const { return false; }
+ virtual bool is_output() const { return false; }
+ virtual bool is_constant() const { return false; }
virtual ~Node();
virtual bool visit_attributes(AttributeVisitor& visitor) { return false; }
- virtual bool is_unary_elementwise_arithmetic() const { return false; }
- virtual bool is_binary_elementwise_arithmetic() const { return false; }
- virtual bool is_binary_elementwise_comparison() const { return false; }
- virtual bool is_binary_elementwise_logical() const { return false; }
- /// \returns true if node supports autobroadcast operations
- virtual bool supports_auto_broadcast() const { return false; }
/// \returns the autobroadcasr spec
virtual const op::AutoBroadcastSpec& get_autob() const;
- /// \returns true if the node can decompose
- virtual bool supports_decompose() const { return false; }
/// \brief Evaluates the op on input_values putting results in output_values
/// \returns true if successful
virtual bool evaluate(const HostTensorVector& output_values,
const element::Type& element_type,
const PartialShape& pshape);
- virtual bool is_parameter() const { return false; }
- virtual bool is_output() const;
- virtual bool is_constant() const;
- virtual bool is_null() const { return false; }
- virtual bool is_op() const { return false; }
- virtual bool is_pattern() const { return false; }
- virtual bool is_commutative() const { return false; }
virtual bool is_dynamic() const;
- virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; }
/// \brief Writes a description of a node to a stream
/// \param os The stream; should be returned
/// True if this and node have one output with same element type and shape
bool has_same_type(std::shared_ptr<const Node> node) const;
- /// Get device placement
- Placement get_placement() const;
-
- /// Set device placement
- void set_placement(Placement placement);
-
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
RTMap& get_rt_info() { return m_rt_info; }
std::set<std::shared_ptr<Node>> m_provenance_group;
std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs;
- Placement m_placement = Placement::DEFAULT;
std::shared_ptr<ngraph::op::util::OpAnnotations> m_op_annotations;
std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
};
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
- virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
public:
static constexpr NodeTypeInfo type_info{"Constant", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
+ bool is_constant() const override { return true; }
Constant() = default;
/// \brief Initialize a constant from tensor
get_data_ptr());
}
- bool is_constant() const override { return true; }
bool get_all_data_elements_bitwise_identical() const
{
return m_all_elements_bitwise_identical;
#include "ngraph/op/constant.hpp"
#include "ngraph/op/crop_and_resize.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
auto& crop_size_et = crop_size.get_element_type();
NODE_VALIDATION_CHECK(this, crop_size_et.is_integral(), "crops_size must be integral");
auto crop_size_node = crop_size.get_node_shared_ptr();
- NODE_VALIDATION_CHECK(this, crop_size_node->is_constant(), "crop_size must be a constant");
+ NODE_VALIDATION_CHECK(
+ this, ngraph::op::is_constant(crop_size_node), "crop_size must be a constant");
auto crop_size_const = static_pointer_cast<op::Constant>(crop_size_node);
if (crop_size_et == element::i8)
{
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
#include "ngraph/node.hpp"
#include "ngraph/op/fused/batch_to_space.hpp"
#include "ngraph/op/reshape.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
using namespace std;
auto block = input_value(1);
auto crops_begin = input_value(2);
auto crops_end = input_value(3);
- NGRAPH_CHECK(block.get_node_shared_ptr()->is_constant(),
+ NGRAPH_CHECK(ngraph::op::is_constant(block.get_node()),
"block_shape input node is expected to be a static constant");
- NGRAPH_CHECK(crops_begin.get_node_shared_ptr()->is_constant(),
+ NGRAPH_CHECK(ngraph::op::is_constant(crops_begin.get_node()),
"crops_begin input node is expected to be a static constant");
- NGRAPH_CHECK(crops_end.get_node_shared_ptr()->is_constant(),
+ NGRAPH_CHECK(ngraph::op::is_constant(crops_end.get_node()),
"crops_end input node is expected to be a static constant");
const auto& data_type = get_input_element_type(0);
#include "ngraph/op/divide.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/multiply.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
const auto& input_rank = input_pshape.rank();
const auto& axes_rank = axes_pshape.rank();
- NODE_VALIDATION_CHECK(this, axes_node->is_constant(), "Input axes must be Constant type");
+ NODE_VALIDATION_CHECK(this, op::is_constant(axes_node), "Input axes must be Constant type");
if (axes_rank.is_static())
{
#include "ngraph/node.hpp"
#include "ngraph/op/fused/space_to_batch.hpp"
#include "ngraph/op/pad.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
using namespace std;
auto block = input_value(1);
auto crops_begin = input_value(2);
auto crops_end = input_value(3);
- NGRAPH_CHECK(block.get_node_shared_ptr()->is_constant(),
+ NGRAPH_CHECK(ngraph::op::is_constant(block.get_node()),
"block_shape input node is expected to be a static constant");
- NGRAPH_CHECK(crops_begin.get_node_shared_ptr()->is_constant(),
+ NGRAPH_CHECK(ngraph::op::is_constant(crops_begin.get_node()),
"crops_begin input node is expected to be a static constant");
- NGRAPH_CHECK(crops_end.get_node_shared_ptr()->is_constant(),
+ NGRAPH_CHECK(ngraph::op::is_constant(crops_end.get_node()),
"crops_end input node is expected to be a static constant");
const auto& data_type = get_input_element_type(0);
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/reshape.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/runtime/reference/copy.hpp"
#include "ngraph/validation_util.hpp"
const auto axes_node = input_value(1).get_node_shared_ptr();
- if (data_rank.is_dynamic() || !axes_node->is_constant())
+ if (data_rank.is_dynamic() || !op::is_constant(axes_node))
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
#include "ngraph/op/non_max_suppression.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
const auto max_output_boxes_per_class = input_value(2).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[1].is_static() &&
- max_output_boxes_per_class->is_constant())
+ op::is_constant(max_output_boxes_per_class))
{
const auto num_boxes = num_boxes_boxes.get_length();
const auto max_output_boxes_per_class = max_boxes_output_from_input();
const auto num_boxes_boxes = boxes_ps[1];
const auto max_output_boxes_per_class_node = input_value(2).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[1].is_static() &&
- max_output_boxes_per_class_node->is_constant())
+ op::is_constant(max_output_boxes_per_class_node))
{
const auto num_boxes = num_boxes_boxes.get_length();
const auto num_classes = scores_ps[1].get_length();
const auto num_boxes_boxes = boxes_ps[1];
const auto max_output_boxes_per_class_node = input_value(2).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[0].is_static() && scores_ps[1].is_static() &&
- max_output_boxes_per_class_node->is_constant())
+ op::is_constant(max_output_boxes_per_class_node))
{
const auto num_boxes = num_boxes_boxes.get_length();
const auto num_classes = scores_ps[1].get_length();
#include "ngraph/op/not.hpp"
#include "ngraph/op/op.hpp"
+#include "ngraph/op/util/elementwise_args.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/not.hpp"
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::v1::LogicalNot::validate_and_infer_types()
{
- auto args_et_pshape = validate_and_infer_elementwise_args();
+ auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::v0::Not::validate_and_infer_types()
{
- auto args_et_pshape = validate_and_infer_elementwise_args();
+ auto args_et_pshape = ngraph::op::util::validate_and_infer_elementwise_args(this);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
#include "ngraph/op/one_hot.hpp"
#include "ngraph/attribute_visitor.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
const auto& depth = input_value(1).get_node_shared_ptr();
PartialShape result_shape{PartialShape::dynamic()};
- if (indices_shape.is_static() && indices_shape.rank().is_static() && depth->is_constant())
+ if (indices_shape.is_static() && indices_shape.rank().is_static() && op::is_constant(depth))
{
const auto indices_rank = indices_shape.rank().get_length();
/// Root of all actual ops
class NGRAPH_API Op : public Node
{
- public:
- virtual bool is_op() const override { return true; }
protected:
Op()
: Node()
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
#include "ngraph/except.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
auto pads_begin_node = input_value(1).get_node_shared_ptr();
auto pads_end_node = input_value(2).get_node_shared_ptr();
- if (arg_shape_rank.is_static() && pads_begin_node->is_constant() &&
- pads_end_node->is_constant())
+ if (arg_shape_rank.is_static() && op::is_constant(pads_begin_node) &&
+ op::is_constant(pads_end_node))
{
const auto implied_rank = pads_begin_coord.size();
std::vector<Dimension> result_dims(implied_rank, Dimension::dynamic());
bool visit_attributes(AttributeVisitor& visitor) override;
- bool is_parameter() const override { return true; }
void validate_and_infer_types() override;
bool get_cacheable() const { return m_cacheable; }
{
m_element_type = element_type;
}
-
+ bool is_parameter() const override { return true; }
protected:
bool m_cacheable;
PartialShape m_partial_shape;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_output() const override { return true; }
void set_needs_default_layout(bool val) { m_needs_default_layout = val; }
bool needs_default_layout() const { return m_needs_default_layout; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
bool constant_fold(OutputVector& output_values,
const OutputVector& inputs_values) override;
-
+ bool is_output() const override { return true; }
private:
bool m_needs_default_layout{false};
};
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reverse.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
const auto rank = input_rank.get_length();
const auto rev_axes_node = input_value(1).get_node_shared_ptr();
- if (rev_axes_node->is_constant())
+ if (op::is_constant(rev_axes_node))
{
const auto rev_axes_constant = as_type_ptr<op::Constant>(rev_axes_node);
#include "ngraph/op/scatter_elements_update.hpp"
#include "ngraph/op/constant.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/runtime/reference/scatter_elements_update.hpp"
#include "ngraph/validation_util.hpp"
" and: ",
updates_shape);
- if (input_value(3).get_node_shared_ptr()->is_constant() && data_shape.rank().is_static())
+ if (ngraph::op::is_constant(input_value(3).get_node()) && data_shape.rank().is_static())
{
const auto axis_input = as_type_ptr<op::v0::Constant>(input_value(3).get_node_shared_ptr());
auto axis = axis_input->cast_vector<int64_t>().at(0);
{
m_auto_broadcast = auto_broadcast;
}
- bool supports_auto_broadcast() const override { return true; }
// TODO: Move all uses of get_autob to get_auto_broadcast() and remove this.
const AutoBroadcastSpec& get_autob() const override { return m_auto_broadcast; }
private:
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/util.hpp"
bool op::v0::Softmax::are_axes_constant() const
{
- return input_value(1).get_node_shared_ptr()->is_constant();
+ return op::is_constant(input_value(1).get_node());
}
const AxisSet op::v0::Softmax::get_axes() const
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/split.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
NODE_VALIDATION_CHECK(this, is_scalar(axis_shape), "The 'axis' input node must be scalar");
const auto axis_node = input_value(1).get_node_shared_ptr();
- NODE_VALIDATION_CHECK(this, axis_node->is_constant(), "The 'axis' input node must be constant");
+ NODE_VALIDATION_CHECK(
+ this, op::is_constant(axis_node), "The 'axis' input node must be constant");
const auto axis_node_const = as_type_ptr<op::Constant>(axis_node);
m_axis = axis_node_const->get_data_ptr<int64_t>()[0];
NODE_VALIDATION_CHECK(
this, axis_et.is_integral(), "The 'axis' input only accepts integral types");
- if (input_value(1).get_node_shared_ptr()->is_constant() && data_ps.is_static())
+ if (op::is_constant(input_value(1).get_node()) && data_ps.is_static())
{
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis = axis_input->cast_vector<int64_t>()[0];
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/topk.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
this, k_partial_shape.rank().compatible(0), "The 'K' input must be a scalar.");
size_t k = 0;
- if (input_value(1).get_node_shared_ptr()->is_constant())
+ if (op::is_constant(input_value(1).get_node()))
{
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
get_input_element_type(1));
size_t op::v1::TopK::get_k() const
{
size_t k = 0;
- if (input_value(1).get_node_shared_ptr()->is_constant())
+ if (op::is_constant(input_value(1).get_node()))
{
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
get_input_element_type(1));
// 2. get value of k - from constant node or from HT
size_t k = 0;
- if (input_value(1).get_node_shared_ptr()->is_constant())
+ if (op::is_constant(input_value(1).get_node()))
{
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
get_input_element_type(1));
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/attribute_visitor.hpp"
+#include "ngraph/op/util/elementwise_args.hpp"
using namespace std;
using namespace ngraph;
{
}
+void op::util::BinaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic(
+ const op::AutoBroadcastSpec& autob)
+{
+ auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, autob);
+ element::Type& args_et = std::get<0>(args_et_pshape);
+ PartialShape& args_pshape = std::get<1>(args_et_pshape);
+
+ NODE_VALIDATION_CHECK(this,
+ args_et.is_dynamic() || args_et != element::boolean,
+ "Arguments cannot have boolean element type (argument element type: ",
+ args_et,
+ ").");
+
+ set_output_type(0, args_et, args_pshape);
+}
+
void op::util::BinaryElementwiseArithmetic::validate_and_infer_types()
{
validate_and_infer_elementwise_arithmetic(m_autob);
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
- bool is_binary_elementwise_arithmetic() const override { return true; }
- bool supports_auto_broadcast() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private:
AutoBroadcastSpec m_autob;
+ void validate_and_infer_elementwise_arithmetic(const op::AutoBroadcastSpec& autob);
};
}
}
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/attribute_visitor.hpp"
+#include "ngraph/op/util/elementwise_args.hpp"
using namespace std;
using namespace ngraph;
void op::util::BinaryElementwiseComparison::validate_and_infer_types()
{
- auto args_et_pshape = validate_and_infer_elementwise_args(m_autob);
+ auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, m_autob);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, element::boolean, args_pshape);
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
- bool supports_auto_broadcast() const override { return true; }
- bool is_binary_elementwise_comparison() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private:
#include "ngraph/op/util/binary_elementwise_logical.hpp"
#include "ngraph/attribute_visitor.hpp"
+#include "ngraph/op/util/elementwise_args.hpp"
using namespace std;
using namespace ngraph;
{
}
+void op::util::BinaryElementwiseLogical::validate_and_infer_elementwise_logical(
+ const op::AutoBroadcastSpec& autob)
+{
+ auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, autob);
+ element::Type& args_et = std::get<0>(args_et_pshape);
+ PartialShape& args_pshape = std::get<1>(args_et_pshape);
+
+ NODE_VALIDATION_CHECK(
+ this,
+ args_et.is_dynamic() || args_et == element::boolean,
+ "Operands for logical operators must have boolean element type but have element type ",
+ args_et,
+ ".");
+
+ set_output_type(0, element::boolean, args_pshape);
+}
+
void op::util::BinaryElementwiseLogical::validate_and_infer_types()
{
validate_and_infer_elementwise_logical(m_autob);
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
- bool supports_auto_broadcast() const override { return true; }
- bool is_binary_elementwise_logical() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private:
+ void validate_and_infer_elementwise_logical(const op::AutoBroadcastSpec& autob);
AutoBroadcastSpec m_autob;
};
}
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/sum.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
" doesn't match rank of input tensor ",
arg_shape.size());
- if (shape_constant && input_value(2).get_node_shared_ptr()->is_constant())
+ if (shape_constant && op::is_constant(input_value(2).get_node()))
{
auto target_shape = shape_constant->get_shape_val();
auto axes_mapping_val =
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "elementwise_args.hpp"
+
+using namespace ngraph;
+
+std::tuple<element::Type, PartialShape>
+ ngraph::op::util::validate_and_infer_elementwise_args(Node* node,
+ const op::AutoBroadcastSpec& autob)
+{
+ NGRAPH_CHECK(node != nullptr, "nGraph node is empty! Cannot validate eltwise arguments.");
+ element::Type element_type = node->get_input_element_type(0);
+ PartialShape pshape = node->get_input_partial_shape(0);
+
+ if (node->get_input_size() > 1)
+ {
+ for (size_t i = 1; i < node->get_input_size(); ++i)
+ {
+ NODE_VALIDATION_CHECK(
+ node,
+ element::Type::merge(element_type, element_type, node->get_input_element_type(i)),
+ "Argument element types are inconsistent.");
+
+ if (autob.m_type == op::AutoBroadcastType::NONE)
+ {
+ NODE_VALIDATION_CHECK(
+ node,
+ PartialShape::merge_into(pshape, node->get_input_partial_shape(i)),
+ "Argument shapes are inconsistent.");
+ }
+ else if (autob.m_type == op::AutoBroadcastType::NUMPY ||
+ autob.m_type == op::AutoBroadcastType::PDPD)
+ {
+ NODE_VALIDATION_CHECK(node,
+ PartialShape::broadcast_merge_into(
+ pshape, node->get_input_partial_shape(i), autob),
+ "Argument shapes are inconsistent.");
+ }
+ else
+ {
+ NODE_VALIDATION_CHECK(node, false, "Unsupported auto broadcast specification");
+ }
+ }
+ }
+
+ return std::make_tuple(element_type, pshape);
+}
#pragma once
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
+#include "ngraph/node.hpp"
namespace ngraph
{
- enum class Placement
+ namespace op
{
- DEFAULT,
- INTERPRETER,
- CPU,
- GPU,
- NNP,
- };
-
- std::string placement_to_string(Placement placement);
+ namespace util
+ {
+ std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
+ Node* node, const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
+ }
+ }
}
class NGRAPH_API FusedOp : public Op
{
public:
- bool supports_decompose() const final { return true; }
// Fused op decomposition can be performed in the presence of
// partial shapes
virtual bool can_decompose_with_partial_shapes() { return false; }
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+#include "ngraph/op/util/op_types.hpp"
+#include "ngraph/op/add.hpp"
+#include "ngraph/op/and.hpp"
+#include "ngraph/op/constant.hpp"
+#include "ngraph/op/equal.hpp"
+#include "ngraph/op/maximum.hpp"
+#include "ngraph/op/minimum.hpp"
+#include "ngraph/op/multiply.hpp"
+#include "ngraph/op/not_equal.hpp"
+#include "ngraph/op/op.hpp"
+#include "ngraph/op/or.hpp"
+#include "ngraph/op/parameter.hpp"
+#include "ngraph/op/result.hpp"
+#include "ngraph/op/select.hpp"
+#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
+#include "ngraph/op/util/binary_elementwise_comparison.hpp"
+#include "ngraph/op/util/binary_elementwise_logical.hpp"
+#include "ngraph/op/util/fused_op.hpp"
+#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
+#include "ngraph/op/xor.hpp"
+#include "ngraph/type.hpp"
+
+bool ngraph::op::is_unary_elementwise_arithmetic(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::util::UnaryElementwiseArithmetic*>(node) != nullptr;
+}
+
+bool ngraph::op::is_binary_elementwise_arithmetic(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::util::BinaryElementwiseArithmetic*>(node) != nullptr;
+}
+
+bool ngraph::op::is_binary_elementwise_comparison(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::util::BinaryElementwiseComparison*>(node) != nullptr;
+}
+
+bool ngraph::op::is_binary_elementwise_logical(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::util::BinaryElementwiseLogical*>(node) != nullptr;
+}
+
+bool ngraph::op::supports_auto_broadcast(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::v1::Select*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::util::BinaryElementwiseComparison*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::util::BinaryElementwiseLogical*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::util::BinaryElementwiseArithmetic*>(node) != nullptr;
+}
+
+bool ngraph::op::supports_decompose(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::util::FusedOp*>(node) != nullptr;
+}
+
+bool ngraph::op::is_op(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::Op*>(node) != nullptr;
+}
+
+bool ngraph::op::is_parameter(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::Parameter*>(node) != nullptr;
+}
+
+bool ngraph::op::is_output(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::Result*>(node) != nullptr;
+}
+
+bool ngraph::op::is_constant(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::Constant*>(node) != nullptr;
+}
+
+bool ngraph::op::is_commutative(const ngraph::Node* node)
+{
+ return dynamic_cast<const ngraph::op::v0::Add*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v1::Add*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v0::Maximum*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v1::Maximum*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v0::Equal*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v1::Equal*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v0::NotEqual*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v1::NotEqual*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v1::LogicalAnd*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v0::Xor*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v1::LogicalXor*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v0::Minimum*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v1::Minimum*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v0::Multiply*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v1::Multiply*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v0::Or*>(node) != nullptr ||
+ dynamic_cast<const ngraph::op::v1::LogicalOr*>(node) != nullptr;
+}
+
+bool ngraph::op::is_unary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_unary_elementwise_arithmetic(node.get());
+}
+bool ngraph::op::is_binary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_binary_elementwise_arithmetic(node.get());
+}
+bool ngraph::op::is_binary_elementwise_comparison(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_binary_elementwise_comparison(node.get());
+}
+bool ngraph::op::is_binary_elementwise_logical(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_binary_elementwise_logical(node.get());
+}
+
+bool ngraph::op::supports_auto_broadcast(const std::shared_ptr<ngraph::Node>& node)
+{
+ return supports_auto_broadcast(node.get());
+}
+
+bool ngraph::op::supports_decompose(const std::shared_ptr<ngraph::Node>& node)
+{
+ return supports_decompose(node.get());
+}
+
+bool ngraph::op::is_op(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_op(node.get());
+}
+bool ngraph::op::is_parameter(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_parameter(node.get());
+}
+bool ngraph::op::is_output(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_output(node.get());
+}
+bool ngraph::op::is_constant(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_constant(node.get());
+}
+bool ngraph::op::is_commutative(const std::shared_ptr<ngraph::Node>& node)
+{
+ return is_commutative(node.get());
+}
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include <memory>
+#include "ngraph/ngraph_visibility.hpp"
+#include "ngraph/node.hpp"
+
+namespace ngraph
+{
+ namespace op
+ {
+ NGRAPH_API
+ bool is_unary_elementwise_arithmetic(const ngraph::Node* node);
+ NGRAPH_API
+ bool is_binary_elementwise_arithmetic(const ngraph::Node* node);
+ NGRAPH_API
+ bool is_binary_elementwise_comparison(const ngraph::Node* node);
+ NGRAPH_API
+ bool is_binary_elementwise_logical(const ngraph::Node* node);
+
+ NGRAPH_API
+ bool supports_auto_broadcast(const ngraph::Node* node);
+
+ NGRAPH_API
+ bool supports_decompose(const ngraph::Node* node);
+
+ NGRAPH_API
+ bool is_op(const ngraph::Node* node);
+ NGRAPH_API
+ bool is_parameter(const ngraph::Node* node);
+ NGRAPH_API
+ bool is_output(const ngraph::Node* node);
+ NGRAPH_API
+ bool is_constant(const ngraph::Node* node);
+ NGRAPH_API
+ bool is_commutative(const ngraph::Node* node);
+
+ NGRAPH_API
+ bool is_unary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node);
+ NGRAPH_API
+ bool is_binary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node);
+ NGRAPH_API
+ bool is_binary_elementwise_comparison(const std::shared_ptr<ngraph::Node>& node);
+ NGRAPH_API
+ bool is_binary_elementwise_logical(const std::shared_ptr<ngraph::Node>& node);
+
+ NGRAPH_API
+ bool supports_auto_broadcast(const std::shared_ptr<ngraph::Node>& node);
+
+ NGRAPH_API
+ bool supports_decompose(const std::shared_ptr<ngraph::Node>& node);
+
+ NGRAPH_API
+ bool is_op(const std::shared_ptr<ngraph::Node>& node);
+ NGRAPH_API
+ bool is_parameter(const std::shared_ptr<ngraph::Node>& node);
+ NGRAPH_API
+ bool is_output(const std::shared_ptr<ngraph::Node>& node);
+ NGRAPH_API
+ bool is_constant(const std::shared_ptr<ngraph::Node>& node);
+ NGRAPH_API
+ bool is_commutative(const std::shared_ptr<ngraph::Node>& node);
+ }
+}
//*****************************************************************************
#include "ngraph/op/util/scatter_base.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
bool compatible = true;
int64_t axis;
- bool is_axis_constant = input_value(AXIS).get_node_shared_ptr()->is_constant();
+ bool is_axis_constant = op::is_constant(input_value(AXIS).get_node());
// Get axis value if possible.
if (is_axis_constant && data_shape.rank().is_static())
//*****************************************************************************
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
+#include "ngraph/op/util/elementwise_args.hpp"
using namespace ngraph;
{
}
+void op::util::UnaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic()
+{
+ auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
+ element::Type& args_et = std::get<0>(args_et_pshape);
+ PartialShape& args_pshape = std::get<1>(args_et_pshape);
+
+ NODE_VALIDATION_CHECK(this,
+ args_et.is_dynamic() || args_et != element::boolean,
+ "Arguments cannot have boolean element type (argument element type: ",
+ args_et,
+ ").");
+
+ set_output_type(0, args_et, args_pshape);
+}
+
void op::util::UnaryElementwiseArithmetic::validate_and_infer_types()
{
validate_and_infer_elementwise_arithmetic();
public:
void validate_and_infer_types() override;
- bool is_unary_elementwise_arithmetic() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
+
+ private:
+ void validate_and_infer_elementwise_arithmetic();
};
}
}
#include <numeric>
#include "ngraph/op/constant.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/validation_util.hpp"
const auto& data_type = data.get_element_type();
set_output_size(num_outputs);
- if (data_shape.rank().is_static() && axis_input->is_constant() &&
- split_lengths_input->is_constant())
+ if (data_shape.rank().is_static() && op::is_constant(axis_input) &&
+ op::is_constant(split_lengths_input))
{
const auto axis_input_constant = as_type_ptr<op::Constant>(axis_input);
auto axis_val = axis_input_constant->cast_vector<int64_t>()[0];
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
- virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
#include "ngraph/op/topk.hpp"
#include "ngraph/op/transpose.hpp"
#include "ngraph/op/util/attr_types.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/transpose.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset2.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/pattern/matcher.hpp"
bool replaced = false;
for (auto n : f->get_ordered_ops())
{
- if (n->is_output() || n->is_parameter())
+ if (op::is_output(n) || op::is_parameter(n))
{
continue;
}
#include <sstream>
#include "common_function_collection.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
{
for (const shared_ptr<Node>& n : current_function->get_ordered_ops())
{
- if (n->is_constant() || n->is_parameter())
+ if (op::is_constant(n) || op::is_parameter(n))
{
continue;
}
- if (n->is_op())
+ if (op::is_op(n))
{
auto op = std::static_pointer_cast<op::Op>(n);
auto annotations = op->get_op_annotations();
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
// TODO: Do we need another map, so we could
// specify how to compute hash for each op?
- if (p_this.is_commutative())
+ if (ngraph::op::is_commutative(&p_this))
{
sort(begin(cargs), end(cargs));
}
for (auto n : f->get_ordered_ops())
{
- if (n->is_output() || n->is_parameter())
+ if (op::is_output(n) || op::is_parameter(n))
{
continue;
}
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/provenance.hpp"
using namespace std;
{
bool modified = false;
- if (node->supports_decompose())
+ if (op::supports_decompose(node))
{
if (m_has_direct_support && m_has_direct_support(*node))
{
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/op/util/binary_elementwise_logical.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<Node> node)
{
- if (node->supports_auto_broadcast())
+ if (ngraph::op::supports_auto_broadcast(node))
{
if (node->get_autob().m_type != op::AutoBroadcastType::NONE)
{
NodeVector ngraph::pass::explicit_broadcast(std::shared_ptr<Node>& node)
{
NodeVector rc;
- if (node->supports_auto_broadcast())
+ if (ngraph::op::supports_auto_broadcast(node))
{
auto autob = node->get_autob();
if (autob.m_type == op::AutoBroadcastType::NONE)
#include "ngraph/op/concat.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/slice.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
std::map<descriptor::Tensor*, descriptor::Tensor*> in_place_outputs;
std::set<const descriptor::Tensor*> reused_inputs;
- if (node->is_op())
+ if (op::is_op(node))
{
auto op = std::static_pointer_cast<op::Op>(node);
// concat and slice in_place_oi should be treated differently
if ((node->liveness_free_list.count(input) != 0 ||
is_type<op::GetOutputElement>(node) ||
(m_disable_memory_sharing && !oi_pair.destructive &&
- !input_node->is_parameter() && !input_node->is_constant())) &&
+ !op::is_parameter(input_node) && !op::is_constant(input_node))) &&
node->liveness_new_list.count(output) != 0)
{
#include "ngraph/op/slice.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/sum.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/util.hpp"
#include "nop_elimination.hpp"
if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input))
{
PartialShape data_shape;
- if (input->is_parameter())
+ if (op::is_parameter(input))
{
data_shape = unsqueeze->input(0).get_partial_shape();
}
if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input))
{
PartialShape data_shape;
- if (input->is_parameter())
+ if (op::is_parameter(input))
{
data_shape = squeeze_i->input(0).get_partial_shape();
}
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/util/op_annotations.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
{
for (auto& node : function->get_ordered_ops())
{
- if (node->is_op())
+ if (op::is_op(node))
{
auto op = static_pointer_cast<op::Op>(node);
NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name();
op_annotations = op_annotations_factory();
op->set_op_annotations(op_annotations);
}
- if (node->is_parameter())
+ if (op::is_parameter(node))
{
auto parameter = static_pointer_cast<op::Parameter>(node);
op_annotations->set_cacheable(parameter->get_cacheable());
{
auto input_value_node = input.get_source_output().get_node_shared_ptr();
NGRAPH_DEBUG << "propagate cacheability: arg is " << *input_value_node;
- if (input_value_node->is_op())
+ if (op::is_op(input_value_node))
{
auto arg_op = static_pointer_cast<op::Op>(input_value_node);
auto arg_op_annotations = arg_op->get_op_annotations();
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp"
continue;
}
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
- if (n->is_unary_elementwise_arithmetic())
+ if (op::is_unary_elementwise_arithmetic(n))
{
Swimmer nsw{n->input(0), csw.reshape};
work_queue.push_back(nsw);
{
NGRAPH_DEBUG << "Start: Processing node " << n->get_name();
// collect all Result nodes for a sanity check
- if (n->is_output())
+ if (ngraph::op::is_output(n))
{
results.push_back(n);
}
{
sink_reshape(reshape, reorders, reshapes_to_delete);
}
- else if (n->is_unary_elementwise_arithmetic())
+ else if (op::is_unary_elementwise_arithmetic(n))
{
sink_unary(n, reorders, reshapes_to_delete);
}
- else if (n->is_binary_elementwise_arithmetic())
+ else if (op::is_binary_elementwise_arithmetic(n))
{
sink_binary(n, reorders, reshapes_to_delete);
}
#include "ngraph/pass/shape_relevance.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace ngraph;
shape_determinants.insert(node);
already_visited.insert(node);
- if (node->is_parameter())
+ if (op::is_parameter(node))
{
auto node_as_param = static_cast<op::Parameter*>(node);
if (!node_as_param->is_relevant_to_shapes())
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
std::string pass::VisualizeTree::get_constant_value(std::shared_ptr<Node> node, size_t max_elements)
{
- if (!node->is_constant())
+ if (!op::is_constant(node))
return {};
std::stringstream ss;
ss << "{" << node->get_element_type().get_type_name() << "}";
vector<string> attributes;
attributes.push_back("shape=box");
- if (node->is_output())
+ if (ngraph::op::is_output(node))
{
attributes.push_back("color=crimson");
attributes.push_back("penwidth=1.5");
#include "ngraph/op/product.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/sum.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/type.hpp"
#include "zero_dim_tensor_elimination.hpp"
set<Output<Node>> zero_length_source_outputs;
for (auto n : f->get_ordered_ops())
{
- if (n->is_output() || n->is_parameter() || n->is_constant() || n->get_output_size() > 1)
+ if (op::is_output(n) || op::is_parameter(n) || op::is_constant(n) ||
+ n->get_output_size() > 1)
{
continue;
}
// if any `GetOutputElement` is zero-length
// we replace it w/ a signalling constant
// so we don't have to deal w/ multi-output nodes directly
- if (n->is_output() || n->is_parameter() || n->get_output_size() > 1)
+ if (op::is_output(n) || op::is_parameter(n) || n->get_output_size() > 1)
{
continue;
}
#include "ngraph/log.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
+#include "ngraph/op/util/op_types.hpp"
namespace ngraph
{
return false;
}
- if (graph_node->is_commutative())
+ if (ngraph::op::is_commutative(graph_node))
{
// TODO: [nikolayk] we don't really have to use lexicographically-based perms,
// heap's algo should be faster
ValuePredicate get_predicate() const;
- bool is_pattern() const override { return true; }
protected:
ValuePredicate m_predicate;
};
+++ /dev/null
-//*****************************************************************************
-// Copyright 2017-2020 Intel Corporation
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//*****************************************************************************
-
-#include <deque>
-#include <sstream>
-
-#include "ngraph/function.hpp"
-#include "ngraph/graph_util.hpp"
-#include "ngraph/node.hpp"
-#include "ngraph/placement.hpp"
-#include "ngraph/util.hpp"
-
-using namespace std;
-using namespace ngraph;
-
-std::string ngraph::placement_to_string(Placement placement)
-{
- switch (placement)
- {
- case Placement::DEFAULT: return "DEFAULT";
- case Placement::INTERPRETER: return "INTERPRETER";
- case Placement::CPU: return "CPU";
- case Placement::GPU: return "GPU";
- case Placement::NNP: return "NNP";
- }
- throw runtime_error("unhandled placement type");
-}
#include "ngraph/op/assign.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/tensor_iterator.hpp"
+#include "ngraph/op/util/op_types.hpp"
using namespace ngraph;
for (auto old_node : f->get_ordered_ops())
{
- if (old_node->is_parameter())
+ if (op::is_parameter(old_node))
{
continue;
}
const AxisSet broadcast_axes{0, 2};
auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
- EXPECT_TRUE(axes_mapping.get_node()->is_constant());
+ EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), 2);
EXPECT_EQ(axes_mapping_shape, (Shape{1, 3}));
const AxisSet broadcast_axes{0, 1, 2, 3};
auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
- EXPECT_TRUE(axes_mapping.get_node()->is_constant());
+ EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), 0);
EXPECT_EQ(axes_mapping_shape, (Shape{}));
const AxisSet broadcast_axes{};
auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
- EXPECT_TRUE(axes_mapping.get_node()->is_constant());
+ EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), output_shape.size());
EXPECT_EQ(axes_mapping_shape, (Shape{0, 1, 2, 3}));
auto axes_mapping =
builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
- EXPECT_TRUE(axes_mapping.get_node()->is_constant());
+ EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), 2);
EXPECT_EQ(axes_mapping_shape, (Shape{1, 2}));
auto axes_mapping =
builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
- EXPECT_TRUE(axes_mapping.get_node()->is_constant());
+ EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), 0);
EXPECT_EQ(axes_mapping_shape, (Shape{}));
auto axes_mapping =
builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
- EXPECT_TRUE(axes_mapping.get_node()->is_constant());
+ EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), output_shape.size());
EXPECT_EQ(axes_mapping_shape, (Shape{0, 1, 2, 3}));
// clang-format on
#include "gtest/gtest.h"
+#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/frontend/onnx_import/onnx_utils.hpp"
#include "ngraph/frontend/onnx_import/default_opset.hpp"
std::shared_ptr<ngraph::Node> C = ng_inputs.at(2);
A = A * C;
- if (!B->is_null())
+ if (!ngraph::op::is_null(B))
{
B = B / C;
}
for (const auto& ng_input : ng_inputs)
{
- if (!ng_input->is_null())
+ if (!ngraph::op::is_null(ng_input))
{
result = ng_input * result;
}
#include "ngraph/file_util.hpp"
#include "ngraph/frontend/onnx_import/default_opset.hpp"
#include "ngraph/frontend/onnx_import/onnx.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/all_close.hpp"
for (auto ng_node : ng_function->get_ordered_ops())
{
- if (ng_node->is_constant())
+ if (op::is_constant(ng_node))
{
const auto folded_node = as_type_ptr<default_opset::Constant>(ng_node);
const auto output_values = folded_node->cast_vector<T>();
{
auto arg0 = make_shared<op::Parameter>(element::f32, Shape{1});
ASSERT_NE(nullptr, arg0);
- EXPECT_TRUE(arg0->is_parameter());
+ EXPECT_TRUE(op::is_parameter(arg0));
}
TEST(op, is_parameter)
ASSERT_NE(nullptr, arg0);
auto t0 = make_shared<op::Add>(arg0, arg0);
ASSERT_NE(nullptr, t0);
- EXPECT_FALSE(t0->is_parameter());
+ EXPECT_FALSE(op::is_parameter(t0));
}
TEST(op, provenance_tag)
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/validation_util.hpp"
#include "util/test_tools.hpp"
void op_is_Abs()
{
op::Abs node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Acos()
{
op::Acos node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Add()
{
op::Add node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Any()
{
op::Any node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Asin()
{
op::Asin node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Atan()
{
op::Atan node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_AvgPool()
{
op::AvgPool node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_BatchNormInference()
{
op::BatchNormInference node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Broadcast()
{
op::Broadcast node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_BroadcastDistributed()
{
op::BroadcastDistributed node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_BroadcastLike()
{
op::BroadcastLike node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Ceiling()
{
op::Ceiling node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Clamp()
{
op::Clamp node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Concat()
{
op::Concat node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Constant()
{
op::Constant node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Convert()
{
op::Convert node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Convolution()
{
op::Convolution node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_ConvolutionBackpropData()
{
op::ConvolutionBackpropData node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Cos()
{
op::Cos node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Cosh()
{
op::Cosh node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_CropAndResize()
{
op::CropAndResize node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_CumSum()
{
op::CumSum node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_DepthToSpace()
{
op::DepthToSpace node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Dequantize()
{
op::Dequantize node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Divide()
{
op::Divide node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Dot()
{
op::Dot node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Elu()
{
op::Elu node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_EmbeddingBagOffsetsSum()
{
op::EmbeddingBagOffsetsSum node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_EmbeddingBagPackedSum()
{
op::EmbeddingBagPackedSum node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_EmbeddingLookup()
{
op::EmbeddingLookup node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_EmbeddingSegmentsSum()
{
op::EmbeddingSegmentsSum node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Equal()
{
op::Equal node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Erf()
{
op::Erf node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Exp()
{
op::Exp node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_ExtractImagePatches()
{
op::ExtractImagePatches node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_FakeQuantize()
{
op::FakeQuantize node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Floor()
{
op::Floor node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_GRN()
{
op::GRN node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_GRUCell()
{
op::GRUCell node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Gather()
{
op::Gather node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_GatherND()
{
op::GatherND node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Gelu()
{
op::Gelu node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_GetOutputElement()
{
op::GetOutputElement node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Greater()
{
op::Greater node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_GreaterEq()
{
op::GreaterEq node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_GroupConvolution()
{
op::GroupConvolution node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_GroupConvolutionBackpropData()
{
op::GroupConvolutionBackpropData node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_HardSigmoid()
{
op::HardSigmoid node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Interpolate()
{
op::Interpolate node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Less()
{
op::Less node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_LessEq()
{
op::LessEq node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Log()
{
op::Log node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_LRN()
{
op::LRN node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_LSTMCell()
{
op::LSTMCell node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_LSTMSequence()
{
op::LSTMSequence node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_MatMul()
{
op::MatMul node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_NormalizeL2()
{
op::NormalizeL2 node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Max()
{
op::Max node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Maximum()
{
op::Maximum node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Min()
{
op::Min node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Minimum()
{
op::Minimum node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Multiply()
{
op::Multiply node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_MVN()
{
op::MVN node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Negative()
{
op::Negative node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Not()
{
op::Not node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_NotEqual()
{
op::NotEqual node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_OneHot()
{
op::OneHot node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Or()
{
op::Or node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_TRUE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_logical(&node));
}
void op_is_Pad()
{
op::Pad node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Parameter()
{
op::Parameter node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Passthrough()
{
op::Passthrough node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Power()
{
op::Power node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_PRelu()
{
op::PRelu node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Product()
{
op::Product node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Quantize()
{
op::Quantize node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_QuantizedConvolution()
{
op::QuantizedConvolution node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_QuantizedDot()
{
op::QuantizedDot node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Recv()
{
op::Recv node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Range()
{
op::Range node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Relu()
{
op::Relu node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_ReplaceSlice()
{
op::ReplaceSlice node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Reshape()
{
op::Reshape node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Result()
{
op::Result node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Reverse()
{
op::Reverse node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_ReverseSequence()
{
op::ReverseSequence node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_RNNCell()
{
op::RNNCell node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Round()
{
op::Round node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Select()
{
op::Select node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Selu()
{
op::Selu node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Send()
{
op::Send node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_ShapeOf()
{
op::ShapeOf node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_ShuffleChannels()
{
op::ShuffleChannels node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Sigmoid()
{
op::Sigmoid node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Sign()
{
op::Sign node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Sin()
{
op::Sin node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Sinh()
{
op::Sinh node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Slice()
{
op::Slice node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Softmax()
{
op::Softmax node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_SpaceToDepth()
{
op::SpaceToDepth node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Split()
{
op::Split node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Sqrt()
{
op::Sqrt node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_SquaredDifference()
{
op::SquaredDifference node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Squeeze()
{
op::Squeeze node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_StopGradient()
{
op::StopGradient node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Subtract()
{
op::Subtract node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_TRUE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Sum()
{
op::Sum node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Tan()
{
op::Tan node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Tanh()
{
op::Tanh node;
- EXPECT_TRUE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_TRUE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_TensorIterator()
{
op::TensorIterator node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Tile()
{
op::Tile node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_TopK()
{
op::TopK node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Unsqueeze()
{
op::Unsqueeze node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_FALSE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
}
void op_is_Xor()
{
op::Xor node;
- EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
- EXPECT_FALSE(node.is_binary_elementwise_comparison());
- EXPECT_TRUE(node.is_binary_elementwise_logical());
+ EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
+ EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
+ EXPECT_TRUE(op::is_binary_elementwise_logical(&node));
}
}
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "opset0_downgrade.hpp"
#include "opset1_upgrade.hpp"
ASSERT_TRUE(bcast_v1);
EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
- ASSERT_TRUE(bcast_v1->input_value(1).get_node()->is_constant());
- ASSERT_TRUE(bcast_v1->input_value(2).get_node()->is_constant());
+ ASSERT_TRUE(op::is_constant(bcast_v1->input_value(1).get_node()));
+ ASSERT_TRUE(op::is_constant(bcast_v1->input_value(2).get_node()));
EXPECT_EQ(
as_type_ptr<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())->get_shape_val(),
(Shape{3, 5, 4, 6}));
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
auto b = make_shared<op::Parameter>(element::i32, shape);
auto is_bea = [](std::shared_ptr<Node> node) -> bool {
- return node->is_binary_elementwise_arithmetic();
+ return op::is_binary_elementwise_arithmetic(node);
};
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
auto add_ab = a + b;
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp"
for (auto op : m_nodes)
{
event::Duration d2(op->description(), "Interpreter");
- if (op->is_parameter())
+ if (op::is_parameter(op))
{
continue;
}
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/attr_types.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/provenance.hpp"
*node);
const auto& arg_shape = arg_pshape.to_shape();
- NGRAPH_CHECK(target_shape_input.get_node_shared_ptr()->is_constant());
+ NGRAPH_CHECK(op::is_constant(target_shape_input.get_node()));
auto target_shape = node->get_output_shape(0);
NGRAPH_CHECK(node->get_broadcast_axes().first);
const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
const auto input_rank = node->get_input_partial_shape(0).rank();
- if (target_shape_input->is_constant() && node->get_output_partial_shape(0).is_static() &&
+ if (op::is_constant(target_shape_input) && node->get_output_partial_shape(0).is_static() &&
input_rank.is_static())
{
const auto output_shape = node->get_output_shape(0);
shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
{
const auto indices = node->input_value(0);
- const auto depth = node->input_value(1).get_node_shared_ptr();
+ const auto depth = node->input_value(1).get_node();
auto on_value = node->input_value(2);
auto off_value = node->input_value(3);
const auto axis = node->get_axis();
- NGRAPH_CHECK(depth->is_constant(), "depth input must be constant", *node);
+ NGRAPH_CHECK(op::is_constant(depth), "depth input must be constant", *node);
const auto output_pshape = node->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.is_static(), "output shape must be static", *node);
const auto output_shape = output_pshape.to_shape();
shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
{
auto axes_node = node->input_value(1).get_node_shared_ptr();
- NGRAPH_CHECK(axes_node->is_constant(),
+ NGRAPH_CHECK(op::is_constant(axes_node),
"Unable to convert Reverse:v1 to Reverse:v0 "
"if reduction axes are not constant. Node: ",
*node);
const auto data_shape = data_pshape.to_shape();
const auto order_node = node->input_value(1).get_node_shared_ptr();
- NGRAPH_CHECK(order_node->is_constant(),
+ NGRAPH_CHECK(op::is_constant(order_node),
"Unable to convert Transpose:v1 to Reshape:v0 "
"if order node is not constant. Node: ",
*node);
{
const auto split_lengths = node->input_value(2).get_node_shared_ptr();
- NGRAPH_CHECK(split_lengths->is_constant(),
+ NGRAPH_CHECK(op::is_constant(split_lengths),
"Unable to convert VariadicSplit:v1 to Split:v0 "
"if 'split_lengths' input is not constant. Node: ",
*node);
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/graph_util.hpp"
+#include "ngraph/op/util/op_types.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/provenance.hpp"
#include "op/avg_pool.hpp"
shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
{
- NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
+ NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
"axes parameter is expected to be a static constant");
AxisSet axes = node->get_axes();
shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
{
- NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
+ NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
"parameter k is expected to be a static constant");
- NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant(),
+ NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
"parameter top_k_axis is expected to be a static constant");
const auto k = node->get_k();
for (size_t i = 0; i < f0->get_output_size(); ++i)
{
- EXPECT_TRUE(f0->get_output_op(i)->is_output());
+ EXPECT_TRUE(op::is_output(f0->get_output_op(i)));
}
}