#include "TFOpcode.h"
#include "TFNodeVisitor.forward.h"
+#include <array>
+
namespace moco
{
namespace tf
TFOpcode opcode(void) const final { return Code; }
};
+/**
+ * @brief Nodes with the fixed number of inputs
+ */
+template <unsigned N, typename Base> class FixedArityNode : public Base
+{
+public:
+ FixedArityNode()
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args[n] = std::unique_ptr<loco::Use>{new loco::Use{this}};
+ }
+ }
+
+ virtual ~FixedArityNode() = default;
+
+public:
+ unsigned arity(void) const final { return N; }
+
+ loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
+
+ void drop(void) final
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args.at(n)->node(nullptr);
+ }
+ }
+
+protected:
+ // This API allows inherited classes to access "_args" field.
+ loco::Use *at(unsigned n) const { return _args.at(n).get(); }
+
+private:
+ std::array<std::unique_ptr<loco::Use>, N> _args;
+};
+
} // namespace tf
} // namespace moco
}
*/
-class TFAdd final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Add>>
+class TFAdd final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Add>>
{
public:
TFAdd() = default;
}
*/
-class TFAvgPool final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::AvgPool>>
+class TFAvgPool final : public FixedArityNode<1, TFNodeImpl<TFOpcode::AvgPool>>
{
public:
TFAvgPool() = default;
}
*/
-class TFBiasAdd final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::BiasAdd>>
+class TFBiasAdd final : public FixedArityNode<2, TFNodeImpl<TFOpcode::BiasAdd>>
{
public:
TFBiasAdd() = default;
* @note Implementation for this class came from Canonical ConstGen
* Read comments in loco::ConstGen for details
*/
-class TFConst final : public loco::FixedArityNode<0, TFNodeImpl<TFOpcode::Const>>,
+class TFConst final : public FixedArityNode<0, TFNodeImpl<TFOpcode::Const>>,
public loco::NodeMixin<loco::NodeTrait::DataType>,
public loco::NodeMixin<loco::NodeTrait::TensorShape>
{
namespace tf
{
-class TFConv2D final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Conv2D>>
+class TFConv2D final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Conv2D>>
{
public:
loco::Node *ifm(void) const { return at(0)->node(); }
{
class TFDepthwiseConv2dNative final
- : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::DepthwiseConv2dNative>>
+ : public FixedArityNode<2, TFNodeImpl<TFOpcode::DepthwiseConv2dNative>>
{
public:
loco::Node *ifm(void) const { return at(0)->node(); }
namespace tf
{
-class TFFusedBatchNorm final : public loco::FixedArityNode<5, TFNodeImpl<TFOpcode::FusedBatchNorm>>
+class TFFusedBatchNorm final : public FixedArityNode<5, TFNodeImpl<TFOpcode::FusedBatchNorm>>
{
public:
TFFusedBatchNorm() = default;
}
*/
-class TFIdentity final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Identity>>
+class TFIdentity final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Identity>>
{
public:
TFIdentity() = default;
}
*/
-class TFMaxPool final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::MaxPool>>
+class TFMaxPool final : public FixedArityNode<1, TFNodeImpl<TFOpcode::MaxPool>>
{
public:
TFMaxPool() = default;
}
*/
-class TFMul final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Mul>>
+class TFMul final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Mul>>
{
public:
TFMul() = default;
}
*/
-class TFRealDiv final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::RealDiv>>
+class TFRealDiv final : public FixedArityNode<2, TFNodeImpl<TFOpcode::RealDiv>>
{
public:
TFRealDiv() = default;
namespace tf
{
-class TFRelu final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Relu>>
+class TFRelu final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Relu>>
{
public:
TFRelu() = default;
namespace tf
{
-class TFRelu6 final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Relu6>>
+class TFRelu6 final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Relu6>>
{
public:
TFRelu6() = default;
}
*/
-class TFReshape final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Reshape>>
+class TFReshape final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Reshape>>
{
public:
TFReshape() = default;
}
*/
-class TFRsqrt final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Rsqrt>>
+class TFRsqrt final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Rsqrt>>
{
public:
TFRsqrt() = default;
*/
/// @note Mixed in dtype() is for 'out_type' attribute
-class TFShape final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Shape>>,
+class TFShape final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Shape>>,
public loco::NodeMixin<loco::NodeTrait::DataType>
{
public:
namespace tf
{
-class TFSoftmax final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Softmax>>
+class TFSoftmax final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Softmax>>
{
public:
TFSoftmax() = default;
}
*/
-class TFSqrt final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Sqrt>>
+class TFSqrt final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Sqrt>>
{
public:
TFSqrt() = default;
}
*/
-class TFSquaredDifference final
- : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::SquaredDifference>>
+class TFSquaredDifference final : public FixedArityNode<2, TFNodeImpl<TFOpcode::SquaredDifference>>
{
public:
TFSquaredDifference() = default;
}
*/
-class TFSqueeze final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Squeeze>>
+class TFSqueeze final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Squeeze>>
{
public:
TFSqueeze() = default;
}
*/
-class TFStopGradient final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::StopGradient>>
+class TFStopGradient final : public FixedArityNode<1, TFNodeImpl<TFOpcode::StopGradient>>
{
public:
TFStopGradient() = default;
}
*/
-class TFSub final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Sub>>
+class TFSub final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Sub>>
{
public:
TFSub() = default;
namespace tf
{
-class TFTanh final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Tanh>>
+class TFTanh final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Tanh>>
{
public:
TFTanh() = default;