+#include <type_traits>
+#include <limits>
+
#include "model_analyzer.h"
#include "nnc/core/IR/model/graph/ir_node.h"
+#include "nncc/core/ADT/tensor/Shape.h"
+#include "nnc/core/linalg/ShapeRange.h"
+
+#include "nnc/core/IR/model/operations/concat_op.h"
+#include "nnc/core/IR/model/operations/conv_2d_op.h"
+#include "nnc/core/IR/model/operations/depthwise_conv2d_op.h"
+#include "nnc/core/IR/model/operations/softmax_op.h"
+#include "nnc/core/IR/model/operations/pool_op.h"
+#include "nnc/core/IR/model/operations/fully_connected_op.h"
+#include "nnc/core/IR/model/operations/capped_relu_op.h"
+#include "nnc/core/IR/model/operations/bias_add_op.h"
+#include "nnc/core/IR/model/operations/relu_op.h"
+#include "nnc/core/IR/model/operations/reshape_op.h"
using namespace std;
namespace soft
{
+using nncc::core::ADT::tensor::Shape;
+using nncc::core::ADT::tensor::Index;
+using nncc::contrib::core::data::ShapeRange;
+using nncc::contrib::core::ADT::TensorVariant;
+
+void ModelAnalyzer::packData(const void *data, size_t size)
+{
+ const char *p = static_cast<const char *>(data);
+ size_t old_size = _packedParameters.size();
+ _packedParameters.resize(old_size + size);
+ copy(p, p + size, _packedParameters.data() + old_size);
+}
+
+template <typename T>
+void ModelAnalyzer::serializeT(const T &obj)
+{
+ packData(&obj, sizeof(T));
+}
+
+// convert enum to it's underlying type
+template <class E>
+typename underlying_type<E>::type etoi(E enumVal)
+{
+ return static_cast<typename underlying_type<E>::type>(enumVal);
+}
+
+void ModelAnalyzer::serializeShape(const Shape &s)
+{
+ uint32_t rank = s.rank();
+ assert(rank < 100);
+ serializeT<char>(s.rank());
+ for (uint32_t i = 0; i < rank; ++i)
+ {
+ uint32_t dim = s.dim(i);
+ serializeT(dim);
+ }
+}
+
+void ModelAnalyzer::serializeTensor(const TensorVariant &t)
+{
+ // serialize type
+ assert(etoi(t.getDataType()) < 100);
+ serializeT<char>(etoi(t.getDataType()));
+ // seriazlie data size
+ size_t eSize = t.getElementSize();
+ assert(eSize < 100);
+ serializeT<char>(eSize);
+ // serialize shape
+ const Shape &shape = t.getShape();
+ serializeShape(shape);
+ // serialize actual data
+ size_t tSize = eSize * num_elements(shape);
+
+ size_t oldSize = _packedParameters.size();
+ _packedParameters.reserve(oldSize + tSize);
+ for (const Index &idx: ShapeRange(shape))
+ {
+ packData(t.at(idx), eSize);
+ }
+}
+
void ModelAnalyzer::addOpDescr(ADT::INode *node, const string &opName)
{
size_t offset = _packedParameters.size();
void ModelAnalyzer::visit(ADT::INode *node, ops::ConcatOp &op)
{
addOpDescr(node, "concat");
- // TODO add parameters dump
+ // axis number should fit into one byte
+ assert(op.getAxis() < 100);
+ serializeT<char>(op.getAxis());
}
void ModelAnalyzer::visit(ADT::INode *node, ops::Conv2DOp &op)
{
addOpDescr(node, "conv2d");
- // TODO add parameters dump
+ // serialize kernel
+ const TensorVariant &kernel = op.getKernel();
+ serializeTensor(kernel);
+ // serialize strides
+ serializeShape(op.getStrides());
+ // serialize padding type
+ assert(etoi(op.getPaddingType()) < 100);
+ serializeT<char>(etoi(op.getPaddingType()));
+ // serialize pads
+ uint32_t padsRank = op.getInputShape(0).rank();
+ assert(padsRank < 100);
+ serializeT<char>(padsRank);
+ for (int i = 0; i < padsRank; ++i)
+ {
+ auto pad = op.getPadding(i);
+ assert(pad <= numeric_limits<int32_t>::max());
+ assert(pad >= 0);
+ serializeT<int32_t>(op.getPadding(i));
+ }
}
void ModelAnalyzer::visit(ADT::INode *node, ops::DepthwiseConv2DOp &op)
{
addOpDescr(node, "depthwiseConv2d");
- // TODO add parameters dump
+ // serialize kernel
+ const TensorVariant &kernel = op.getKernel();
+ serializeTensor(kernel);
+ // serialize strides
+ serializeShape(op.getStrides());
+ // serialize padding type
+ assert(etoi(op.getPaddingType()) < 100);
+ serializeT<char>(etoi(op.getPaddingType()));
+ // serialize pads
+ uint32_t padsRank = kernel.getShape().rank();
+ assert(padsRank < 100);
+ serializeT<char>(padsRank);
+ for (uint32_t i = 0; i < padsRank; ++i)
+ {
+ auto pad = op.getPadding(i);
+ assert(pad <= numeric_limits<int32_t>::max());
+ assert(pad >= 0);
+ serializeT<int32_t>(pad);
+ }
}
void ModelAnalyzer::visit(ADT::INode *node, ops::SoftmaxOp &op)
{
addOpDescr(node, "softmax");
- // TODO add parameters dump
+ // axis number should fit into one byte
+ assert(op.getAxis() < 100);
+ serializeT<char>(op.getAxis());
}
void ModelAnalyzer::visit(ADT::INode *node, ops::PoolOp &op)
{
addOpDescr(node, "pool");
- // TODO add parameters dump
+ // serialize padding type
+ assert(etoi(op.getPaddingType()) < 100);
+ serializeT<char>(etoi(op.getPaddingType()));
+ // serialize pooling type
+ assert(etoi(op.getPoolingType()) < 100);
+ serializeT<char>(etoi(op.getPoolingType()));
+ // serialize window shape
+ const Shape &windowShape = op.getWindowShape();
+ serializeShape(windowShape);
+ // serialize strindes
+ serializeShape(op.getStrides());
+ // serialize padding
+ int rank = windowShape.rank();
+ assert(rank < 100);
+ for (uint32_t i = 0; i < rank; ++i)
+ {
+ auto pad = op.getPadding(i);
+ assert(pad <= numeric_limits<int32_t>::max());
+ assert(pad >= 0);
+ serializeT<int32_t>(pad);
+ }
}
void ModelAnalyzer::visit(ADT::INode *node, ops::FullyConnectedOp &op)
{
addOpDescr(node, "fullConnect");
- // TODO add parameters dump
+ serializeTensor(op.getWeights());
}
void ModelAnalyzer::visit(ADT::INode *node, ops::CappedReluOp &op)
{
addOpDescr(node, "cappedRelu");
- // TODO add parameters dump
+ static_assert(sizeof(float) == 4, "unsupported float type");
+ serializeT<float>(op.getCap());
}
void ModelAnalyzer::visit(ADT::INode *node, ops::BiasAddOp &op)
{
addOpDescr(node, "biasAdd");
- // TODO add parameters dump
+ serializeTensor(op.getWeights());
}
void ModelAnalyzer::visit(ADT::INode *node, ops::VariableOp &op)
{
assert(node->getPrevNodes().empty());
addOpDescr(node, "in");
- // TODO add parameters dump
+ // no parameters to dump
}
void ModelAnalyzer::visit(ADT::INode *node, ops::ReluOp &op)
{
addOpDescr(node, "relu");
- // TODO add parameters dump
+ // no parameters to dump
}
void ModelAnalyzer::visit(ADT::INode *node, ops::ReshapeOp &op)
{
addOpDescr(node, "reshape");
- // TODO add parameters dump
+ serializeShape(op.getOutputShape(0));
}
} // namespace soft