return false;
}
+bool IsUint8(const WorkloadInfo& info)
+{
+ auto checkUint8 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::QuantisedAsymm8;};
+ auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkUint8);
+ if (it != std::end(info.m_InputTensorInfos))
+ {
+ return true;
+ }
+ it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkUint8);
+ if (it != std::end(info.m_OutputTensorInfos))
+ {
+ return true;
+ }
+ return false;
+}
+
RefWorkloadFactory::RefWorkloadFactory()
{
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefRsqrtFloat32Workload, NullWorkload>(descriptor, info);
+ if (IsFloat16(info))
+ {
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ }
+ else if(IsUint8(info))
+ {
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ }
+ return std::make_unique<RefRsqrtWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor,
workloads/RefReshapeWorkload.cpp \
workloads/RefResizeBilinearFloat32Workload.cpp \
workloads/RefResizeBilinearUint8Workload.cpp \
- workloads/RefRsqrtFloat32Workload.cpp \
+ workloads/RefRsqrtWorkload.cpp \
workloads/RefSoftmaxWorkload.cpp \
workloads/RefSpaceToBatchNdWorkload.cpp \
workloads/RefStridedSliceWorkload.cpp \
BOOST_AUTO_TEST_CASE(CreateRsqrtFloat32)
{
- RefCreateRsqrtTest<RefRsqrtFloat32Workload, armnn::DataType::Float32>();
+ RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::Float32>();
}
template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
RefResizeBilinearFloat32Workload.hpp
RefResizeBilinearUint8Workload.cpp
RefResizeBilinearUint8Workload.hpp
- RefRsqrtFloat32Workload.cpp
- RefRsqrtFloat32Workload.hpp
+ RefRsqrtWorkload.cpp
+ RefRsqrtWorkload.hpp
RefSoftmaxWorkload.cpp
RefSoftmaxWorkload.hpp
RefSpaceToBatchNdWorkload.cpp
+++ /dev/null
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefRsqrtFloat32Workload.hpp"
-
-#include "RefWorkloadUtils.hpp"
-#include "Rsqrt.hpp"
-
-#include <Profiling.hpp>
-
-namespace armnn
-{
-
-void RefRsqrtFloat32Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtFloat32Workload_Execute");
-
- Rsqrt(GetInputTensorDataFloat(0, m_Data),
- GetOutputTensorDataFloat(0, m_Data),
- GetTensorInfo(m_Data.m_Inputs[0]));
-}
-
-} //namespace armnn
--- /dev/null
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefRsqrtWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Rsqrt.hpp"
+
+#include <Profiling.hpp>
+
+namespace armnn
+{
+
+void RefRsqrtWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtWorkload_Execute");
+
+ const TensorInfo& inputTensorInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+
+ std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, m_Data.m_Inputs[0]->Map());
+ Decoder<float>& decoder = *decoderPtr;
+
+ const TensorInfo& outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputTensorInfo, m_Data.m_Outputs[0]->Map());
+ Encoder<float>& encoder = *encoderPtr;
+
+ Rsqrt(decoder,
+ encoder,
+ GetTensorInfo(m_Data.m_Inputs[0]));
+}
+
+} //namespace armnn
namespace armnn
{
-class RefRsqrtFloat32Workload : public Float32Workload<RsqrtQueueDescriptor>
+class RefRsqrtWorkload : public BaseWorkload<RsqrtQueueDescriptor>
{
public:
- using Float32Workload<RsqrtQueueDescriptor>::Float32Workload;
+ using BaseWorkload<RsqrtQueueDescriptor>::BaseWorkload;
virtual void Execute() const override;
};
#include "RefBatchToSpaceNdUint8Workload.hpp"
#include "RefBatchToSpaceNdFloat32Workload.hpp"
#include "RefDebugWorkload.hpp"
-#include "RefRsqrtFloat32Workload.hpp"
+#include "RefRsqrtWorkload.hpp"
#include "RefDequantizeWorkload.hpp"
#include "RefQuantizeWorkload.hpp"
#include "RefReshapeWorkload.hpp"
namespace armnn
{
-void Rsqrt(const float* in,
- float* out,
+void Rsqrt(Decoder<float>& in,
+ Encoder<float>& out,
const TensorInfo& tensorInfo)
{
- for (size_t i = 0; i < tensorInfo.GetNumElements(); i++)
+ for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
{
- out[i] = 1.f / sqrtf(in[i]);
+ out[i];
+ in[i];
+ out.Set(1.f / sqrtf(in.Get()));
}
}
// SPDX-License-Identifier: MIT
//
+#include "BaseIterator.hpp"
#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>
/// Performs the reciprocal squareroot function elementwise
/// on the inputs to give the outputs.
-void Rsqrt(const float* in,
- float* out,
+void Rsqrt(Decoder<float>& in,
+ Encoder<float>& out,
const TensorInfo& tensorInfo);
} //namespace armnn