From: Sadik Armagan Date: Fri, 5 Apr 2019 14:25:46 +0000 (+0100) Subject: IVGCVSW-2914 Add Switch Layer and no-op factory method X-Git-Tag: submit/tizen/20200316.035456~727 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=eff363d58992fb6384053259f9e1ee773f8cd4df;p=platform%2Fupstream%2Farmnn.git IVGCVSW-2914 Add Switch Layer and no-op factory method Change-Id: I6a6ece708a49e8a97c83a3e7fec11c88af1e1cfa Signed-off-by: Sadik Armagan Signed-off-by: Aron Virginas-Tar --- diff --git a/Android.mk b/Android.mk index 6d5a0fa..cd26fa5 100644 --- a/Android.mk +++ b/Android.mk @@ -127,6 +127,7 @@ LOCAL_SRC_FILES := \ src/armnn/layers/SplitterLayer.cpp \ src/armnn/layers/StridedSliceLayer.cpp \ src/armnn/layers/SubtractionLayer.cpp \ + src/armnn/layers/SwitchLayer.cpp \ src/armnn/Descriptors.cpp \ src/armnn/Exceptions.cpp \ src/armnn/Graph.cpp \ diff --git a/CMakeLists.txt b/CMakeLists.txt index d1fe635..b297423 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -279,6 +279,8 @@ list(APPEND armnn_sources src/armnn/layers/StridedSliceLayer.hpp src/armnn/layers/SubtractionLayer.cpp src/armnn/layers/SubtractionLayer.hpp + src/armnn/layers/SwitchLayer.cpp + src/armnn/layers/SwitchLayer.hpp src/armnn/BackendSettings.hpp src/armnn/CompatibleTypes.hpp src/armnn/Descriptors.cpp diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index 1b75810..dc84302 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -257,6 +257,12 @@ public: const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + + virtual bool IsSwitchSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output0, + const TensorInfo& output1, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; }; // class ILayerSupport using ILayerSupportSharedPtr = std::shared_ptr; diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index 3a4c39b..eabad58 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -341,6 +341,12 @@ public: virtual void VisitSubtractionLayer(const IConnectableLayer* layer, const char* name = nullptr) = 0; + /// Function a switch layer should call back to when its Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param name - Optional name for the layer. + virtual void VisitSwitchLayer(const IConnectableLayer* layer, + const char* name = nullptr) = 0; + }; } // namespace armnn diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 8243b39..a15ceb1 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -399,6 +399,11 @@ public: /// @ return - Interface for configuring the layer. virtual IConnectableLayer* AddGatherLayer(const char* name = nullptr) = 0; + /// Adds a switch layer to the network. + /// @param name - Optional name for the layer. + /// @return - Interface for configuring the layer. + virtual IConnectableLayer* AddSwitchLayer(const char* name = nullptr) = 0; + virtual void Accept(ILayerVisitor& visitor) const = 0; protected: diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp index e23fdd0..c9fc264 100644 --- a/include/armnn/LayerSupport.hpp +++ b/include/armnn/LayerSupport.hpp @@ -338,4 +338,13 @@ bool IsSubtractionSupported(const BackendId& backend, const TensorInfo& output, char* reasonIfUnsupported = nullptr, size_t reasonIfUnsupportedMaxLength = 1024); + +/// Deprecated in favor of IBackend and ILayerSupport interfaces +bool IsSwitchSupported(const BackendId& backend, + const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output0, + const TensorInfo& output1, + char* reasonIfUnsupported = nullptr, + size_t reasonIfUnsupportedMaxLength = 1024); } diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index f4e0f43..12eb225 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -178,6 +178,9 @@ public: void VisitGatherLayer(const IConnectableLayer*, const char*) override { DefaultPolicy::Apply(); } + + void VisitSwitchLayer(const IConnectableLayer*, + const char*) override { DefaultPolicy::Apply(); } }; } //namespace armnn diff --git a/src/armnn/InternalTypes.cpp b/src/armnn/InternalTypes.cpp index 93a4f94..a811706 100644 --- a/src/armnn/InternalTypes.cpp +++ b/src/armnn/InternalTypes.cpp @@ -57,6 +57,7 @@ char const* GetLayerTypeAsCString(LayerType type) case LayerType::Splitter: return "Splitter"; case LayerType::StridedSlice: return "StridedSlice"; case LayerType::Subtraction: return "Subtraction"; + case LayerType::Switch: return "Switch"; default: BOOST_ASSERT_MSG(false, "Unknown layer type"); return "Unknown"; diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp index 7c7c601..5765b5b 100644 --- a/src/armnn/InternalTypes.hpp +++ b/src/armnn/InternalTypes.hpp @@ -57,9 +57,10 @@ enum class LayerType SpaceToBatchNd, Splitter, StridedSlice, + Subtraction, // Last layer goes here. LastLayer, - Subtraction = LastLayer + Switch = LastLayer }; const char* GetLayerTypeAsCString(LayerType type); diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp index bc6eec8..320d9ce 100644 --- a/src/armnn/LayerSupport.cpp +++ b/src/armnn/LayerSupport.cpp @@ -530,4 +530,15 @@ bool IsSubtractionSupported(const BackendId& backend, FORWARD_LAYER_SUPPORT_FUNC(backend, IsSubtractionSupported, input0, input1, output); } +bool IsSwitchSupported(const BackendId& backend, + const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output0, + const TensorInfo& output1, + char* reasonIfUnsupported, + size_t reasonIfUnsupportedMaxLength) +{ + FORWARD_LAYER_SUPPORT_FUNC(backend, IsSwitchSupported, input0, input1, output0, output1); +} + } // namespace armnn diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp index 0bd68e0..31cfa66 100644 --- a/src/armnn/LayersFwd.hpp +++ b/src/armnn/LayersFwd.hpp @@ -50,6 +50,7 @@ #include "layers/SplitterLayer.hpp" #include "layers/StridedSliceLayer.hpp" #include "layers/SubtractionLayer.hpp" +#include "layers/SwitchLayer.hpp" namespace armnn { @@ -122,5 +123,6 @@ DECLARE_LAYER(SpaceToBatchNd) DECLARE_LAYER(Splitter) DECLARE_LAYER(StridedSlice) DECLARE_LAYER(Subtraction) +DECLARE_LAYER(Switch) } diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 73db2e8..c1462c0 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -971,6 +971,11 @@ IConnectableLayer* Network::AddMergeLayer(const char* name) return m_Graph->AddLayer(name); } +IConnectableLayer* Network::AddSwitchLayer(const char* name) +{ + return m_Graph->AddLayer(name); +} + void Network::Accept(ILayerVisitor& visitor) const { for (auto layer : GetGraph()) diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index bb7b9eb..660ca87 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -176,6 +176,8 @@ public: IConnectableLayer* AddMergeLayer(const char* name = nullptr) override; + IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override; + void Accept(ILayerVisitor& visitor) const override; private: diff --git a/src/armnn/layers/SwitchLayer.cpp b/src/armnn/layers/SwitchLayer.cpp new file mode 100644 index 0000000..eae6e0d --- /dev/null +++ b/src/armnn/layers/SwitchLayer.cpp @@ -0,0 +1,60 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include "SwitchLayer.hpp" + +#include "LayerCloneBase.hpp" + +#include +#include + +namespace armnn +{ + +SwitchLayer::SwitchLayer(const char* name) + : Layer(2, 2, LayerType::Switch, name) +{} + +std::unique_ptr SwitchLayer::CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const +{ + SwitchQueueDescriptor descriptor; + return factory.CreateSwitch(descriptor, PrepInfoAndDesc(descriptor, graph)); +} + +SwitchLayer* SwitchLayer::Clone(Graph& graph) const +{ + return CloneBase(graph, GetName()); +} + +void SwitchLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(2, CHECK_LOCATION()); + + BOOST_ASSERT_MSG(GetNumOutputSlots() == 2, "SwitchLayer: The layer should return 2 outputs."); + + // Assuming first input is the Input and second input is the Constant + std::vector inferredShapes = InferOutputShapes({ + GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() }); + + BOOST_ASSERT(inferredShapes.size() == 1); + + ConditionalThrowIfNotEqual( + "SwitchLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", + GetOutputSlot(0).GetTensorInfo().GetShape(), + inferredShapes[0]); + + ConditionalThrowIfNotEqual( + "SwitchLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", + GetOutputSlot(1).GetTensorInfo().GetShape(), + inferredShapes[0]); +} + +void SwitchLayer::Accept(ILayerVisitor& visitor) const +{ + visitor.VisitSwitchLayer(this, GetName()); +} + +} // namespace armnn diff --git a/src/armnn/layers/SwitchLayer.hpp b/src/armnn/layers/SwitchLayer.hpp new file mode 100644 index 0000000..bfda8c2 --- /dev/null +++ b/src/armnn/layers/SwitchLayer.hpp @@ -0,0 +1,42 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "Layer.hpp" + +namespace armnn +{ + +/// This layer calculates both true and false outputs for input. +class SwitchLayer : public Layer +{ +public: + /// Makes a workload for the Switch type. + /// @param [in] graph The graph where this layer can be found. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + SwitchLayer* Clone(Graph& graph) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref SwitchLayer. + void ValidateTensorShapesFromInputs() override; + + void Accept(ILayerVisitor& visitor) const override; + +protected: + /// Constructor to create a SwitchLayer. + /// @param [in] name Optional name for the layer. + SwitchLayer(const char* name); + + /// Default destructor + ~SwitchLayer() = default; +}; + +} // namespace armnn diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 09cdd7c..076072e 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -222,6 +222,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_SplitterLayer] = &Deserializer::ParseSplitter; m_ParserFunctions[Layer_StridedSliceLayer] = &Deserializer::ParseStridedSlice; m_ParserFunctions[Layer_SubtractionLayer] = &Deserializer::ParseSubtraction; + m_ParserFunctions[Layer_SwitchLayer] = &Deserializer::ParseSwitch; } Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex) @@ -306,6 +307,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_StridedSliceLayer()->base(); case Layer::Layer_SubtractionLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_SubtractionLayer()->base(); + case Layer::Layer_SwitchLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_SwitchLayer()->base(); case Layer::Layer_NONE: default: throw ParseException(boost::str( @@ -2108,4 +2111,27 @@ void Deserializer::ParseMerge(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +void Deserializer::ParseSwitch(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + auto inputs = GetInputs(graph, layerIndex); + CHECK_LOCATION(); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 2); + + auto layerName = GetLayerName(graph, layerIndex); + IConnectableLayer* layer = m_Network->AddSwitchLayer(layerName.c_str()); + + armnn::TensorInfo output0TensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(output0TensorInfo); + + armnn::TensorInfo output1TensorInfo = ToTensorInfo(outputs[1]); + layer->GetOutputSlot(1).SetTensorInfo(output1TensorInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + } // namespace armnnDeserializer diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index df983d9..dfa5b06 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -114,6 +114,7 @@ private: void ParseSplitter(GraphPtr graph, unsigned int layerIndex); void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex); void ParseSubtraction(GraphPtr graph, unsigned int layerIndex); + void ParseSwitch(GraphPtr graph, unsigned int layerIndex); void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot); void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot); diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md index 4e5610c..770f7a8 100644 --- a/src/armnnDeserializer/DeserializerSupport.md +++ b/src/armnnDeserializer/DeserializerSupport.md @@ -41,5 +41,6 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * Splitter * StridedSlice * Subtraction +* Switch More machine learning layers will be supported in future releases. diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 8b275b6..e8d72fc 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -119,7 +119,8 @@ enum LayerType : uint { Lstm = 34, Quantize = 35, Dequantize = 36, - Merge = 37 + Merge = 37, + Switch = 38 } // Base layer table to be used as part of other layers @@ -529,6 +530,10 @@ table MergeLayer { base:LayerBase; } +table SwitchLayer { + base:LayerBase; +} + union Layer { ActivationLayer, AdditionLayer, @@ -567,7 +572,8 @@ union Layer { LstmLayer, QuantizeLayer, DequantizeLayer, - MergeLayer + MergeLayer, + SwitchLayer } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index fe30c3e..74d0c43 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -893,6 +893,14 @@ void SerializerVisitor::VisitSubtractionLayer(const armnn::IConnectableLayer* la CreateAnyLayer(fbSubtractionLayer.o, serializer::Layer::Layer_SubtractionLayer); } +void SerializerVisitor::VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name) +{ + auto fbSwitchBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Switch); + auto fbSwitchLayer = serializer::CreateSwitchLayer(m_flatBufferBuilder, fbSwitchBaseLayer); + + CreateAnyLayer(fbSwitchLayer.o, serializer::Layer::Layer_SwitchLayer); +} + fb::Offset SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, const serializer::LayerType layerType) { diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 775df83..4a71837 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -191,6 +191,9 @@ public: void VisitSubtractionLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; + + void VisitSwitchLayer(const armnn::IConnectableLayer* layer, + const char* name = nullptr) override; private: /// Creates the Input Slots and Output Slots and LayerBase for the layer. diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index a8335e1..5b54bfd 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -41,5 +41,6 @@ The Arm NN SDK Serializer currently supports the following layers: * Splitter * StridedSlice * Subtraction +* Switch More machine learning layers will be supported in future releases. diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index a1ef9ee..2724ba4 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -2113,6 +2113,56 @@ BOOST_AUTO_TEST_CASE(SerializeSubtraction) deserializedNetwork->Accept(verifier); } +BOOST_AUTO_TEST_CASE(SerializeSwitch) +{ + class SwitchLayerVerifier : public LayerVerifierBase + { + public: + SwitchLayerVerifier(const std::string& layerName, + const std::vector& inputInfos, + const std::vector& outputInfos) + : LayerVerifierBase(layerName, inputInfos, outputInfos) {} + + void VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name) override + { + VerifyNameAndConnections(layer, name); + } + + void VisitConstantLayer(const armnn::IConnectableLayer* layer, + const armnn::ConstTensor& input, + const char *name) override {} + }; + + const std::string layerName("switch"); + const armnn::TensorInfo info({ 1, 4 }, armnn::DataType::Float32); + + std::vector constantData = GenerateRandomData(info.GetNumElements()); + armnn::ConstTensor constTensor(info, constantData); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const constantLayer = network->AddConstantLayer(constTensor, "constant"); + armnn::IConnectableLayer* const switchLayer = network->AddSwitchLayer(layerName.c_str()); + armnn::IConnectableLayer* const trueOutputLayer = network->AddOutputLayer(0); + armnn::IConnectableLayer* const falseOutputLayer = network->AddOutputLayer(1); + + inputLayer->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(0)); + constantLayer->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(1)); + switchLayer->GetOutputSlot(0).Connect(trueOutputLayer->GetInputSlot(0)); + switchLayer->GetOutputSlot(1).Connect(falseOutputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(info); + constantLayer->GetOutputSlot(0).SetTensorInfo(info); + switchLayer->GetOutputSlot(0).SetTensorInfo(info); + switchLayer->GetOutputSlot(1).SetTensorInfo(info); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + SwitchLayerVerifier verifier(layerName, {info, info}, {info, info}); + deserializedNetwork->Accept(verifier); +} + BOOST_AUTO_TEST_CASE(SerializeDeserializeNonLinearNetwork) { class ConstantLayerVerifier : public LayerVerifierBase diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index fc2d502..6cad7b9 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -397,4 +397,13 @@ bool LayerSupportBase::IsSubtractionSupported(const TensorInfo& input0, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsSwitchSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output0, + const TensorInfo& output1, + Optional reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + } // namespace armnn diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 7c38b67..3c39f89 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -246,6 +246,12 @@ public: const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; + + bool IsSwitchSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output0, + const TensorInfo& output1, + Optional reasonIfUnsupported = EmptyOptional()) const override; }; } // namespace armnn diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 348c864..b850a65 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -75,45 +75,23 @@ void ValidateTensorShapesMatch(const TensorInfo& first, } //--------------------------------------------------------------- -void ValidateNoInputs(const WorkloadInfo& workloadInfo, std::string const& descName) +void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize) { - if (workloadInfo.m_InputTensorInfos.size() != 0) + if (workloadInfo.m_InputTensorInfos.size() != expectedSize) { throw InvalidArgumentException(descName + - ": Requires no inputs. " + - to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided."); - } -} - -//--------------------------------------------------------------- -void ValidateSingleInput(const WorkloadInfo& workloadInfo, std::string const& descName) -{ - if (workloadInfo.m_InputTensorInfos.size() != 1) - { - throw InvalidArgumentException(descName + - ": Requires exactly one input. " + - to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided." ); - } -} - -//--------------------------------------------------------------- -void ValidateTwoInputs(const WorkloadInfo& workloadInfo, std::string const& descName) -{ - if (workloadInfo.m_InputTensorInfos.size() != 2) - { - throw InvalidArgumentException(descName + - ": Requires exactly two workloadInfo.m_InputTensorInfos. " + + ": Requires exactly " + to_string(expectedSize) + "input(s). " + to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided."); } } //--------------------------------------------------------------- -void ValidateSingleOutput(const WorkloadInfo& workloadInfo, std::string const& descName) +void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize) { - if (workloadInfo.m_OutputTensorInfos.size() != 1) + if (workloadInfo.m_OutputTensorInfos.size() != expectedSize) { throw InvalidArgumentException(descName + - ": Requires exactly one output. " + + ": Requires exactly " + to_string(expectedSize) + " output(s). " + to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided."); } } @@ -242,6 +220,18 @@ void ValidateTensorQuantizationMultiplier(const TensorInfo& inputTensor1, const } } +//--------------------------------------------------------------- +void ValidateDataTypes(const TensorInfo& info, + const std::vector& supportedTypes, + std::string const& descName) +{ + auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType()); + if (iterator == supportedTypes.end()) + { + throw InvalidArgumentException(descName + ": " + " Tensor type is not supported."); + } +} + } //namespace void QueueDescriptor::ValidateInputsOutputs(const std::string& descName, @@ -254,8 +244,8 @@ void QueueDescriptor::ValidateInputsOutputs(const std::string& descName, //--------------------------------------------------------------- void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "MemCopyQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MemCopyQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MemCopyQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MemCopyQueueDescriptor" , 1); if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size()) { @@ -299,8 +289,8 @@ void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ActivationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ActivationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "ActivationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ActivationQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "ActivationQueueDescriptor", @@ -311,8 +301,8 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "SoftmaxQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "SoftmaxQueueDescriptor"); + ValidateNumInputs(workloadInfo, "SoftmaxQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "SoftmaxQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], @@ -324,7 +314,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "SplitterQueueDescriptor"); + ValidateNumInputs(workloadInfo, "SplitterQueueDescriptor", 1); if (workloadInfo.m_OutputTensorInfos.size() <= 0) { @@ -372,7 +362,7 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void MergerQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleOutput(workloadInfo, "MergerQueueDescriptor"); + ValidateNumOutputs(workloadInfo, "MergerQueueDescriptor", 1); if (m_Inputs.size() <= 0) { @@ -444,8 +434,8 @@ void MergerQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "FullyConnectedQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "FullyConnectedQueueDescriptor"); + ValidateNumInputs(workloadInfo, "FullyConnectedQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "FullyConnectedQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FullyConnectedQueueDescriptor", 2, "output"); if (!(workloadInfo.m_InputTensorInfos[0].GetNumDimensions() == 2 || @@ -487,8 +477,8 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c //--------------------------------------------------------------- void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "NormalizationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "NormalizationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "NormalizationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "NormalizationQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "NormalizationQueueDescriptor", @@ -498,8 +488,8 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "AdditionQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "AdditionQueueDescriptor"); + ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -513,8 +503,8 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "MultiplicationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MultiplicationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -526,8 +516,8 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "BatchNormalizationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "BatchNormalizationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "BatchNormalizationQueueDescriptor", @@ -554,8 +544,8 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "Convolution2dQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "Convolution2dQueueDescriptor"); + ValidateNumInputs(workloadInfo, "Convolution2dQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "Convolution2dQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "output"); @@ -580,8 +570,8 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1); ValidateTensorNumDimensions( workloadInfo.m_InputTensorInfos[0], "DepthwiseConvolution2dQueueDescriptor", 4, "input"); @@ -625,8 +615,8 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "PermuteQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "PermuteQueueDescriptor"); + ValidateNumInputs(workloadInfo, "PermuteQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "PermuteQueueDescriptor", 1); const PermutationVector& mapping = m_Parameters.m_DimMappings; @@ -650,8 +640,8 @@ void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "Pooling2dQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "Pooling2dQueueDescriptor"); + ValidateNumInputs(workloadInfo, "Pooling2dQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "Pooling2dQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "output"); @@ -659,8 +649,8 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ResizeBilinearQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ResizeBilinearQueueDescriptor"); + ValidateNumInputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "output"); @@ -694,8 +684,8 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "FakeQuantizationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "FakeQuantizationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "output"); @@ -713,8 +703,8 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "L2NormalizationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "L2NormalizationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "L2NormalizationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "L2NormalizationQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "output"); @@ -727,8 +717,8 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateNoInputs(workloadInfo, "ConstantQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ConstantQueueDescriptor"); + ValidateNumInputs(workloadInfo, "ConstantQueueDescriptor", 0); + ValidateNumOutputs(workloadInfo, "ConstantQueueDescriptor", 1); if (!m_LayerOutput) { @@ -744,8 +734,8 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ReshapeQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ReshapeQueueDescriptor"); + ValidateNumInputs(workloadInfo, "ReshapeQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ReshapeQueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetNumElements() != workloadInfo.m_OutputTensorInfos[0].GetNumElements()) { @@ -757,8 +747,8 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "SpaceToBatchNdQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "SpaceToBatchNdQueueDescriptor"); + ValidateNumInputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "output"); @@ -804,8 +794,8 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "FloorQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "FlootQueueDescriptor"); + ValidateNumInputs(workloadInfo, "FloorQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "FlootQueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0] != workloadInfo.m_OutputTensorInfos[0]) { @@ -821,8 +811,8 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor"); + ValidateNumInputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32) { @@ -843,8 +833,8 @@ void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor"); + ValidateNumInputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float16) { @@ -864,8 +854,8 @@ void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "DivisionQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "DivisionQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -877,8 +867,8 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "SubtractionQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "SubtractionQueueDescriptor"); + ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -890,8 +880,8 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "MaximumQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MaximumQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -903,8 +893,8 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "MeanQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MeanQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MeanQueueDescriptor", 1); const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; @@ -929,8 +919,8 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "PadQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "PadQueueDescriptor"); + ValidateNumInputs(workloadInfo, "PadQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "PadQueueDescriptor", 1); const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; @@ -948,8 +938,8 @@ void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "QuantizeQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "QuantizeQueueDescriptor"); + ValidateNumInputs(workloadInfo, "QuantizeQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "QuantizeQueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32) @@ -966,14 +956,14 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "BatchToSpaceNdQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "BatchToSpaceNdQueueDescriptor"); + ValidateNumInputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1); } void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "StridedSliceQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "StridedSliceQueueDescriptor"); + ValidateNumInputs(workloadInfo, "StridedSliceQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1); const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; const uint32_t rank = input.GetNumDimensions(); @@ -1015,8 +1005,8 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "MinimumQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MinimumQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -1028,14 +1018,14 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "DebugQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "DebugQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DebugQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "DebugQueueDescriptor", 1); } void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "EqualQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "EqualQueueDescriptor"); + ValidateNumInputs(workloadInfo, "EqualQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "EqualQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -1052,8 +1042,8 @@ void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "GreaterQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "GreaterQueueDescriptor"); + ValidateNumInputs(workloadInfo, "GreaterQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "GreaterQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -1070,8 +1060,8 @@ void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "RsqrtQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "RsqrtQueueDescriptor"); + ValidateNumInputs(workloadInfo, "RsqrtQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "RsqrtQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "RsqrtQueueDescriptor", @@ -1081,8 +1071,8 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "GatherQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "GatherQueueDescriptor"); + ValidateNumInputs(workloadInfo, "GatherQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "GatherQueueDescriptor", 1); const TensorInfo& indices = workloadInfo.m_InputTensorInfos[1]; @@ -1102,7 +1092,7 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DetectionPostProcessQueueDescriptor", 2); if (workloadInfo.m_OutputTensorInfos.size() != 4) { @@ -1155,8 +1145,8 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "DequantizeQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "DequantizeQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DequantizeQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "DequantizeQueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedAsymm8 && workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedSymm16) @@ -1172,8 +1162,8 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MergeQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "MergeQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -1192,6 +1182,42 @@ void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output"); } +void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateNumInputs(workloadInfo, "SwitchQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "SwitchQueueDescriptor", 2); + + std::vector supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "SwitchQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "SwitchQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "SwitchQueueDescriptor"); + + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[0], + "SwitchQueueDescriptor", + "input0", + "output0"); + + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[1], + "SwitchQueueDescriptor", + "input0", + "output1"); +} + void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { // This is internally generated so it should not need validation. diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 1bf7352..1b5f86d 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -426,4 +426,9 @@ struct MergeQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; +struct SwitchQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } //namespace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 4ea3ea9..d9774b0 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -729,6 +729,19 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::Switch: + { + const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); + const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo(); + const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo(); + result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType), + OverrideDataType(input1, dataType), + OverrideDataType(output0, dataType), + OverrideDataType(output1, dataType), + reason); + break; + } case LayerType::Mean: { auto cLayer = boost::polymorphic_downcast(&layer); @@ -1041,4 +1054,10 @@ std::unique_ptr IWorkloadFactory::CreateSubtraction(const Subtraction return std::unique_ptr(); } +std::unique_ptr IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr(); +} + } diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 889bc9d..5c07b3a 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -177,6 +177,9 @@ public: virtual std::unique_ptr CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateSwitch(const SwitchQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; }; } //namespace armnn diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index 0588607..a7d7b09 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -402,6 +402,8 @@ DECLARE_LAYER_POLICY_2_PARAM(StridedSlice) DECLARE_LAYER_POLICY_1_PARAM(Subtraction) +DECLARE_LAYER_POLICY_1_PARAM(Switch) + // Generic implementation to get the number of input slots for a given layer type; template