CANONICAL_NODE(TensorConcat, TensorConcat)
CANONICAL_NODE(TensorBiasAdd, BiasAdd<Domain::Tensor>)
CANONICAL_NODE(TensorBroadcast, TensorBroadcast)
+CANONICAL_NODE(TensorReduce, TensorReduce)
CANONICAL_NODE(TensorSoftmax, Softmax<Domain::Tensor>)
CANONICAL_NODE(TransposedConv2D, TransposedConv2D)
#include "loco/IR/Stride.h"
#include "loco/IR/Padding2D.h"
#include "loco/IR/TensorAxis.h"
+#include "loco/IR/TensorAxisSet.h"
#include "loco/IR/FeatureCodec.h"
#include "loco/IR/FilterCodec.h"
#include "loco/IR/DepthwiseFilterCodec.h"
};
/**
+ * @brief Reduce type functions
+ */
+enum class ReduceFunc
+{
+ Mean, // ReduceMean
+ // TODO Support other reduce operations
+};
+
+/**
+ * @brief Computes ReduceFunc operations for Tensor domain
+ * @note All the reduce functions always keep dimensions
+ */
+class TensorReduce final
+ : public CanonicalNodeDef<CanonicalOpcode::TensorReduce, FixedArity<1>::Mixin>
+{
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+public:
+ const TensorAxisSet *axes(void) const { return &_axes; }
+ TensorAxisSet *axes(void) { return &_axes; }
+
+public:
+ ReduceFunc func(void) const { return _func; }
+ void func(ReduceFunc func) { _func = func; }
+
+private:
+ TensorAxisSet _axes;
+ ReduceFunc _func;
+};
+
+/**
* @brief 2D Transposed Convolution
*
* @note TransposedConv2D have a few important conventions that IR users should
return loco::NodeShape{tensor_shape};
}
+ // CASE: TensorReduce
+ loco::NodeShape visit(const loco::TensorReduce *node) final
+ {
+ auto tensor_shape = node_shape(node->input()).as<loco::TensorShape>();
+ auto const tensor_rank = tensor_shape.rank();
+
+ for (uint32_t d = 0; d < tensor_rank; ++d)
+ if (node->axes()->defined(d))
+ tensor_shape.dim(d) = 1;
+
+ return loco::NodeShape{tensor_shape};
+ }
+
// CASE: TensorSoftmax
loco::NodeShape visit(const loco::TensorSoftmax *node) final { return node_shape(node->input()); }
};
loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); }
loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); }
loco::DataType visit(const loco::TensorBroadcast *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::TransposedConv2D *node) { return loco::dtype_get(node->ifm()); }
};