From: narpra01 Date: Wed, 23 Jan 2019 15:23:11 +0000 (+0000) Subject: IVGCVSW-2511 Add end to end Gather layer test X-Git-Tag: submit/tizen/20200316.035456~935 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=db2b160bf9e7759d0157dfa57ee940290f5170e3;p=platform%2Fupstream%2Farmnn.git IVGCVSW-2511 Add end to end Gather layer test * Add end to end test for Gather operator * Add Support for int32 to Constant layer for Ref * Add Int32Workload * Add RefConstantWorkload as template for float, uint8, int32 * Remove unused RefBaseConstantWorkload * Remove unused RefConstantFloat32Workload * Remove unused RefConstantUint8Workload * Add support check for int32 in LayerSupport functions Change-Id: Ic970588a49ebe2aafb12be8adef52371feacaa7b --- diff --git a/src/armnn/LayerSupportCommon.hpp b/src/armnn/LayerSupportCommon.hpp index c309f8c..109728c 100644 --- a/src/armnn/LayerSupportCommon.hpp +++ b/src/armnn/LayerSupportCommon.hpp @@ -12,12 +12,13 @@ namespace armnn { -template +template bool IsSupportedForDataTypeGeneric(Optional reasonIfUnsupported, DataType dataType, Float16Func float16FuncPtr, Float32Func float32FuncPtr, Uint8Func uint8FuncPtr, + Int32Func int32FuncPtr, Params&&... params) { switch(dataType) @@ -28,6 +29,8 @@ bool IsSupportedForDataTypeGeneric(Optional reasonIfUnsupported, return float32FuncPtr(reasonIfUnsupported, std::forward(params)...); case DataType::QuantisedAsymm8: return uint8FuncPtr(reasonIfUnsupported, std::forward(params)...); + case DataType::Signed32: + return int32FuncPtr(reasonIfUnsupported, std::forward(params)...); default: return false; } @@ -76,6 +79,16 @@ bool FalseFuncU8(Optional reasonIfUnsupported, Params&&... params) } template +bool FalseFuncI32(Optional reasonIfUnsupported, Params&&... params) +{ + if (reasonIfUnsupported) + { + reasonIfUnsupported.value() = "Layer is not supported with int32 data type"; + } + return false; +} + +template bool FalseInputFuncF32(Optional reasonIfUnsupported, Params&&... params) { if (reasonIfUnsupported) diff --git a/src/backends/backendsCommon/MakeWorkloadHelper.hpp b/src/backends/backendsCommon/MakeWorkloadHelper.hpp index 78a9669..7784cc6 100644 --- a/src/backends/backendsCommon/MakeWorkloadHelper.hpp +++ b/src/backends/backendsCommon/MakeWorkloadHelper.hpp @@ -37,8 +37,8 @@ struct MakeWorkloadForType // Makes a workload for one the specified types based on the data type requirements of the tensorinfo. // Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos. -template +template std::unique_ptr MakeWorkloadHelper(const QueueDescriptorType& descriptor, const WorkloadInfo& info, Args&&... args) @@ -58,6 +58,8 @@ std::unique_ptr MakeWorkloadHelper(const QueueDescriptorType& descrip return MakeWorkloadForType::Func(descriptor, info, std::forward(args)...); case DataType::QuantisedAsymm8: return MakeWorkloadForType::Func(descriptor, info, std::forward(args)...); + case DataType::Signed32: + return MakeWorkloadForType::Func(descriptor, info, std::forward(args)...); default: BOOST_ASSERT_MSG(false, "Unknown DataType."); return nullptr; @@ -73,10 +75,9 @@ std::unique_ptr MakeWorkloadHelper(const QueueDescriptorType& descrip const WorkloadInfo& info, Args&&... args) { - return MakeWorkloadHelper(descriptor, info, + return MakeWorkloadHelper(descriptor, info, std::forward(args)...); } - } //namespace } //namespace armnn diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp index 6539219..34d1363 100644 --- a/src/backends/backendsCommon/Workload.hpp +++ b/src/backends/backendsCommon/Workload.hpp @@ -162,6 +162,9 @@ template using Uint8Workload = TypedWorkload; template +using Int32Workload = TypedWorkload; + +template using Float16ToFloat32Workload = MultiTypedWorkload; diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt index 8107176..80a9cfe 100644 --- a/src/backends/backendsCommon/test/CMakeLists.txt +++ b/src/backends/backendsCommon/test/CMakeLists.txt @@ -16,6 +16,8 @@ list(APPEND armnnBackendsCommonUnitTests_sources DebugTestImpl.hpp EndToEndTestImpl.hpp FullyConnectedTestImpl.hpp + GatherTestImpl.hpp + GatherEndToEndTestImpl.hpp IsLayerSupportedTestImpl.hpp JsonPrinterTestImpl.cpp JsonPrinterTestImpl.hpp diff --git a/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp new file mode 100644 index 0000000..d30da54 --- /dev/null +++ b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp @@ -0,0 +1,124 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include +#include + +namespace{ + +armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo, + const armnn::TensorInfo& indicesInfo, + const armnn::TensorInfo& outputInfo, + const std::vector& indicesData) +{ + armnn::INetworkPtr net(armnn::INetwork::Create()); + + armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0); + armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData)); + armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer("gather"); + armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output"); + Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0); + Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1); + Connect(gatherLayer, outputLayer, outputInfo, 0, 0); + + return net; +} + +template> +void GatherEndToEnd(const std::vector& backends) +{ + armnn::TensorInfo paramsInfo({ 8 }, ArmnnType); + armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); + armnn::TensorInfo outputInfo({ 3 }, ArmnnType); + + paramsInfo.SetQuantizationScale(1.0f); + paramsInfo.SetQuantizationOffset(0); + outputInfo.SetQuantizationScale(1.0f); + outputInfo.SetQuantizationOffset(0); + + // Creates structures for input & output. + std::vector paramsData{ + 1, 2, 3, 4, 5, 6, 7, 8 + }; + + std::vector indicesData{ + 7, 6, 5 + }; + + std::vector expectedOutput{ + 8, 7, 6 + }; + + // Builds up the structure of the network + armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData); + + BOOST_TEST_CHECKPOINT("create a network"); + + std::map> inputTensorData = {{ 0, paramsData }}; + std::map> expectedOutputData = {{ 0, expectedOutput }}; + + EndToEndLayerTestImpl(move(net), inputTensorData, expectedOutputData, backends); +} + +template> +void GatherMultiDimEndToEnd(const std::vector& backends) +{ + armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType); + armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32); + armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType); + + paramsInfo.SetQuantizationScale(1.0f); + paramsInfo.SetQuantizationOffset(0); + outputInfo.SetQuantizationScale(1.0f); + outputInfo.SetQuantizationOffset(0); + + // Creates structures for input & output. + std::vector paramsData{ + 1, 2, 3, + 4, 5, 6, + + 7, 8, 9, + 10, 11, 12, + + 13, 14, 15, + 16, 17, 18 + }; + + std::vector indicesData{ + 1, 2, 1, + 2, 1, 0 + }; + + std::vector expectedOutput{ + 7, 8, 9, + 10, 11, 12, + 13, 14, 15, + 16, 17, 18, + 7, 8, 9, + 10, 11, 12, + + 13, 14, 15, + 16, 17, 18, + 7, 8, 9, + 10, 11, 12, + 1, 2, 3, + 4, 5, 6 + }; + + // Builds up the structure of the network + armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData); + + BOOST_TEST_CHECKPOINT("create a network"); + + std::map> inputTensorData = {{ 0, paramsData }}; + std::map> expectedOutputData = {{ 0, expectedOutput }}; + + EndToEndLayerTestImpl(move(net), inputTensorData, expectedOutputData, backends); +} + +} // anonymous namespace \ No newline at end of file diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index cb03e8b..3e35f9d 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -121,6 +121,7 @@ bool IsSupportedForDataTypeCl(Optional reasonIfUnsupported, floatFuncPtr, floatFuncPtr, uint8FuncPtr, + &FalseFunc<>, std::forward(params)...); } @@ -265,7 +266,8 @@ bool ClLayerSupport::IsFloorSupported(const TensorInfo& input, input.GetDataType(), &FalseFuncF16<>, &TrueFunc<>, - &FalseFuncU8<>); + &FalseFuncU8<>, + &FalseFuncI32<>); } bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index 76cdf14..2f83c8f 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -71,6 +71,7 @@ bool IsSupportedForDataTypeNeon(Optional reasonIfUnsupported, floatFuncPtr, floatFuncPtr, uint8FuncPtr, + &FalseFunc<>, std::forward(params)...); } @@ -212,7 +213,8 @@ bool NeonLayerSupport::IsFloorSupported(const TensorInfo& input, input.GetDataType(), &FalseFuncF16<>, &TrueFunc<>, - &FalseFuncU8<>); + &FalseFuncU8<>, + &FalseFuncI32<>); } bool NeonLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 25c2baf..45f108c 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -34,6 +34,7 @@ bool IsSupportedForDataTypeRef(Optional reasonIfUnsupported, &FalseFunc, floatFuncPtr, uint8FuncPtr, + &FalseFunc, std::forward(params)...); } @@ -105,10 +106,12 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, bool RefLayerSupport::IsConstantSupported(const TensorInfo& output, Optional reasonIfUnsupported) const { - return IsSupportedForDataTypeRef(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &FalseFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, @@ -119,12 +122,14 @@ bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, input.GetDataType(), &TrueFunc<>, &FalseInputFuncF32<>, - &FalseFuncU8<>) && + &FalseFuncU8<>, + &FalseFuncI32<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &FalseOutputFuncF16<>, &TrueFunc<>, - &FalseFuncU8<>)); + &FalseFuncU8<>, + &FalseFuncI32<>)); } bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, @@ -135,12 +140,14 @@ bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, input.GetDataType(), &FalseInputFuncF16<>, &TrueFunc<>, - &FalseFuncU8<>) && + &FalseFuncU8<>, + &FalseFuncI32<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &TrueFunc<>, &FalseOutputFuncF32<>, - &FalseFuncU8<>)); + &FalseFuncU8<>, + &FalseFuncI32<>)); } bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 9bdda9d..b112e9d 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -24,7 +24,7 @@ template RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const { - return armnn::MakeWorkloadHelper(descriptor, info); + return armnn::MakeWorkloadHelper(descriptor, info); } RefWorkloadFactory::RefWorkloadFactory() @@ -126,8 +126,8 @@ std::unique_ptr RefWorkloadFactory::CreateFullyConnected( std::unique_ptr RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkloadHelper - (descriptor, info); + return MakeWorkloadHelper(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, @@ -205,7 +205,8 @@ std::unique_ptr RefWorkloadFactory::CreateL2Normalization(const L2Nor std::unique_ptr RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return MakeWorkloadHelper(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 8dd6a51..763f26e 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -24,13 +24,11 @@ BACKEND_SOURCES := \ workloads/Pooling2d.cpp \ workloads/RefActivationFloat32Workload.cpp \ workloads/RefActivationUint8Workload.cpp \ - workloads/RefBaseConstantWorkload.cpp \ workloads/RefBatchNormalizationFloat32Workload.cpp \ workloads/RefBatchNormalizationUint8Workload.cpp \ workloads/RefBatchToSpaceNdFloat32Workload.cpp \ workloads/RefBatchToSpaceNdUint8Workload.cpp \ - workloads/RefConstantFloat32Workload.cpp \ - workloads/RefConstantUint8Workload.cpp \ + workloads/RefConstantWorkload.cpp \ workloads/RefConvertFp16ToFp32Workload.cpp \ workloads/RefConvertFp32ToFp16Workload.cpp \ workloads/RefConvolution2dFloat32Workload.cpp \ diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 4f4a161..330f406 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -4,6 +4,7 @@ // #include +#include #include #include @@ -416,4 +417,24 @@ BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim3Uint8Test) MergerDim3EndToEnd(defaultBackends); } +BOOST_AUTO_TEST_CASE(RefGatherFloatTest) +{ + GatherEndToEnd(defaultBackends); +} + +BOOST_AUTO_TEST_CASE(RefGatherUint8Test) +{ + GatherEndToEnd(defaultBackends); +} + +BOOST_AUTO_TEST_CASE(RefGatherMultiDimFloatTest) +{ + GatherMultiDimEndToEnd(defaultBackends); +} + +BOOST_AUTO_TEST_CASE(RefGatherMultiDimUint8Test) +{ + GatherMultiDimEndToEnd(defaultBackends); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 583c89a..f95fda0 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -32,8 +32,6 @@ list(APPEND armnnRefBackendWorkloads_sources RefActivationFloat32Workload.hpp RefActivationUint8Workload.cpp RefActivationUint8Workload.hpp - RefBaseConstantWorkload.cpp - RefBaseConstantWorkload.hpp RefBatchNormalizationFloat32Workload.cpp RefBatchNormalizationFloat32Workload.hpp RefBatchNormalizationUint8Workload.cpp @@ -42,10 +40,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefBatchToSpaceNdFloat32Workload.hpp RefBatchToSpaceNdUint8Workload.cpp RefBatchToSpaceNdUint8Workload.hpp - RefConstantFloat32Workload.cpp - RefConstantFloat32Workload.hpp - RefConstantUint8Workload.cpp - RefConstantUint8Workload.hpp + RefConstantWorkload.cpp + RefConstantWorkload.hpp RefConvertFp16ToFp32Workload.cpp RefConvertFp16ToFp32Workload.hpp RefConvertFp32ToFp16Workload.cpp diff --git a/src/backends/reference/workloads/RefConstantFloat32Workload.cpp b/src/backends/reference/workloads/RefConstantFloat32Workload.cpp deleted file mode 100644 index 074e8cc..0000000 --- a/src/backends/reference/workloads/RefConstantFloat32Workload.cpp +++ /dev/null @@ -1,19 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefConstantFloat32Workload.hpp" - -#include "Profiling.hpp" - -namespace armnn -{ - -void RefConstantFloat32Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantFloat32Workload_Execute"); - RefBaseConstantWorkload::Execute(); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefConstantFloat32Workload.hpp b/src/backends/reference/workloads/RefConstantFloat32Workload.hpp deleted file mode 100644 index 76e3a42..0000000 --- a/src/backends/reference/workloads/RefConstantFloat32Workload.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "RefBaseConstantWorkload.hpp" - -namespace armnn -{ - -class RefConstantFloat32Workload : public RefBaseConstantWorkload -{ -public: - using RefBaseConstantWorkload::RefBaseConstantWorkload; - virtual void Execute() const override; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefConstantUint8Workload.cpp b/src/backends/reference/workloads/RefConstantUint8Workload.cpp deleted file mode 100644 index 07e4719..0000000 --- a/src/backends/reference/workloads/RefConstantUint8Workload.cpp +++ /dev/null @@ -1,19 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefConstantUint8Workload.hpp" - -#include "Profiling.hpp" - -namespace armnn -{ - -void RefConstantUint8Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantUint8Workload_Execute"); - RefBaseConstantWorkload::Execute(); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefConstantUint8Workload.hpp b/src/backends/reference/workloads/RefConstantUint8Workload.hpp deleted file mode 100644 index 02552ac..0000000 --- a/src/backends/reference/workloads/RefConstantUint8Workload.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "RefBaseConstantWorkload.hpp" - -namespace armnn -{ - -class RefConstantUint8Workload : public RefBaseConstantWorkload -{ -public: - using RefBaseConstantWorkload::RefBaseConstantWorkload; - virtual void Execute() const override; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBaseConstantWorkload.cpp b/src/backends/reference/workloads/RefConstantWorkload.cpp similarity index 81% rename from src/backends/reference/workloads/RefBaseConstantWorkload.cpp rename to src/backends/reference/workloads/RefConstantWorkload.cpp index 647677b..e074c6f 100644 --- a/src/backends/reference/workloads/RefBaseConstantWorkload.cpp +++ b/src/backends/reference/workloads/RefConstantWorkload.cpp @@ -1,9 +1,9 @@ -// +// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // -#include "RefBaseConstantWorkload.hpp" +#include "RefConstantWorkload.hpp" #include "RefWorkloadUtils.hpp" @@ -17,7 +17,7 @@ namespace armnn { template -void RefBaseConstantWorkload::Execute() const +void RefConstantWorkload::Execute() const { // Considering the reference backend independently, it could be possible to initialise the intermediate tensor // created by the layer output handler at workload construction time, rather than at workload execution time. @@ -27,6 +27,8 @@ void RefBaseConstantWorkload::Execute() const // could have a non-owning reference to the layer output tensor managed by the const input layer); again, this is // not an option for other backends, and the extra complexity required to make this work for the reference backend // may not be worth the effort (skipping a memory copy in the first inference). + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantWorkload_Execute"); + if (!m_RanOnce) { const ConstantQueueDescriptor& data = this->m_Data; @@ -43,7 +45,8 @@ void RefBaseConstantWorkload::Execute() const } } -template class RefBaseConstantWorkload; -template class RefBaseConstantWorkload; +template class RefConstantWorkload; +template class RefConstantWorkload; +template class RefConstantWorkload; } //namespace armnn diff --git a/src/backends/reference/workloads/RefBaseConstantWorkload.hpp b/src/backends/reference/workloads/RefConstantWorkload.hpp similarity index 51% rename from src/backends/reference/workloads/RefBaseConstantWorkload.hpp rename to src/backends/reference/workloads/RefConstantWorkload.hpp index 82ee11f..75d7ecc 100644 --- a/src/backends/reference/workloads/RefBaseConstantWorkload.hpp +++ b/src/backends/reference/workloads/RefConstantWorkload.hpp @@ -1,4 +1,4 @@ -// +// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // @@ -15,19 +15,26 @@ namespace armnn // Base class template providing an implementation of the Constant layer common to all data types. template -class RefBaseConstantWorkload : public TypedWorkload +class RefConstantWorkload : public TypedWorkload { public: - RefBaseConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info) + RefConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info) : TypedWorkload(descriptor, info) , m_RanOnce(false) { } + using TypedWorkload::m_Data; + using TypedWorkload::TypedWorkload; + virtual void Execute() const override; private: mutable bool m_RanOnce; }; +using RefConstantFloat32Workload = RefConstantWorkload; +using RefConstantUint8Workload = RefConstantWorkload; +using RefConstantInt32Workload = RefConstantWorkload; + } //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 8550ee5..1cbceb3 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -5,11 +5,10 @@ #pragma once -#include "RefConstantUint8Workload.hpp" #include "ElementwiseFunction.hpp" #include "RefElementwiseWorkload.hpp" #include "ConvImpl.hpp" -#include "RefBaseConstantWorkload.hpp" +#include "RefConstantWorkload.hpp" #include "RefConvolution2dUint8Workload.hpp" #include "RefSplitterUint8Workload.hpp" #include "RefResizeBilinearUint8Workload.hpp" @@ -46,7 +45,6 @@ #include "RefSpaceToBatchNdWorkload.hpp" #include "RefSplitterFloat32Workload.hpp" #include "RefStridedSliceWorkload.hpp" -#include "RefConstantFloat32Workload.hpp" #include "RefActivationFloat32Workload.hpp" #include "RefConvolution2dFloat32Workload.hpp" #include "Pooling2d.hpp"