void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
Node::Node(const graph::operation::Node::InitParam &init_param)
+ : operation::Node{OperandConstraint::createEQ(1u)}
{
assert(init_param.input_count == 7);
assert(init_param.output_count == 1);
_param.activation_index = operand::Index{init_param.inputs[6]};
}
-void Node::setInputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setOutputs(indexes);
-}
-
} // namespace Implicit
} // namespace AvgPool2D
} // namespace operation
virtual std::string getName() const override { return "AvgPool2D"; }
public:
- virtual void setInputs(const operand::IndexSet &indexes) override;
- virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
const Param ¶m() const { return _param; }
private:
void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
Node::Node(const graph::operation::Node::InitParam &init_param)
+ : operation::Node{OperandConstraint::createGE(2u)}
{
assert(init_param.input_count >= 2); // At least one one input tensor and axis
assert(init_param.output_count == 1);
_param.axis_index = operand::Index{init_param.inputs[init_param.input_count - 1]};
}
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setOutputs(indexes);
-}
-
} // namespace Concat
} // namespace operation
} // namespace graph
virtual std::string getName() const override { return "Concat"; }
public:
- virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
const Param ¶m() const { return _param; }
private:
void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
Node::Node(const graph::operation::Node::InitParam &init_param)
+ : operation::Node{OperandConstraint::createEQ(3u)}
{
assert(init_param.input_count == 7 && init_param.output_count == 1);
_param.activation_index = operand::Index{init_param.inputs[6]};
}
-void Node::setInputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 3);
-
- graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setOutputs(indexes);
-}
-
} // namespace Implicit
} // namespace Conv2D
} // namespace operation
virtual std::string getName() const override { return "Conv2D"; }
public:
- virtual void setInputs(const operand::IndexSet &indexes) override;
- virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
const Param ¶m() const { return _param; }
private:
void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
Node::Node(const graph::operation::Node::InitParam &init_param)
+ : operation::Node{OperandConstraint::createEQ(3u)}
{
assert(init_param.input_count == 4 && init_param.output_count == 1);
_param.activation_index = operand::Index{init_param.inputs[3]};
}
-void Node::setInputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 3);
-
- graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setOutputs(indexes);
-}
-
} // namespace FullyConnected
} // namespace operation
} // namespace graph
virtual std::string getName() const override { return "FullyConnected"; }
public:
- virtual void setInputs(const operand::IndexSet &indexes) override;
- virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
const Param ¶m() const { return _param; }
private:
void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
Node::Node(const graph::operation::Node::InitParam &init_param)
+ : operation::Node{OperandConstraint::createEQ(1u)}
{
assert(init_param.input_count == 7);
assert(init_param.output_count == 1);
_param.activation_index = operand::Index{init_param.inputs[6]};
}
-void Node::setInputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setOutputs(indexes);
-}
-
} // namespace Implicit
} // namespace MaxPool2D
} // namespace operation
Node(const graph::operation::Node::InitParam &init_param);
public:
- virtual void setInputs(const operand::IndexSet &indexes) override;
- virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
const Param ¶m() const { return _param; }
private:
void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
+Node::Node(const graph::operation::Node::InitParam &)
+ : operation::Node{OperandConstraint::createEQ(1u)}
+{
+}
+
} // namespace NOP
} // namespace operation
} // namespace graph
class Node : public graph::operation::Node
{
public:
- Node(const graph::operation::Node::InitParam &) {}
+ Node(const graph::operation::Node::InitParam &);
public:
virtual void accept(NodeVisitor &&) const override;
#include "Node.h"
+#include <cassert>
+
#include "LowerInfo.h"
namespace neurun
namespace operation
{
-Node::Node() = default;
+Node::Node(OperandConstraint input_constr) : _input_constr{input_constr} {}
Node::~Node() = default;
+void Node::setInputs(const operand::IndexSet &indexes)
+{
+ assert(_input_constr.check(indexes.size()));
+ _inputs = indexes;
+}
+
+void Node::setOutputs(const operand::IndexSet &indexes) { _outputs = indexes; }
+
void Node::replaceInput(const operand::Index &from, const operand::Index &to)
{
_inputs.replace(from, to);
#include <memory>
#include "graph/operand/IndexSet.h"
+#include "OperandConstraint.h"
namespace neurun
{
};
public:
- Node();
+ Node(OperandConstraint input_constr);
virtual ~Node();
public:
virtual std::string getName() const = 0;
public:
- virtual void replaceInput(const operand::Index &from, const operand::Index &to);
- virtual void replaceOutput(const operand::Index &from, const operand::Index &to);
- virtual const operand::IndexSet &getInputs() const { return _inputs; }
- virtual const operand::IndexSet &getOutputs() const { return _outputs; }
+ void replaceInput(const operand::Index &from, const operand::Index &to);
+ void replaceOutput(const operand::Index &from, const operand::Index &to);
+ const operand::IndexSet &getInputs() const { return _inputs; }
+ const operand::IndexSet &getOutputs() const { return _outputs; }
// It's for only input/output tensors but const data.
- virtual void setInputs(const operand::IndexSet &indexes) { _inputs = indexes; }
- virtual void setOutputs(const operand::IndexSet &indexes) { _outputs = indexes; }
+ void setInputs(const operand::IndexSet &indexes);
+ void setOutputs(const operand::IndexSet &indexes);
public:
void lower_info(std::unique_ptr<LowerInfo> &&lower_info);
private:
operand::IndexSet _inputs;
operand::IndexSet _outputs;
+ OperandConstraint _input_constr;
std::unique_ptr<LowerInfo> _lower_info;
};
--- /dev/null
+#include "OperandConstraint.h"
+
+namespace neurun
+{
+namespace graph
+{
+namespace operation
+{
+
+} // namespace operation
+} // namespace graph
+} // namespace neurun
--- /dev/null
+#ifndef __NEURUN_GRAPH_OPERATION_OPERAND_CONSTRAINT_H__
+#define __NEURUN_GRAPH_OPERATION_OPERAND_CONSTRAINT_H__
+
+#include <stdint.h>
+#include <limits>
+#include <set>
+
+namespace neurun
+{
+namespace graph
+{
+namespace operation
+{
+
+class OperandConstraint
+{
+private:
+ static const uint32_t INF = std::numeric_limits<uint32_t>::max();
+
+public:
+ static OperandConstraint createAny() { return OperandConstraint{0u, INF}; }
+ static OperandConstraint createEQ(uint32_t exact) { return OperandConstraint{exact, exact}; }
+ static OperandConstraint createLE(uint32_t end) { return OperandConstraint{0u, end}; }
+ static OperandConstraint createGE(uint32_t begin) { return OperandConstraint{begin, INF}; }
+ static OperandConstraint createRange(uint32_t begin, uint32_t end)
+ {
+ return OperandConstraint{begin, end};
+ }
+
+private:
+ OperandConstraint(uint32_t begin, uint32_t end) : _begin{begin}, _end{end} {}
+
+public:
+ bool check(uint32_t ind) const { return _begin <= ind && ind <= _end; }
+
+private:
+ uint32_t _begin;
+ uint32_t _end;
+};
+
+} // namespace operation
+} // namespace graph
+} // namespace neurun
+
+#endif // __NEURUN_GRAPH_OPERATION_OPERAND_CONSTRAINT_H__
void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
Node::Node(const operand::Index &input, const operand::Index &output)
+ : operation::Node{OperandConstraint::createEQ(1u)}
{
setInputs({input});
setOutputs({output});
}
-void Node::setInputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setOutputs(indexes);
-}
-
} // namespace Permute
} // namespace operation
} // namespace graph
public:
Node(const operand::Index &input, const operand::Index &output);
-
-public:
- virtual void setInputs(const operand::IndexSet &indexes) override;
- virtual void setOutputs(const operand::IndexSet &indexes) override;
};
} // namespace Permute
void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
Node::Node(const graph::operation::Node::InitParam &init_param)
+ : operation::Node{OperandConstraint::createEQ(1u)}
{
assert(init_param.input_count == 2 && init_param.output_count == 1);
setOutputs({init_param.outputs[0]});
}
-void Node::setInputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1); // TODO Should be 2 (See also the constructor)
-
- graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setOutputs(indexes);
-}
-
} // namespace Reshape
} // namespace operation
} // namespace graph
public:
Node(const graph::operation::Node::InitParam &init_param);
-
-public:
- virtual void setInputs(const operand::IndexSet &indexes) override;
- virtual void setOutputs(const operand::IndexSet &indexes) override;
};
} // namespace Reshape
void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
Node::Node(const graph::operation::Node::InitParam &init_param)
+ : operation::Node{OperandConstraint::createEQ(1u)}
{
assert(init_param.input_count == 2 && init_param.output_count == 1);
_param.scale_index = operand::Index{init_param.inputs[1]};
}
-void Node::setInputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
- assert(indexes.size() == 1);
-
- graph::operation::Node::setOutputs(indexes);
-}
-
} // namespace Softmax
} // namespace operation
} // namespace graph
Node(const graph::operation::Node::InitParam &init_param);
public:
- virtual void setInputs(const operand::IndexSet &indexes) override;
- virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
const Param ¶m() const { return _param; }
private:
public:
SimpleMockNode(const neurun::graph::operand::IndexSet &inputs,
const neurun::graph::operand::IndexSet &outputs)
+ : neurun::graph::operation::Node{neurun::graph::operation::OperandConstraint::createAny()}
{
setInputs(inputs);
setOutputs(outputs);