MLCE-101 Add dilation parameter to serializer
authorMatthew Bentham <matthew.bentham@arm.com>
Mon, 13 May 2019 09:02:45 +0000 (10:02 +0100)
committerMatteo Martincigh <matteo.martincigh@arm.com>
Mon, 13 May 2019 16:41:09 +0000 (16:41 +0000)
Change-Id: I8142e179d38c7a2a9163cf3d30bd1f411e8e109c
Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
src/armnnDeserializer/Deserializer.cpp
src/armnnSerializer/ArmnnSchema.fbs
src/armnnSerializer/Serializer.cpp

index cbded60..8b790f7 100644 (file)
@@ -980,6 +980,8 @@ void Deserializer::ParseConvolution2d(GraphPtr graph, unsigned int layerIndex)
     descriptor.m_PadBottom = serializerDescriptor->padBottom();
     descriptor.m_StrideX = serializerDescriptor->strideX();
     descriptor.m_StrideY = serializerDescriptor->strideY();;
+    descriptor.m_DilationX = serializerDescriptor->dilationX();
+    descriptor.m_DilationY = serializerDescriptor->dilationY();;
     descriptor.m_BiasEnabled = serializerDescriptor->biasEnabled();;
     descriptor.m_DataLayout = ToDataLayout(serializerDescriptor->dataLayout());
 
index e8d72fc..0419c4b 100644 (file)
@@ -172,6 +172,8 @@ table Convolution2dDescriptor {
     padBottom:uint;
     strideX:uint;
     strideY:uint;
+    dilationX:uint = 1;
+    dilationY:uint = 1;
     biasEnabled:bool = false;
     dataLayout:DataLayout = NCHW;
 }
@@ -296,6 +298,8 @@ table DepthwiseConvolution2dDescriptor {
     padBottom:uint;
     strideX:uint;
     strideY:uint;
+    dilationX:uint = 1;
+    dilationY:uint = 1;
     biasEnabled:bool = false;
     dataLayout:DataLayout = NCHW;
 }
index 0b8ad06..865ed7a 100644 (file)
@@ -237,6 +237,8 @@ void SerializerVisitor::VisitConvolution2dLayer(const armnn::IConnectableLayer*
                                                               descriptor.m_PadBottom,
                                                               descriptor.m_StrideX,
                                                               descriptor.m_StrideY,
+                                                              descriptor.m_DilationX,
+                                                              descriptor.m_DilationY,
                                                               descriptor.m_BiasEnabled,
                                                               GetFlatBufferDataLayout(descriptor.m_DataLayout));
     auto flatBufferWeightsConstTensorInfo = CreateConstTensorInfo(weights);
@@ -272,6 +274,8 @@ void SerializerVisitor::VisitDepthwiseConvolution2dLayer(const armnn::IConnectab
                                                                descriptor.m_PadBottom,
                                                                descriptor.m_StrideX,
                                                                descriptor.m_StrideY,
+                                                               descriptor.m_DilationX,
+                                                               descriptor.m_DilationY,
                                                                descriptor.m_BiasEnabled,
                                                                GetFlatBufferDataLayout(descriptor.m_DataLayout));