#include "Frontend.h"
+#include <nncc/core/ADT/tensor/Shape.h>
+#include <nncc/core/ADT/tensor/LexicalLayout.h>
+
+#include <map>
+#include <set>
+#include <string>
+
#include <cassert>
+#include <stdexcept>
+
+#include <algorithm>
+
+using namespace nncc::core::ADT;
+
+using tensor::num_elements;
+using tensor::LexicalLayout;
+
+namespace
+{
+
+tensor::Shape as_tensor_shape(const ::caffe::BlobShape &blob_shape)
+{
+ const uint32_t rank = blob_shape.dim_size();
+
+ tensor::Shape res;
+
+ res.resize(rank);
+
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ {
+ res.dim(axis) = blob_shape.dim(axis);
+ }
+
+ return res;
+}
+
+} // namespace
Frontend::Frontend() : _prototxt{new ::caffe::NetParameter}, _caffemodel{new ::caffe::NetParameter}
{
auto m = coco::Module::create();
auto d = coco::Data::create();
- // TODO Remove this restriction
- assert(_prototxt->layer_size() == 0);
- assert(_caffemodel->layer_size() == 0);
+ // For inter-layer communication
+ std::map<std::string, tensor::Shape> shape_ctx;
+ std::map<std::string, coco::Bag *> bag_ctx;
+
+ std::set<std::string> top;
+
+ for (const auto &layer : _prototxt->layer())
+ {
+ assert(layer.has_name());
+ assert(layer.has_type());
+
+ top.clear();
+ top.insert(layer.top().begin(), layer.top().end());
+
+ if (layer.type() == "Input")
+ {
+ assert(layer.has_input_param());
+ const auto ¶m = layer.input_param();
+
+ for (uint32_t n = 0; n < layer.top_size(); ++n)
+ {
+ const auto &name = layer.top(n);
+ const auto shape = as_tensor_shape(param.shape(n));
+
+ auto bag = m->entity()->bag()->create(num_elements(shape));
+ auto input = m->entity()->input()->create(shape);
+
+ input->bag(bag);
+ input->name(name);
+ input->reorder<LexicalLayout>();
+
+ m->input()->insert(input);
+
+ bag_ctx[name] = bag;
+ shape_ctx[name] = shape;
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"Not supported: " + layer.type()};
+ }
+ }
+
+ // Finalize: Create output for each top blob
+ for (const auto &name : top)
+ {
+ const auto &shape = shape_ctx.at(name);
+ auto bag = bag_ctx.at(name);
+
+ auto output = m->entity()->output()->create(shape);
+
+ output->bag(bag);
+ output->name(name);
+ output->reorder<LexicalLayout>();
+
+ m->output()->insert(output);
+ }
enco::Bundle bundle;