From 83add2165b680b3cf38403a7ce90ea86febd4cc7 Mon Sep 17 00:00:00 2001 From: Kevin May Date: Tue, 26 Mar 2019 11:39:19 +0000 Subject: [PATCH] MLCE-101 Deeplab v3+ (Add Tf Lite Parser Dilation Check) * Add Parse Exception for convolutions without default dilation Signed-off-by: Kevin May Change-Id: I1b8f75c2d871d81161eb5378ced277438e809ba2 --- src/armnnTfLiteParser/TfLiteParser.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 31e808f..b9a3522 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -226,6 +226,24 @@ void CheckBufferSize(TfLiteParser::BufferRawPtr bufferPtr, #define CHECK_BUFFER_SIZE(BUFFER_PTR, TENSOR_INFO, BUFFER_ID) \ CheckBufferSize(BUFFER_PTR, TENSOR_INFO, BUFFER_ID, CHECK_LOCATION()) +uint32_t CheckDilation(const int32_t dilationFactor, + size_t operatorIndex, + const CheckLocation& location) +{ + if (dilationFactor != 1) + { + std::stringstream ss; + ss << "ArmNN only supports convolution layers with dilations [1,1,1,1] for operator with index " + << operatorIndex << location.AsString(); + throw ParseException(ss.str()); + } + + return static_cast(dilationFactor); +} + +#define CHECK_DILATION(DILATION_FACTOR, OPERATOR_INDEX) \ + CheckDilation(DILATION_FACTOR, OPERATOR_INDEX, CHECK_LOCATION()) + bool IsActivationSupported(tflite::ActivationFunctionType activationType) { switch(activationType) @@ -694,6 +712,9 @@ void TfLiteParser::ParseConv2D(size_t subgraphIndex, size_t operatorIndex) desc.m_StrideY = CHECKED_NON_NEGATIVE(options->stride_h); desc.m_DataLayout = armnn::DataLayout::NHWC; + CHECK_DILATION(options->dilation_h_factor, operatorIndex); + CHECK_DILATION(options->dilation_w_factor, operatorIndex); + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(inputs.size(), 2, 3); @@ -779,6 +800,9 @@ void TfLiteParser::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorInd auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); + CHECK_DILATION(options->dilation_h_factor, operatorIndex); + CHECK_DILATION(options->dilation_w_factor, operatorIndex); + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]); -- 2.7.4