Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / FullyConnected.h
index 9bcf3fd..4280c9a 100644 (file)
@@ -18,6 +18,7 @@
 #ifndef __NNFW_CKER_FULLY_CONNECTED_H__
 #define __NNFW_CKER_FULLY_CONNECTED_H__
 
+#include <ruy/context.h>
 #include "cker/Shape.h"
 #include "cker/Types.h"
 #include "cker/Utils.h"
@@ -78,8 +79,11 @@ inline void FullyConnected(const FullyConnectedParams &params, const Shape &inpu
   MatrixBatchVectorMultiplyAccumulate(weights_data, num_units, input_size, input_data, batch_size,
                                       output_data, /*result_stride=*/1);
 
-  // Apply activation function
-  ApplyActivationToVector(output_data, batch_size * num_units, params.activation, output_data);
+  if (params.activation != FusedActivationFunctionType::kNone)
+  {
+    // Apply activation function
+    ApplyActivationToVector(output_data, batch_size * num_units, params.activation, output_data);
+  }
 }
 
 inline void FullyConnected(const FullyConnectedParams &params, const Shape &input_shape,
@@ -140,7 +144,7 @@ inline void FullyConnectedHybrid(const FullyConnectedParams &params, const Shape
                                  const float *input_data, const Shape &filter_shape,
                                  const int8_t *filter_data, const Shape &, const float *bias_data,
                                  const Shape &output_shape, float *output_data,
-                                 FCTempArena &temp_arena)
+                                 FCTempArena &temp_arena, ruy::Context *ruy_context)
 {
   int total_input_size = input_shape.FlatSize();
   const int input_size = filter_shape.Dims(1);
@@ -186,19 +190,72 @@ inline void FullyConnectedHybrid(const FullyConnectedParams &params, const Shape
   int32_t *scratch = temp_arena.accum_scratch.data();
   MatrixBatchVectorMultiplyAccumulate(filter_data, num_units, input_size, quant_data,
                                       scaling_factors_ptr, batch_size, scratch, output_data,
-                                      /*result_stride=*/1);
+                                      /*result_stride=*/1, ruy_context);
 #else
   MatrixBatchVectorMultiplyAccumulate(filter_data, num_units, input_size, quant_data,
                                       scaling_factors_ptr, batch_size, output_data,
                                       /*result_stride=*/1);
+  UNUSED_RELEASE(ruy_context);
   UNUSED_RELEASE(output_shape);
 #endif
 
   // Apply activation function to floats.
-  ApplyActivationToVector(output_data, batch_size * num_units, params.activation, output_data);
+  if (params.activation != FusedActivationFunctionType::kNone)
+  {
+    // Apply activation function
+    ApplyActivationToVector(output_data, batch_size * num_units, params.activation, output_data);
+  }
   return;
 }
 
+inline void FullyConnectedSparseWeight(const FullyConnectedParams &params, const Shape &input_shape,
+                                       const float *input_data, const Shape &weights_shape,
+                                       const float *weights_data, const Shape &bias_shape,
+                                       const float *bias_data, const Shape &output_shape,
+                                       float *output_data, int w0_size, const uint16_t *w1_segments,
+                                       const uint16_t *w1_indices)
+{
+  UNUSED_RELEASE(params);
+  UNUSED_RELEASE(input_shape);
+
+  assert(weights_shape.DimensionsCount() == 2);
+  assert(output_shape.DimensionsCount() == 2);
+
+  const int output_dims_count = output_shape.DimensionsCount();
+  const int weights_dims_count = weights_shape.DimensionsCount();
+  const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
+  const int output_depth =
+      MatchingDim(weights_shape, weights_dims_count - 2, output_shape, output_dims_count - 1);
+  const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
+
+  UNUSED_RELEASE(bias_shape);
+  if (bias_data)
+  {
+    VectorBatchVectorAssign(bias_data, output_depth, batches, output_data);
+  }
+  else
+  {
+    ZeroVector(output_data, batches * output_depth);
+  }
+  for (int b = 0; b < batches; ++b)
+  {
+    for (int idx_0 = 0; idx_0 < w0_size; ++idx_0)
+    {
+      for (int pw1 = w1_segments[idx_0]; pw1 < w1_segments[idx_0 + 1]; ++pw1)
+      {
+        int idx_1 = w1_indices[pw1];
+        output_data[b * output_depth + idx_0] +=
+            weights_data[pw1] * input_data[b * accum_depth + idx_1];
+      }
+    }
+  }
+  if (params.activation != FusedActivationFunctionType::kNone)
+  {
+    // Apply activation function
+    ApplyActivationToVector(output_data, batches * output_depth, params.activation, output_data);
+  }
+}
+
 } // namespace cker
 } // namespace nnfw