IVGCVSW-4597 Modify BF16 optimizer to Convert only inputs and weights of
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Thu, 26 Mar 2020 09:20:43 +0000 (09:20 +0000)
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Thu, 26 Mar 2020 16:16:55 +0000 (16:16 +0000)
Convolution2d and FullyConnected layers

 * Add InsertConvertFp32ToBf16LayersBefore
 * Add ConvertWeight to ConvertFp32NetworkToBf16Impl for Conv2d and FullyConnected
 * Allow different input and output when input is BF16 and output is FP32
Conv2d and FullyConnected layers
 * Unit tests

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ic8f92ff28edcae08a72a3114a28f50c4619f919b

src/armnn/Network.cpp
src/armnn/NetworkUtils.cpp
src/armnn/NetworkUtils.hpp
src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp
src/armnn/test/optimizations/Fp32NetworkToBf16ConverterTests.cpp
src/backends/backendsCommon/WorkloadData.cpp
src/backends/reference/RefLayerSupport.cpp

index 5f77197..0272b3d 100644 (file)
@@ -1020,10 +1020,11 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
     }
 
     // If Fp32 to Bf16 optimization is set convert Fp32 network to Bf16
+    // Convert input of Convolution2d and FullyConnected from Fp32 to Bf16
+    // Only Constant weight of Convolution2d and FullyConnected are converted from Fp32 to Bf16
     if (options.m_ReduceFp32ToBf16)
     {
         Optimizer::Pass(optGraph, MakeOptimizations(Fp32NetworkToBf16Converter()));
-        Optimizer::Pass(optGraph, MakeOptimizations(ConvertConstantsFloatToBFloat()));
     }
 
     // Initialize backend settings
index 8653a08..0549a11 100644 (file)
@@ -87,6 +87,45 @@ std::vector<ConvertBf16ToFp32Layer*> InsertConvertBf16ToFp32LayersBefore(Graph&
     return convertLayers;
 }
 
+std::vector<ConvertFp32ToBf16Layer*> InsertConvertFp32ToBf16LayersBefore(Graph& graph,
+                                                                         Layer& layer,
+                                                                         bool expectCorrectInputType)
+{
+    std::vector<ConvertFp32ToBf16Layer*> convertLayers;
+    convertLayers.reserve(layer.GetNumInputSlots());
+
+    // Insert a ConvertFp32ToBf16Layer before each input slot
+    for (auto&& inputSlot = layer.BeginInputSlots(); inputSlot != layer.EndInputSlots(); ++inputSlot)
+    {
+        bool allowInsert = true;
+        if (expectCorrectInputType)
+        {
+            // Only insert ConvertFp32ToBf16Layer before FP32 input slots
+            OutputSlot* connectedOutputSlot = inputSlot->GetConnectedOutputSlot();
+            allowInsert =
+                connectedOutputSlot && connectedOutputSlot->GetTensorInfo().GetDataType() == DataType::Float32;
+        }
+
+        if (allowInsert)
+        {
+            const std::string name =
+                std::string("convert_fp32_to_bf16-" + std::to_string(inputSlot->GetSlotIndex()) + "-") +
+                layer.GetName();
+            ConvertFp32ToBf16Layer* convertLayer =
+                graph.InsertNewLayer<ConvertFp32ToBf16Layer>(*inputSlot, name.c_str());
+
+            TensorInfo convertInfo = convertLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+            convertInfo.SetDataType(DataType::BFloat16);
+
+            convertLayer->GetOutputSlot().SetTensorInfo(convertInfo);
+
+            convertLayers.emplace_back(convertLayer);
+        }
+    }
+
+    return convertLayers;
+}
+
 std::vector<ConvertFp16ToFp32Layer*> InsertConvertFp16ToFp32LayersBefore(Graph& graph,
                                                                          Layer& layer,
                                                                          bool expectCorrectInputType)
index 064545a..a922770 100644 (file)
@@ -15,6 +15,10 @@ std::vector<ConvertBf16ToFp32Layer*> InsertConvertBf16ToFp32LayersBefore(Graph&
                                                                          Layer& layer,
                                                                          bool expectCorrectInputType = true);
 
