IVGCVSW-2915 Add Merge Layer and no-op factory method
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>
Fri, 5 Apr 2019 12:37:19 +0000 (13:37 +0100)
committerNattapat Chaimanowong <nattapat.chaimanowong@arm.com>
Fri, 5 Apr 2019 12:37:29 +0000 (13:37 +0100)
Change-Id: I54549671e0d3b207904cf9796a843eb2b0a631f7
Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
31 files changed:
Android.mk
CMakeLists.txt
include/armnn/ILayerSupport.hpp
include/armnn/ILayerVisitor.hpp
include/armnn/INetwork.hpp
include/armnn/LayerSupport.hpp
include/armnn/LayerVisitorBase.hpp
src/armnn/InternalTypes.cpp
src/armnn/InternalTypes.hpp
src/armnn/LayerSupport.cpp
src/armnn/LayersFwd.hpp
src/armnn/Network.cpp
src/armnn/Network.hpp
src/armnn/layers/MergeLayer.cpp [new file with mode: 0644]
src/armnn/layers/MergeLayer.hpp [new file with mode: 0644]
src/armnn/test/NetworkTests.cpp
src/armnnDeserializer/Deserializer.cpp
src/armnnDeserializer/Deserializer.hpp
src/armnnDeserializer/DeserializerSupport.md
src/armnnSerializer/ArmnnSchema.fbs
src/armnnSerializer/Serializer.cpp
src/armnnSerializer/Serializer.hpp
src/armnnSerializer/SerializerSupport.md
src/armnnSerializer/test/SerializerTests.cpp
src/backends/backendsCommon/LayerSupportBase.cpp
src/backends/backendsCommon/LayerSupportBase.hpp
src/backends/backendsCommon/WorkloadData.cpp
src/backends/backendsCommon/WorkloadData.hpp
src/backends/backendsCommon/WorkloadFactory.cpp
src/backends/backendsCommon/WorkloadFactory.hpp
src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp

index 85bd214..6d5a0fa 100644 (file)
@@ -108,6 +108,7 @@ LOCAL_SRC_FILES := \
         src/armnn/layers/MaximumLayer.cpp \
         src/armnn/layers/MeanLayer.cpp \
         src/armnn/layers/MemCopyLayer.cpp \
+        src/armnn/layers/MergeLayer.cpp \
         src/armnn/layers/MergerLayer.cpp \
         src/armnn/layers/MinimumLayer.cpp \
         src/armnn/layers/MultiplicationLayer.cpp \
index ec237aa..d1fe635 100644 (file)
@@ -239,6 +239,8 @@ list(APPEND armnn_sources
     src/armnn/layers/MeanLayer.cpp
     src/armnn/layers/MemCopyLayer.hpp
     src/armnn/layers/MemCopyLayer.cpp
+    src/armnn/layers/MergeLayer.hpp
+    src/armnn/layers/MergeLayer.cpp
     src/armnn/layers/MergerLayer.hpp
     src/armnn/layers/MergerLayer.cpp
     src/armnn/layers/MinimumLayer.cpp
index fe44071..1b75810 100644 (file)
@@ -171,6 +171,11 @@ public:
                                     const TensorInfo& output,
                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
 
+    virtual bool IsMergeSupported(const TensorInfo& input0,
+                                  const TensorInfo& input1,
+                                  const TensorInfo& output,
+                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
     virtual bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
                                    const TensorInfo& output,
                                    const OriginsDescriptor& descriptor,
index e23cf5e..3a4c39b 100644 (file)
@@ -199,6 +199,12 @@ public:
                                 const MeanDescriptor& meanDescriptor,
                                 const char* name = nullptr) = 0;
 
+    /// Function that a merge layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+    /// @param layer - pointer to the layer which is calling back to this visit function.
+    /// @param name - Optional name for the layer.
+    virtual void VisitMergeLayer(const IConnectableLayer* layer,
+                                 const char* name = nullptr) = 0;
+
     /// Function that a merger layer should call back to when its Accept(ILayerVisitor&) function is invoked.
     /// @param layer - pointer to the layer which is calling back to this visit function.
     /// @param mergerDescriptor - WindowsDescriptor to configure the merging process. Number of Views must be equal to
@@ -337,4 +343,4 @@ public:
 
 
 };
