int inputNum = 0;
for (auto &input : inputs) {
auto inputLayout = input.second->getTensorDesc().getLayout();
- if (inputLayout != Layout::NC && inputLayout != Layout::CN && inputLayout != NCHW) {
+ if (inputLayout != Layout::NC && inputLayout != Layout::CN && inputLayout != Layout::CHW && inputLayout != Layout::NCHW) {
THROW_GNA_EXCEPTION << "Expected input blob to have Layout::NC or Layout::CN, but was: "
<< input.second->getTensorDesc().getLayout();
}
- if (inputLayout == NCHW) {
- inputLayout = NC;
+
+ auto dims = input.second->getTensorDesc().getDims();
+ if (inputLayout == Layout::CHW && (dims[0] != 1 || dims[1] != 1)) {
+ THROW_GNA_EXCEPTION << "For Layout::CHW only dimension with height = 1 and channel = 1 is supported, but was: "
+ << dims;
+ }
+
+ if (inputLayout == Layout::NCHW || inputLayout == Layout::CHW) {
+ // specific case that can be squeezed to 2d
+ inputLayout = Layout::NC;
}
+
auto is2D = input.second->getTensorDesc().getLayout() == Layout::NC || input.second->getTensorDesc().getLayout() == Layout::CN;
+ auto is3D = input.second->getTensorDesc().getLayout() == Layout::CHW;
if (!inputsDesc->ptr_inputs_global_id.count(input.first)) {
// should not happen in user code however might happen if there any non executable network based integration of GNAPlugin instance
}
}
- auto dims = input.second->getTensorDesc().getDims();
-
auto importedElements = is2D ? dims[dims.size() - 1] : dims[dims.size() - 1] * dims[dims.size() - 2] * dims[dims.size() - 3];
- auto importedFrames = dims[0];
+ auto importedFrames = is3D ? 1 : dims[0];
auto targetGroups = is2D ? dims[dims.size() - 2] : dims[0]; // TODO: no proper support for groups yet
auto importedElementSizeBytes = gnaFlags->sw_fp32 ? 4 : 2;
auto & outputBlob = outputBlobIt.second;
auto & outputDesc = outputsDesc[output_idx];
if (outputBlob->getTensorDesc().getLayout() == Layout::NC || outputBlob->getTensorDesc().getLayout() == Layout::CN
- || outputBlob->getTensorDesc().getLayout() == Layout::NCHW || outputBlob->getTensorDesc().getLayout() == Layout::NHWC) {
+ || outputBlob->getTensorDesc().getLayout() == Layout::NCHW || outputBlob->getTensorDesc().getLayout() == Layout::CHW) {
// TODO: rotate can be incorporated with exporting - used only in unit tests so far
// TODO: restore:
// if (orientation_out != kDnnInterleavedOrientation) {
// dims[dims.size() - 1]);
// }
auto is2D = outputBlob->getTensorDesc().getLayout() == Layout::NC || outputBlob->getTensorDesc().getLayout() == Layout::CN;
+ auto is3D = outputBlob->getTensorDesc().getLayout() == Layout::CHW;
auto& exportOutputDims = outputBlob->getTensorDesc().getDims();
- auto batchSize = exportOutputDims[0];
+ auto batchSize = is3D ? 1 : exportOutputDims[0];
auto elementsPerBatch = is2D ? exportOutputDims[exportOutputDims.size() - 1]
: exportOutputDims[exportOutputDims.size() - 1]
* exportOutputDims[exportOutputDims.size() - 2]
#endif
}
} else {
- THROW_GNA_EXCEPTION << "Expected output blob to have Layout::NC, Layout::CN, Layout::NCHW or Layout::NHWC. But was "
+ THROW_GNA_EXCEPTION << "Expected output blob to have Layout::NC, Layout::CN, Layout::NCHW or Layout::CHW. But was "
<< outputBlob->getTensorDesc().getLayout();
}
// need to have intermediate blob for interleave conversion
InferenceEngine::Blob::Ptr outputBlob;
auto outputDims = outputsDataMap[name]->getTensorDesc().getDims();
- outputBlob = make_blob_with_precision(TensorDesc(precision, outputDims, outputDims.size() == 2 ? NC : NCHW));
+ outputBlob = make_blob_with_precision(TensorDesc(precision, outputDims, outputDims.size() == 2 ? NC : (outputDims.size() == 3 ? CHW : NCHW)));
outputBlob->allocate();
return outputBlob;
}
// need to have intermediate blob for interleave conversion
// TODO: NCHW format support is experimental = c++ MO did insert reshape, while TF mo - not
auto inputDims = inputsDataMap[name]->getTensorDesc().getDims();
- inputBlob = make_blob_with_precision(TensorDesc(precision, inputDims, inputDims.size() == 2 ? NC : NCHW));
+ inputBlob = make_blob_with_precision(TensorDesc(precision, inputDims, inputDims.size() == 2 ? NC : (inputDims.size() == 3 ? CHW : NCHW)));
inputBlob->allocate();
return inputBlob;
}