arm_compute v18.02
[platform/upstream/armcl.git] / arm_compute / core / NEON / kernels / NEGEMMMatrixMultiplyKernel.h
index 9923c31..d54522c 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -39,6 +39,10 @@ class ITensor;
 class NEGEMMMatrixMultiplyKernel : public INEKernel
 {
 public:
+    const char *name() const override
+    {
+        return "NEGEMMMatrixMultiplyKernel";
+    }
     /** Constructor */
     NEGEMMMatrixMultiplyKernel();
     /** Prevent instances of this class from being copied (As this class contains pointers) */
@@ -54,22 +58,28 @@ public:
      * @note If the output tensor is a matrix, the input matrices @p input0 and @p input1 should be the output of the kernels: @ref NEGEMMInterleave4x4Kernel and @ref NEGEMMTranspose1xWKernel
      *       These two kernels change the layout of the original matrices to be more cache-friendly.
      *
-     * @param[in]  input0 Input tensor containing the interleaved Matrix A or the vector A. Data types supported: QS8/QS16/F16/F32
-     * @param[in]  input1 Input tensor containing the transposed Matrix B if the first input tensor A is not a vector.
-     *                    If the output tensor is a vector, input1 must contain the matrix B not reshaped. Data type supported: same as @p input0
-     * @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
-     * @param[in]  alpha  Weight of the matrix product
+     * @param[in]  input0         Input tensor containing the interleaved Matrix A or the vector A. Data types supported: QS8/QS16/F16/F32
+     * @param[in]  input1         Input tensor containing the transposed Matrix B if the first input tensor A is not a vector.
+     *                            If the output tensor is a vector, input1 must contain the matrix B not reshaped. Data type supported: same as @p input0
+     * @param[out] output         Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+     * @param[in]  alpha          Weight of the matrix product
+     * @param[in]  is_interleaved (Optional) True if input0 and input1 have been reshaped respectively using @ref NEGEMMInterleave4x4Kernel and @ref NEGEMMTranspose1xWKernel
+     * @param[in]  reshape_info   (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
      */
-    void configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha);
+    void configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo());
     /** Static function to check if given info will lead to a valid configuration of @ref NEGEMMMatrixMultiplyKernel
      *
-     * @param[in] input0 Input tensor containing the Matrix A. Data types supported: QS8/QS16/F16/F32
-     * @param[in] input1 Input tensor containing the Matrix B. Data type supported: same as @p input0
-     * @param[in] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
+     * @param[in] input0         Input tensor containing the interleaved Matrix A or the vector A. Data types supported: QS8/QS16/F16/F32
+     * @param[in] input1         Input tensor containing the transposed Matrix B if the first input tensor A is not a vector.
+     *                           If the output tensor is a vector, input1 must contain the matrix B not reshaped. Data type supported: same as @p input0
+     * @param[in] output         Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+     * @param[in] alpha          Weight of the matrix product
+     * @param[in] is_interleaved (Optional) True if input0 and input1 have been reshaped respectively using @ref NEGEMMInterleave4x4Kernel and @ref NEGEMMTranspose1xWKernel
+     * @param[in] reshape_info   (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output);
+    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info);
 
     // Inherited methods overridden:
     void run(const Window &window, const ThreadInfo &info) override;