-} // namespace armnn
\ No newline at end of file
+} // namespace armnn
index 5a9d4f2..8243b39 100644 (file)
@@ -235,6 +235,11 @@ public:
     virtual IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor
         , const char* name = nullptr) = 0;
 
+    /// Adds a merge layer to the network.
+    /// @param name - Optional name for the layer.
+    /// @return - Interface for configuring the layer.
+    virtual IConnectableLayer* AddMergeLayer(const char* name = nullptr) = 0;
+
     /// Adds a merger layer to the network.
     /// @param mergerDescriptor - WindowsDescriptor to configure the merging process. Number of Views must be equal to
     ///                           the number of inputs, and their order must match - e.g. first view corresponds to
index 7c6bc13..e23fdd0 100644 (file)
@@ -204,6 +204,14 @@ bool IsMemCopySupported(const BackendId& backend,
                         size_t reasonIfUnsupportedMaxLength = 1024);
 
 /// Deprecated in favor of IBackend and ILayerSupport interfaces
+bool IsMergeSupported(const BackendId& backend,
+                      const TensorInfo& input0,
+                      const TensorInfo& input1,
+                      const TensorInfo& output,
+                      char* reasonIfUnsupported = nullptr,
+                      size_t reasonIfUnsupportedMaxLength = 1024);
+
+/// Deprecated in favor of IBackend and ILayerSupport interfaces
 bool IsMergerSupported(const BackendId& backend,
                        const std::vector<const TensorInfo*> inputs,
                        const TensorInfo& output,
index a5459e1..f4e0f43 100644 (file)
@@ -87,6 +87,9 @@ public:
                             const ViewsDescriptor&,
                             const char*) override { DefaultPolicy::Apply(); }
 
+    void VisitMergeLayer(const IConnectableLayer*,
+                         const char*) override { DefaultPolicy::Apply(); }
+
     void VisitMergerLayer(const IConnectableLayer*,
                           const OriginsDescriptor&,
                           const char*) override { DefaultPolicy::Apply(); }
index fe1542b..93a4f94 100644 (file)
@@ -39,6 +39,7 @@ char const* GetLayerTypeAsCString(LayerType type)
         case LayerType::Maximum: return "Maximum";
         case LayerType::Mean: return "Mean";
         case LayerType::MemCopy: return "MemCopy";
+        case LayerType::Merge: return "Merge";
         case LayerType::Merger: return "Merger";
         case LayerType::Minimum: return "Minimum";
         case LayerType::Multiplication: return "Multiplication";
index 1972e9c..7c7c601 100644 (file)
@@ -39,6 +39,7 @@ enum class LayerType
     Maximum,
     Mean,
     MemCopy,
+    Merge,
     Merger,
     Minimum,
     Multiplication,
index 0309733..bc6eec8 100644 (file)
@@ -355,6 +355,16 @@ bool IsMemCopySupported(const BackendId &backend,
     FORWARD_LAYER_SUPPORT_FUNC(backend, IsMemCopySupported, input, output);
 }
 
+bool IsMergeSupported(const BackendId& backend,
+                      const TensorInfo& input0,
+                      const TensorInfo& input1,
+                      const TensorInfo& output,
+                      char* reasonIfUnsupported,
+                      size_t reasonIfUnsupportedMaxLength)
+{
+    FORWARD_LAYER_SUPPORT_FUNC(backend, IsMergeSupported, input0, input1, output);
+}
+
 bool IsMergerSupported(const BackendId& backend,
                        std::vector<const TensorInfo*> inputs,
                        const TensorInfo& output,
index 9d87aee..0bd68e0 100644 (file)
@@ -31,6 +31,7 @@
 #include "layers/MaximumLayer.hpp"
 #include "layers/MeanLayer.hpp"
 #include "layers/MemCopyLayer.hpp"
+#include "layers/MergeLayer.hpp"
 #include "layers/MergerLayer.hpp"
 #include "layers/MinimumLayer.hpp"
 #include "layers/MultiplicationLayer.hpp"
@@ -102,6 +103,7 @@ DECLARE_LAYER(Lstm)
 DECLARE_LAYER(Maximum)
 DECLARE_LAYER(Mean)
 DECLARE_LAYER(MemCopy)
+DECLARE_LAYER(Merge)
 DECLARE_LAYER(Merger)
 DECLARE_LAYER(Minimum)
 DECLARE_LAYER(Multiplication)
index 6dbd461..73db2e8 100644 (file)
@@ -966,6 +966,11 @@ IConnectableLayer* Network::AddGatherLayer(const char* name)
     return m_Graph->AddLayer<GatherLayer>(name);
 }
 
