if (args.op.inputs.size() != 3)
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &tensors = args.reader.tensors();
+ const auto &filter_tensor = tensors.at(inputs.at(1));
+ const auto &filter_shape = filter_tensor.get()->shape;
+ const auto &ifm_tensor = tensors.at(inputs.at(2));
+ const auto &ifm_shape = ifm_tensor.get()->shape;
+
+ // ifm and filters must be 4-D tensor
+ if (ifm_shape.size() != 4)
+ return false;
+ if (filter_shape.size() != 4)
+ return false;
+
+ // input shape : [batch, height, width, in_channels]
+ // filters shape : [output_channels, height, weight, in_channels]
+ if (ifm_tensor.get()->shape.at(3) != filter_tensor.get()->shape.at(3))
+ return false;
+
return true;
}
{
auto *node = graph->nodes()->create<CircleTransposeConv>();
- node->inputSizes(inputs[0]);
- node->filter(inputs[1]);
- node->outBackprop(inputs[2]);
+ node->inputSizes(inputs.at(0));
+ node->filter(inputs.at(1));
+ node->outBackprop(inputs.at(2));
const auto *options = op.builtin_options.AsTransposeConvOptions();
node->padding(luci_padding(options->padding));