+std::vector<ConvertFp32ToBf16Layer*> InsertConvertFp32ToBf16LayersBefore(Graph& graph,
+                                                                         Layer& layer,
+                                                                         bool expectCorrectInputType = true);
+
 std::vector<ConvertFp32ToBf16Layer*> InsertConvertFp32ToBf16LayersAfter(Graph& graph, Layer& layer);
 
 std::vector<ConvertFp16ToFp32Layer*> InsertConvertFp16ToFp32LayersBefore(Graph& graph,
index d6350c3..222414c 100644 (file)
@@ -4,68 +4,62 @@
 //
 #pragma once
 
-#include "Optimization.hpp"
 #include "NetworkUtils.hpp"
+#include "Optimization.hpp"
 
 namespace armnn
 {
 namespace optimizations
 {
 
+template <typename LayerT>
+inline LayerT* ConvertWeight(Layer* l)
+{
+    LayerT* layer = boost::polymorphic_downcast<LayerT*>(l);
+    if ((layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected)
+         && layer->m_Weight)
+    {
+        const TensorInfo& info = layer->m_Weight->GetTensorInfo();
+
+        if (info.GetDataType() == DataType::Float32)
+        {
+            std::vector<BFloat16> newValues(info.GetNumElements());
+
+            armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(layer->m_Weight->template GetTensor<float>(),
+                                                                         info.GetNumElements(),
+                                                                         newValues.data());
+
+            TensorInfo newInfo(info.GetShape(), DataType::BFloat16);
+            ConstTensor newInput(newInfo, newValues);
+            layer->m_Weight.reset(new ScopedCpuTensorHandle(newInput));
+        }
+    }
+    return layer;
+}
+
 class ConvertFp32NetworkToBf16Impl
 {
 public:
+
     void Run(Graph& graph, Layer& layer) const
     {
-        if(layer.GetType() == LayerType::Input)
+        // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer.
+        // And also convert weight data type from Float32 to Bfloat16.
+        // Do not convert bias data type.
+        if (layer.GetType() == LayerType::Convolution2d)
         {
-            // if the outputs of this layer are DataType::Float32
-            // add a ConvertFloat32ToBFloat16 layer after each of the outputs
             if (layer.GetDataType() == DataType::Float32)
             {
-                InsertConvertFp32ToBf16LayersAfter(graph, layer);
+                InsertConvertFp32ToBf16LayersBefore(graph,layer);
+                ConvertWeight<Convolution2dLayer>(&layer);
             }
         }
-        else if (layer.GetType() == LayerType::Output)
+        else if (layer.GetType() == LayerType::FullyConnected)
         {
-            // if the inputs of this layer are DataType::Float32
-            // add a ConvertBFloat16ToFloat32 layer before each of the inputs
             if (layer.GetDataType() == DataType::Float32)
             {
-                // NOTE: We need to call InsertConvertBf16ToFp32LayersBefore with expectCorrectInputType = false
-                // here, otherwise it will expect the inputs to be DataType::BFloat16
-                InsertConvertBf16ToFp32LayersBefore(graph, layer, false);
-            }
-        }
-        else if (layer.GetType() != LayerType::ConvertFp32ToBf16 && layer.GetType() != LayerType::ConvertBf16ToFp32)
-        {
-            // if the inputs/outputs of this layer are DataType::Float32
-            // change the data type for all inputs and outputs to DataType::BFloat16
-            for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input)
-            {
-                // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection
-                // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer
-                Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer();
-                if (base.GetType() != LayerType::Input)
-                {
-                    TensorInfo convertInfo = input->GetConnection()->GetTensorInfo();
-                    if (convertInfo.GetDataType() == DataType::Float32)
-                    {
-                        convertInfo.SetDataType(DataType::BFloat16);
-                        input->GetConnection()->SetTensorInfo(convertInfo);
-                    }
-                }
-            }
-
-            // change outputs to DataType::BFloat16
-            for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output)
-            {
-                TensorInfo convertInfo = output->GetTensorInfo();
-                if (convertInfo.GetDataType() == DataType::Float32)
-                {
-                    convertInfo.SetDataType(DataType::BFloat16);
-                    output->SetTensorInfo(convertInfo);
-                }
+                InsertConvertFp32ToBf16LayersBefore(graph,layer);
+                ConvertWeight<FullyConnectedLayer>(&layer);
             }
         }
     }
index 90a1548..b35f983 100644 (file)
 BOOST_AUTO_TEST_SUITE(Optimizer)
 using namespace armnn::optimizations;
 
-BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationTest)
+BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationNoConversionTest)
 {
     armnn::Graph graph;
 
     const armnn::TensorInfo infoFP32({ 2, 2, 1, 3 }, armnn::DataType::Float32);
 
-    // Create the simple test network
+    // Create the simple test network without Conv2D/FullyConnected.
     auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
     input->GetOutputSlot().SetTensorInfo(infoFP32);
 
@@ -38,8 +38,148 @@ BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationTest)
     armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(Fp32NetworkToBf16Converter()));
 
     BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