+IConnectableLayer* Network::AddMergeLayer(const char* name)
+{
+    return m_Graph->AddLayer<MergeLayer>(name);
+}
+
 void Network::Accept(ILayerVisitor& visitor) const
 {
     for (auto layer : GetGraph())
index 782531a..bb7b9eb 100644 (file)
@@ -174,6 +174,8 @@ public:
 
     IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
 
+    IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
+
     void Accept(ILayerVisitor& visitor) const override;
 
 private:
diff --git a/src/armnn/layers/MergeLayer.cpp b/src/armnn/layers/MergeLayer.cpp
new file mode 100644 (file)
index 0000000..1d4dc49
--- /dev/null
@@ -0,0 +1,65 @@
+//
+// Copyright Â© 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "MergeLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+namespace armnn
+{
+
+MergeLayer::MergeLayer(const char* name)
+    : Layer(2, 1, LayerType::Merge, name)
+{}
+
+std::unique_ptr<IWorkload> MergeLayer::CreateWorkload(const Graph& graph,
+                                                      const IWorkloadFactory& factory) const
+{
+    return nullptr;
+}
+
+MergeLayer* MergeLayer::Clone(Graph& graph) const
+{
+    return CloneBase<MergeLayer>(graph, GetName());
+}
+
+void MergeLayer::ValidateTensorShapesFromInputs()
+{
+    VerifyLayerConnections(2, CHECK_LOCATION());
+
+    std::vector<TensorShape> inferredShapes = InferOutputShapes({
+        GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+        GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
+    });
+
+    BOOST_ASSERT(inferredShapes.size() == 1);
+
+    ConditionalThrowIfNotEqual<LayerValidationException>(
+        "MergeLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+        GetOutputSlot(0).GetTensorInfo().GetShape(),
+        inferredShapes[0]);
+}
+
+std::vector<TensorShape> MergeLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+    BOOST_ASSERT(inputShapes.size() == 2);
+
+    ConditionalThrowIfNotEqual<LayerValidationException>(
+        "MergeLayer: TensorShapes set on inputs do not match",
+        inputShapes[0],
+        inputShapes[1]
+    );
+
+    return {inputShapes[0]};
+}
+
+void MergeLayer::Accept(ILayerVisitor& visitor) const
+{
+    visitor.VisitMergeLayer(this, GetName());
+}
+
+} // namespace armnn
diff --git a/src/armnn/layers/MergeLayer.hpp b/src/armnn/layers/MergeLayer.hpp
new file mode 100644 (file)
index 0000000..66664ca
--- /dev/null
@@ -0,0 +1,47 @@
+//
+// Copyright Â© 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "Layer.hpp"
+
+namespace armnn
+{
+
+/// This layer dequantizes the input tensor.
+class MergeLayer : public Layer
+{
+public:
+    /// Makes a workload for the Merge type.
+    /// @param [in] graph The graph where this layer can be found.
+    /// @param [in] factory The workload factory which will create the workload.
+    /// @return A pointer to the created workload, or nullptr if not created.
+    virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
+                                                      const IWorkloadFactory& factory) const override;
+
+    /// Creates a dynamically-allocated copy of this layer.
+    /// @param [in] graph The graph into which this layer is being cloned.
+    MergeLayer* Clone(Graph& graph) const override;
+
+    /// Check if the input tensor shape(s)
+    /// will lead to a valid configuration of @ref MergeLayer.
+    void ValidateTensorShapesFromInputs() override;
+
+    /// Infers the output shapes from given input shapes.
+    /// @param [in] inputShapes The input shapes layer has.
+    /// @return A vector to the inferred output shape.
+    std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
+
+    void Accept(ILayerVisitor& visitor) const override;
+
+protected:
+    /// Constructor to create a MergeLayer.
+    /// @param [in] name Optional name for the layer.
+    MergeLayer(const char* name);
+
+    /// Default destructor
+    ~MergeLayer() = default;
+};
+
+} // namespace armnn
index 4de09a2..dd8eb77 100644 (file)
@@ -417,4 +417,56 @@ BOOST_AUTO_TEST_CASE(Network_AddQuantize)
 
 }
 
