Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / Nodes / CircleTransposeConv.cpp
index 7bdf46d..ddb1966 100644 (file)
@@ -30,6 +30,24 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const
   if (args.op.inputs.size() != 3)
     return false;
 
+  const auto &inputs = args.op.inputs;
+  const auto &tensors = args.reader.tensors();
+  const auto &filter_tensor = tensors.at(inputs.at(1));
+  const auto &filter_shape = filter_tensor.get()->shape;
+  const auto &ifm_tensor = tensors.at(inputs.at(2));
+  const auto &ifm_shape = ifm_tensor.get()->shape;
+
+  // ifm and filters must be 4-D tensor
+  if (ifm_shape.size() != 4)
+    return false;
+  if (filter_shape.size() != 4)
+    return false;
+
+  // input shape : [batch, height, width, in_channels]
+  // filters shape : [output_channels, height, weight, in_channels]
+  if (ifm_tensor.get()->shape.at(3) != filter_tensor.get()->shape.at(3))
+    return false;
+
   return true;
 }
 
@@ -39,9 +57,9 @@ CircleNode *CircleTransposeConvGraphBuilder::build_node(const circle::OperatorT
 {
   auto *node = graph->nodes()->create<CircleTransposeConv>();
 
-  node->inputSizes(inputs[0]);
-  node->filter(inputs[1]);
-  node->outBackprop(inputs[2]);
+  node->inputSizes(inputs.at(0));
+  node->filter(inputs.at(1));
+  node->outBackprop(inputs.at(2));
 
   const auto *options = op.builtin_options.AsTransposeConvOptions();
   node->padding(luci_padding(options->padding));