-                             &IsLayerOfType<armnn::ConvertFp32ToBf16Layer>, &IsLayerOfType<armnn::FloorLayer>,
-                             &IsLayerOfType<armnn::ConvertBf16ToFp32Layer>, &IsLayerOfType<armnn::OutputLayer>));
+                             &IsLayerOfType<armnn::FloorLayer>,
+                             &IsLayerOfType<armnn::OutputLayer>));
+}
+
+BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationConv2DTest)
+{
+    armnn::Graph graph;
+
+    const armnn::TensorInfo infoFP32({ 2, 3, 8, 1 }, armnn::DataType::Float32);
+
+    // Create const tensor fp32 data
+    unsigned int dims[] = { 4, 2, 1, 1 };
+    std::vector<float> floatWeights{ 0.0f, -1.0f,
+                                     3.8f, // 0x40733333 Round down
+                                     3.1055E+29f, // 0x707ADC3C Round up
+                                     9.149516E-10f, // 0x307B7FFF Round down
+                                    -3.8f, // 0xC0733333 Round down
+                                    -3.1055E+29f, // 0xF07ADC3C Round up
+                                    -9.149516E-10f // 0xB07B7FFF Round down
+                                   };
+    armnn::ConstTensor weights(armnn::TensorInfo(4, dims, armnn::DataType::Float32), floatWeights);
+
+    // Create const bias fp32 data
+    unsigned int biasDims[] {4};
+    std::vector<float> floatBias{ 1.0f, 2.0f, 3.0f, 4.0f };
+    armnn::ConstTensor bias(armnn::TensorInfo(1, biasDims, armnn::DataType::Float32), floatBias);
+
+    // A network with Convolution2d layer
+    auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
+    input->GetOutputSlot().SetTensorInfo(infoFP32);
+
+    armnn::Convolution2dDescriptor descriptor;
+
+    auto conv = graph.AddLayer<armnn::Convolution2dLayer>(descriptor, "conv2d");
+    conv->m_Weight = std::make_unique<armnn::ScopedCpuTensorHandle>(weights);
+    conv->m_Bias = std::make_unique<armnn::ScopedCpuTensorHandle>(bias);
+    conv->GetOutputSlot().SetTensorInfo(infoFP32);
+
+    auto output = graph.AddLayer<armnn::OutputLayer>(1, "output");
+
+    // Connect up the layers
+    input->GetOutputSlot().Connect(conv->GetInputSlot(0));
+    conv->GetOutputSlot().Connect(output->GetInputSlot(0));
+
+    BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
+                             &IsLayerOfType<armnn::Convolution2dLayer>, &IsLayerOfType<armnn::OutputLayer>));
+
+    // Run the optimizer
+    armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(Fp32NetworkToBf16Converter()));
+
+    BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
+                             &IsLayerOfType<armnn::ConvertFp32ToBf16Layer>, &IsLayerOfType<armnn::Convolution2dLayer>,
+                             &IsLayerOfType<armnn::OutputLayer>));
+
+    armnn::TensorInfo inputTensor = conv->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+    armnn::TensorInfo outputTensor = conv->GetOutputSlot(0).GetTensorInfo();
+    BOOST_TEST((conv->GetDataType() == armnn::DataType::BFloat16));
+    BOOST_TEST((conv->m_Weight->GetTensorInfo().GetDataType() == armnn::DataType::BFloat16));
+    BOOST_TEST((conv->m_Bias->GetTensorInfo().GetDataType() == armnn::DataType::Float32));
+    BOOST_TEST((inputTensor.GetDataType() == armnn::DataType::BFloat16));
+    BOOST_TEST((outputTensor.GetDataType() == armnn::DataType::Float32));
+
+    // Check whether data matches expected Bf16 data
+    armnn::BFloat16* data = conv->m_Weight->GetTensor<armnn::BFloat16>();
+    BOOST_CHECK(data[0] == armnn::BFloat16(0.0f));
+    BOOST_CHECK(data[1] == armnn::BFloat16(-1.0f));
+    BOOST_CHECK(data[2] == armnn::BFloat16(3.796875f)); // 0x4073
+    BOOST_CHECK(data[3] == armnn::BFloat16(3.1072295E29f)); // 0x707B
+    BOOST_CHECK(data[4] == armnn::BFloat16(9.131327E-10f)); // 0x307B
+    BOOST_CHECK(data[5] == armnn::BFloat16(-3.796875f)); // 0xC073
+    BOOST_CHECK(data[6] == armnn::BFloat16(-3.1072295E29f)); // 0xF07B
+    BOOST_CHECK(data[7] == armnn::BFloat16(-9.131327E-10f)); // 0xB07B
+}
+
+BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationFullyConnectedTest)
+{
+    armnn::Graph graph;
+
+    const armnn::TensorInfo infoFP32({ 2, 3, 8, 1 }, armnn::DataType::Float32);
+
+    // Create const tensor fp32 data
+    unsigned int dims[] = { 4, 2, 1, 1 };
+    std::vector<float> floatWeights{ 0.0f, -1.0f,
+                                     3.8f, // 0x40733333 Round down
+                                     3.1055E+29f, // 0x707ADC3C Round up
+                                     9.149516E-10f, // 0x307B7FFF Round down
+                                    -3.8f, // 0xC0733333 Round down
+                                    -3.1055E+29f, // 0xF07ADC3C Round up
+                                    -9.149516E-10f // 0xB07B7FFF Round down
+                                   };
+    armnn::ConstTensor weights(armnn::TensorInfo(4, dims, armnn::DataType::Float32), floatWeights);
+
+    // Create const bias fp32 data
+    unsigned int biasDims[] {4};
+    std::vector<float> floatBias{ 1.0f, 2.0f, 3.0f, 4.0f };
+    armnn::ConstTensor bias(armnn::TensorInfo(1, biasDims, armnn::DataType::Float32), floatBias);
+
+    // A network with FullyConnected layer
+    auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
+    input->GetOutputSlot().SetTensorInfo(infoFP32);
+
+    armnn::FullyConnectedDescriptor descriptor;
+
+    auto fc = graph.AddLayer<armnn::FullyConnectedLayer>(descriptor, "fully");
+    fc->m_Weight = std::make_unique<armnn::ScopedCpuTensorHandle>(weights);
+    fc->m_Bias = std::make_unique<armnn::ScopedCpuTensorHandle>(bias);
+    fc->GetOutputSlot().SetTensorInfo(infoFP32);
+
+    auto output = graph.AddLayer<armnn::OutputLayer>(1, "output");
+
+    // Connect up the layers
+    input->GetOutputSlot().Connect(fc->GetInputSlot(0));
+    fc->GetOutputSlot().Connect(output->GetInputSlot(0));
+
+    BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
+                             &IsLayerOfType<armnn::FullyConnectedLayer>, &IsLayerOfType<armnn::OutputLayer>));
+
+    // Run the optimizer
+    armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(Fp32NetworkToBf16Converter()));
+
+    BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
+                             &IsLayerOfType<armnn::ConvertFp32ToBf16Layer>, &IsLayerOfType<armnn::FullyConnectedLayer>,
+                             &IsLayerOfType<armnn::OutputLayer>));
+
+    armnn::TensorInfo inputTensor = fc->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+    armnn::TensorInfo outputTensor = fc->GetOutputSlot(0).GetTensorInfo();
+    BOOST_TEST((fc->GetDataType() == armnn::DataType::BFloat16));
+    BOOST_TEST((fc->m_Weight->GetTensorInfo().GetDataType() == armnn::DataType::BFloat16));
+    BOOST_TEST((fc->m_Bias->GetTensorInfo().GetDataType() == armnn::DataType::Float32));
+    BOOST_TEST((inputTensor.GetDataType() == armnn::DataType::BFloat16));
+    BOOST_TEST((outputTensor.GetDataType() == armnn::DataType::Float32));
+
+    // Check whether data matches expected Bf16 data
+    armnn::BFloat16* data = fc->m_Weight->GetTensor<armnn::BFloat16>();
+    BOOST_CHECK(data[0] == armnn::BFloat16(0.0f));
+    BOOST_CHECK(data[1] == armnn::BFloat16(-1.0f));
+    BOOST_CHECK(data[2] == armnn::BFloat16(3.796875f)); // 0x4073
+    BOOST_CHECK(data[3] == armnn::BFloat16(3.1072295E29f)); // 0x707B
+    BOOST_CHECK(data[4] == armnn::BFloat16(9.131327E-10f)); // 0x307B
+    BOOST_CHECK(data[5] == armnn::BFloat16(-3.796875f)); // 0xC073
+    BOOST_CHECK(data[6] == armnn::BFloat16(-3.1072295E29f)); // 0xF07B
+    BOOST_CHECK(data[7] == armnn::BFloat16(-9.131327E-10f)); // 0xB07B
 }
 
 BOOST_AUTO_TEST_SUITE_END()