+BOOST_AUTO_TEST_CASE(Network_AddMerge)
+{
+    struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+    {
+        void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override
+        {
+            m_Visited = true;
+
+            BOOST_TEST(layer);
+
+            std::string expectedName = std::string("merge");
+            BOOST_TEST(std::string(layer->GetName()) == expectedName);
+            BOOST_TEST(std::string(name) == expectedName);
+
+            BOOST_TEST(layer->GetNumInputSlots() == 2);
+            BOOST_TEST(layer->GetNumOutputSlots() == 1);
+
+            const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+            BOOST_TEST((infoIn0.GetDataType() == armnn::DataType::Float32));
+
+            const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo();
+            BOOST_TEST((infoIn1.GetDataType() == armnn::DataType::Float32));
+
+            const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
+            BOOST_TEST((infoOut.GetDataType() == armnn::DataType::Float32));
+        }
+
+        bool m_Visited = false;
+    };
+
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+
+    armnn::IConnectableLayer* input0 = network->AddInputLayer(0);
+    armnn::IConnectableLayer* input1 = network->AddInputLayer(1);
+    armnn::IConnectableLayer* merge = network->AddMergeLayer("merge");
+    armnn::IConnectableLayer* output = network->AddOutputLayer(0);
+
+    input0->GetOutputSlot(0).Connect(merge->GetInputSlot(0));
+    input1->GetOutputSlot(0).Connect(merge->GetInputSlot(1));
+    merge->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+    const armnn::TensorInfo info({3,1}, armnn::DataType::Float32);
+    input0->GetOutputSlot(0).SetTensorInfo(info);
+    input1->GetOutputSlot(0).SetTensorInfo(info);
+    merge->GetOutputSlot(0).SetTensorInfo(info);
+
+    Test testMerge;
+    network->Accept(testMerge);
+
+    BOOST_TEST(testMerge.m_Visited == true);
+}
+
 BOOST_AUTO_TEST_SUITE_END()
index 943c6a7..09cdd7c 100644 (file)
@@ -206,6 +206,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer)
     m_ParserFunctions[Layer_MaximumLayer]                = &Deserializer::ParseMaximum;
     m_ParserFunctions[Layer_MeanLayer]                   = &Deserializer::ParseMean;
     m_ParserFunctions[Layer_MinimumLayer]                = &Deserializer::ParseMinimum;
+    m_ParserFunctions[Layer_MergeLayer]                  = &Deserializer::ParseMerge;
     m_ParserFunctions[Layer_MergerLayer]                 = &Deserializer::ParseMerger;
     m_ParserFunctions[Layer_MultiplicationLayer]         = &Deserializer::ParseMultiplication;
     m_ParserFunctions[Layer_NormalizationLayer]          = &Deserializer::ParseNormalization;
@@ -271,6 +272,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt
             return graphPtr->layers()->Get(layerIndex)->layer_as_MinimumLayer()->base();
         case Layer::Layer_MaximumLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_MaximumLayer()->base();
+        case Layer::Layer_MergeLayer:
+            return graphPtr->layers()->Get(layerIndex)->layer_as_MergeLayer()->base();
         case Layer::Layer_MergerLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_MergerLayer()->base();
         case Layer::Layer_MultiplicationLayer:
@@ -2085,4 +2088,24 @@ void Deserializer::ParseDequantize(GraphPtr graph, unsigned int layerIndex)
     RegisterOutputSlots(graph, layerIndex, layer);
 }
 
