[nnc] Remove limitation on input shape in importers (#2903)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Tue, 22 Jan 2019 12:25:44 +0000 (15:25 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 22 Jan 2019 12:25:44 +0000 (15:25 +0300)
Remove limitations on input shape in Caffe2, TensorFlow Lite and ONNX importers.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp

index d9bd917..4c28296 100644 (file)
@@ -499,8 +499,6 @@ Caffe2OpCreator::convertReshape(const std::vector<mir::IODescriptor>& inputs,
 
 std::vector<IODescriptor>
 Caffe2OpCreator::createInput(const std::string& name, const mir::Shape& shape) {
-  // TODO For now we only support convolutional networks with one element per batch.
-  assert(shape.rank() == 4 && shape.dim(0) == 1);
   auto variable = _graph->create<ops::VariableOp>(name, shape);
   return {variable->getOutput(0)};
 }
index 3ddc20a..b78de2f 100644 (file)
@@ -185,13 +185,11 @@ void ONNXImporterImpl::createGraphInputs() {
       auto constant = _graph->create<mir::ops::ConstantOp>(name, _constantTensors.at(name));
       _tensorNameToIODescriptor[name] = constant->getOutput(0);
     } else {
-      // We're dealing with graph input (assuming the picture only)
-      auto onnx_input_shape = input.type().tensor_type().shape();
-      assert(onnx_input_shape.dim_size() == 4);
-      mir::Shape shape(4);
+      const auto& onnx_input_shape = input.type().tensor_type().shape();
+      mir::Shape shape(onnx_input_shape.dim_size());
       for (int i = 0; i < onnx_input_shape.dim_size(); i++) {
         assert(onnx_input_shape.dim(i).has_dim_value());
-        shape.dim(i) = onnx_input_shape.dim(i).dim_value();
+        shape.dim(i) = static_cast<int32_t>(onnx_input_shape.dim(i).dim_value());
       }
       // TODO: Temporary solution!
       auto node = _graph->create<mir::ops::VariableOp>(name, shape);
index bcfdb9a..d837548 100644 (file)
@@ -154,9 +154,6 @@ void TfliteImporter::walkSubGraph(const SubGraph* s) {
     const Tensor* t = (*s->tensors())[i];
     Shape input_shape = ShapeHelper::createShape(*t->shape(), t->shape()->size());
 
-    // TODO Remove this limitation.
-    assert(input_shape.dim(0) == 1);
-
     auto input = _graph->create<mir::ops::VariableOp>(t->name()->c_str(), input_shape);
     _tensorMap[i] = input->getOutput(0);
   }