\ No newline at end of file
index 85c074a..f968ad7 100644 (file)
@@ -26,10 +26,9 @@ DataType GetBiasDataType(DataType inputDataType)
 {
     switch (inputDataType)
     {
-        case DataType::BFloat16:
-            return DataType::BFloat16;
         case DataType::Float16:
             return DataType::Float16;
+        case DataType::BFloat16:
         case DataType::Float32:
             return DataType::Float32;
         case DataType::QAsymmS8:
@@ -1009,7 +1008,20 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
     };
 
     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
-    ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+    // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
+    if (inputTensorInfo.GetDataType() == DataType::BFloat16)
+    {
+        if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
+        {
+            throw InvalidArgumentException(descriptorName  + ": " + " Output tensor type must be BFloat16 or Float32 "
+                                           "for BFloat16 input.");
+        }
+    }
+    else
+    {
+        ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+    }
 }
 
 void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
@@ -1206,7 +1218,20 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
     };
 
     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
-    ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+    // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
+    if (inputTensorInfo.GetDataType() == DataType::BFloat16)
+    {
+        if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
+        {
+            throw InvalidArgumentException(descriptorName  + ": " + " Output tensor type must be BFloat16 or Float32 "
+                                           "for BFloat16 input.");
+        }
+    }
+    else
+    {
+        ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+    }
 }
 
 void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
index 551a7b5..7b25a43 100644 (file)
@@ -474,8 +474,20 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
                                   "Reference Convolution2d: output is not a supported type.");
 
-    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+    // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
+    if (input.GetDataType() == DataType::BFloat16)
+    {
+        if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
+        {
+            reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
+            supported = false;
+        }
+    }
+    else
+    {
+        supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
                                   "Reference Convolution2d: input and output types mismatched.");
+    }
 
     const DataType inputType = input.GetDataType();
     if (IsQuantized8BitType(inputType))
@@ -882,12 +894,24 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
                                   "Reference Fully Connected: output type not supported.");
 
-    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
-                                  "Reference Fully Connected: input and output types mismatched.");
-
     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
                                   "Reference Fully Connected: weights type not supported.");
 
+    // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
+    if (input.GetDataType() == DataType::BFloat16)
+    {
+        if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
+        {
+            reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
+            supported = false;
+        }
+    }
+    else
+    {
+        supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+                                  "Reference Fully Connected: input and output types mismatched.");
+    }
+
     ARMNN_NO_DEPRECATE_WARN_BEGIN
     std::array<DataType, 3> supportedWeightTypes =
     {