+void Deserializer::ParseMerge(GraphPtr graph, unsigned int layerIndex)
+{
+    CHECK_LAYERS(graph, 0, layerIndex);
+
+    Deserializer::TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
+    CHECK_VALID_SIZE(inputs.size(), 2);
+
+    Deserializer::TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
+    CHECK_VALID_SIZE(outputs.size(), 1);
+
+    const std::string layerName = GetLayerName(graph, layerIndex);
+    IConnectableLayer* layer = m_Network->AddMergeLayer(layerName.c_str());
+
+    armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+    layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+    RegisterInputSlots(graph, layerIndex, layer);
+    RegisterOutputSlots(graph, layerIndex, layer);
+}
+
 } // namespace armnnDeserializer
index f18c163..df983d9 100644 (file)
@@ -97,6 +97,7 @@ private:
     void ParseMaximum(GraphPtr graph, unsigned int layerIndex);
     void ParseMean(GraphPtr graph, unsigned int layerIndex);
     void ParseMinimum(GraphPtr graph, unsigned int layerIndex);
+    void ParseMerge(GraphPtr graph, unsigned int layerIndex);
     void ParseMerger(GraphPtr graph, unsigned int layerIndex);
     void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
     void ParseNormalization(GraphPtr graph, unsigned int layerIndex);
index 77856cf..4e5610c 100644 (file)
@@ -25,6 +25,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
 * Lstm
 * Maximum
 * Mean
+* Merge
 * Merger
 * Minimum
 * Multiplication
index 3aa644d..8b275b6 100644 (file)
@@ -118,7 +118,8 @@ enum LayerType : uint {
     DetectionPostProcess = 33,
     Lstm = 34,
     Quantize = 35,
-    Dequantize = 36
+    Dequantize = 36,
+    Merge = 37
 }
 
 // Base layer table to be used as part of other layers
@@ -524,6 +525,10 @@ table DequantizeLayer {
     base:LayerBase;
 }
 
+table MergeLayer {
+    base:LayerBase;
+}
+
 union Layer {
     ActivationLayer,
     AdditionLayer,
@@ -561,7 +566,8 @@ union Layer {
     DetectionPostProcessLayer,
     LstmLayer,
     QuantizeLayer,
-    DequantizeLayer
+    DequantizeLayer,
+    MergeLayer
 }
 
 table AnyLayer {
index 7181f01..fe30c3e 100644 (file)
@@ -500,6 +500,14 @@ void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer,
     CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer);
 }
 
+void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name)
+{
+    auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge);
+    auto fbMergeLayer     = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer);
+
+    CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer);
+}
+
 void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
                                          const armnn::OriginsDescriptor& mergerDescriptor,
                                          const char* name)
index 5c3e48a..775df83 100644 (file)
@@ -129,6 +129,9 @@ public:
     void VisitMaximumLayer(const armnn::IConnectableLayer* layer,
                            const char* name = nullptr) override;
 
+    void VisitMergeLayer(const armnn::IConnectableLayer* layer,
+                         const char* name = nullptr) override;
+
     void VisitMergerLayer(const armnn::IConnectableLayer* layer,
                           const armnn::OriginsDescriptor& mergerDescriptor,
                           const char* name = nullptr) override;
index a3c5852..a8335e1 100644 (file)
@@ -25,6 +25,7 @@ The Arm NN SDK Serializer currently supports the following layers:
 * Lstm
 * Maximum
 * Mean
+* Merge
 * Merger
 * Minimum
 * Multiplication
index 0979076..a1ef9ee 100644 (file)
@@ -1185,6 +1185,46 @@ BOOST_AUTO_TEST_CASE(SerializeMean)
     deserializedNetwork->Accept(verifier);
 }
 
