return Debug1dTest<armnn::DataType::Float32>(workloadFactory, memoryManager);
}
+LayerTestResult<armnn::BFloat16, 4> Debug4dBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Debug4dTest<armnn::DataType::BFloat16>(workloadFactory, memoryManager);
+}
+
+LayerTestResult<armnn::BFloat16, 3> Debug3dBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Debug3dTest<armnn::DataType::BFloat16>(workloadFactory, memoryManager);
+}
+
+LayerTestResult<armnn::BFloat16, 2> Debug2dBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Debug2dTest<armnn::DataType::BFloat16>(workloadFactory, memoryManager);
+}
+
+LayerTestResult<armnn::BFloat16, 1> Debug1dBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Debug1dTest<armnn::DataType::BFloat16>(workloadFactory, memoryManager);
+}
+
LayerTestResult<uint8_t, 4> Debug4dUint8Test(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
#include "LayerTestResult.hpp"
+#include <BFloat16.hpp>
+
#include <armnn/backends/IBackendInternal.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+LayerTestResult<armnn::BFloat16, 4> Debug4dBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 3> Debug3dBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 2> Debug2dBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 1> Debug1dBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
LayerTestResult<uint8_t, 4> Debug4dUint8Test(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
{
bool supported = true;
- std::array<DataType, 7> supportedTypes =
+ std::array<DataType, 8> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
+ if (IsBFloat16(info))
+ {
+ return std::make_unique<RefDebugBFloat16Workload>(descriptor, info);
+ }
if (IsFloat16(info))
{
return std::make_unique<RefDebugFloat16Workload>(descriptor, info);
ARMNN_AUTO_TEST_CASE(Debug2dFloat32, Debug2dFloat32Test)
ARMNN_AUTO_TEST_CASE(Debug1dFloat32, Debug1dFloat32Test)
+ARMNN_AUTO_TEST_CASE(Debug4dBFloat16, Debug4dBFloat16Test)
+ARMNN_AUTO_TEST_CASE(Debug3dBFloat16, Debug3dBFloat16Test)
+ARMNN_AUTO_TEST_CASE(Debug2dBFloat16, Debug2dBFloat16Test)
+ARMNN_AUTO_TEST_CASE(Debug1dBFloat16, Debug1dBFloat16Test)
+
ARMNN_AUTO_TEST_CASE(Debug4dUint8, Debug4dUint8Test)
ARMNN_AUTO_TEST_CASE(Debug3dUint8, Debug3dUint8Test)
ARMNN_AUTO_TEST_CASE(Debug2dUint8, Debug2dUint8Test)
#include "Debug.hpp"
+#include <BFloat16.hpp>
#include <Half.hpp>
#include <boost/numeric/conversion/cast.hpp>
std::cout << " }" << std::endl;
}
+template void Debug<BFloat16>(const TensorInfo& inputInfo,
+ const BFloat16* inputData,
+ LayerGuid guid,
+ const std::string& layerName,
+ unsigned int slotIndex);
+
template void Debug<Half>(const TensorInfo& inputInfo,
const Half* inputData,
LayerGuid guid,
m_Callback = func;
}
+template class RefDebugWorkload<DataType::BFloat16>;
template class RefDebugWorkload<DataType::Float16>;
template class RefDebugWorkload<DataType::Float32>;
template class RefDebugWorkload<DataType::QAsymmU8>;
DebugCallbackFunction m_Callback;
};
+using RefDebugBFloat16Workload = RefDebugWorkload<DataType::BFloat16>;
using RefDebugFloat16Workload = RefDebugWorkload<DataType::Float16>;
using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>;
using RefDebugQAsymmU8Workload = RefDebugWorkload<DataType::QAsymmU8>;