From ce5045a00485f8a8c35814c0781ccbcca5678e5c Mon Sep 17 00:00:00 2001 From: Kevin May Date: Wed, 2 Oct 2019 14:07:47 +0100 Subject: [PATCH] IVGCVSW-3932 Add frontend for INSTANCE_NORMALIZATION Signed-off-by: Kevin May Change-Id: Ib152148ccd8d2733c617d0cf9402661fc6b71316 --- Android.mk | 1 + CMakeLists.txt | 2 + include/armnn/Descriptors.hpp | 20 +++++++++ include/armnn/DescriptorsFwd.hpp | 1 + include/armnn/ILayerSupport.hpp | 6 +++ include/armnn/ILayerVisitor.hpp | 8 ++++ include/armnn/INetwork.hpp | 7 +++ include/armnn/LayerVisitorBase.hpp | 4 ++ src/armnn/InternalTypes.cpp | 1 + src/armnn/InternalTypes.hpp | 1 + src/armnn/LayersFwd.hpp | 2 + src/armnn/Network.cpp | 6 +++ src/armnn/Network.hpp | 3 ++ src/armnn/layers/InstanceNormalizationLayer.cpp | 52 ++++++++++++++++++++++ src/armnn/layers/InstanceNormalizationLayer.hpp | 43 ++++++++++++++++++ .../test/TestNameAndDescriptorLayerVisitor.cpp | 23 ++++++++++ .../test/TestNameAndDescriptorLayerVisitor.hpp | 34 ++++++++++++++ src/armnnSerializer/Serializer.cpp | 8 ++++ src/armnnSerializer/Serializer.hpp | 4 ++ src/backends/backendsCommon/LayerSupportBase.cpp | 8 ++++ src/backends/backendsCommon/LayerSupportBase.hpp | 6 +++ src/backends/backendsCommon/WorkloadData.cpp | 46 +++++++++++++++++++ src/backends/backendsCommon/WorkloadData.hpp | 15 +++++++ src/backends/backendsCommon/WorkloadFactory.cpp | 22 +++++++++ src/backends/backendsCommon/WorkloadFactory.hpp | 4 ++ .../test/IsLayerSupportedTestImpl.hpp | 2 + 26 files changed, 329 insertions(+) create mode 100644 src/armnn/layers/InstanceNormalizationLayer.cpp create mode 100644 src/armnn/layers/InstanceNormalizationLayer.hpp diff --git a/Android.mk b/Android.mk index 4c3789c..6bf9a50 100644 --- a/Android.mk +++ b/Android.mk @@ -138,6 +138,7 @@ LOCAL_SRC_FILES := \ src/armnn/layers/GatherLayer.cpp \ src/armnn/layers/GreaterLayer.cpp \ src/armnn/layers/InputLayer.cpp \ + src/armnn/layers/InstanceNormalizationLayer.cpp \ src/armnn/layers/L2NormalizationLayer.cpp \ src/armnn/layers/LstmLayer.cpp \ src/armnn/layers/MaximumLayer.cpp \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ae352d..94da6bf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -280,6 +280,8 @@ list(APPEND armnn_sources src/armnn/layers/GreaterLayer.hpp src/armnn/layers/InputLayer.hpp src/armnn/layers/InputLayer.cpp + src/armnn/layers/InstanceNormalizationLayer.hpp + src/armnn/layers/InstanceNormalizationLayer.cpp src/armnn/layers/L2NormalizationLayer.hpp src/armnn/layers/L2NormalizationLayer.cpp src/armnn/layers/LstmLayer.cpp diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index c973089..5bf4043 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -468,6 +468,26 @@ struct BatchNormalizationDescriptor DataLayout m_DataLayout; }; +/// An InstanceNormalizationDescriptor for InstanceNormalizationLayer +struct InstanceNormalizationDescriptor +{ + InstanceNormalizationDescriptor() + : m_Gamma(1.0f) + , m_Beta(0.0f) + , m_Eps(1e-12f) + , m_DataLayout(DataLayout::NCHW) + {} + + /// Gamma, the scale scalar value applied for the normalized tensor. Defaults to 1.0. + float m_Gamma; + /// Beta, the offset scalar value applied for the normalized tensor. Defaults to 1.0. + float m_Beta; + /// Epsilon, small scalar value added to variance to avoid dividing by zero. Defaults to 1e-12f. + float m_Eps; + /// The data layout to be used (NCHW, NHWC). + DataLayout m_DataLayout; +}; + /// A BatchToSpaceNdDescriptor for the BatchToSpaceNdLayer. struct BatchToSpaceNdDescriptor { diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index e9624f1..2cc9582 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -16,6 +16,7 @@ struct DepthwiseConvolution2dDescriptor; struct DetectionPostProcessDescriptor; struct FakeQuantizationDescriptor; struct FullyConnectedDescriptor; +struct InstanceNormalizationDescriptor; struct L2NormalizationDescriptor; struct LstmDescriptor; struct MeanDescriptor; diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index e18b86a..fef7595 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -157,6 +157,12 @@ public: virtual bool IsInputSupported(const TensorInfo& input, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsInstanceNormalizationSupported( + const TensorInfo& input, + const TensorInfo& output, + const InstanceNormalizationDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsL2NormalizationSupported(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index 486a13f..b9c96d5 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -206,6 +206,14 @@ public: LayerBindingId id, const char* name = nullptr) = 0; + /// Function that an instance normalization layer should call back to when its Accept(ILayerVisitor&) + /// function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param desc - Parameters for the instance normalization operation. + /// @param name - Optional name for the layer. + virtual void VisitInstanceNormalizationLayer(const IConnectableLayer* layer, + const InstanceNormalizationDescriptor& desc, + const char* name = nullptr) = 0; /// Function that an L2 normalization layer should call back to when its Accept(ILayerVisitor&) /// function is invoked. Normalization is performed along dimension 1, but requires a 4d input. diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 0e0b99a..dc831db 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -329,6 +329,13 @@ public: virtual IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor, const char* name = nullptr) = 0; + /// Adds an instance normalization layer to the network. + /// @param desc - Parameters for the instance normalization operation. + /// @param name - Optional name for the layer. + /// @return - Interface for configuring the layer. + virtual IConnectableLayer* AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc, + const char* name = nullptr) = 0; + /// Adds an L2 normalization layer to the network. /// Normalization is performed along dimension 1, but requires a 4d input. /// @param desc - Parameters for the L2 normalization operation. diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index 65d2303..719e59d 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -112,6 +112,10 @@ public: LayerBindingId, const char*) override { DefaultPolicy::Apply(__func__); } + void VisitInstanceNormalizationLayer(const IConnectableLayer*, + const InstanceNormalizationDescriptor&, + const char*) override { DefaultPolicy::Apply(__func__); } + void VisitL2NormalizationLayer(const IConnectableLayer*, const L2NormalizationDescriptor&, const char*) override { DefaultPolicy::Apply(__func__); } diff --git a/src/armnn/InternalTypes.cpp b/src/armnn/InternalTypes.cpp index e6f7367..612d00b 100644 --- a/src/armnn/InternalTypes.cpp +++ b/src/armnn/InternalTypes.cpp @@ -38,6 +38,7 @@ char const* GetLayerTypeAsCString(LayerType type) case LayerType::Gather: return "Gather"; case LayerType::Greater: return "Greater"; case LayerType::Input: return "Input"; + case LayerType::InstanceNormalization: return "InstanceNormalization"; case LayerType::L2Normalization: return "L2Normalization"; case LayerType::Lstm: return "Lstm"; case LayerType::Maximum: return "Maximum"; diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp index fbca9bc..039d0f8 100644 --- a/src/armnn/InternalTypes.hpp +++ b/src/armnn/InternalTypes.hpp @@ -38,6 +38,7 @@ enum class LayerType Gather, Greater, Input, + InstanceNormalization, L2Normalization, Lstm, Maximum, diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp index 3599eac..1f539f3 100644 --- a/src/armnn/LayersFwd.hpp +++ b/src/armnn/LayersFwd.hpp @@ -30,6 +30,7 @@ #include "layers/GatherLayer.hpp" #include "layers/GreaterLayer.hpp" #include "layers/InputLayer.hpp" +#include "layers/InstanceNormalizationLayer.hpp" #include "layers/L2NormalizationLayer.hpp" #include "layers/LstmLayer.hpp" #include "layers/MaximumLayer.hpp" @@ -113,6 +114,7 @@ DECLARE_LAYER(FullyConnected) DECLARE_LAYER(Gather) DECLARE_LAYER(Greater) DECLARE_LAYER(Input) +DECLARE_LAYER(InstanceNormalization) DECLARE_LAYER(L2Normalization) DECLARE_LAYER(Lstm) DECLARE_LAYER(Maximum) diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index cf9a138..9d10b9a 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1224,6 +1224,12 @@ resizeDescriptor, const char* name) return m_Graph->AddLayer(resizeDescriptor, name); } +IConnectableLayer* Network::AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc, + const char* name) +{ + return m_Graph->AddLayer(desc, name); +} + IConnectableLayer* Network::AddL2NormalizationLayer(const L2NormalizationDescriptor& desc, const char* name) { diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 4a8bfbc..e11f3d2 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -152,6 +152,9 @@ public: IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor, const char* name = nullptr) override; + IConnectableLayer* AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc, + const char* name = nullptr) override; + IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc, const char* name = nullptr) override; diff --git a/src/armnn/layers/InstanceNormalizationLayer.cpp b/src/armnn/layers/InstanceNormalizationLayer.cpp new file mode 100644 index 0000000..fc3044a --- /dev/null +++ b/src/armnn/layers/InstanceNormalizationLayer.cpp @@ -0,0 +1,52 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include "InstanceNormalizationLayer.hpp" + +#include "LayerCloneBase.hpp" + +#include +#include +#include + +namespace armnn +{ + +InstanceNormalizationLayer::InstanceNormalizationLayer(const InstanceNormalizationDescriptor& param, const char* name) + : LayerWithParameters(1, 1, LayerType::InstanceNormalization, param, name) +{ +} + +std::unique_ptr InstanceNormalizationLayer::CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const +{ + InstanceNormalizationQueueDescriptor descriptor; + return factory.CreateInstanceNormalization(descriptor, PrepInfoAndDesc(descriptor, graph)); +} + +InstanceNormalizationLayer* InstanceNormalizationLayer::Clone(Graph& graph) const +{ + return CloneBase(graph, m_Param, GetName()); +} + +void InstanceNormalizationLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(1, CHECK_LOCATION()); + + auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + + BOOST_ASSERT(inferredShapes.size() == 1); + + ConditionalThrowIfNotEqual( + "InstanceNormalizationLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", + GetOutputSlot(0).GetTensorInfo().GetShape(), + inferredShapes[0]); +} + +void InstanceNormalizationLayer::Accept(ILayerVisitor& visitor) const +{ + visitor.VisitInstanceNormalizationLayer(this, GetParameters(), GetName()); +} + +} // namespace armnn diff --git a/src/armnn/layers/InstanceNormalizationLayer.hpp b/src/armnn/layers/InstanceNormalizationLayer.hpp new file mode 100644 index 0000000..9ba5673 --- /dev/null +++ b/src/armnn/layers/InstanceNormalizationLayer.hpp @@ -0,0 +1,43 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "LayerWithParameters.hpp" + +namespace armnn +{ + +/// This layer represents an instance normalization operation. +class InstanceNormalizationLayer : public LayerWithParameters +{ +public: + /// Makes a workload for the InstanceNormalization type. + /// @param [in] graph The graph where this layer can be found. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + InstanceNormalizationLayer* Clone(Graph& graph) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref InstanceNormalizationLayer. + void ValidateTensorShapesFromInputs() override; + + void Accept(ILayerVisitor& visitor) const override; + +protected: + /// Constructor to create a InstanceNormalizationLayer. + /// @param [in] param InstanceNormalizationDescriptor to configure the Instance normalization operation. + /// @param [in] name Optional name for the layer. + InstanceNormalizationLayer(const InstanceNormalizationDescriptor& param, const char* name); + + /// Default destructor + ~InstanceNormalizationLayer() = default; +}; + +} // namespace diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp index 653612f..dcc5dc4 100644 --- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp +++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp @@ -282,6 +282,29 @@ BOOST_AUTO_TEST_CASE(CheckResizeLayerVisitorNameNullAndDescriptor) layer->Accept(visitor); } +BOOST_AUTO_TEST_CASE(CheckInstanceNormalizationLayerVisitorNameAndDescriptor) +{ + const char* layerName = "InstanceNormalizationLayer"; + InstanceNormalizationDescriptor descriptor; + descriptor.m_DataLayout = DataLayout::NHWC; + TestInstanceNormalizationLayerVisitor visitor(descriptor, layerName); + Network net; + + IConnectableLayer *const layer = net.AddInstanceNormalizationLayer(descriptor, layerName); + layer->Accept(visitor); +} + +BOOST_AUTO_TEST_CASE(CheckInstanceNormalizationLayerVisitorNameNullAndDescriptor) +{ + InstanceNormalizationDescriptor descriptor; + descriptor.m_DataLayout = DataLayout::NHWC; + TestInstanceNormalizationLayerVisitor visitor(descriptor); + Network net; + + IConnectableLayer *const layer = net.AddInstanceNormalizationLayer(descriptor); + layer->Accept(visitor); +} + BOOST_AUTO_TEST_CASE(CheckL2NormalizationLayerVisitorNameAndDescriptor) { const char* layerName = "L2NormalizationLayer"; diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp index f1936d6..aa0b359 100644 --- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp +++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp @@ -418,6 +418,40 @@ public: }; }; +class TestInstanceNormalizationLayerVisitor : public TestLayerVisitor +{ +private: + InstanceNormalizationDescriptor m_VisitorDescriptor; + +public: + explicit TestInstanceNormalizationLayerVisitor(const InstanceNormalizationDescriptor& desc, + const char* name = nullptr) + : TestLayerVisitor(name) + { + m_VisitorDescriptor.m_Beta = desc.m_Beta; + m_VisitorDescriptor.m_Gamma = desc.m_Gamma; + m_VisitorDescriptor.m_Eps = desc.m_Eps; + m_VisitorDescriptor.m_DataLayout = desc.m_DataLayout; + }; + + void CheckDescriptor(const InstanceNormalizationDescriptor& desc) + { + BOOST_CHECK(desc.m_Beta == m_VisitorDescriptor.m_Beta); + BOOST_CHECK(desc.m_Gamma == m_VisitorDescriptor.m_Gamma); + BOOST_CHECK(desc.m_Eps == m_VisitorDescriptor.m_Eps); + BOOST_CHECK(desc.m_DataLayout == m_VisitorDescriptor.m_DataLayout); + } + + void VisitInstanceNormalizationLayer(const IConnectableLayer* layer, + const InstanceNormalizationDescriptor& desc, + const char* name = nullptr) override + { + CheckLayerPointer(layer); + CheckDescriptor(desc); + CheckLayerName(name); + }; +}; + class TestL2NormalizationLayerVisitor : public TestLayerVisitor { private: diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 99ba7e3..84a1b6b 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -426,6 +426,14 @@ void SerializerVisitor::VisitGreaterLayer(const armnn::IConnectableLayer* layer, CreateAnyLayer(fbGreaterLayer.o, serializer::Layer::Layer_GreaterLayer); } +void SerializerVisitor::VisitInstanceNormalizationLayer( + const armnn::IConnectableLayer* layer, + const armnn::InstanceNormalizationDescriptor& instanceNormalizationDescriptor, + const char* name) +{ + throw UnimplementedException("SerializerVisitor::InstanceNormalizationLayer is not implemented"); +} + void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer, const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor, const char* name) diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 429487d..f98bd17 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -125,6 +125,10 @@ public: armnn::LayerBindingId id, const char* name = nullptr) override; + void VisitInstanceNormalizationLayer(const armnn::IConnectableLayer* layer, + const armnn::InstanceNormalizationDescriptor& instanceNormalizationDescriptor, + const char* name = nullptr) override; + void VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer, const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor, const char* name = nullptr) override; diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 656407d..c41f0b1 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -234,6 +234,14 @@ bool LayerSupportBase::IsInputSupported(const TensorInfo& input, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsInstanceNormalizationSupported(const TensorInfo& input, + const TensorInfo& output, + const InstanceNormalizationDescriptor& descriptor, + Optional reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index c3875e6..495870e 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -141,6 +141,12 @@ public: bool IsInputSupported(const TensorInfo& input, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsInstanceNormalizationSupported( + const TensorInfo& input, + const TensorInfo& output, + const InstanceNormalizationDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsL2NormalizationSupported(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index e49fd09..aca5023 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1233,6 +1233,52 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) } } +void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"InstanceNormalizationQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 1); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + + if (inputTensorInfo.GetNumDimensions() > 4) + { + throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported."); + } + + ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + + // Check the supported data types + std::vector supportedTypes = + { + DataType::Float32, + DataType::Float16 + }; + + ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); + ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + + ValidatePointer(m_Beta, descriptorName, "beta"); + ValidatePointer(m_Eps, descriptorName, "epsilon"); + ValidatePointer(m_Gamma, descriptorName, "gamma"); + + const TensorInfo& beta = m_Beta->GetTensorInfo(); + const TensorInfo& epsilon = m_Eps->GetTensorInfo(); + const TensorInfo& gamma = m_Gamma->GetTensorInfo(); + + ValidateTensorNumDimensions(beta, descriptorName, 1, "beta"); + ValidateTensorNumDimensions(epsilon, descriptorName, 1, "epsilon"); + ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma"); + + ValidateTensorDataTypesMatch(inputTensorInfo, beta, descriptorName, "input", "beta"); + ValidateTensorDataTypesMatch(inputTensorInfo, epsilon, descriptorName, "input", "epsilon"); + ValidateTensorDataTypesMatch(inputTensorInfo, gamma, descriptorName, "input", "gamma"); +} + void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"L2NormalizationQueueDescriptor"}; diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 177bfb7..14d7b58 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -307,6 +307,21 @@ struct FakeQuantizationQueueDescriptor : QueueDescriptorWithParameters +{ + InstanceNormalizationQueueDescriptor() + : m_Beta(nullptr) + , m_Eps(nullptr) + , m_Gamma(nullptr) + { + } + + const ConstCpuTensorHandle* m_Beta; + const ConstCpuTensorHandle* m_Eps; + const ConstCpuTensorHandle* m_Gamma; + void Validate(const WorkloadInfo& workloadInfo) const; +}; + struct L2NormalizationQueueDescriptor : QueueDescriptorWithParameters { void Validate(const WorkloadInfo& workloadInfo) const; diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 44888b3..98fe158 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -371,6 +371,21 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason); break; } + case LayerType::InstanceNormalization: + { + auto cLayer = boost::polymorphic_downcast(&layer); + const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters(); + + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsInstanceNormalizationSupported( + OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + descriptor, + reason); + break; + } case LayerType::L2Normalization: { auto cLayer = boost::polymorphic_downcast(&layer); @@ -1139,6 +1154,13 @@ std::unique_ptr IWorkloadFactory::CreateGreater(const GreaterQueueDes return std::unique_ptr(); } +std::unique_ptr IWorkloadFactory::CreateInstanceNormalization( + const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr(); +} + std::unique_ptr IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 2809e2f..9fa0221 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -120,6 +120,10 @@ public: virtual std::unique_ptr CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr CreateInstanceNormalization( + const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + virtual std::unique_ptr CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index e492cd6..c860414 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -435,6 +435,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Greater) DECLARE_LAYER_POLICY_CUSTOM_PARAM(Input, armnn::LayerBindingId) +DECLARE_LAYER_POLICY_2_PARAM(InstanceNormalization) + DECLARE_LAYER_POLICY_2_PARAM(L2Normalization) DECLARE_LAYER_POLICY_2_PARAM(Lstm) -- 2.7.4