+BOOST_AUTO_TEST_CASE(SerializeMerge)
+{
+    class MergeLayerVerifier : public LayerVerifierBase
+    {
+    public:
+        MergeLayerVerifier(const std::string& layerName,
+                           const std::vector<armnn::TensorInfo>& inputInfos,
+                           const std::vector<armnn::TensorInfo>& outputInfos)
+        : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+
+        void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override
+        {
+            VerifyNameAndConnections(layer, name);
+        }
+    };
+
+    const std::string layerName("merge");
+    const armnn::TensorInfo info({ 1, 2, 2, 3 }, armnn::DataType::Float32);
+
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+    armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
+    armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
+    armnn::IConnectableLayer* const mergeLayer = network->AddMergeLayer(layerName.c_str());
+    armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+    inputLayer0->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(0));
+    inputLayer1->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(1));
+    mergeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+    inputLayer0->GetOutputSlot(0).SetTensorInfo(info);
+    inputLayer1->GetOutputSlot(0).SetTensorInfo(info);
+    mergeLayer->GetOutputSlot(0).SetTensorInfo(info);
+
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    BOOST_CHECK(deserializedNetwork);
+
+    MergeLayerVerifier verifier(layerName, {info, info}, {info});
+    deserializedNetwork->Accept(verifier);
+}
+
 BOOST_AUTO_TEST_CASE(SerializeMerger)
 {
     class MergerLayerVerifier : public LayerVerifierBase
index 04f822c..fc2d502 100644 (file)
@@ -253,6 +253,14 @@ bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo& input,
     return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
 }
 
+bool LayerSupportBase::IsMergeSupported(const TensorInfo& input0,
+                                        const TensorInfo& input1,
+                                        const TensorInfo& output,
+                                        Optional<std::string&> reasonIfUnsupported) const
+{
+    return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
 bool LayerSupportBase::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
                                          const TensorInfo& output,
                                          const OriginsDescriptor& descriptor,
index 7d64095..7c38b67 100644 (file)
@@ -160,6 +160,11 @@ public:
                             const TensorInfo& output,
                             Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 
+    bool IsMergeSupported(const TensorInfo& input0,
+                          const TensorInfo& input1,
+                          const TensorInfo& output,
+                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
     bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
                            const TensorInfo& output,
                            const OriginsDescriptor& descriptor,
index 91b1c57..348c864 100644 (file)
@@ -1170,6 +1170,28 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
     }
 }
 
+void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+    ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor");
+    ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor");
+
+    ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+                              workloadInfo.m_InputTensorInfos[1],
+                              "MergeQueueDescriptor",
+                              "input0",
+                              "input1");
+
+    ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+                              workloadInfo.m_OutputTensorInfos[0],
+                              "MergeQueueDescriptor",
+                              "input0",
+                              "output");
+
+    const DataType dataType = workloadInfo.m_InputTensorInfos[0].GetDataType();
+    ValidateTensorDataType(workloadInfo.m_InputTensorInfos[1], dataType, "MergeQueueDescriptor", "input1");
+    ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output");
+}
+
 void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
     // This is internally generated so it should not need validation.
index 5640701..1bf7352 100644 (file)
@@ -421,4 +421,9 @@ struct DequantizeQueueDescriptor : QueueDescriptor
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
+struct MergeQueueDescriptor : QueueDescriptor
+{
+    void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
 } //namespace armnn
index 6534a00..4ea3ea9 100644 (file)
@@ -519,6 +519,18 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
                                                             reason);
             break;
         }
+        case LayerType::Merge:
+        {
+            const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+            const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+            const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+
+            result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
+                                                          OverrideDataType(input1, dataType),
+                                                          OverrideDataType(output, dataType),
+                                                          reason);
+            break;
+        }
         case LayerType::Merger:
         {
             auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
@@ -915,6 +927,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDes
     return std::unique_ptr<IWorkload>();
 }
 
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
+                                                         const WorkloadInfo& info) const
+{
+    return std::unique_ptr<IWorkload>();
+}
+
 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
                                                           const WorkloadInfo& info) const
 {
index ed7303c..889bc9d 100644 (file)
@@ -121,6 +121,9 @@ public:
     virtual std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
                                                      const WorkloadInfo& info) const;
 
+    virtual std::unique_ptr<IWorkload> CreateMerge(const MergeQueueDescriptor& descriptor,
+                                                    const WorkloadInfo& info) const;
+
     virtual std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor,
                                                     const WorkloadInfo&          info) const;
 
index 26fb03f..0588607 100644 (file)
@@ -362,6 +362,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Maximum)
 
 DECLARE_LAYER_POLICY_2_PARAM(Mean)
 
+DECLARE_LAYER_POLICY_1_PARAM(Merge)
+
 DECLARE_LAYER_POLICY_2_PARAM(Merger)
 
 DECLARE_LAYER_POLICY_1_PARAM(Minimum)