arm_compute v18.02
[platform/upstream/armcl.git] / arm_compute / core / CL / kernels / CLGEMMMatrixMultiplyKernel.h
index 4e73d7e..7260c4a 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -58,8 +58,10 @@ public:
      * @param[out] output                    Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
      * @param[in]  alpha                     Weight of the matrix product
      * @param[in]  is_interleaved_transposed (Optional) True if input0 and input1 have been reshaped respectively using @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel
+     * @param[in]  reshape_info              (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
+     *
      */
-    void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed = true);
+    void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed = true, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo());
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyKernel
      *
      * @param[in] input0                    Input tensor containing the Matrix A. Data types supported: QS8/QS16/F16/F32
@@ -67,11 +69,13 @@ public:
      * @param[in] output                    Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
      * @param[in] alpha                     Weight of the matrix product
      * @param[in] is_interleaved_transposed True if input0 and input1 have been reshaped respectively using @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel
+     * @param[in] reshape_info              GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
      * @param[in] gpu_target                GPU Target
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed, GPUTarget gpu_target);
+    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
+                           GPUTarget gpu_target);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;