fp16 ocl support for googlenet
authorLi Peng <peng.li@intel.com>
Thu, 26 Apr 2018 11:22:29 +0000 (19:22 +0800)
committerLi Peng <peng.li@intel.com>
Wed, 16 May 2018 14:45:02 +0000 (22:45 +0800)
Signed-off-by: Li Peng <peng.li@intel.com>
32 files changed:
modules/dnn/src/layers/concat_layer.cpp
modules/dnn/src/layers/convolution_layer.cpp
modules/dnn/src/layers/elementwise_layers.cpp
modules/dnn/src/layers/eltwise_layer.cpp
modules/dnn/src/layers/flatten_layer.cpp
modules/dnn/src/layers/fully_connected_layer.cpp
modules/dnn/src/layers/lrn_layer.cpp
modules/dnn/src/layers/mvn_layer.cpp
modules/dnn/src/layers/pooling_layer.cpp
modules/dnn/src/layers/slice_layer.cpp
modules/dnn/src/layers/softmax_layer.cpp
modules/dnn/src/ocl4dnn/include/ocl4dnn.hpp
modules/dnn/src/ocl4dnn/src/math_functions.cpp
modules/dnn/src/ocl4dnn/src/ocl4dnn_conv_spatial.cpp
modules/dnn/src/ocl4dnn/src/ocl4dnn_inner_product.cpp
modules/dnn/src/ocl4dnn/src/ocl4dnn_lrn.cpp
modules/dnn/src/ocl4dnn/src/ocl4dnn_pool.cpp
modules/dnn/src/ocl4dnn/src/ocl4dnn_softmax.cpp
modules/dnn/src/opencl/activations.cl
modules/dnn/src/opencl/concat.cl
modules/dnn/src/opencl/conv_layer_spatial.cl
modules/dnn/src/opencl/eltwise.cl
modules/dnn/src/opencl/gemm_buffer.cl [new file with mode: 0644]
modules/dnn/src/opencl/gemm_image.cl
modules/dnn/src/opencl/math.cl
modules/dnn/src/opencl/matvec_mul.cl
modules/dnn/src/opencl/mvn.cl
modules/dnn/src/opencl/ocl4dnn_lrn.cl
modules/dnn/src/opencl/ocl4dnn_pooling.cl
modules/dnn/src/opencl/slice.cl
modules/dnn/src/opencl/softmax.cl
modules/dnn/src/opencl/softmax_loss.cl

index 172d0a0..a72b282 100644 (file)
@@ -128,14 +128,14 @@ public:
             for( i = 0; i < ninputs; i++ )
             {
                 Mat& inp = *inputs[i];
-                CV_Assert( inp.isContinuous() && inp.type() == CV_32F &&
+                CV_Assert( inp.isContinuous() && (inp.type() == CV_32F || inp.type() == CV_16S) &&
                            inp.dims == 4 && inp.size[0] == output.size[0] &&
                            inp.size[2] == output.size[2] &&
                            inp.size[3] == output.size[3] );
                 nchannels += inp.size[1];
             }
             CV_Assert( nchannels == output.size[1] );
-            CV_Assert( output.isContinuous() && output.type() == CV_32F );
+            CV_Assert( output.isContinuous() && (output.type() == CV_32F || output.type() == CV_16S) );
 
             cc.chptrs.resize(nchannels*batchsz);
 
@@ -186,6 +186,7 @@ public:
         std::vector<UMat> inputs;
         std::vector<UMat> outputs;
 
+        bool use_half = (inps.depth() == CV_16S);
         inps.getUMatVector(inputs);
         outs.getUMatVector(outputs);
 
@@ -199,11 +200,12 @@ public:
         int num_concats = total(shape(inputs[0]), 0, cAxis);
         int offset_concat_axis = 0;
         UMat& outMat = outputs[0];
-        String buildopt = String("-DDtype=") + ocl::typeToStr(inputs[0].type()) + String(" ");
+        String buildopt = format(" -DDtype=%s", (use_half) ? "half" : "float");
+        String kname = format("concat_%s", use_half ? "half" : "float");
 
         for (size_t i = 0; i < inputs.size(); i++)
         {
-            ocl::Kernel kernel("concat", ocl::dnn::concat_oclsrc, buildopt);
+            ocl::Kernel kernel(kname.c_str(), ocl::dnn::concat_oclsrc, buildopt);
             if (kernel.empty())
                 return false;
 
@@ -235,7 +237,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
index 2bb96f9..96a9d5b 100644 (file)
@@ -94,7 +94,7 @@ public:
         CV_Assert(blobs[0].dims == 4 && blobs[0].size[3] == kernel.width && blobs[0].size[2] == kernel.height);
 
         const Mat &input = *inputs[0];
-        CV_Assert(input.dims == 4 && (input.type() == CV_32F || input.type() == CV_64F));
+        CV_Assert(input.dims == 4 && (input.type() == CV_32F || input.type() == CV_64F || input.type() == CV_16S));
         for (size_t i = 0; i < inputs.size(); i++)
         {
             CV_Assert(inputs[i]->type() == input.type());
@@ -288,7 +288,7 @@ public:
         newActiv = true;
         activType = OCL4DNN_CONV_FUSED_ACTIV_NONE;
 
-        if (preferableTarget == DNN_TARGET_OPENCL)
+        if (IS_DNN_OPENCL_TARGET(preferableTarget))
         {
             Ptr<PowerLayer> activ_power = activ.dynamicCast<PowerLayer>();
             if (!activ_power.empty())
@@ -842,6 +842,7 @@ public:
         std::vector<UMat> inputs;
         std::vector<UMat> outputs;
 
+        bool use_half = (inps.depth() == CV_16S);
         inps.getUMatVector(inputs);
         outs.getUMatVector(outputs);
 
@@ -860,6 +861,7 @@ public:
             config.dilation = dilation;
             config.group = inputs[0].size[1] / umat_blobs[0].size[1];
             config.bias_term = (hasBias()) ? true : false;
+            config.use_half = use_half;
 
             convolutionOp = Ptr<OCL4DNNConvSpatial<float> >(new OCL4DNNConvSpatial<float>(config));
         }
@@ -964,7 +966,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
@@ -1360,6 +1362,9 @@ public:
         std::vector<UMat> outputs;
         std::vector<UMat> internals;
 
+        if (inputs_.depth() == CV_16S)
+            return false;
+
         inputs_.getUMatVector(inputs);
         outputs_.getUMatVector(outputs);
         internals_.getUMatVector(internals);
@@ -1450,7 +1455,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
index 8600967..a24b913 100644 (file)
@@ -176,7 +176,7 @@ public:
     {
         CV_TRACE_FUNCTION();
 
-        CV_OCL_RUN((this->preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(this->preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    func.applyOCL(inputs_arr, outputs_arr, internals_arr))
 
@@ -223,7 +223,12 @@ public:
 #ifdef HAVE_OPENCL
 static String oclGetTMacro(const UMat &m)
 {
-    return String("-DT=") + ocl::typeToStr(m.type()) + String(" ");
+    String str_name = ocl::typeToStr(m.type());
+
+    if (str_name == "short")
+        str_name = "half";
+
+    return format("-DT=%s -Dconvert_T=convert_%s ", str_name.c_str(), str_name.c_str());
 }
 #endif
 
@@ -516,8 +521,28 @@ struct SigmoidFunctor
 #ifdef HAVE_OPENCL
     bool applyOCL(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
     {
-        // TODO: implement OCL version
-        return false;
+        std::vector<UMat> inputs;
+        std::vector<UMat> outputs;
+
+        inps.getUMatVector(inputs);
+        outs.getUMatVector(outputs);
+        String buildopt = oclGetTMacro(inputs[0]);
+
+        for (size_t i = 0; i < inputs.size(); i++)
+        {
+            UMat& src = inputs[i];
+            UMat& dst = outputs[i];
+
+            ocl::Kernel kernel("SigmoidForward", ocl::dnn::activations_oclsrc, buildopt);
+            kernel.set(0, (int)src.total());
+            kernel.set(1, ocl::KernelArg::PtrReadOnly(src));
+            kernel.set(2, ocl::KernelArg::PtrWriteOnly(dst));
+
+            size_t gSize = src.total();
+            CV_Assert(kernel.run(1, &gSize, NULL, false));
+        }
+
+        return true;
     }
 #endif
 
@@ -561,8 +586,28 @@ struct ELUFunctor
 #ifdef HAVE_OPENCL
     bool applyOCL(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
     {
-        // TODO: implement OCL version
-        return false;
+        std::vector<UMat> inputs;
+        std::vector<UMat> outputs;
+
+        inps.getUMatVector(inputs);
+        outs.getUMatVector(outputs);
+        String buildopt = oclGetTMacro(inputs[0]);
+
+        for (size_t i = 0; i < inputs.size(); i++)
+        {
+            UMat& src = inputs[i];
+            UMat& dst = outputs[i];
+
+            ocl::Kernel kernel("ELUForward", ocl::dnn::activations_oclsrc, buildopt);
+            kernel.set(0, (int)src.total());
+            kernel.set(1, ocl::KernelArg::PtrReadOnly(src));
+            kernel.set(2, ocl::KernelArg::PtrWriteOnly(dst));
+
+            size_t gSize = src.total();
+            CV_Assert(kernel.run(1, &gSize, NULL, false));
+        }
+
+        return true;
     }
 #endif
 
@@ -604,8 +649,28 @@ struct AbsValFunctor
 #ifdef HAVE_OPENCL
     bool applyOCL(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
     {
-        // TODO: implement OCL version
-        return false;
+        std::vector<UMat> inputs;
+        std::vector<UMat> outputs;
+
+        inps.getUMatVector(inputs);
+        outs.getUMatVector(outputs);
+        String buildopt = oclGetTMacro(inputs[0]);
+
+        for (size_t i = 0; i < inputs.size(); i++)
+        {
+            UMat& src = inputs[i];
+            UMat& dst = outputs[i];
+
+            ocl::Kernel kernel("AbsValForward", ocl::dnn::activations_oclsrc, buildopt);
+            kernel.set(0, (int)src.total());
+            kernel.set(1, ocl::KernelArg::PtrReadOnly(src));
+            kernel.set(2, ocl::KernelArg::PtrWriteOnly(dst));
+
+            size_t gSize = src.total();
+            CV_Assert(kernel.run(1, &gSize, NULL, false));
+        }
+
+        return true;
     }
 #endif
 
index 58a651e..39961ab 100644 (file)
@@ -271,6 +271,9 @@ public:
         std::vector<UMat> inputs;
         std::vector<UMat> outputs;
 
+        if (inputs_.depth() == CV_16S && op != SUM)
+            return false;
+
         inputs_.getUMatVector(inputs);
         outputs_.getUMatVector(outputs);
 
@@ -284,10 +287,15 @@ public:
                     {
                         size_t localsize[] = { 128 };
                         size_t globalsize[] = { (size_t)channels / 4 * localsize[0] };
+                        String opts;
+                        if (inputs_.depth() == CV_16S)
+                            opts = " -DDtype=half -DDtype4=half4 -DDtype8=half8";
+                        else
+                            opts = " -DDtype=float -DDtype4=float4 -DDtype8=float8";
 
                         for (int i = 0; i < (inputs.size() - 1); ++i)
                         {
-                            String buildopt = format("-DLOOP=%d", i);
+                            String buildopt = format("-DLOOP=%d", i) + opts;
                             ocl::Kernel kernel("op_sum4", ocl::dnn::eltwise_oclsrc, buildopt);
                             int idx = 0;
                             UMat inpMat = (i == 0) ? inputs[0] : UMat();
@@ -306,6 +314,9 @@ public:
                     }
                     else
                     {
+                        if (inputs_.depth() == CV_16S)
+                            return false;
+
                         float coeff1 = coeffs.empty() ? 1.f : coeffs[0];
                         float coeff2 = coeffs.empty() ? 1.f : coeffs[1];
                         UMat mul0, mul1;
@@ -343,7 +354,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
index 1df1681..f737ac2 100644 (file)
@@ -140,7 +140,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    outputs_arr.isUMatVector() &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
index 9ee7e98..d459e65 100644 (file)
@@ -64,6 +64,7 @@ public:
 #ifdef HAVE_OPENCL
     Ptr<OCL4DNNInnerProduct<float> > innerProductOp;
     std::vector<UMat> umat_blobs;
+    std::vector<UMat> half_blobs;
 #endif
 
     FullyConnectedLayerImpl(const LayerParams& params)
@@ -277,6 +278,7 @@ public:
         std::vector<UMat> inputs;
         std::vector<UMat> outputs;
 
+        bool use_half = (inps.depth() == CV_16S);
         inps.getUMatVector(inputs);
         outs.getUMatVector(outputs);
 
@@ -293,6 +295,17 @@ public:
             config.bias_term = bias;
             config.M = outerSize;
             config.K = innerSize;
+            config.use_half = use_half;
+
+            if (use_half)
+            {
+                half_blobs.resize(umat_blobs.size());
+                for (int i = 0; i < umat_blobs.size(); i++)
+                {
+                    if (!umat_blobs[i].empty())
+                        convertFp16(umat_blobs[i], half_blobs[i]);
+                }
+            }
 
             innerProductOp = Ptr<OCL4DNNInnerProduct<float> >(new OCL4DNNInnerProduct<float>(config));
         }
@@ -309,13 +322,15 @@ public:
             dstMat = outputs[i].reshape(1, outshape.size(), &outshape[0]);
             dstMat.setTo(0.0f);
 
-            if (!innerProductOp->Forward(srcMat, umat_blobs[0], (bias) ? umat_blobs[1] : UMat(), dstMat))
+            if (!innerProductOp->Forward(srcMat, (use_half) ? half_blobs[0] : umat_blobs[0],
+                                         (bias) ? (use_half ? half_blobs[1] : umat_blobs[1]) : UMat(),
+                                         dstMat))
             {
                 ret = false;
                 break;
             }
 
-            if (bias && (outerSize > 1))
+            if (!use_half && bias && (outerSize > 1))
             {
                 UMat& biases = umat_blobs[1];
                 cv::gemm(biasOnesMat, biases, 1, dstMat, 1, dstMat, 0);
@@ -353,7 +368,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
index 25eb154..1b2a902 100644 (file)
@@ -106,6 +106,7 @@ public:
         std::vector<UMat> inputs;
         std::vector<UMat> outputs;
 
+        bool use_half = (inps.depth() == CV_16S);
         inps.getUMatVector(inputs);
         outs.getUMatVector(outputs);
 
@@ -128,6 +129,7 @@ public:
             config.height = inputs[0].size[2];
             config.width = inputs[0].size[3];
             config.norm_by_size = normBySize;
+            config.use_half = use_half;
 
             lrnOp = Ptr<OCL4DNNLRN<float> >(new OCL4DNNLRN<float>(config));
         }
@@ -146,7 +148,7 @@ public:
 
         CV_Assert(inputs_arr.total() == outputs_arr.total());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
index f948c71..647308a 100644 (file)
@@ -102,6 +102,9 @@ public:
     {
         UMat bnorm_weight = scale.empty() ? UMat() : scale.getUMat(ACCESS_READ);
         UMat bnorm_bias = shift.empty() ? UMat() : shift.getUMat(ACCESS_READ);
+        bool use_half = (inputs[0].depth() == CV_16S);
+        String opts = format(" -DT=%s -DT4=%s -Dconvert_T=%s", use_half ? "half" : "float",
+                             use_half ? "half4" : "float4", use_half ? "convert_half4" : "convert_float4");
 
         int splitDim = (acrossChannels) ? 1 : 2;
         for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
@@ -111,12 +114,11 @@ public:
             int newRows = total(shape(inpMat), 0, splitDim);
 
             MatShape s = shape(newRows, inpMat.total() / newRows);
-            UMat oneMat = UMat::ones(s[1], 1, CV_32F);
-            UMat meanMat = UMat(s[0], 1, CV_32F);
+            UMat meanMat = UMat(s[0], 1, (use_half) ? CV_16S : CV_32F);
             UMat tmpMat  = UMat(s[0], s[1], CV_32F);
             float alpha = 1.0f / s[1];
 
-            String buildopt = "-DNUM=4";
+            String buildopt = "-DNUM=4" + opts;
             ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt);
             size_t localsize[] = { 128 };
             size_t globalsize[] = { (size_t)s[0] / 4 * localsize[0] };
@@ -167,13 +169,14 @@ public:
         int row_size = total(shape(inputs[0]), 0, splitDim);
         int plane_size = total(shape(inputs[0]), splitDim);
         if (normVariance && (row_size % 4 == 0) && (plane_size % 4 == 0))
-        {
-            bool ret = fast_forward_ocl(inputs, outputs);
-            return ret;
-        }
+            return fast_forward_ocl(inputs, outputs);
+
+        if (inputs[0].depth() == CV_16S)
+            return false;
 
         UMat bnorm_weight = scale.empty() ? UMat() : scale.getUMat(ACCESS_READ);
         UMat bnorm_bias = shift.empty() ? UMat() : shift.getUMat(ACCESS_READ);
+        String opts = format(" -DT=float -DT4=float4 -Dconvert_T=convert_float4");
 
         for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
         {
@@ -195,7 +198,7 @@ public:
 
             int number = (s[1] % 8 == 0) ? 8 : ((s[1] % 4 == 0) ? 4 : 1);
             size_t global[] = { (size_t)s[0], (size_t)(s[1] / number) };
-            String buildopt = format("-DNUM=%d", number);
+            String buildopt = format("-DNUM=%d", number) + opts;
             if (normVariance)
             {
                 String kname = format("calc_mean%d", number);
@@ -249,7 +252,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
index bee9d5d..2bcce1d 100644 (file)
@@ -147,6 +147,7 @@ public:
         std::vector<UMat> inputs;
         std::vector<UMat> outputs;
 
+        bool use_half = (inps.depth() == CV_16S);
         inps.getUMatVector(inputs);
         outs.getUMatVector(outputs);
 
@@ -164,6 +165,7 @@ public:
                                 (type == AVE ? LIBDNN_POOLING_METHOD_AVE :
                                                LIBDNN_POOLING_METHOD_STO);
             config.avePoolPaddedArea = avePoolPaddedArea;
+            config.use_half = use_half;
             poolOp = Ptr<OCL4DNNPool<float> >(new OCL4DNNPool<float>(config));
         }
 
@@ -189,7 +191,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
index 826c640..4b3a975 100644 (file)
@@ -181,6 +181,7 @@ public:
         std::vector<UMat> inputs;
         std::vector<UMat> outputs;
 
+        bool use_half = (inputs_.depth() == CV_16S);
         inputs_.getUMatVector(inputs);
         outputs_.getUMatVector(outputs);
 
@@ -188,6 +189,11 @@ public:
             (total(shape(outputs[0]), 2) % 4 != 0))
             return false;
 
+        String opts;
+        if (use_half)
+            opts = "-DDtype=half -DDtype4=half4 -DDtype8=half8";
+        else
+            opts = "-DDtype=float -DDtype4=float4 -DDtype8=float8";
         const UMat& inpMat = inputs[0];
         for (size_t i = 0; i < outputs.size(); i++)
         {
@@ -196,7 +202,7 @@ public:
             int rows = outputs[i].size[2];
             int cols = outputs[i].size[3];
 
-            ocl::Kernel kernel("slice", ocl::dnn::slice_oclsrc);
+            ocl::Kernel kernel("slice", ocl::dnn::slice_oclsrc, opts);
             size_t local[] = { 128 };
             size_t global[] = { (size_t)groups * channels / 4 * local[0] };
             int idx = 0;
@@ -222,7 +228,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
index 1712153..c26028e 100644 (file)
@@ -99,15 +99,16 @@ public:
         softmaxOp.release();
     }
 
-    bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays itns)
+    bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_)
     {
         std::vector<UMat> inputs;
         std::vector<UMat> outputs;
         std::vector<UMat> internals;
 
-        inps.getUMatVector(inputs);
-        outs.getUMatVector(outputs);
-        itns.getUMatVector(internals);
+        bool use_half = (inputs_.depth() == CV_16S);
+        inputs_.getUMatVector(inputs);
+        outputs_.getUMatVector(outputs);
+        internals_.getUMatVector(internals);
 
         if (softmaxOp.empty())
         {
@@ -117,6 +118,7 @@ public:
             config.axis = axisRaw;
             config.channels = inputs[0].size[axisRaw];
             config.logsoftmax = logSoftMax;
+            config.use_half = use_half;
 
             softmaxOp = Ptr<OCL4DNNSoftmax<float> >(new OCL4DNNSoftmax<float>(config));
         }
@@ -128,15 +130,13 @@ public:
             return true;
 
         UMat& bufMat = internals[0];
-        src.copyTo(dstMat);
-
         int axis = clamp(axisRaw, src.dims);
         MatShape s = shape(src);
         size_t outerSize = total(s, 0, axis);
         size_t channels = src.size[axis];
         size_t innerSize = total(s, axis + 1);
 
-        String buildOpts = String("-DT=") + ocl::typeToStr(src.type());
+        String buildOpts = format("-DT=%s", use_half ? "half" : "float");
         ocl::Kernel kmax, ksub, ksum, kdiv;
 
         if (!kmax.create("kernel_channel_max", ocl::dnn::softmax_oclsrc, buildOpts))
@@ -152,38 +152,31 @@ public:
         if (!kdiv.create("kernel_channel_div", ocl::dnn::softmax_oclsrc, buildOpts))
             return false;
 
-        size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize();
         size_t bufSize = internals[0].total();
         size_t totalSize = src.total();
 
-        // adjust local/global size
-        size_t internal_localSize[1] = { (bufSize == 1) ? 1 : wgSize };
-        size_t internal_globalSize[1] = { divUp(bufSize, (unsigned int)internal_localSize[0]) * internal_localSize[0] };
-
-        // adjust local/global size (total)
-        size_t total_localSize[1] = { (totalSize == 1) ? 1 : wgSize };
-        size_t total_globalSize[1] = { divUp(totalSize, (unsigned int)total_localSize[0]) * total_localSize[0] };
+        size_t internal_globalSize[1] = { bufSize };
+        size_t total_globalSize[1] = { totalSize };
 
         kmax.args((int)outerSize, (int)channels, (int)innerSize,
-                  ocl::KernelArg::PtrReadOnly(dstMat), ocl::KernelArg::PtrReadWrite(bufMat));
-        if (!kmax.run(1, internal_globalSize, internal_localSize, false))
+                  ocl::KernelArg::PtrReadOnly(src), ocl::KernelArg::PtrReadWrite(bufMat));
+        if (!kmax.run(1, internal_globalSize, NULL, false))
             return false;
 
         ksub.args((int)totalSize, (int)outerSize, (int)channels, (int)innerSize,
-                  ocl::KernelArg::PtrReadOnly(bufMat), ocl::KernelArg::PtrReadWrite(dstMat));
-        if (!ksub.run(1, total_globalSize, total_localSize, false))
+                  ocl::KernelArg::PtrReadOnly(bufMat),
+                  ocl::KernelArg::PtrReadOnly(src), ocl::KernelArg::PtrWriteOnly(dstMat));
+        if (!ksub.run(1, total_globalSize, NULL, false))
             return false;
 
-        cv::exp(dstMat, dstMat);
-
         ksum.args((int)outerSize, (int)channels, (int)innerSize,
                   ocl::KernelArg::PtrReadOnly(dstMat), ocl::KernelArg::PtrReadWrite(bufMat));
-        if (!ksum.run(1, internal_globalSize, internal_localSize, false))
+        if (!ksum.run(1, internal_globalSize, NULL, false))
             return false;
 
         kdiv.args((int)totalSize, (int)outerSize, (int)channels, (int)innerSize,
                   ocl::KernelArg::PtrReadOnly(bufMat), ocl::KernelArg::PtrReadWrite(dstMat));
-        if (!kdiv.run(1, total_globalSize, total_localSize, false))
+        if (!kdiv.run(1, total_globalSize, NULL, false))
             return false;
 
         return true;
@@ -195,7 +188,7 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
-        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+        CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) &&
                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
 
index 6eb60ee..e0ce77e 100644 (file)
@@ -59,7 +59,8 @@ struct OCL4DNNConvConfig
         stride(1, 1),
         dilation(1, 1),
         group(1),
-        bias_term(false)
+        bias_term(false),
+        use_half(false)
     {}
     MatShape in_shape;
     MatShape out_shape;
@@ -69,6 +70,7 @@ struct OCL4DNNConvConfig
     Size dilation;
     int group; // = 1;
     bool bias_term; // = false;
+    bool use_half; // = false;
 };
 
 typedef enum {
@@ -272,6 +274,8 @@ class OCL4DNNConvSpatial
         int32_t group_;
         bool bias_term_;
         UMat swizzled_weights_umat;
+        UMat weights_half;
+        UMat bias_half;
         UMat bottom_data2_;
 
         int32_t bottom_index_;
@@ -327,6 +331,7 @@ class OCL4DNNConvSpatial
         ocl4dnnFusedActiv_t fused_activ_;
         float power_;
         bool fused_eltwise_;
+        bool use_half_;
 };
 
 typedef enum {
@@ -345,7 +350,8 @@ struct OCL4DNNPoolConfig
         channels(0),
         pool_method(LIBDNN_POOLING_METHOD_MAX),
         global_pooling(false),
-        avePoolPaddedArea(false)
+        avePoolPaddedArea(true),
+        use_half(false)
     {}
     MatShape in_shape;
     MatShape out_shape;
@@ -358,6 +364,7 @@ struct OCL4DNNPoolConfig
     ocl4dnnPoolingMethod_t pool_method; // = LIBDNN_POOLING_METHOD_MAX;
     bool global_pooling; // = false;
     bool avePoolPaddedArea;
+    bool use_half;
 };
 
 template<typename Dtype>
@@ -391,13 +398,14 @@ class OCL4DNNPool
         int32_t pooled_height_;
         int32_t pooled_width_;
         bool avePoolPaddedArea;
+        bool use_half;
 };
 
 struct OCL4DNNInnerProductConfig
 {
     OCL4DNNInnerProductConfig() :
         num_output(0), M(0), K(0),
-        bias_term(false), transpose(false), phase_test(true)
+        bias_term(false), transpose(false), phase_test(true), use_half(false)
     {}
     int num_output;
     int M;
@@ -405,6 +413,7 @@ struct OCL4DNNInnerProductConfig
     bool bias_term;
     bool transpose; // = false;
     bool phase_test; // = true;
+    bool use_half; // = false;
 };
 
 template<typename Dtype>
@@ -428,6 +437,7 @@ class OCL4DNNInnerProduct
         bool transpose_;
         bool image_copied_;
         bool phase_test_;
+        bool use_half_;
 };
 
 typedef enum {
@@ -441,7 +451,7 @@ struct OCL4DNNLRNConfig
         lrn_type(LRNParameter_NormRegion_ACROSS_CHANNELS),
         phase_test(true),
         local_size(0), alpha(0.f), beta(0.f), k(0.f), norm_by_size(false),
-        batch_size(0), channels(0), height(0), width(0)
+        batch_size(0), channels(0), height(0), width(0), use_half(false)
     {}
     MatShape in_shape;
     LRNParameter_NormRegion_WITHIN_CHANNEL_t lrn_type;
@@ -455,6 +465,7 @@ struct OCL4DNNLRNConfig
     int32_t channels;
     int32_t height;
     int32_t width;
+    bool use_half;
 };
 
 template<typename Dtype>
@@ -477,16 +488,18 @@ class OCL4DNNLRN
         int32_t height_;
         int32_t width_;
         bool norm_by_size_;
+        bool use_half_;
 };
 
 struct OCL4DNNSoftmaxConfig
 {
-    OCL4DNNSoftmaxConfig() : axis(0), channels(0), logsoftmax(false)
+    OCL4DNNSoftmaxConfig() : axis(0), channels(0), logsoftmax(false), use_half(false)
     {}
     MatShape in_shape;
     int axis;
     int channels;
     bool logsoftmax;
+    bool use_half;
 };
 
 template<typename Dtype>
@@ -506,6 +519,7 @@ class OCL4DNNSoftmax
         bool use_slm_;
         bool log_softmax_;
         UMat scale_data_;
+        bool use_half_;
 };
 
 }}} // namespace cv::dnn::ocl4dnn
index 3f4a70b..b2dda73 100644 (file)
 
 namespace cv { namespace dnn { namespace ocl4dnn {
 
+enum gemm_data_type_t
+{
+    TYPE_FLOAT = 1,
+    TYPE_HALF = 2
+};
+
 // Create and copy buffer to image for GEMM's matrix A and B.
 // Will return image to caller if the input image is NULL. Otherwise,
 // will use the image directly. It's caller's responsibility to
@@ -60,6 +66,7 @@ ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset,
                                           int width, int ld)
 {
     ocl::Image2D image;
+    String opts = format("-DTYPE=%d", TYPE_FLOAT);
 
     if (!is_matrix_a && transpose)
     {
@@ -73,7 +80,8 @@ ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset,
             UMat mat(height, width, CV_32FC1);
             image = ocl::Image2D(mat);
 
-            ocl::Kernel oclk_gemm_copy("gemm_buffer_copy_image_transpose_float", ocl::dnn::gemm_image_oclsrc);
+            ocl::Kernel oclk_gemm_copy("gemm_buffer_copy_image_transpose_float",
+                                       ocl::dnn::gemm_image_oclsrc, opts);
 
             size_t global_copy[2];
             global_copy[0] = width;
@@ -96,7 +104,7 @@ ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset,
             image = ocl::Image2D(mat);
 
             ocl::Kernel oclk_gemm_copy("gemm_buffer_copy_image_no_transpose_float",
-                                       ocl::dnn::gemm_image_oclsrc);
+                                       ocl::dnn::gemm_image_oclsrc, opts);
 
             size_t global_copy[2];
             global_copy[0] = padded_width;
@@ -129,7 +137,7 @@ enum gemm_type_t
     GEMM_TYPE_FAST_IMAGE_32_1,
     GEMM_TYPE_FAST_IMAGE_32_2,
     GEMM_TYPE_FAST_IMAGE_B_IMAGE,
-    GEMM_TYPE_MAX
+    GEMM_TYPE_FAST_BUFFER
 };
 
 template<typename Dtype>
@@ -145,6 +153,8 @@ static bool ocl4dnnFastImageGEMM(const CBLAS_TRANSPOSE TransA,
     CHECK_EQ(gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || gemm_type == GEMM_TYPE_FAST_IMAGE_32_2 ||
              gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE, true) << "Invalid fast image gemm type." << std::endl;
 
+    bool halfPrecisionMode = (A.depth() == CV_16S);
+
     if (is_image_a)
     {
         CHECK_EQ(offA, 0) << "Invalid input image offset." << std::endl;
@@ -157,6 +167,7 @@ static bool ocl4dnnFastImageGEMM(const CBLAS_TRANSPOSE TransA,
         return false;
     }
 
+    String opts = format("-DTYPE=%d", halfPrecisionMode ? TYPE_HALF : TYPE_FLOAT);
     int widthA = (TransA == CblasNoTrans) ? K : M;
     int heightA = (TransA == CblasNoTrans) ? M : K;
     int widthB = (TransB == CblasNoTrans) ? N : K;
@@ -178,7 +189,7 @@ static bool ocl4dnnFastImageGEMM(const CBLAS_TRANSPOSE TransA,
     int blockC_width = blocksize;
     int blockC_height = blocksize;
 
-    int use_buffer_indicator = 8;
+    int use_buffer_indicator = (halfPrecisionMode) ? 16 : 8;
     // To fix the edge problem caused by the sub group block read.
     // we have to pad the image if it's not multiple of tile.
     // just padding one line is enough as the sub group block read
@@ -221,9 +232,13 @@ static bool ocl4dnnFastImageGEMM(const CBLAS_TRANSPOSE TransA,
     else
         kernel_name += "1";
 
-    kernel_name += "_float";
+    if (halfPrecisionMode) {
+        kernel_name += "_half";
+    } else {
+        kernel_name += "_float";
+    }
 
-    ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_image_oclsrc);
+    ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_image_oclsrc, opts);
     if (oclk_gemm_float.empty())
         return false;
 
@@ -255,6 +270,10 @@ static bool ocl4dnnFastImageGEMM(const CBLAS_TRANSPOSE TransA,
                 bool padding_A = false;
                 bool padding_B = false;
 
+                if (halfPrecisionMode && is_image_b) {
+                    padding_A = true;
+                }
+
                 if (!is_image_a && !is_image_b)
                 {
                     if (M * K < N * K)
@@ -265,17 +284,19 @@ static bool ocl4dnnFastImageGEMM(const CBLAS_TRANSPOSE TransA,
 
                 if (!is_image_a)
                 {
-                    ImA = ocl4dnnGEMMCopyBufferToImage<Dtype>(A, blockA_offset,
-                                                              true, TransA != CblasNoTrans,
-                                                              padding_A, imageA_h, imageA_w,
-                                                              blockA_height, blockA_width, ldA);
+                    if (!halfPrecisionMode)
+                        ImA = ocl4dnnGEMMCopyBufferToImage<Dtype>(A, blockA_offset,
+                                                                  true, TransA != CblasNoTrans,
+                                                                  padding_A, imageA_h, imageA_w,
+                                                                  blockA_height, blockA_width, ldA);
                 }
                 if (!is_image_b)
                 {
-                    ImB = ocl4dnnGEMMCopyBufferToImage<Dtype>(B, blockB_offset,
-                                                              false, false,
-                                                              padding_B, imageB_h, imageB_w,
-                                                              blockB_height, blockB_width, ldB);
+                    if (!halfPrecisionMode)
+                        ImB = ocl4dnnGEMMCopyBufferToImage<Dtype>(B, blockB_offset,
+                                                                  false, false,
+                                                                  padding_B, imageB_h, imageB_w,
+                                                                  blockB_height, blockB_width, ldB);
                 }
             } else {
                 // We will use normal read_imagef to read image B when B has transpose.
@@ -283,32 +304,48 @@ static bool ocl4dnnFastImageGEMM(const CBLAS_TRANSPOSE TransA,
                 if (!is_image_a)
                 {
                     bool padding;
-                    padding = !is_image_b;
-                    ImA = ocl4dnnGEMMCopyBufferToImage<Dtype>(A, blockA_offset,
-                                                              true, TransA != CblasNoTrans,
-                                                              padding, imageA_h, imageA_w,
-                                                              blockA_height, blockA_width, ldA);
+                    padding = !is_image_b || halfPrecisionMode;
+                    if (!halfPrecisionMode)
+                        ImA = ocl4dnnGEMMCopyBufferToImage<Dtype>(A, blockA_offset,
+                                                                  true, TransA != CblasNoTrans,
+                                                                  padding, imageA_h, imageA_w,
+                                                                  blockA_height, blockA_width, ldA);
                 }
 
                 if (!is_image_b && (K % use_buffer_indicator != 0))
                 {
-                    ImB = ocl4dnnGEMMCopyBufferToImage<Dtype>(B, blockB_offset,
-                                                              false, true, false, imageB_h, imageB_w,
-                                                              blockB_height, blockB_width, ldB);
+                    if (!halfPrecisionMode)
+                        ImB = ocl4dnnGEMMCopyBufferToImage<Dtype>(B, blockB_offset,
+                                                                  false, true, false,
+                                                                  imageB_h, imageB_w,
+                                                                  blockB_height, blockB_width, ldB);
                 }
             }
 
             size_t global[2];
             if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 || gemm_type == GEMM_TYPE_FAST_IMAGE_B_IMAGE)
             {
-                global[0] = (size_t)( blockC_width + 7 ) & ~7;
+                if (halfPrecisionMode) {
+                    global[0] = (size_t)( blockC_width + 15 ) & ~15;
+                } else {
+                    global[0] = (size_t)( blockC_width + 7 ) & ~7;
+                }
             } else {
-                global[0] = (size_t)( (blockC_width / 2 ) + 7 ) ^ ~7;
+                if (halfPrecisionMode) {
+                    global[0] = (size_t)( (blockC_width / 2 ) + 15 ) ^ ~15;
+                } else {
+                    global[0] = (size_t)( (blockC_width / 2 ) + 7 ) ^ ~7;
+                }
             }
             global[1] = (size_t)(blockC_height + 31) / 32;
 
             size_t local[2];
-            local[0] = 8;
+            if (halfPrecisionMode)
+            {
+                local[0] = 16;
+            } else {
+                local[0] = 8;
+            }
             local[1] = 1;
 
             cl_uint arg_idx = 0;
@@ -386,13 +423,109 @@ static bool ocl4dnnFastImageGEMM(const CBLAS_TRANSPOSE TransA,
 }
 
 template<typename Dtype>
+static bool ocl4dnnFastBufferGEMM(const CBLAS_TRANSPOSE TransA,
+                                  const CBLAS_TRANSPOSE TransB, const int32_t M,
+                                  const int32_t N, const int32_t K, const Dtype alpha,
+                                  const UMat A, const int32_t offA, const UMat B,
+                                  const int32_t offB, const Dtype beta, UMat C,
+                                  const int32_t offC, enum gemm_type_t gemm_type)
+{
+    CHECK_EQ(gemm_type == GEMM_TYPE_FAST_BUFFER, true)
+             << "Invalid fast buffer gemm type." << std::endl;
+
+    bool halfPrecisionMode = (A.depth() == CV_16S);
+
+    size_t sub_group_size = 8;
+    bool is_small_batch = (M == 2 || M == 4 || M == 8);
+    String kernel_name("gemm_buffer_");
+    if (TransA == CblasNoTrans && TransB == CblasNoTrans) {
+        kernel_name += "NN";
+        if (halfPrecisionMode) {
+            sub_group_size = 16;
+        }
+    } else if (TransA == CblasNoTrans && TransB != CblasNoTrans) {
+        if (M == 2)
+            kernel_name +="NT_M_2";
+        else if (M == 4)
+            kernel_name +="NT_M_4";
+        else if (M == 8)
+            kernel_name +="NT_M_8";
+        else
+            kernel_name += "NT";
+    }
+
+    if (halfPrecisionMode) {
+        kernel_name += "_half";
+    } else {
+        kernel_name += "_float";
+    }
+
+    String opts = format("-DTYPE=%d", halfPrecisionMode ? TYPE_HALF : TYPE_FLOAT);
+    ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts);
+    size_t local[2] = {};
+    size_t global[2] = {};
+    if (TransA == CblasNoTrans && TransB != CblasNoTrans && is_small_batch) {
+        if (M == 8)
+            local[0] = 16;
+        else if (M == 4)
+            local[0] = 32;
+        else
+            local[0] = 64;
+        local[1] = 1;
+
+        if (M == 8)
+            global[0] = N * local[0];
+        else
+            global[0] = (N + 3) / 4 * local[0];
+        global[1] = 1;
+    } else {
+        size_t lx = sub_group_size;
+        size_t ly = (TransB != CblasNoTrans && TransA == CblasNoTrans && halfPrecisionMode) ? 2 : 4;
+        int dx = (TransB != CblasNoTrans && TransA == CblasNoTrans) ? 1 : 4;
+        int dy = 8;
+        size_t gx = (size_t)(N + dx - 1) / dx;
+        size_t gy = (size_t)(M + dy - 1) / dy;
+        global[0] = (gx + lx - 1) / lx * lx;
+        global[1] = (gy + ly - 1) / ly * ly;
+        local[0] = lx;
+        local[1] = ly;
+    }
+
+    int arg_idx = 0;
+    oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(A));
+    oclk_gemm_float.set(arg_idx++, offA);
+    oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(B));
+    oclk_gemm_float.set(arg_idx++, offB);
+    oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrWriteOnly(C));
+    oclk_gemm_float.set(arg_idx++, offC);
+    oclk_gemm_float.set(arg_idx++, M);
+    oclk_gemm_float.set(arg_idx++, N);
+    oclk_gemm_float.set(arg_idx++, K);
+    oclk_gemm_float.set(arg_idx++, (float)alpha);
+    oclk_gemm_float.set(arg_idx++, (float)beta);
+
+    bool ret;
+    if (TransB == CblasNoTrans || TransA != CblasNoTrans) {
+        int stride = 256;
+        for (int start_index = 0; start_index < K; start_index += stride) {
+            oclk_gemm_float.set(arg_idx, start_index);
+            ret = oclk_gemm_float.run(2, global, local, false);
+        }
+    } else {
+        ret = oclk_gemm_float.run(2, global, local, false);
+    }
+    return ret;
+}
+
+template<typename Dtype>
 bool ocl4dnnGEMMCommon(const CBLAS_TRANSPOSE TransB,
                        const int32_t M, const int32_t N, const int32_t K,
                        const UMat A, const UMat B,
                        const UMat B_image, UMat C,
                        const size_t max_image_size)
 {
-    gemm_type_t gemm_type = GEMM_TYPE_FAST_IMAGE_32_1;
+    bool halfPrecisionMode = (A.depth() == CV_16S);
+    gemm_type_t gemm_type = halfPrecisionMode ? GEMM_TYPE_FAST_BUFFER : GEMM_TYPE_FAST_IMAGE_32_1;
 
     if (gemm_type == GEMM_TYPE_FAST_IMAGE_32_1 ||
         gemm_type == GEMM_TYPE_FAST_IMAGE_32_2)
@@ -409,6 +542,11 @@ bool ocl4dnnGEMMCommon(const CBLAS_TRANSPOSE TransB,
                                            GEMM_TYPE_FAST_IMAGE_B_IMAGE,
                                            max_image_size);
     }
+    else if (gemm_type == GEMM_TYPE_FAST_BUFFER)
+    {
+        return ocl4dnnFastBufferGEMM<Dtype>(CblasNoTrans, TransB, M, N, K,
+                                            1.f, A, 0, B, 0, 0.f, C, 0, gemm_type);
+    }
     return false;
 }
 
@@ -436,10 +574,17 @@ bool ocl4dnnGEMV<float>(const CBLAS_TRANSPOSE TransA,
                  const int32_t offy)
 {
     bool ret = false;
+    bool use_half = (A.depth() == CV_16S);
+    String opts;
+    if (use_half)
+        opts = format("-DDtype=%s -DDtype4=%s -Dconvert_Dtype=convert_%s", "half", "half4", "half");
+    else
+        opts = format("-DDtype=%s -DDtype4=%s -Dconvert_Dtype=convert_%s", "float", "float4", "float");
 
     if (TransA == CblasNoTrans)
     {
-        ocl::Kernel k(CL_KERNEL_SELECT("matvec_mul4"), cv::ocl::dnn::matvec_mul_oclsrc);
+        String kname = format("matvec_mul4_%s", use_half ? "half" : "float");
+        ocl::Kernel k(kname.c_str(), cv::ocl::dnn::matvec_mul_oclsrc, opts);
         if (k.empty())
             return false;
 
@@ -469,7 +614,8 @@ bool ocl4dnnGEMV<float>(const CBLAS_TRANSPOSE TransA,
 
         if ((row_size % 4) != 0 && ret)
         {
-            ocl::Kernel k_1(CL_KERNEL_SELECT("matvec_mul1"), cv::ocl::dnn::matvec_mul_oclsrc);
+            String kname = format("matvec_mul1_%s", use_half ? "half" : "float");
+            ocl::Kernel k_1(kname.c_str(), cv::ocl::dnn::matvec_mul_oclsrc, opts);
             size_t localsize[] = { 128 };
             size_t globalsize[] = { row_size % 4 * localsize[0] };
             uint row_offset = row_size - (row_size % 4);
@@ -499,7 +645,15 @@ bool ocl4dnnAXPY(const int32_t N, const Dtype alpha,
                  const UMat X, const int32_t offX, UMat Y,
                  const int32_t offY)
 {
-    ocl::Kernel oclk_axpy(CL_KERNEL_SELECT("axpy"), cv::ocl::dnn::math_oclsrc);
+    bool use_half = (X.depth() == CV_16S);
+    String opts;
+    if (use_half)
+        opts = "-DDtype=half -DDtype4=half4 -Dconvert_Dtype=convert_half";
+    else
+        opts = "-DDtype=float -DDtype4=float4 -Dconvert_Dtype=convert_float";
+
+    String kname = format("axpy_%s", use_half ? "half" : "float");
+    ocl::Kernel oclk_axpy(kname.c_str(), cv::ocl::dnn::math_oclsrc, opts);
     if (oclk_axpy.empty())
         return false;
 
index 8543229..44a622f 100644 (file)
@@ -54,6 +54,7 @@
 #include "opencl_kernels_dnn.hpp"
 #include "../include/math_functions.hpp"
 #include "../include/default_kernel_config.hpp"
+#include "opencv2/dnn/shape_utils.hpp"
 
 #if defined WIN32 || defined _WIN32
 #include <windows.h>
@@ -85,6 +86,7 @@ OCL4DNNConvSpatial<Dtype>::OCL4DNNConvSpatial(OCL4DNNConvConfig config)
     max_value_ = 0;
     prev_kernel_type_ = -1;
     tuned_ = false;
+    use_half_ = config.use_half;
 
     // assumption: spatial dimension is 2.
     kernel_h_ = config.kernel.height;
@@ -204,18 +206,40 @@ void OCL4DNNConvSpatial<Dtype>::setFusionArg(ocl4dnnFusedActiv_t fused_activ, bo
     return;
 }
 
+typedef enum {
+    TYPE_FLOAT = 1,
+    TYPE_HALF = 2
+} ocl4dnnConvSpatialType_t;
+
 template<typename Dtype>
 void OCL4DNNConvSpatial<Dtype>::collectCommonInformation()
 {
-    addDef("Dtype", "float");
-    addDef("Dtype2", "float2");
-    addDef("Dtype4", "float4");
-    addDef("Dtype8", "float8");
-    addDef("Dtype16", "float16");
-    addDef("as_Dtype", "as_float");
-    addDef("as_Dtype2", "as_float2");
-    addDef("as_Dtype4", "as_float4");
-    addDef("as_Dtype8", "as_float8");
+    if (use_half_)
+    {
+        addDef("TYPE", TYPE_HALF);
+        addDef("Dtype", "half");
+        addDef("Dtype2", "half2");
+        addDef("Dtype4", "half4");
+        addDef("Dtype8", "half8");
+        addDef("Dtype16", "half16");
+        addDef("as_Dtype", "as_half");
+        addDef("as_Dtype2", "as_half2");
+        addDef("as_Dtype4", "as_half4");
+        addDef("as_Dtype8", "as_half8");
+    }
+    else
+    {
+        addDef("TYPE", TYPE_FLOAT);
+        addDef("Dtype", "float");
+        addDef("Dtype2", "float2");
+        addDef("Dtype4", "float4");
+        addDef("Dtype8", "float8");
+        addDef("Dtype16", "float16");
+        addDef("as_Dtype", "as_float");
+        addDef("as_Dtype2", "as_float2");
+        addDef("as_Dtype4", "as_float4");
+        addDef("as_Dtype8", "as_float8");
+    }
 }
 
 typedef enum {
@@ -477,10 +501,16 @@ bool OCL4DNNConvSpatial<Dtype>::Forward(const UMat& bottom,
         fused_eltwise_ = false;
     }
 
-    prepareKernel(bottom, top, weight, bias, numImages);
+    if (use_half_ && bias_half.empty() && !bias.empty())
+        convertFp16((UMat&)bias, bias_half);
+
+    if (use_half_ && weights_half.empty())
+        convertFp16((UMat&)weight, weights_half);
+
+    prepareKernel(bottom, top, weight, (use_half_) ? bias_half : bias, numImages);
     if (bestKernelConfig.empty())
         return false;
-    return convolve(bottom, top, weight, bias, numImages, bestKernelConfig);
+    return convolve(bottom, top, weight, (use_half_) ? bias_half : bias, numImages, bestKernelConfig);
 }
 
 template<typename Dtype>
@@ -556,6 +586,12 @@ std::string OCL4DNNConvSpatial<Dtype>::generateSpecificKey(int32_t type, int32_t
                << "_" << blockWidth
                << "_" << blockHeight
                << "_" << blockDepth;
+
+    if (!use_half_)
+        keyBuilder << "_float";
+    else
+        keyBuilder << "_half";
+
     return keyBuilder.str();
 }
 
@@ -637,9 +673,13 @@ bool OCL4DNNConvSpatial<Dtype>::swizzleWeight(const UMat &weight,
 
     if (swizzled_weights_umat.empty())
         swizzled_weights_umat.create(1, (int)alignSize(num_output_, 16) * channels_ *
-                                     kernel_h_ * (int)alignSize(kernel_w_, 2), CV_32FC1);
+                                     kernel_h_ * (int)alignSize(kernel_w_, 2),
+                                     (use_half_) ? CV_16SC1 : CV_32FC1);
+
+    UMat swizzled_weights_tmp;
+    if (use_half_)
+        swizzled_weights_tmp.create(shape(swizzled_weights_umat), CV_32F);
 
-    ocl::Queue queue = ocl::Queue::getDefault();
     if (!interleave) {
         cl_uint argIdx = 0;
         int32_t channels = channels_ / group_;
@@ -650,7 +690,10 @@ bool OCL4DNNConvSpatial<Dtype>::swizzleWeight(const UMat &weight,
             return false;
 
         oclk_copy_weight.set(argIdx++, ocl::KernelArg::PtrReadOnly(weight));
-        oclk_copy_weight.set(argIdx++, ocl::KernelArg::PtrWriteOnly(swizzled_weights_umat));
+        if (use_half_)
+            oclk_copy_weight.set(argIdx++, ocl::KernelArg::PtrWriteOnly(swizzled_weights_tmp));
+        else
+            oclk_copy_weight.set(argIdx++, ocl::KernelArg::PtrWriteOnly(swizzled_weights_umat));
         oclk_copy_weight.set(argIdx++, kernel_w_);
         oclk_copy_weight.set(argIdx++, kernel_h_);
         oclk_copy_weight.set(argIdx++, channels);
@@ -669,7 +712,11 @@ bool OCL4DNNConvSpatial<Dtype>::swizzleWeight(const UMat &weight,
         // assumption: kernel dimesion is 2
         Mat weightMat = weight.getMat(ACCESS_READ);
         Dtype* cpu_weight = (Dtype *)weightMat.ptr<float>();
-        Mat swizzledWeightMat = swizzled_weights_umat.getMat(ACCESS_WRITE);
+        Mat swizzledWeightMat;
+        if (use_half_)
+            swizzledWeightMat = swizzled_weights_tmp.getMat(ACCESS_WRITE);
+        else
+            swizzledWeightMat = swizzled_weights_umat.getMat(ACCESS_WRITE);
         Dtype* cpu_swizzled_weight = (Dtype *)swizzledWeightMat.ptr<float>();
 
         int interleavedRows = (kernel_w_ / 2) * 2;
@@ -694,6 +741,10 @@ bool OCL4DNNConvSpatial<Dtype>::swizzleWeight(const UMat &weight,
                          rowAlignment);
         free(tmpSwizzledWeight);
     }
+
+    if (use_half_)
+        convertFp16(swizzled_weights_tmp, swizzled_weights_umat);
+
     return true;
 }
 
@@ -727,9 +778,10 @@ void OCL4DNNConvSpatial<float>::CreateSubBuffer(const UMat& buffer, UMat& sub_bu
     cl_mem sub_mem;
     cl_buffer_region region;
     cl_int err;
+    size_t element_size = (use_half_) ? sizeof(short) : sizeof(float);
 
-    region.origin = offset * sizeof(float);
-    region.size = size * sizeof(float);
+    region.origin = offset * element_size;
+    region.size = size * element_size;
     sub_mem = clCreateSubBuffer((cl_mem)buffer.handle(ACCESS_READ),
                                 write_only ? CL_MEM_WRITE_ONLY : CL_MEM_READ_ONLY,
                                 CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
@@ -739,8 +791,9 @@ void OCL4DNNConvSpatial<float>::CreateSubBuffer(const UMat& buffer, UMat& sub_bu
         return;
     }
 
-    int step = sizeof(float), rows = size, cols = 1;
-    ocl::convertFromBuffer(sub_mem, step, rows, cols, CV_32FC1, sub_buffer);
+    int step = element_size, rows = size, cols = 1;
+    ocl::convertFromBuffer(sub_mem, step, rows, cols,
+                           (use_half_) ? CV_16SC1 : CV_32FC1, sub_buffer);
 
     //decrease ocl mem refcount
     clReleaseMemObject(sub_mem);
@@ -978,7 +1031,10 @@ bool OCL4DNNConvSpatial<float>::convolve(const UMat &bottom, UMat &top,
         cl_uint argIdx = 0;
         setFusionArg(fused_activ_, fused_eltwise_, kernel, argIdx);
         kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom));
-        kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(weight));
+        if (use_half_)
+            kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(weights_half));
+        else
+            kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(weight));
         if (bias_term_)
             kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(bias));
         kernel.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top));
@@ -1018,7 +1074,10 @@ bool OCL4DNNConvSpatial<float>::convolve(const UMat &bottom, UMat &top,
                 setFusionArg(fused_activ_, fused_eltwise_, kernel, argIdx);
                 kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom));
                 kernel.set(argIdx++, image_offset);
-                kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(weight));
+                if (use_half_)
+                    kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(weights_half));
+                else
+                    kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(weight));
                 kernel.set(argIdx++, kernel_offset);
                 if (bias_term_)
                     kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(bias));
@@ -1132,14 +1191,27 @@ bool OCL4DNNConvSpatial<float>::verifyResult(const UMat &bottom,
         return false;
 
     int32_t sz[4] = {numImages, num_output_, output_h_, output_w_};
-    top.zeros(4, sz, CV_32FC1);
+    top.zeros(4, sz, (use_half_) ? CV_16SC1 : CV_32FC1);
     bool saved_tuned = tuned_;
     tuned_ = false;
     convolve(bottom, top, weight, bias, numImages, config);
     tuned_ = saved_tuned;
 
-    float *data = (float *)top.getMat(ACCESS_READ).ptr<float>();
-    float *verify_data = (float *)verifyTop.getMat(ACCESS_READ).ptr<float>();
+    UMat new_top, new_verify_top;
+    float *data, *verify_data;
+    if (use_half_)
+    {
+        convertFp16(top, new_top);
+        convertFp16(verifyTop, new_verify_top);
+
+        data = (float *)new_top.getMat(ACCESS_READ).ptr<float>();
+        verify_data = (float *)new_verify_top.getMat(ACCESS_READ).ptr<float>();
+    }
+    else
+    {
+        data = (float *)top.getMat(ACCESS_READ).ptr<float>();
+        verify_data = (float *)verifyTop.getMat(ACCESS_READ).ptr<float>();
+    }
 
     for (int32_t n = 0; n < num_; ++n) {
         for (int32_t g = 0; g < group_; ++g) {
@@ -1148,9 +1220,19 @@ bool OCL4DNNConvSpatial<float>::verifyResult(const UMat &bottom,
                 for (int h = 0; h < output_h_ && !verificationFail; h++)
                     for (int w = 0; w < output_w_; w++) {
                         size_t offset = output_image_offset + out_ch * output_w_ * output_h_ + h * output_w_ + w;
-                        if (fabs(data[offset] - verify_data[offset]) > 0.1 * fabs(verify_data[offset]) &&
-                            !(fabs(verify_data[offset]) < 1.e-3 &&
-                            fabs(data[offset] - verify_data[offset]) < 1.e-4))
+
+                        float error_factor = fabs(data[offset] - verify_data[offset]);
+                        if (use_half_ && error_factor > 0.1 * fabs(verify_data[offset]) &&
+                            error_factor > 0.04 && !(fabs(verify_data[offset]) < 1.e-3 && error_factor < 1.e-4))
+                        {
+                            dbgPrint(printf("test verification failed @ image %d group %d"
+                                            "out_ch %d h %d w %d got %G expected %G\n",
+                                            n, g, out_ch, h, w, data[offset], verify_data[offset]));
+                            verificationFail = 1;
+                            goto out;
+                        }
+                        else if (!use_half_ && error_factor > 0.1 * fabs(verify_data[offset]) &&
+                                 !(fabs(verify_data[offset]) < 1.e-3 && error_factor < 1.e-4))
                         {
                             dbgPrint(printf("test verification failed @ image %d group %d"
                                             "out_ch %d h %d w %d got %G expected %G\n",
@@ -1719,15 +1801,16 @@ void OCL4DNNConvSpatial<Dtype>::prepareKernel(const UMat &bottom, UMat &top,
     if (loadTunedConfig()) // check external storage
         return;
 
-    UMat benchData(1, numImages * top_dim_, CV_32FC1);
+    UMat benchData(1, numImages * top_dim_, (use_half_) ? CV_16SC1 : CV_32FC1);
+
+    calculateBenchmark(bottom, benchData, (use_half_) ? weights_half : weight, bias, numImages);
+
     if (force_auto_tuning_)
     {
-        calculateBenchmark(bottom, benchData, weight, bias, numImages);
         setupConvolution(bottom, top, weight, bias, numImages, benchData);
     }
     else
     {
-        calculateBenchmark(bottom, benchData, weight, bias, numImages);
         useFirstAvailable(bottom, top, weight, bias, numImages, benchData);
     }
     cacheTunedConfig();
index aabee57..ee7a2c7 100644 (file)
@@ -56,6 +56,7 @@ OCL4DNNInnerProduct<Dtype>::OCL4DNNInnerProduct(OCL4DNNInnerProductConfig config
     K_ = config.K;
     phase_test_ = config.phase_test;
     image_copied_ = false;
+    use_half_ = config.use_half;
 }
 
 template<typename Dtype>
@@ -89,13 +90,24 @@ bool OCL4DNNInnerProduct<Dtype>::Forward(const UMat& bottom,
         if (M_ <= max_image_size &&
             N_ <= max_image_size &&
             K_ <= max_image_size &&
-            cv::traits::Depth<Dtype>::value == CV_32F &&
             ocl::Device::getDefault().intelSubgroupsSupport())
         {
             ret = ocl4dnnGEMMCommon<Dtype>(transpose_ ? CblasNoTrans : CblasTrans,
                                            M_, N_, K_, bottom, weight, UMat(), top,
                                            max_image_size);
         }
+
+        if (use_half_ && bias_term_)
+        {
+            UMat biasOneMat = UMat::ones(M_, 1, CV_32F);
+            UMat newbias, tmpTop;
+
+            convertFp16(bias, newbias);
+            convertFp16(top, tmpTop);
+            cv::gemm(biasOneMat, newbias, 1, tmpTop, 1, tmpTop, 0);
+            convertFp16(tmpTop, top);
+        }
+
         return ret;
     }
 }
index c7062f4..b0fcfa9 100644 (file)
@@ -61,6 +61,7 @@ OCL4DNNLRN<Dtype>::OCL4DNNLRN(OCL4DNNLRNConfig config)
     channels_ = config.channels;
     height_ = config.height;
     width_ = config.width;
+    use_half_ = config.use_half;
 }
 
 template<typename Dtype>
@@ -97,8 +98,10 @@ bool OCL4DNNLRN<Dtype>::crossChannelForward(const UMat& bottom, UMat& top)
     int32_t n_threads = num_ * height_ * width_;
     size_t global_work_size_[1] = {(size_t)n_threads};
     String opts = clOptionSupport("-cl-no-subgroup-ifp") ? " -cl-no-subgroup-ifp " : "";
+    opts += format("-D Dtype=%s", (use_half_) ? "half" : "float");
     ocl::Kernel oclk_lrn_fill;
-    if (!oclk_lrn_fill.create(CL_KERNEL_SELECT("lrn_full_no_scale"), ocl::dnn::ocl4dnn_lrn_oclsrc, opts))
+    String kname = format("lrn_full_no_scale_%s", (use_half_) ? "half" : "float");
+    if (!oclk_lrn_fill.create(kname.c_str(), ocl::dnn::ocl4dnn_lrn_oclsrc, opts))
         return false;
 
     oclk_lrn_fill.set(argIdx++, n_threads);
index 2d9c4dc..81238e9 100644 (file)
@@ -56,6 +56,7 @@ OCL4DNNPool<Dtype>::OCL4DNNPool(OCL4DNNPoolConfig config)
     channels_ = config.channels;
     pool_method_ = config.pool_method;
     avePoolPaddedArea = config.avePoolPaddedArea;
+    use_half = config.use_half;
 
     for (int i = 0; i < spatial_dims; ++i)
     {
@@ -105,12 +106,15 @@ bool OCL4DNNPool<Dtype>::Forward(const UMat& bottom,
     case LIBDNN_POOLING_METHOD_MAX:
         {
             bool haveMask = !top_mask.empty();
+            String kname = haveMask ? "max_pool_forward_mask" : "max_pool_forward";
+            kname += (use_half) ? "_half" : "_float";
             ocl::Kernel oclk_max_pool_forward(
-                haveMask ? CL_KERNEL_SELECT("max_pool_forward_mask") : CL_KERNEL_SELECT("max_pool_forward"),
+                kname.c_str(),
                 ocl::dnn::ocl4dnn_pooling_oclsrc,
-                format("-D KERNEL_MAX_POOL=1 -D KERNEL_W=%d -D KERNEL_H=%d"
+                format(" -D Dtype=%s -D KERNEL_MAX_POOL=1 -D KERNEL_W=%d -D KERNEL_H=%d"
                        " -D STRIDE_W=%d -D STRIDE_H=%d"
                        " -D PAD_W=%d -D PAD_H=%d%s",
+                       (use_half) ? "half" : "float",
                        kernel_w_, kernel_h_,
                        stride_w_, stride_h_,
                        pad_w_, pad_h_,
@@ -139,11 +143,14 @@ bool OCL4DNNPool<Dtype>::Forward(const UMat& bottom,
         {
             CV_Assert(top_mask.empty());
 
-            ocl::Kernel oclk_ave_pool_forward(CL_KERNEL_SELECT("ave_pool_forward"),
+            String kname = format("ave_pool_forward_%s", (use_half) ? "half" : "float");
+            ocl::Kernel oclk_ave_pool_forward(
+                kname.c_str(),
                 ocl::dnn::ocl4dnn_pooling_oclsrc,
-                format("-D KERNEL_AVE_POOL=1 -D KERNEL_W=%d -D KERNEL_H=%d"
+                format(" -D Dtype=%s -D KERNEL_AVE_POOL=1 -D KERNEL_W=%d -D KERNEL_H=%d"
                        " -D STRIDE_W=%d -D STRIDE_H=%d"
                        " -D PAD_W=%d -D PAD_H=%d%s",
+                       (use_half) ? "half" : "float",
                        kernel_w_, kernel_h_,
                        stride_w_, stride_h_,
                        pad_w_, pad_h_,
@@ -171,7 +178,9 @@ bool OCL4DNNPool<Dtype>::Forward(const UMat& bottom,
         {
             CV_Assert(top_mask.empty());
 
-            ocl::Kernel oclk_sto_pool_forward(CL_KERNEL_SELECT("sto_pool_forward_test"),
+            String kname = format("sto_pool_forward_test_%s", (use_half) ? "half" : "float");
+            ocl::Kernel oclk_sto_pool_forward(
+                kname.c_str(),
                 ocl::dnn::ocl4dnn_pooling_oclsrc,
                 format("-D KERNEL_STO_POOL=1 -D KERNEL_W=%d -D KERNEL_H=%d"
                        " -D STRIDE_W=%d -D STRIDE_H=%d",
index 6b95764..7857671 100644 (file)
@@ -52,6 +52,7 @@ OCL4DNNSoftmax<Dtype>::OCL4DNNSoftmax(OCL4DNNSoftmaxConfig config)
     softmax_axis_ = config.axis;
     channels_ = config.channels;
     log_softmax_ = config.logsoftmax;
+    use_half_ = config.use_half;
 
     inner_num_ = 1;
     outer_num_ = 1;
@@ -91,10 +92,13 @@ bool OCL4DNNSoftmax<Dtype>::Forward(const UMat& bottom, UMat& top)
 
         if (log_softmax_) opts += " -DLOG_SOFTMAX ";
         if (use_slm_)
-            kname = CL_KERNEL_SELECT("softmax_forward_slm");
+            kname = "softmax_forward_slm";
         else
-            kname = CL_KERNEL_SELECT("softmax_forward");
+            kname = "softmax_forward";
 
+        kname += format("%s", (use_half_) ? "_half" : "_float");
+        opts += format(" -D Dtype=%s -D DTYPE_MAX=%s", (use_half_) ? "half" : "float",
+                       (use_half_) ? "HALF_MAX" : "FLT_MAX");
         if (!oclk_softmax_forward_kernel.create(kname.c_str(), ocl::dnn::softmax_loss_oclsrc, opts))
             return false;
 
index ab2532e..9b5a9bb 100644 (file)
 //
 //M*/
 
+#define CONCAT(A,B) A##_##B
+#define TEMPLATE(name,type) CONCAT(name,type)
+#define KERNEL_ARG_DTYPE float
+
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
+
 __kernel void ReLUForward(const int count, __global const T* in, __global T* out
 #ifndef RELU_NO_SLOPE
-, T negative_slope
+, KERNEL_ARG_DTYPE negative_slope
 #endif
 ) {
   int index = get_global_id(0);
@@ -55,18 +63,19 @@ __kernel void ReLUForward(const int count, __global const T* in, __global T* out
 }
 
 __kernel void ReLU6Forward(const int count, __global const T* in, __global T* out,
-                           const T minValue, const T maxValue)
+                           const KERNEL_ARG_DTYPE minValue, const KERNEL_ARG_DTYPE maxValue)
 {
   int index = get_global_id(0);
   if(index < count)
   {
     T x = in[index];
-    out[index] = clamp(x, minValue, maxValue);
+    out[index] = clamp(x, convert_T(minValue), convert_T(maxValue));
   }
 }
 
 __kernel void PReLUForward(const int count, const int channels, const int plane_size,
-                           __global const T* in, __global T* out, __global const T* slope_data)
+                           __global const T* in, __global T* out,
+                           __global const KERNEL_ARG_DTYPE* slope_data)
 {
   int index = get_global_id(0);
   int c = (index / plane_size) % channels;
@@ -99,8 +108,22 @@ __kernel void AbsValForward(const int n, __global const T* in, __global T* out)
     out[index] = fabs(in[index]);
 }
 
-__kernel void PowForward(const int n, __global const T* in, __global T* out, const T power, const T scale, const T shift) {
+__kernel void PowForward(const int n, __global const T* in, __global T* out,
+                         const KERNEL_ARG_DTYPE power,
+                         const KERNEL_ARG_DTYPE scale,
+                         const KERNEL_ARG_DTYPE shift)
+{
   int index = get_global_id(0);
   if (index < n)
     out[index] = pow(shift + scale * in[index], power);
 }
+
+__kernel void ELUForward(const int n, __global const T* in, __global T* out)
+{
+  int index = get_global_id(0);
+  if (index < n)
+  {
+    T src = in[index];
+    out[index] = (src >= 0.f) ? src : exp(src) - 1;
+  }
+}
index 041e6ac..69fb752 100644 (file)
 //
 //M*/
 
-__kernel void concat(const int nthreads,
-                     __global const Dtype* in_data,
-                     const int num_concats,
-                     const int concat_size,
-                     const int top_concat_axis,
-                     const int bottom_concat_axis,
-                     const int offset_concat_axis,
-                     __global Dtype* out_data) {
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
 
-  for (int index = get_global_id(0); index < nthreads;
-      index += get_global_size(0)) {
-    const int total_concat_size = concat_size * bottom_concat_axis;
-    const int concat_num = index / total_concat_size;
-    const int concat_index = index % total_concat_size;
-    const int top_index = concat_index
-        + (concat_num * top_concat_axis + offset_concat_axis) * concat_size;
-    out_data[top_index] = in_data[index];
-  }
+#define CONCAT(A,B) A##_##B
+#define TEMPLATE(name,type) CONCAT(name,type)
+
+__kernel void TEMPLATE(concat, Dtype)(const int nthreads,
+                                      __global const Dtype* in_data,
+                                      const int num_concats,
+                                      const int concat_size,
+                                      const int top_concat_axis,
+                                      const int bottom_concat_axis,
+                                      const int offset_concat_axis,
+                                      __global Dtype* out_data)
+{
+    for (int index = get_global_id(0); index < nthreads; index += get_global_size(0))
+    {
+        const int total_concat_size = concat_size * bottom_concat_axis;
+        const int concat_num = index / total_concat_size;
+        const int concat_index = index % total_concat_size;
+        const int top_index = concat_index +
+                              (concat_num * top_concat_axis + offset_concat_axis) * concat_size;
+        out_data[top_index] = in_data[index];
+    }
 }
index 5308bf1..621ab6f 100644 (file)
 //
 //M*/
 
-#if APPLY_BIAS
-#define BIAS_KERNEL_ARG __global Dtype * biases_base,
-#else
-#define BIAS_KERNEL_ARG
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
 #endif
 
+#define KERNEL_ARG_DTYPE float
+#define TYPE_FLOAT  1
+#define TYPE_HALF   2
+
 #if defined(FUSED_CONV_RELU)
-#define ACTIVATION_RELU_FUNCTION(x, c) ((Dtype)(x) > 0 ? (Dtype)(x) : ((Dtype)(x) * (Dtype)(negative_slope)))
-#define FUSED_ARG Dtype negative_slope,
+#define ACTIVATION_RELU_FUNCTION(x, c) ((Dtype)(x) > 0 ? (Dtype)(x) : ((Dtype)(x) * (negative_slope)))
+#define FUSED_ARG KERNEL_ARG_DTYPE negative_slope,
 #elif defined(FUSED_CONV_PRELU)
-#define ACTIVATION_RELU_FUNCTION(x, c) ((Dtype)(x) > 0 ? (Dtype)(x) : ((Dtype)(x) * (Dtype)(negative_slope[c])))
-#define FUSED_ARG __global const Dtype *negative_slope,
+#define ACTIVATION_RELU_FUNCTION(x, c) ((Dtype)(x) > 0 ? (Dtype)(x) : ((Dtype)(x) * (negative_slope[c])))
+#define FUSED_ARG __global const KERNEL_ARG_DTYPE* negative_slope,
 #elif defined(FUSED_CONV_POWER)
-#define ACTIVATION_RELU_FUNCTION(x, c) pow(x, power)
-#define FUSED_ARG Dtype power,
+#define ACTIVATION_RELU_FUNCTION(x, c) pow(x, (Dtype)power)
+#define FUSED_ARG KERNEL_ARG_DTYPE power,
 #elif defined(FUSED_CONV_TANH)
 #define ACTIVATION_RELU_FUNCTION(x, c) tanh(x)
 #define FUSED_ARG
 #elif defined(FUSED_CONV_RELU6)
-#define ACTIVATION_RELU_FUNCTION(x, c) (clamp((Dtype)(x), min_value, max_value))
-#define FUSED_ARG Dtype min_value, Dtype max_value,
+#define ACTIVATION_RELU_FUNCTION(x, c) (clamp((Dtype)(x), (Dtype)min_value, (Dtype)max_value))
+#define FUSED_ARG KERNEL_ARG_DTYPE min_value, KERNEL_ARG_DTYPE max_value,
 #else
 #define ACTIVATION_RELU_FUNCTION(x, c) (x)
 #define FUSED_ARG
 #define ELTWISE_DATA_ARG
 #endif
 
+#if APPLY_BIAS
+#define BIAS_KERNEL_ARG __global Dtype * biases_base,
+#else
+#define BIAS_KERNEL_ARG
+#endif
 
 #define __CAT(x, y) x##y
 #define CAT(x, y) __CAT(x, y)
 #define LOOP(N, VAR, STMT) CAT(LOOP, N)((VAR), (STMT))
 
 #if defined(convolve_simd) || defined(Conv_Interleaved)
+#if TYPE == TYPE_HALF
+#define INT_TYPE ushort
+#define INT_TYPE2 ushort2
+#define INT_TYPE4 ushort4
+#define INT_TYPE8 ushort8
+#define SUB_GROUP_BLOCK_READ2 intel_sub_group_block_read_us2
+#define SUB_GROUP_BLOCK_READ4 intel_sub_group_block_read_us4
+#define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read_us8
+#define SUB_GROUP_BLOCK_READ intel_sub_group_block_read_us
+#else
 #define INT_TYPE uint
 #define INT_TYPE2 uint2
 #define INT_TYPE4 uint4
 #define SUB_GROUP_BLOCK_READ8 intel_sub_group_block_read8
 #define SUB_GROUP_BLOCK_READ intel_sub_group_block_read
 #endif
+#endif
 
 #ifdef KERNEL_BASIC
 
@@ -418,6 +436,25 @@ typedef struct float15 { float s0; float s1; float s2; float s3; float s4; float
                          float s6; float s7; float s8; float s9; float sa; float sb; float sc; float sd; float se; } float15;
 typedef struct float0 { float s0; } float0; //never used but makes compiler happy.
 
+typedef struct half1 { half s0; } half1;
+typedef struct half5 { half s0; half s1; half s2; half s3; half s4; } half5;
+typedef struct half6 { half s0; half s1; half s2; half s3; half s4; half s5; } half6;
+typedef struct half7 { half s0; half s1; half s2; half s3; half s4; half s5; half s6; } half7;
+typedef struct half9 { half s0; half s1; half s2; half s3; half s4; half s5; half s6; half s7; half s8; } half9;
+typedef struct half10 { half s0; half s1; half s2; half s3; half s4; half s5;
+                        half s6; half s7; half s8; half s9; } half10;
+typedef struct half11 { half s0; half s1; half s2; half s3; half s4; half s5;
+                        half s6; half s7; half s8; half s9; half sa; } half11;
+typedef struct half12 { half s0; half s1; half s2; half s3; half s4; half s5;
+                        half s6; half s7; half s8; half s9; half sa; half sb; } half12;
+typedef struct half13 { half s0; half s1; half s2; half s3; half s4; half s5;
+                        half s6; half s7; half s8; half s9; half sa; half sb; half sc; } half13;
+typedef struct half14 { half s0; half s1; half s2; half s3; half s4; half s5;
+                        half s6; half s7; half s8; half s9; half sa; half sb; half sc; half sd; } half14;
+typedef struct half15 { half s0; half s1; half s2; half s3; half s4; half s5;
+                        half s6; half s7; half s8; half s9; half sa; half sb; half sc; half sd; half se; } half15;
+typedef struct half0 { half s0; } half0; //never used but makes compiler happy.
+
 #define OUT_PITCH_X output_width
 #define ROW_PITCH input_width
 
index 6f3a374..80d3305 100644 (file)
@@ -40,9 +40,9 @@
 //
 //M*/
 
-#define Dtype float
-#define Dtype4 float4
-#define Dtype8 float8
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
 
 __kernel void op_sum4(__global const Dtype * A,
                       __global const Dtype * B,
@@ -73,20 +73,20 @@ __kernel void op_sum4(__global const Dtype * A,
         a2 = vload4(i, src0_read + 2 * A_col_size);
         a3 = vload4(i, src0_read + 3 * A_col_size);
 
-        dot0 = a0 * coeff1 + b0 * coeff2;
-        dot1 = a1 * coeff1 + b1 * coeff2;
-        dot2 = a2 * coeff1 + b2 * coeff2;
-        dot3 = a3 * coeff1 + b3 * coeff2;
+        dot0 = a0 * (Dtype4)coeff1 + b0 * (Dtype4)coeff2;
+        dot1 = a1 * (Dtype4)coeff1 + b1 * (Dtype4)coeff2;
+        dot2 = a2 * (Dtype4)coeff1 + b2 * (Dtype4)coeff2;
+        dot3 = a3 * (Dtype4)coeff1 + b3 * (Dtype4)coeff2;
 #else
         a0 = vload4(i, dst0_read);
         a1 = vload4(i, dst0_read + A_col_size);
         a2 = vload4(i, dst0_read + 2 * A_col_size);
         a3 = vload4(i, dst0_read + 3 * A_col_size);
 
-        dot0 = a0 + b0 * coeff2;
-        dot1 = a1 + b1 * coeff2;
-        dot2 = a2 + b2 * coeff2;
-        dot3 = a3 + b3 * coeff2;
+        dot0 = a0 + b0 * (Dtype4)coeff2;
+        dot1 = a1 + b1 * (Dtype4)coeff2;
+        dot2 = a2 + b2 * (Dtype4)coeff2;
+        dot3 = a3 + b3 * (Dtype4)coeff2;
 #endif
         vstore4(dot0, i, dst0_read);
         vstore4(dot1, i, dst0_read + A_col_size);
diff --git a/modules/dnn/src/opencl/gemm_buffer.cl b/modules/dnn/src/opencl/gemm_buffer.cl
new file mode 100644 (file)
index 0000000..8cbc34d
--- /dev/null
@@ -0,0 +1,1342 @@
+/*M///////////////////////////////////////////////////////////////////////////////////////
+//
+//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
+//
+//  By downloading, copying, installing or using the software you agree to this license.
+//  If you do not agree to this license, do not download, install,
+//  copy or use the software.
+//
+//
+//                           License Agreement
+//                For Open Source Computer Vision Library
+//
+// Copyright (C) 2017, Intel Corporation, all rights reserved.
+// Third party copyrights are property of their respective owners.
+//
+// Redistribution and use in source and binary forms, with or without modification,
+// are permitted provided that the following conditions are met:
+//
+//   * Redistribution's of source code must retain the above copyright notice,
+//     this list of conditions and the following disclaimer.
+//
+//   * Redistribution's in binary form must reproduce the above copyright notice,
+//     this list of conditions and the following disclaimer in the documentation
+//     and/or other materials provided with the distribution.
+//
+//   * The name of the copyright holders may not be used to endorse or promote products
+//     derived from this software without specific prior written permission.
+//
+// This software is provided by the copyright holders and contributors "as is" and
+// any express or implied warranties, including, but not limited to, the implied
+// warranties of merchantability and fitness for a particular purpose are disclaimed.
+// In no event shall the Intel Corporation or contributors be liable for any direct,
+// indirect, incidental, special, exemplary, or consequential damages
+// (including, but not limited to, procurement of substitute goods or services;
+// loss of use, data, or profits; or business interruption) however caused
+// and on any theory of liability, whether in contract, strict liability,
+// or tort (including negligence or otherwise) arising in any way out of
+// the use of this software, even if advised of the possibility of such damage.
+//
+//M*/
+
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
+
+#define CONCAT(A,B) A##_##B
+#define TEMPLATE(name,type) CONCAT(name,type)
+
+#define KERNEL_ARG_DTYPE float
+#define TYPE_FLOAT  1
+#define TYPE_HALF   2
+
+#if TYPE == TYPE_HALF
+#define Dtype  half
+#define Dtype2 half2
+#define Dtype4 half4
+#define Dtype8 half8
+#define Dtype16 half16
+
+#define as_Dtype  as_half
+#define as_Dtype2 as_half2
+#define as_Dtype4 as_half4
+#define as_Dtype8 as_half8
+#define as_Dtype16 as_half16
+#else
+#define Dtype  float
+#define Dtype2 float2
+#define Dtype4 float4
+#define Dtype8 float8
+#define Dtype16 float16
+
+#define as_Dtype  as_float
+#define as_Dtype2 as_float2
+#define as_Dtype4 as_float4
+#define as_Dtype8 as_float8
+#define as_Dtype16 as_float16
+#endif
+
+#if TYPE == TYPE_HALF
+#define SHUFFLE_TYPE2(val) as_ushort2(val)
+#define SHUFFLE_TYPE8(val) as_ushort8(val)
+#define SIMD_SIZE_GEMM 16
+#else
+#define SHUFFLE_TYPE2(val) val
+#define SHUFFLE_TYPE8(val) val
+#define SIMD_SIZE_GEMM 8
+#endif
+
+#if defined(cl_intel_subgroups)
+#pragma OPENCL EXTENSION  cl_intel_subgroups : enable
+#endif
+
+#define VEC_SIZE        4
+#define LWG_HEIGHT      4
+#define TILE_M          8
+#if TYPE == TYPE_HALF
+#define TILE_K          32
+#define TILE_N          64
+#else
+#define TILE_K          16
+#define TILE_N          32
+#endif
+
+__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1)))
+__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM)))
+__kernel void TEMPLATE(gemm_buffer_NN, Dtype)(
+    const __global Dtype *src0, int off0,
+    const __global Dtype *src1, int off1,
+    __global Dtype *dst, int offd,
+    int M,
+    int N,
+    int K,
+    KERNEL_ARG_DTYPE alpha_in,
+    KERNEL_ARG_DTYPE beta_in,
+    int start_index)
+{
+    const Dtype alpha = (Dtype)alpha_in;
+    const Dtype beta = (Dtype)beta_in;
+    const int group_x = get_group_id(0);
+    const int group_y = get_group_id(1);
+    const int local_x = get_local_id(0);
+    const int local_y = get_local_id(1);
+    const int global_x = get_global_id(0);
+    const int global_y = get_global_id(1);
+
+    Dtype4 brow;
+    Dtype2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7;
+
+    __global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;
+
+    const __global Dtype *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0;
+
+    const __global Dtype *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;
+
+    int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M);
+
+    int row0 = mad24(global_y, TILE_M, 0) < M ? 0 : border;
+    int row1 = mad24(global_y, TILE_M, 1) < M ? 1 : border;
+    int row2 = mad24(global_y, TILE_M, 2) < M ? 2 : border;
+    int row3 = mad24(global_y, TILE_M, 3) < M ? 3 : border;
+    int row4 = mad24(global_y, TILE_M, 4) < M ? 4 : border;
+    int row5 = mad24(global_y, TILE_M, 5) < M ? 5 : border;
+    int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border;
+    int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border;
+
+    Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);
+    Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N);
+    Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);
+    Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);
+    Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);
+    Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);
+    Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);
+    Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);
+
+    int end_index = min(start_index + 256, K);
+    int w = start_index;
+    while( w + TILE_K <= end_index ) {
+        arow0 = alpha * vload2(0, src0_read + row0 * K);
+        arow1 = alpha * vload2(0, src0_read + row1 * K);
+        arow2 = alpha * vload2(0, src0_read + row2 * K);
+        arow3 = alpha * vload2(0, src0_read + row3 * K);
+        arow4 = alpha * vload2(0, src0_read + row4 * K);
+        arow5 = alpha * vload2(0, src0_read + row5 * K);
+        arow6 = alpha * vload2(0, src0_read + row6 * K);
+        arow7 = alpha * vload2(0, src0_read + row7 * K);
+
+#define MM_DOT_PRODUCT( index, suffix )   \
+        brow = vload4(0, src1_read0);  src1_read0 += N; \
+        dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \
+        dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \
+        dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \
+        dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \
+        dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \
+        dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \
+        dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \
+        dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );
+
+        MM_DOT_PRODUCT(0, 0);
+        MM_DOT_PRODUCT(0, 1);
+        MM_DOT_PRODUCT(1, 0);
+        MM_DOT_PRODUCT(1, 1);
+        MM_DOT_PRODUCT(2, 0);
+        MM_DOT_PRODUCT(2, 1);
+        MM_DOT_PRODUCT(3, 0);
+        MM_DOT_PRODUCT(3, 1);
+        MM_DOT_PRODUCT(4, 0);
+        MM_DOT_PRODUCT(4, 1);
+        MM_DOT_PRODUCT(5, 0);
+        MM_DOT_PRODUCT(5, 1);
+        MM_DOT_PRODUCT(6, 0);
+        MM_DOT_PRODUCT(6, 1);
+        MM_DOT_PRODUCT(7, 0);
+        MM_DOT_PRODUCT(7, 1);
+#if TYPE == TYPE_HALF
+        MM_DOT_PRODUCT(8, 0);
+        MM_DOT_PRODUCT(8, 1);
+        MM_DOT_PRODUCT(9, 0);
+        MM_DOT_PRODUCT(9, 1);
+        MM_DOT_PRODUCT(10, 0);
+        MM_DOT_PRODUCT(10, 1);
+        MM_DOT_PRODUCT(11, 0);
+        MM_DOT_PRODUCT(11, 1);
+        MM_DOT_PRODUCT(12, 0);
+        MM_DOT_PRODUCT(12, 1);
+        MM_DOT_PRODUCT(13, 0);
+        MM_DOT_PRODUCT(13, 1);
+        MM_DOT_PRODUCT(14, 0);
+        MM_DOT_PRODUCT(14, 1);
+        MM_DOT_PRODUCT(15, 0);
+        MM_DOT_PRODUCT(15, 1);
+#endif
+#undef MM_DOT_PRODUCT
+
+        src0_read += TILE_K;
+        w += TILE_K;
+    }
+
+    if(w < end_index) {
+        arow0.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row0 * K)[0] : 0.0f;
+        arow0.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row0 * K)[1] : 0.0f;
+        arow1.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row1 * K)[0] : 0.0f;
+        arow1.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row1 * K)[1] : 0.0f;
+        arow2.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row2 * K)[0] : 0.0f;
+        arow2.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row2 * K)[1] : 0.0f;
+        arow3.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row3 * K)[0] : 0.0f;
+        arow3.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row3 * K)[1] : 0.0f;
+        arow4.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row4 * K)[0] : 0.0f;
+        arow4.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row4 * K)[1] : 0.0f;
+        arow5.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row5 * K)[0] : 0.0f;
+        arow5.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row5 * K)[1] : 0.0f;
+        arow6.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row6 * K)[0] : 0.0f;
+        arow6.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row6 * K)[1] : 0.0f;
+        arow7.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row7 * K)[0] : 0.0f;
+        arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f;
+
+#define MM_DOT_PRODUCT( index, suffix )   \
+        brow = (w < K) ? vload4(0, src1_read0) : (Dtype4)0.0f;  src1_read0 += N; w++; \
+        dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \
+        dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \
+        dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \
+        dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \
+        dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \
+        dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \
+        dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \
+        dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );
+
+        MM_DOT_PRODUCT(0, 0);
+        MM_DOT_PRODUCT(0, 1);
+        MM_DOT_PRODUCT(1, 0);
+        MM_DOT_PRODUCT(1, 1);
+        MM_DOT_PRODUCT(2, 0);
+        MM_DOT_PRODUCT(2, 1);
+        MM_DOT_PRODUCT(3, 0);
+        MM_DOT_PRODUCT(3, 1);
+        MM_DOT_PRODUCT(4, 0);
+        MM_DOT_PRODUCT(4, 1);
+        MM_DOT_PRODUCT(5, 0);
+        MM_DOT_PRODUCT(5, 1);
+        MM_DOT_PRODUCT(6, 0);
+        MM_DOT_PRODUCT(6, 1);
+        MM_DOT_PRODUCT(7, 0);
+        MM_DOT_PRODUCT(7, 1);
+#if TYPE == TYPE_HALF
+        MM_DOT_PRODUCT(8, 0);
+        MM_DOT_PRODUCT(8, 1);
+        MM_DOT_PRODUCT(9, 0);
+        MM_DOT_PRODUCT(9, 1);
+        MM_DOT_PRODUCT(10, 0);
+        MM_DOT_PRODUCT(10, 1);
+        MM_DOT_PRODUCT(11, 0);
+        MM_DOT_PRODUCT(11, 1);
+        MM_DOT_PRODUCT(12, 0);
+        MM_DOT_PRODUCT(12, 1);
+        MM_DOT_PRODUCT(13, 0);
+        MM_DOT_PRODUCT(13, 1);
+        MM_DOT_PRODUCT(14, 0);
+        MM_DOT_PRODUCT(14, 1);
+        MM_DOT_PRODUCT(15, 0);
+        MM_DOT_PRODUCT(15, 1);
+#endif
+#undef MM_DOT_PRODUCT
+    }
+
+    if(global_x * 4 < N && global_y * 8 < M) {
+        if(mad24(global_x, 4, 3) < N) {
+            vstore4(dot00, 0, dst_write0); dst_write0 += N;
+            if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); }
+        } else if(mad24(global_x, 4, 2) < N) {
+            vstore2(dot00.xy, 0, dst_write0);
+            dst_write0[2] = dot00.z;
+            dst_write0 += N;
+            if(mad24(global_y, 8, 1) < M) {
+                vstore2(dot01.xy, 0, dst_write0);
+                dst_write0[2] = dot01.z;
+                dst_write0 += N;
+            } else
+                return;
+            if(mad24(global_y, 8, 2) < M) {
+                vstore2(dot02.xy, 0, dst_write0);
+                dst_write0[2] = dot02.z;
+                dst_write0 += N;
+            } else
+                return;
+            if(mad24(global_y, 8, 3) < M) {
+                vstore2(dot03.xy, 0, dst_write0);
+                dst_write0[2] = dot03.z;
+                dst_write0 += N;
+            } else
+                return;
+            if(mad24(global_y, 8, 4) < M) {
+                vstore2(dot04.xy, 0, dst_write0);
+                dst_write0[2] = dot04.z;
+                dst_write0 += N;
+            } else
+                return;
+            if(mad24(global_y, 8, 5) < M) {
+                vstore2(dot05.xy, 0, dst_write0);
+                dst_write0[2] = dot05.z;
+                dst_write0 += N;
+            } else
+                return;
+            if(mad24(global_y, 8, 6) < M) {
+                vstore2(dot06.xy, 0, dst_write0);
+                dst_write0[2] = dot06.z;
+                dst_write0 += N;
+            } else
+                return;
+            if(mad24(global_y, 8, 7) < M) {
+                vstore2(dot07.xy, 0, dst_write0);
+                dst_write0[2] = dot07.z;
+            }
+        } else if(mad24(global_x, 4, 1) < N) {
+            vstore2(dot00.xy, 0, dst_write0); dst_write0 += N;
+            if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); }
+        } else {
+            dst_write0[0] = dot00.x; dst_write0 += N;
+            if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; }
+            else return;
+            if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; }
+        }
+    }
+}
+
+#undef VEC_SIZE
+#undef LWG_HEIGHT
+#undef TILE_M
+#undef TILE_K
+#undef TILE_N
+
+#define VEC_SIZE        1
+#define TILE_M          8
+#define TILE_N          8
+#define SLM_BLOCK       128
+
+#if TYPE == TYPE_HALF
+#define LWG_HEIGHT      2
+#define TILE_K          64
+#else
+#define LWG_HEIGHT      4
+#define TILE_K          32
+#endif
+
+#if TYPE == TYPE_HALF
+__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))
+__attribute__((intel_reqd_sub_group_size(8)))
+__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
+    const __global Dtype *src0, int off0,
+    const __global Dtype *src1, int off1,
+    __global Dtype *dst, int offd,
+    int M,
+    int N,
+    int K,
+    KERNEL_ARG_DTYPE alpha_in,
+    KERNEL_ARG_DTYPE beta_in)
+{
+    const Dtype alpha = (Dtype)alpha_in;
+    const Dtype beta = (Dtype)beta_in;
+    const int group_x = get_group_id(0);
+    const int group_y = get_group_id(1);
+    const int local_x = get_local_id(0);
+    const int local_y = get_local_id(1);
+    const int global_x = get_global_id(0);
+    const int global_y = get_global_id(1);
+
+    Dtype8 dot00 = 0.f;
+    Dtype8 dot01 = 0.f;
+    Dtype8 dot02 = 0.f;
+    Dtype8 dot03 = 0.f;
+    Dtype8 dot04 = 0.f;
+    Dtype8 dot05 = 0.f;
+    Dtype8 dot06 = 0.f;
+    Dtype8 dot07 = 0.f;
+
+    Dtype8 brow0;
+    Dtype8 brow1;
+    Dtype8 brow2;
+    Dtype8 brow3;
+    Dtype8 brow4;
+    Dtype8 brow5;
+    Dtype8 brow6;
+    Dtype8 brow7;
+
+    __global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;
+
+    const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;
+
+    const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;
+
+    __local Dtype slm_brow[8 * SLM_BLOCK];
+    __local Dtype* slm_brow0;
+
+    int local_index = mad24(local_y, 8, local_x) * 8;
+    int w;
+    for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {
+        barrier(CLK_LOCAL_MEM_FENCE);
+        vstore4(vload4(0, (__global float *)(src1_read0 + mad24(0, K, local_index))), 0, (__local float *)(slm_brow + mad24(0, SLM_BLOCK, local_index)));
+        vstore4(vload4(0, (__global float *)(src1_read0 + mad24(1, K, local_index))), 0, (__local float *)(slm_brow + mad24(1, SLM_BLOCK, local_index)));
+        vstore4(vload4(0, (__global float *)(src1_read0 + mad24(2, K, local_index))), 0, (__local float *)(slm_brow + mad24(2, SLM_BLOCK, local_index)));
+        vstore4(vload4(0, (__global float *)(src1_read0 + mad24(3, K, local_index))), 0, (__local float *)(slm_brow + mad24(3, SLM_BLOCK, local_index)));
+        vstore4(vload4(0, (__global float *)(src1_read0 + mad24(4, K, local_index))), 0, (__local float *)(slm_brow + mad24(4, SLM_BLOCK, local_index)));
+        vstore4(vload4(0, (__global float *)(src1_read0 + mad24(5, K, local_index))), 0, (__local float *)(slm_brow + mad24(5, SLM_BLOCK, local_index)));
+        vstore4(vload4(0, (__global float *)(src1_read0 + mad24(6, K, local_index))), 0, (__local float *)(slm_brow + mad24(6, SLM_BLOCK, local_index)));
+        vstore4(vload4(0, (__global float *)(src1_read0 + mad24(7, K, local_index))), 0, (__local float *)(slm_brow + mad24(7, SLM_BLOCK, local_index)));
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        slm_brow0 = slm_brow + local_x * (TILE_K / 8);
+        w = b_tile;
+        int end_w = min(b_tile + SLM_BLOCK, K);
+        while( w + TILE_K <= end_w ) {
+            Dtype8 arow;
+
+            brow0 = as_half8(vload4(0, (__local float *)(slm_brow0 + 0 * SLM_BLOCK)));
+            brow1 = as_half8(vload4(0, (__local float *)(slm_brow0 + 1 * SLM_BLOCK)));
+            brow2 = as_half8(vload4(0, (__local float *)(slm_brow0 + 2 * SLM_BLOCK)));
+            brow3 = as_half8(vload4(0, (__local float *)(slm_brow0 + 3 * SLM_BLOCK)));
+            brow4 = as_half8(vload4(0, (__local float *)(slm_brow0 + 4 * SLM_BLOCK)));
+            brow5 = as_half8(vload4(0, (__local float *)(slm_brow0 + 5 * SLM_BLOCK)));
+            brow6 = as_half8(vload4(0, (__local float *)(slm_brow0 + 6 * SLM_BLOCK)));
+            brow7 = as_half8(vload4(0, (__local float *)(slm_brow0 + 7 * SLM_BLOCK)));
+
+#define MM_DOT_PRODUCT( _row, _dot )   \
+            arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K)));                           \
+            _dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \
+            _dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \
+            _dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \
+            _dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \
+            _dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \
+            _dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \
+            _dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \
+            _dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );
+
+            MM_DOT_PRODUCT( 0, dot00 );
+            MM_DOT_PRODUCT( 1, dot01 );
+            MM_DOT_PRODUCT( 2, dot02 );
+            MM_DOT_PRODUCT( 3, dot03 );
+            MM_DOT_PRODUCT( 4, dot04 );
+            MM_DOT_PRODUCT( 5, dot05 );
+            MM_DOT_PRODUCT( 6, dot06 );
+            MM_DOT_PRODUCT( 7, dot07 );
+#undef MM_DOT_PRODUCT
+
+            src0_read += TILE_K;
+            slm_brow0 += TILE_K;
+            w += TILE_K;
+        }
+        src1_read0 += SLM_BLOCK;
+    }
+
+    if(w < K) {
+        Dtype8 arow;
+
+#define READ_BROW(_brow, _row) \
+        _brow = as_half8(vload4(0, (__local float *)(slm_brow0 + _row * SLM_BLOCK))); \
+        _brow.s0 = (mad24(local_x, 8, w) < K) ? _brow.s0 : 0.0f; \
+        _brow.s1 = (mad24(local_x, 8, w + 1) < K) ? _brow.s1 : 0.0f; \
+        _brow.s2 = (mad24(local_x, 8, w + 2) < K) ? _brow.s2 : 0.0f; \
+        _brow.s3 = (mad24(local_x, 8, w + 3) < K) ? _brow.s3 : 0.0f; \
+        _brow.s4 = (mad24(local_x, 8, w + 4) < K) ? _brow.s4 : 0.0f; \
+        _brow.s5 = (mad24(local_x, 8, w + 5) < K) ? _brow.s5 : 0.0f; \
+        _brow.s6 = (mad24(local_x, 8, w + 6) < K) ? _brow.s6 : 0.0f; \
+        _brow.s7 = (mad24(local_x, 8, w + 7) < K) ? _brow.s7 : 0.0f;
+
+        READ_BROW(brow0, 0);
+        READ_BROW(brow1, 1);
+        READ_BROW(brow2, 2);
+        READ_BROW(brow3, 3);
+        READ_BROW(brow4, 4);
+        READ_BROW(brow5, 5);
+        READ_BROW(brow6, 6);
+        READ_BROW(brow7, 7);
+
+#undef READ_BROW
+
+#define MM_DOT_PRODUCT( _row, _dot )   \
+        arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K)));                           \
+        arow.s0 = (mad24(local_x, 8, w) < K) ? arow.s0 : 0.0f; \
+        arow.s1 = (mad24(local_x, 8, w + 1) < K) ? arow.s1 : 0.0f; \
+        arow.s2 = (mad24(local_x, 8, w + 2) < K) ? arow.s2 : 0.0f; \
+        arow.s3 = (mad24(local_x, 8, w + 3) < K) ? arow.s3 : 0.0f; \
+        arow.s4 = (mad24(local_x, 8, w + 4) < K) ? arow.s4 : 0.0f; \
+        arow.s5 = (mad24(local_x, 8, w + 5) < K) ? arow.s5 : 0.0f; \
+        arow.s6 = (mad24(local_x, 8, w + 6) < K) ? arow.s6 : 0.0f; \
+        arow.s7 = (mad24(local_x, 8, w + 7) < K) ? arow.s7 : 0.0f; \
+        _dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \
+        _dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \
+        _dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \
+        _dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \
+        _dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \
+        _dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \
+        _dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \
+        _dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );
+
+        MM_DOT_PRODUCT( 0, dot00 );
+        MM_DOT_PRODUCT( 1, dot01 );
+        MM_DOT_PRODUCT( 2, dot02 );
+        MM_DOT_PRODUCT( 3, dot03 );
+        MM_DOT_PRODUCT( 4, dot04 );
+        MM_DOT_PRODUCT( 5, dot05 );
+        MM_DOT_PRODUCT( 6, dot06 );
+        MM_DOT_PRODUCT( 7, dot07 );
+#undef MM_DOT_PRODUCT
+    }
+
+#define REDUCE(_dot) \
+    _dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) +  \
+           as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));
+
+    REDUCE(dot00);
+    REDUCE(dot01);
+    REDUCE(dot02);
+    REDUCE(dot03);
+    REDUCE(dot04);
+    REDUCE(dot05);
+    REDUCE(dot06);
+    REDUCE(dot07);
+#undef REDUCE
+
+    Dtype output = 0.0f;
+#define OUTPUT( _dot) \
+    output = (local_x == 0) ? _dot.s0 : output; \
+    output = (local_x == 1) ? _dot.s1 : output; \
+    output = (local_x == 2) ? _dot.s2 : output; \
+    output = (local_x == 3) ? _dot.s3 : output; \
+    output = (local_x == 4) ? _dot.s4 : output; \
+    output = (local_x == 5) ? _dot.s5 : output; \
+    output = (local_x == 6) ? _dot.s6 : output; \
+    output = (local_x == 7) ? _dot.s7 : output; \
+    dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
+    dst_write0 += N;
+
+    if(global_x < N && global_y * 8 < M) {
+        OUTPUT(dot00);
+        if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }
+        if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }
+        if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }
+        if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }
+        if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }
+        if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }
+        if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }
+    }
+#undef OUTPUT
+}
+
+#else
+
+__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))
+__attribute__((intel_reqd_sub_group_size(8)))
+__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
+    const __global Dtype *src0, int off0,
+    const __global Dtype *src1, int off1,
+    __global Dtype *dst, int offd,
+    int M,
+    int N,
+    int K,
+    KERNEL_ARG_DTYPE alpha_in,
+    KERNEL_ARG_DTYPE beta_in)
+{
+    const Dtype alpha = (Dtype)alpha_in;
+    const Dtype beta = (Dtype)beta_in;
+    const int group_x = get_group_id(0);
+    const int group_y = get_group_id(1);
+    const int local_x = get_local_id(0);
+    const int local_y = get_local_id(1);
+    const int global_x = get_global_id(0);
+    const int global_y = get_global_id(1);
+
+    Dtype8 dot00 = 0.f;
+    Dtype8 dot01 = 0.f;
+    Dtype8 dot02 = 0.f;
+    Dtype8 dot03 = 0.f;
+    Dtype8 dot04 = 0.f;
+    Dtype8 dot05 = 0.f;
+    Dtype8 dot06 = 0.f;
+    Dtype8 dot07 = 0.f;
+
+    Dtype4 brow0;
+    Dtype4 brow1;
+    Dtype4 brow2;
+    Dtype4 brow3;
+    Dtype4 brow4;
+    Dtype4 brow5;
+    Dtype4 brow6;
+    Dtype4 brow7;
+
+    __global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;
+
+    const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;
+
+    const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;
+
+    __local Dtype slm_brow[8 * SLM_BLOCK];
+    __local Dtype* slm_brow0;
+
+    int local_index = mad24(local_y, 8, local_x) * 4;
+    int w;
+    for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {
+        barrier(CLK_LOCAL_MEM_FENCE);
+        vstore4(vload4(0, src1_read0 + mad24(0, K, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index));
+        vstore4(vload4(0, src1_read0 + mad24(1, K, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index));
+        vstore4(vload4(0, src1_read0 + mad24(2, K, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index));
+        vstore4(vload4(0, src1_read0 + mad24(3, K, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index));
+        vstore4(vload4(0, src1_read0 + mad24(4, K, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index));
+        vstore4(vload4(0, src1_read0 + mad24(5, K, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index));
+        vstore4(vload4(0, src1_read0 + mad24(6, K, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index));
+        vstore4(vload4(0, src1_read0 + mad24(7, K, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index));
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        slm_brow0 = slm_brow + local_x * (TILE_K / 8);
+        w = b_tile;
+        int end_w = min(b_tile + SLM_BLOCK, K);
+        while( w + TILE_K <= end_w ) {
+            Dtype4 arow;
+
+            brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK);
+            brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK);
+            brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK);
+            brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK);
+            brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK);
+            brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK);
+            brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK);
+            brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK);
+
+#define MM_DOT_PRODUCT( _row, _dot )   \
+            arow = vload4(0, src0_read + _row * K);                           \
+            _dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
+            _dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
+            _dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
+            _dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
+
+            MM_DOT_PRODUCT( 0, dot00 );
+            MM_DOT_PRODUCT( 1, dot01 );
+            MM_DOT_PRODUCT( 2, dot02 );
+            MM_DOT_PRODUCT( 3, dot03 );
+            MM_DOT_PRODUCT( 4, dot04 );
+            MM_DOT_PRODUCT( 5, dot05 );
+            MM_DOT_PRODUCT( 6, dot06 );
+            MM_DOT_PRODUCT( 7, dot07 );
+#undef MM_DOT_PRODUCT
+
+            src0_read += TILE_K;
+            slm_brow0 += TILE_K;
+            w += TILE_K;
+        }
+        src1_read0 += SLM_BLOCK;
+    }
+
+    if(w < K) {
+        Dtype4 arow;
+
+#define READ_BROW(_brow, _row) \
+        _brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \
+        _brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \
+        _brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \
+        _brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \
+        _brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;
+
+        READ_BROW(brow0, 0);
+        READ_BROW(brow1, 1);
+        READ_BROW(brow2, 2);
+        READ_BROW(brow3, 3);
+        READ_BROW(brow4, 4);
+        READ_BROW(brow5, 5);
+        READ_BROW(brow6, 6);
+        READ_BROW(brow7, 7);
+
+#undef READ_BROW
+
+#define MM_DOT_PRODUCT( _row, _dot )   \
+        arow = vload4(0, src0_read + _row * K);                           \
+        arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \
+        arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \
+        arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \
+        arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \
+        _dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
+        _dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
+        _dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
+        _dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
+
+        MM_DOT_PRODUCT( 0, dot00 );
+        MM_DOT_PRODUCT( 1, dot01 );
+        MM_DOT_PRODUCT( 2, dot02 );
+        MM_DOT_PRODUCT( 3, dot03 );
+        MM_DOT_PRODUCT( 4, dot04 );
+        MM_DOT_PRODUCT( 5, dot05 );
+        MM_DOT_PRODUCT( 6, dot06 );
+        MM_DOT_PRODUCT( 7, dot07 );
+#undef MM_DOT_PRODUCT
+    }
+
+#define REDUCE(_dot) \
+    _dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) +  \
+           as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));
+
+    REDUCE(dot00);
+    REDUCE(dot01);
+    REDUCE(dot02);
+    REDUCE(dot03);
+    REDUCE(dot04);
+    REDUCE(dot05);
+    REDUCE(dot06);
+    REDUCE(dot07);
+#undef REDUCE
+
+    Dtype output = 0.0f;
+#define OUTPUT( _dot) \
+    output = (local_x == 0) ? _dot.s0 : output; \
+    output = (local_x == 1) ? _dot.s1 : output; \
+    output = (local_x == 2) ? _dot.s2 : output; \
+    output = (local_x == 3) ? _dot.s3 : output; \
+    output = (local_x == 4) ? _dot.s4 : output; \
+    output = (local_x == 5) ? _dot.s5 : output; \
+    output = (local_x == 6) ? _dot.s6 : output; \
+    output = (local_x == 7) ? _dot.s7 : output; \
+    dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
+    dst_write0 += N;
+
+    if(global_x < N && global_y * 8 < M) {
+        OUTPUT(dot00);
+        if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }
+        if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }
+        if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }
+        if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }
+        if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }
+        if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }
+        if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }
+    }
+#undef OUTPUT
+}
+#endif
+
+#undef VEC_SIZE
+#undef LWG_HEIGHT
+#undef TILE_M
+#undef TILE_K
+#undef TILE_N
+#undef SLM_BLOCK
+
+#define SLM_SIZE 64
+void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(
+                           const __global Dtype* srca_read0,
+                           const __global Dtype* srca_read1,
+                           const __global Dtype* srcb_read,
+                           __local Dtype4* work0,
+                           __local Dtype4* work1,
+                           int N,
+                           int K,
+                           int x_gid,
+                           int lid,
+                           Dtype alpha,
+                           Dtype beta,
+                           __global Dtype* dstc0,
+                           __global Dtype* dstc1)
+{
+  __local Dtype* work_each0 = (__local Dtype*)work0;
+  __local Dtype* work_each1 = (__local Dtype*)work1;
+
+  int rows = N - x_gid * 4;
+
+  Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+  Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+
+  int i = lid;
+  while( i < K / 4) {
+    const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};
+    const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};
+#pragma unroll
+    for(int j = 0; j < rows; ++j) {
+      dot0[j] += b0 * vload4(i, srcb_read + j * K);
+      dot1[j] += b1 * vload4(i, srcb_read + j * K);
+    }
+
+    i += get_local_size(0);
+  }
+#pragma unroll
+  for(int j = 0; j < rows; ++j) {
+    work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
+    work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
+  }
+
+  if(i == K / 4) {
+    short tail_items = K % 4;
+
+    if(tail_items != 0) {
+      const __global Dtype *srcb_tail = srcb_read + i * 4;
+      const __global Dtype *srca_tail0 = srca_read0 + i * 4;
+      const __global Dtype *srca_tail1 = srca_read1 + i * 4;
+#pragma unroll
+      for(short i = 0; i < tail_items; ++i) {
+        const Dtype at0 = srca_tail0[i];
+        const Dtype at1 = srca_tail1[i];
+#pragma unroll
+        for(int j = 0; j < rows; ++j) {
+          work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];
+          work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];
+        }
+      }
+    }
+  }
+
+  for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if(lid < stride) {
+      work0[lid] += work0[lid+stride];
+      work1[lid] += work1[lid+stride];
+    }
+  }
+
+  if(lid == 0) {
+#pragma unroll
+    for(int j = 0; j < rows; ++j) {
+      dstc0[(x_gid * 4  + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
+      dstc1[(x_gid * 4  + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
+    }
+  }
+}
+
+__kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(
+          __global const Dtype * A,
+          int offA,
+          __global const Dtype * B,
+          int offB,
+          __global Dtype * C,
+          int offC,
+          int M,
+          int N,
+          int K,
+          KERNEL_ARG_DTYPE alpha_f,
+          KERNEL_ARG_DTYPE beta_f)
+{
+  Dtype alpha = (Dtype)alpha_f;
+  Dtype beta = (Dtype)beta_f;
+  int x_gid = get_group_id(0);
+  int lid = get_local_id(0);
+
+  const __global Dtype *srca_read0 = A + offA;
+  const __global Dtype *srca_read1 = srca_read0 + K;
+
+  const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;
+
+  __global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);
+  __global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);
+
+  __local Dtype4 work0[SLM_SIZE];
+  __local Dtype4 work1[SLM_SIZE];
+  __local Dtype* work_each0 = (__local Dtype*)work0;
+  __local Dtype* work_each1 = (__local Dtype*)work1;
+
+  if(x_gid == N / 4) {
+    TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) \
+         (srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1);
+  } else {
+    Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+    Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+    int i = lid;
+    while( i < K / 4) {
+      const Dtype4 b0 = vload4(i, srca_read0);
+      const Dtype4 b1 = vload4(i, srca_read1);
+#pragma unroll
+      for(int j = 0; j < 4; ++j) {
+        Dtype4 a = vload4(i, srcb_read + j * K);
+        dot0[j] += b0 * a;
+        dot1[j] += b1 * a;
+      }
+      i += get_local_size(0);
+    }
+
+#pragma unroll
+    for(int j = 0; j < 4; ++j) {
+      work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
+      work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
+    }
+
+    if(i == K / 4) {
+      short tail_items = K % 4;
+      if(tail_items != 0) {
+        const __global Dtype *srcb_tail = srcb_read + i * 4;
+
+        const __global Dtype *srca_tail0 = srca_read0 + i * 4;
+        const __global Dtype *srca_tail1 = srca_read1 + i * 4;
+#pragma unroll
+        for(short i = 0; i < tail_items; ++i) {
+          const Dtype at0 = srca_tail0[i];
+          const Dtype at1 = srca_tail1[i];
+#pragma unroll
+          for(int j = 0; j < 4; ++j) {
+            work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];
+            work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];
+          }
+        }
+      }
+    }
+
+    for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
+      barrier(CLK_LOCAL_MEM_FENCE);
+      if(lid < stride) {
+        work0[lid] += work0[lid+stride];
+        work1[lid] += work1[lid+stride];
+      }
+    }
+
+    if(lid == 0) {
+      dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
+      dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
+    }
+  }
+}
+#undef SLM_SIZE
+
+#define SLM_SIZE 32
+void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(
+                           const __global Dtype* srca_read0,
+                           const __global Dtype* srca_read1,
+                           const __global Dtype* srca_read2,
+                           const __global Dtype* srca_read3,
+                           const __global Dtype* srcb_read,
+                           __local Dtype4* work0,
+                           __local Dtype4* work1,
+                           __local Dtype4* work2,
+                           __local Dtype4* work3,
+                           int N,
+                           int K,
+                           int x_gid,
+                           int lid,
+                           Dtype alpha,
+                           Dtype beta,
+                           __global Dtype* dstc0,
+                           __global Dtype* dstc1,
+                           __global Dtype* dstc2,
+                           __global Dtype* dstc3)
+{
+  __local Dtype* work_each0 = (__local Dtype*)(work0 + lid);
+  __local Dtype* work_each1 = (__local Dtype*)(work1 + lid);
+  __local Dtype* work_each2 = (__local Dtype*)(work2 + lid);
+  __local Dtype* work_each3 = (__local Dtype*)(work3 + lid);
+
+  int rows = N - x_gid * 4;
+
+  Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+  Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+  Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+  Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+
+  int i = lid;
+  while( i < K / 4) {
+    const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};
+    const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};
+    const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]};
+    const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]};
+#pragma unrol
+    for(int j = 0; j < rows; ++j) {
+      dot0[j] += a0 * vload4(i, srcb_read + j * K);
+      dot1[j] += a1 * vload4(i, srcb_read + j * K);
+      dot2[j] += a2 * vload4(i, srcb_read + j * K);
+      dot3[j] += a3 * vload4(i, srcb_read + j * K);
+    }
+
+    i += get_local_size(0);
+  }
+#pragma unroll
+  for(int j = 0; j < rows; ++j) {
+    work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
+    work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
+    work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;
+    work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;
+  }
+
+  if(i == K / 4) {
+    short tail_items = K % 4;
+
+    if(tail_items != 0) {
+      const __global Dtype *srcb_tail = srcb_read + i * 4;
+
+      const __global Dtype *srca_tail0 = srca_read0 + i * 4;
+      const __global Dtype *srca_tail1 = srca_read1 + i * 4;
+      const __global Dtype *srca_tail2 = srca_read2 + i * 4;
+      const __global Dtype *srca_tail3 = srca_read3 + i * 4;
+#pragma unroll
+      for(short i = 0; i < tail_items; ++i) {
+        const Dtype at0 = srca_tail0[i];
+        const Dtype at1 = srca_tail1[i];
+        const Dtype at2 = srca_tail2[i];
+        const Dtype at3 = srca_tail3[i];
+#pragma unroll
+        for(int j = 0; j < rows; ++j) {
+          work_each0[j] += at0 * srcb_tail[i + j * K];
+          work_each1[j] += at1 * srcb_tail[i + j * K];
+          work_each2[j] += at2 * srcb_tail[i + j * K];
+          work_each3[j] += at3 * srcb_tail[i + j * K];
+        }
+      }
+    }
+  }
+
+  for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if(lid < stride) {
+      work0[lid] += work0[lid+stride];
+      work1[lid] += work1[lid+stride];
+      work2[lid] += work2[lid+stride];
+      work3[lid] += work3[lid+stride];
+    }
+  }
+
+  if(lid == 0) {
+#pragma unroll
+    for(int j = 0; j < rows; ++j) {
+      dstc0[(x_gid * 4  + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
+      dstc1[(x_gid * 4  + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
+      dstc2[(x_gid * 4  + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)];
+      dstc3[(x_gid * 4  + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)];
+    }
+  }
+}
+
+__kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(
+          __global const Dtype * A,
+          int offA,
+          __global const Dtype * B,
+          int offB,
+          __global Dtype * C,
+          int offC,
+          int M,
+          int N,
+          int K,
+          KERNEL_ARG_DTYPE alpha_f,
+          KERNEL_ARG_DTYPE beta_f)
+{
+  Dtype alpha = (Dtype)alpha_f;
+  Dtype beta = (Dtype)beta_f;
+  int x_gid = get_group_id(0);
+  int lid = get_local_id(0);
+  int lsize = get_local_size(0);
+
+  const __global Dtype *srca_read0 = A + offA;
+  const __global Dtype *srca_read1 = srca_read0 + K;
+  const __global Dtype *srca_read2 = srca_read1 + K;
+  const __global Dtype *srca_read3 = srca_read2 + K;
+
+  const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;
+
+  __global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);
+  __global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);
+  __global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N);
+  __global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N);
+
+  __local Dtype4 work0[SLM_SIZE];
+  __local Dtype4 work1[SLM_SIZE];
+  __local Dtype4 work2[SLM_SIZE];
+  __local Dtype4 work3[SLM_SIZE];
+  __local Dtype* work_each0 = (__local Dtype*)(work0 + lid);
+  __local Dtype* work_each1 = (__local Dtype*)(work1 + lid);
+  __local Dtype* work_each2 = (__local Dtype*)(work2 + lid);
+  __local Dtype* work_each3 = (__local Dtype*)(work3 + lid);
+
+  if(x_gid == N / 4) {
+    TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) \
+         (srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, \
+         work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, \
+         (__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3);
+  } else {
+    Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+    Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+    Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+    Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
+
+    int kid = lid;
+    while( kid < K / 4) {
+      const Dtype4 b0 = vload4(kid, srca_read0);
+      const Dtype4 b1 = vload4(kid, srca_read1);
+      const Dtype4 b2 = vload4(kid, srca_read2);
+      const Dtype4 b3 = vload4(kid, srca_read3);
+#pragma unroll
+      for(int j = 0; j < 4; ++j) {
+        Dtype4 a = vload4(kid, srcb_read + j * K);
+        dot0[j] += b0 * a;
+        dot1[j] += b1 * a;
+        dot2[j] += b2 * a;
+        dot3[j] += b3 * a;
+      }
+      kid += lsize;
+    }
+#pragma unroll
+    for(int j = 0; j < 4; ++j) {
+      work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
+      work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
+      work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;
+      work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;
+    }
+
+    if(kid == (K >> 2)) {
+      short tail_items = K % 4;
+      if(tail_items != 0) {
+        int offset = kid << 2;
+        const __global Dtype *srcb_tail = srcb_read + offset;
+
+        const __global Dtype *srca_tail0 = srca_read0 + offset;
+        const __global Dtype *srca_tail1 = srca_read1 + offset;
+        const __global Dtype *srca_tail2 = srca_read2 + offset;
+        const __global Dtype *srca_tail3 = srca_read3 + offset;
+#pragma unroll
+        for(short i = 0; i < tail_items; ++i) {
+          const Dtype at0 = srca_tail0[i];
+          const Dtype at1 = srca_tail1[i];
+          const Dtype at2 = srca_tail2[i];
+          const Dtype at3 = srca_tail3[i];
+#pragma unroll
+          for(int j = 0; j < 4; ++j) {
+            work_each0[j] += at0 * srcb_tail[i + j * K];
+            work_each1[j] += at1 * srcb_tail[i + j * K];
+            work_each2[j] += at2 * srcb_tail[i + j * K];
+            work_each3[j] += at3 * srcb_tail[i + j * K];
+          }
+        }
+      }
+    }
+
+    for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
+      barrier(CLK_LOCAL_MEM_FENCE);
+      if(lid < stride) {
+        work0[lid] += work0[lid+stride];
+        work1[lid] += work1[lid+stride];
+        work2[lid] += work2[lid+stride];
+        work3[lid] += work3[lid+stride];
+      }
+    }
+
+    if(lid == 0) {
+      dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
+      dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
+      dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
+      dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];
+    }
+  }
+}
+#undef SLM_SIZE
+
+#define SLM_SIZE 16
+__kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(
+          __global const Dtype * A,
+          int offA,
+          __global const Dtype * B,
+          int offB,
+          __global Dtype * C,
+          int offC,
+          int M,
+          int N,
+          int K,
+          KERNEL_ARG_DTYPE alpha_f,
+          KERNEL_ARG_DTYPE beta_f)
+{
+  Dtype alpha = (Dtype)alpha_f;
+  Dtype beta = (Dtype)beta_f;
+  int x_gid = get_group_id(0);
+  int lid = get_local_id(0);
+  int lsize = get_local_size(0);
+
+  const __global Dtype *srca_read0 = A + offA;
+  const __global Dtype *srca_read1 = srca_read0 + K;
+  const __global Dtype *srca_read2 = srca_read1 + K;
+  const __global Dtype *srca_read3 = srca_read2 + K;
+  const __global Dtype *srca_read4 = srca_read3 + K;
+  const __global Dtype *srca_read5 = srca_read4 + K;
+  const __global Dtype *srca_read6 = srca_read5 + K;
+  const __global Dtype *srca_read7 = srca_read6 + K;
+
+  const __global Dtype *srcb_read = B + x_gid * K + offB;
+
+  __global Dtype *dstc0 = C + offC;
+  __global Dtype *dstc1 = dstc0 + N;
+  __global Dtype *dstc2 = dstc1 + N;
+  __global Dtype *dstc3 = dstc2 + N;
+  __global Dtype *dstc4 = dstc3 + N;
+  __global Dtype *dstc5 = dstc4 + N;
+  __global Dtype *dstc6 = dstc5 + N;
+  __global Dtype *dstc7 = dstc6 + N;
+
+  __local Dtype work0[SLM_SIZE];
+  __local Dtype work1[SLM_SIZE];
+  __local Dtype work2[SLM_SIZE];
+  __local Dtype work3[SLM_SIZE];
+  __local Dtype work4[SLM_SIZE];
+  __local Dtype work5[SLM_SIZE];
+  __local Dtype work6[SLM_SIZE];
+  __local Dtype work7[SLM_SIZE];
+
+  Dtype4 dot0 = (Dtype4)(0.);
+  Dtype4 dot1 = (Dtype4)(0.);
+  Dtype4 dot2 = (Dtype4)(0.);
+  Dtype4 dot3 = (Dtype4)(0.);
+  Dtype4 dot4 = (Dtype4)(0.);
+  Dtype4 dot5 = (Dtype4)(0.);
+  Dtype4 dot6 = (Dtype4)(0.);
+  Dtype4 dot7 = (Dtype4)(0.);
+
+  int kid = lid;
+  while( kid < K / 4) {
+    const Dtype4 a0 = vload4(kid, srca_read0);
+    const Dtype4 a1 = vload4(kid, srca_read1);
+    const Dtype4 a2 = vload4(kid, srca_read2);
+    const Dtype4 a3 = vload4(kid, srca_read3);
+    const Dtype4 a4 = vload4(kid, srca_read4);
+    const Dtype4 a5 = vload4(kid, srca_read5);
+    const Dtype4 a6 = vload4(kid, srca_read6);
+    const Dtype4 a7 = vload4(kid, srca_read7);
+    Dtype4 b = vload4(kid, srcb_read);
+    dot0 += a0 * b;
+    dot1 += a1 * b;
+    dot2 += a2 * b;
+    dot3 += a3 * b;
+    dot4 += a4 * b;
+    dot5 += a5 * b;
+    dot6 += a6 * b;
+    dot7 += a7 * b;
+
+    kid += lsize;
+  }
+  work0[lid] = dot0.x + dot0.y + dot0.z + dot0.w;
+  work1[lid] = dot1.x + dot1.y + dot1.z + dot1.w;
+  work2[lid] = dot2.x + dot2.y + dot2.z + dot2.w;
+  work3[lid] = dot3.x + dot3.y + dot3.z + dot3.w;
+  work4[lid] = dot4.x + dot4.y + dot4.z + dot4.w;
+  work5[lid] = dot5.x + dot5.y + dot5.z + dot5.w;
+  work6[lid] = dot6.x + dot6.y + dot6.z + dot6.w;
+  work7[lid] = dot7.x + dot7.y + dot7.z + dot7.w;
+
+  if(kid == (K >> 2)) {
+    short tail_items = K % 4;
+    if(tail_items != 0) {
+      int offset = kid << 2;
+      const __global Dtype *srcb_tail = srcb_read + offset;
+
+      const __global Dtype *srca_tail0 = srca_read0 + offset;
+      const __global Dtype *srca_tail1 = srca_read1 + offset;
+      const __global Dtype *srca_tail2 = srca_read2 + offset;
+      const __global Dtype *srca_tail3 = srca_read3 + offset;
+      const __global Dtype *srca_tail4 = srca_read4 + offset;
+      const __global Dtype *srca_tail5 = srca_read5 + offset;
+      const __global Dtype *srca_tail6 = srca_read6 + offset;
+      const __global Dtype *srca_tail7 = srca_read7 + offset;
+#pragma unroll
+      for(short item = 0; item < tail_items; ++item) {
+        work0[lid] += srca_tail0[item] * srcb_tail[item];
+        work1[lid] += srca_tail1[item] * srcb_tail[item];
+        work2[lid] += srca_tail2[item] * srcb_tail[item];
+        work3[lid] += srca_tail3[item] * srcb_tail[item];
+        work4[lid] += srca_tail4[item] * srcb_tail[item];
+        work5[lid] += srca_tail5[item] * srcb_tail[item];
+        work6[lid] += srca_tail6[item] * srcb_tail[item];
+        work7[lid] += srca_tail7[item] * srcb_tail[item];
+      }
+    }
+  }
+
+  for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if(lid < stride) {
+      work0[lid] += work0[lid+stride];
+      work1[lid] += work1[lid+stride];
+      work2[lid] += work2[lid+stride];
+      work3[lid] += work3[lid+stride];
+      work4[lid] += work4[lid+stride];
+      work5[lid] += work5[lid+stride];
+      work6[lid] += work6[lid+stride];
+      work7[lid] += work7[lid+stride];
+    }
+  }
+
+  if(lid == 0) {
+    dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
+    dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
+    dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
+    dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];
+    dstc4[x_gid] = alpha * work4[0] + beta * dstc4[x_gid];
+    dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid];
+    dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid];
+    dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid];
+  }
+}
+#undef SLM_SIZE
+
+#undef VEC_SIZE
+#undef LWG_HEIGHT
+#undef TILE_M
+#undef TILE_K
+#undef TILE_N
+#undef SIMD_SIZE_GEMM
+#undef SHUFFLE_TYPE2
+#undef SHUFFLE_TYPE8
index 37ae523..710637a 100644 (file)
 //
 //M*/
 
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
+
 #define CONCAT(A,B) A##_##B
 #define TEMPLATE(name,type) CONCAT(name,type)
 
-// Types used for parameters, offset computations and so on
-#define int_tp int
-#define uint_tp unsigned int
-
+#define KERNEL_ARG_DTYPE float
+#define TYPE_FLOAT  1
+#define TYPE_HALF   2
+
+#if TYPE == TYPE_HALF
+#define Dtype  half
+#define Dtype2 half2
+#define Dtype4 half4
+#define Dtype8 half8
+#define Dtype16 half16
+
+#define as_Dtype  as_half
+#define as_Dtype2 as_half2
+#define as_Dtype4 as_half4
+#define as_Dtype8 as_half8
+#define as_Dtype16 as_half16
+#else
 #define Dtype  float
 #define Dtype2 float2
 #define Dtype4 float4
 #define Dtype8 float8
+#define Dtype16 float16
 
 #define as_Dtype  as_float
 #define as_Dtype2 as_float2
 #define as_Dtype4 as_float4
 #define as_Dtype8 as_float8
-
-#define KERNEL_ARG_DTYPE float
+#define as_Dtype16 as_float16
+#endif
 
 #if defined(cl_intel_subgroups)
 #pragma OPENCL EXTENSION  cl_intel_subgroups : enable
 
 // common block to calculate (alpha * AxB + beta * C) and output to destination image.
 
+#if TYPE == TYPE_HALF
+#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read_us8( __image, __coord )
+#define SHUFFLE_TYPE2(val) as_ushort2(val)
+#define SHUFFLE_TYPE8(val) as_ushort8(val)
+#define READ_IMAGE(__image, __coord) read_imageh(__image, sampler, __coord)
+#define SIZE_OF_ELEMENT sizeof(ushort)
+#define SIMD_SIZE_GEMM 16
+#define TILE_N 16
+#else
 #define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read8( __image, __coord )
 #define SHUFFLE_TYPE2(val) val
 #define SHUFFLE_TYPE8(val) val
 #define SIZE_OF_ELEMENT sizeof(uint)
 #define SIMD_SIZE_GEMM 8
 #define TILE_N 8
+#endif
 
 //#define USE_IMAGE_C
 #ifdef USE_IMAGE_C
+#if TYPE == TYPE_HALF
+#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read_us8( _C, _coordC ) )
+#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write_us8( _C, _coordC, as_ushort8( _val ) )
+#else
 #define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read8( _C, _coordC ) )
 #define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) )
+#endif
 #define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst
 #define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint))
 #else
             blockC03 += blockAxB03; \
         } \
     } else { \
-        blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC );    coordC.y += 8; \
-        blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC );    coordC.y += 8; \
-        blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC );    coordC.y += 8; \
-        blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \
+        blockC00 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC );    coordC.y += 8; \
+        blockC01 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC );    coordC.y += 8; \
+        blockC02 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC );    coordC.y += 8; \
+        blockC03 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); \
         if (!ALPHA1) { \
           blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \
           blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \
                   intel_sub_group_shuffle( _block.s7, _col ) );
 
 // A's column block multiply B 's row block.
+#if TYPE == TYPE_HALF
+#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB00, _blockB01 )    \
+        {   \
+            const Dtype8    acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 );    \
+            const Dtype8    acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 );    \
+            const Dtype8    acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 );    \
+            const Dtype8    acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 );    \
+            const Dtype8    acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 );    \
+            const Dtype8    acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 );    \
+            const Dtype8    acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 );    \
+            const Dtype8    acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 );    \
+            const Dtype8    acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 );    \
+            const Dtype8    acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 );    \
+            const Dtype8    acola = TRANSPOSE_BLOCK_8( _blockA, 10 );    \
+            const Dtype8    acolb = TRANSPOSE_BLOCK_8( _blockA, 11 );    \
+            const Dtype8    acolc = TRANSPOSE_BLOCK_8( _blockA, 12 );    \
+            const Dtype8    acold = TRANSPOSE_BLOCK_8( _blockA, 13 );    \
+            const Dtype8    acole = TRANSPOSE_BLOCK_8( _blockA, 14 );    \
+            const Dtype8    acolf = TRANSPOSE_BLOCK_8( _blockA, 15 );    \
+            _result = mad( (Dtype8)(_blockB00.s0), acol0, _result );      \
+            _result = mad( (Dtype8)(_blockB00.s1), acol1, _result );      \
+            _result = mad( (Dtype8)(_blockB00.s2), acol2, _result );      \
+            _result = mad( (Dtype8)(_blockB00.s3), acol3, _result );      \
+            _result = mad( (Dtype8)(_blockB00.s4), acol4, _result );      \
+            _result = mad( (Dtype8)(_blockB00.s5), acol5, _result );      \
+            _result = mad( (Dtype8)(_blockB00.s6), acol6, _result );      \
+            _result = mad( (Dtype8)(_blockB00.s7), acol7, _result );      \
+            _result = mad( (Dtype8)(_blockB01.s0), acol8, _result );      \
+            _result = mad( (Dtype8)(_blockB01.s1), acol9, _result );      \
+            _result = mad( (Dtype8)(_blockB01.s2), acola, _result );      \
+            _result = mad( (Dtype8)(_blockB01.s3), acolb, _result );      \
+            _result = mad( (Dtype8)(_blockB01.s4), acolc, _result );      \
+            _result = mad( (Dtype8)(_blockB01.s5), acold, _result );      \
+            _result = mad( (Dtype8)(_blockB01.s6), acole, _result );      \
+            _result = mad( (Dtype8)(_blockB01.s7), acolf, _result );      \
+        }
+#else
 #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB )    \
         {   \
             const Dtype8    acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 );    \
             _result = mad( (Dtype8)(_blockB.s6), acol6, _result );      \
             _result = mad( (Dtype8)(_blockB.s7), acol7, _result );      \
         }
+#endif
 
+#if TYPE == TYPE_HALF
+#define GEMM_NN(ALPHA1, BETA_NOT0) \
+__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
+__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
+__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
+    __read_only image2d_t A, \
+    __read_only image2d_t B, \
+    MATC_PARAMETER, \
+    KERNEL_ARG_DTYPE alpha_in, \
+    KERNEL_ARG_DTYPE beta_in, \
+    int width0, \
+    int isFirstColBlock) \
+{ \
+    const Dtype alpha = (Dtype)alpha_in; \
+    const Dtype beta = (Dtype)beta_in; \
+    const int group_x = get_group_id(0); \
+    const int group_y = get_group_id(1); \
+    Dtype8 blockAxB00 = 0; \
+    Dtype8 blockAxB01 = 0; \
+    Dtype8 blockAxB02 = 0; \
+    Dtype8 blockAxB03 = 0; \
+    int2    coordA = (int2)( 0, group_y * TILE_M ); \
+    int2    coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \
+    do \
+    {  \
+        int2    coordBTemp = coordB; \
+        Dtype8  blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) );    coordB.y += TILE_K; \
+        Dtype8  blockB01 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) );    coordB.y += TILE_K; \
+        int2    coordATemp = coordA; \
+        Dtype8  blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordATemp.y += 8; \
+        Dtype8  blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordATemp.y += 8; \
+        Dtype8  blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordATemp.y += 8; \
+        Dtype8  blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \
+        MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, blockB01 ); \
+        MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, blockB01 ); \
+        MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, blockB01 ); \
+        MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, blockB01 ); \
+    } \
+    while( coordB.y < width0 ); \
+    GEMM_OUTPUT(ALPHA1, BETA_NOT0);  \
+}
+#else
 #define GEMM_NN(ALPHA1, BETA_NOT0) \
 __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
 __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
@@ -231,6 +344,7 @@ __kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
     while( coordB.y < width0 ); \
     GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
 }
+#endif
 
 GEMM_NN(1, 0) // ALPHA == 1, BETA == 0
 GEMM_NN(1, 1) // ALPHA == 1, BETA != 0
@@ -264,6 +378,45 @@ GEMM_NN(0, 1) // ALPHA != 1, BETA != 0
             _result = mad( (Dtype8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result );      \
         }
 
+#if TYPE == TYPE_HALF
+#define GEMM_TN(ALPHA1, BETA_NOT0) \
+__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
+__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
+__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
+    __read_only image2d_t A, \
+    __read_only image2d_t B, \
+    MATC_PARAMETER, \
+    KERNEL_ARG_DTYPE alpha_in, \
+    KERNEL_ARG_DTYPE beta_in, \
+    int width0, \
+    int isFirstColBlock) \
+{ \
+    const Dtype alpha = (Dtype)alpha_in; \
+    const Dtype beta = (Dtype)beta_in; \
+    const int group_x = get_group_id(0);\
+    const int group_y = get_group_id(1);\
+    Dtype8 blockAxB00 = 0;\
+    Dtype8 blockAxB01 = 0;\
+    Dtype8 blockAxB02 = 0;\
+    Dtype8 blockAxB03 = 0;\
+    int2    coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\
+    int2    coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\
+    do\
+    {\
+        int2    coordBTemp = coordB;\
+        Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) );    coordB.y += TILE_K;\
+        int2    coordATemp = coordA;\
+        Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordATemp.x += 16 * SIZE_OF_ELEMENT;\
+        Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordA.y += TILE_K;\
+        MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \
+        MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \
+        MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \
+        MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \
+    } \
+    while( coordB.y < width0 ); \
+    GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
+}
+#else
 #define GEMM_TN(ALPHA1, BETA_NOT0) \
 __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
 __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
@@ -303,6 +456,7 @@ __kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
     while( coordB.y < width0 ); \
     GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
 }
+#endif
 
 GEMM_TN(1, 0) // ALPHA == 1, BETA == 0
 GEMM_TN(1, 1) // ALPHA == 1, BETA != 0
@@ -324,6 +478,7 @@ GEMM_TN(0, 1) // ALPHA != 1, BETA != 0
                   intel_sub_group_shuffle( _block.s6, _col),   \
                   intel_sub_group_shuffle( _block.s7, _col) )
 
+#if TYPE == TYPE_HALF
 #define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB )    \
         {   \
             const Dtype8    acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 );    \
@@ -334,6 +489,14 @@ GEMM_TN(0, 1) // ALPHA != 1, BETA != 0
             const Dtype8    acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 );    \
             const Dtype8    acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 );    \
             const Dtype8    acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 );    \
+            const Dtype8    acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 );    \
+            const Dtype8    acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 );    \
+            const Dtype8    acola = TRANSPOSE_BLOCK_8( _blockA, 10 );    \
+            const Dtype8    acolb = TRANSPOSE_BLOCK_8( _blockA, 11 );    \
+            const Dtype8    acolc = TRANSPOSE_BLOCK_8( _blockA, 12 );    \
+            const Dtype8    acold = TRANSPOSE_BLOCK_8( _blockA, 13 );    \
+            const Dtype8    acole = TRANSPOSE_BLOCK_8( _blockA, 14 );    \
+            const Dtype8    acolf = TRANSPOSE_BLOCK_8( _blockA, 15 );    \
             _result = mad( (Dtype8)_blockB.s0, acol0, _result );      \
             _result = mad( (Dtype8)_blockB.s1, acol1, _result );      \
             _result = mad( (Dtype8)_blockB.s2, acol2, _result );      \
@@ -342,8 +505,80 @@ GEMM_TN(0, 1) // ALPHA != 1, BETA != 0
             _result = mad( (Dtype8)_blockB.s5, acol5, _result );      \
             _result = mad( (Dtype8)_blockB.s6, acol6, _result );      \
             _result = mad( (Dtype8)_blockB.s7, acol7, _result );      \
+            _result = mad( (Dtype8)_blockB.s8, acol8, _result );      \
+            _result = mad( (Dtype8)_blockB.s9, acol9, _result );      \
+            _result = mad( (Dtype8)_blockB.sa, acola, _result );      \
+            _result = mad( (Dtype8)_blockB.sb, acolb, _result );      \
+            _result = mad( (Dtype8)_blockB.sc, acolc, _result );      \
+            _result = mad( (Dtype8)_blockB.sd, acold, _result );      \
+            _result = mad( (Dtype8)_blockB.se, acole, _result );      \
+            _result = mad( (Dtype8)_blockB.sf, acolf, _result );      \
         }
+#else
+#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB )    \
+        {   \
+            const Dtype8    acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 );    \
+            const Dtype8    acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 );    \
+            const Dtype8    acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 );    \
+            const Dtype8    acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 );    \
+            const Dtype8    acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 );    \
+            const Dtype8    acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 );    \
+            const Dtype8    acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 );    \
+            const Dtype8    acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 );    \
+            _result = mad( (Dtype8)_blockB.s0, acol0, _result );      \
+            _result = mad( (Dtype8)_blockB.s1, acol1, _result );      \
+            _result = mad( (Dtype8)_blockB.s2, acol2, _result );      \
+            _result = mad( (Dtype8)_blockB.s3, acol3, _result );      \
+            _result = mad( (Dtype8)_blockB.s4, acol4, _result );      \
+            _result = mad( (Dtype8)_blockB.s5, acol5, _result );      \
+            _result = mad( (Dtype8)_blockB.s6, acol6, _result );      \
+            _result = mad( (Dtype8)_blockB.s7, acol7, _result );      \
+        }
+#endif
 
+#if TYPE == TYPE_HALF
+#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
+__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
+__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
+__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
+    __read_only image2d_t A, \
+    MATB_PARAMETER, \
+    MATC_PARAMETER, \
+    KERNEL_ARG_DTYPE alpha_in, \
+    KERNEL_ARG_DTYPE beta_in, \
+    int padded_k, \
+    int k, \
+    int isFirstColBlock) \
+{ \
+    const Dtype alpha = (Dtype)alpha_in; \
+    const Dtype beta = (Dtype)beta_in; \
+    const int group_x = get_group_id(0); \
+    const int group_y = get_group_id(1); \
+    Dtype8 blockAxB00 = 0; \
+    Dtype8 blockAxB01 = 0; \
+    Dtype8 blockAxB02 = 0; \
+    Dtype8 blockAxB03 = 0; \
+    int2    coordA = (int2)( 0, group_y * TILE_M ); \
+    int2    coordB = (int2)( 0, ( group_x * TILE_N )); \
+    const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
+    do \
+    { \
+        Dtype16 blockB00; \
+        BLOCKB_READ8(blockB00, B, coordB); \
+        int2    coordATemp = coordA; \
+        Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordATemp.y += 8; \
+        Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordATemp.y += 8; \
+        Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordATemp.y += 8; \
+        Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \
+        MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \
+        MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \
+        MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \
+        MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \
+    } \
+    while( coordB.x < padded_k / VECSIZE ); \
+    GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
+}
+#else
 #define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
 __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
 __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
@@ -385,12 +620,23 @@ __kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dt
     while( coordB.x < padded_k / VECSIZE ); \
     GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
 }
+#endif
 
+#if TYPE == TYPE_HALF
+#define BLOCKB_READ8(_blockb, _B, _coordB) \
+        int2 _coordBTemp = _coordB; \
+        _coordBTemp.y += get_local_id(0); \
+        _blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s89ab = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.scdef = READ_IMAGE(_B, _coordBTemp); _coordB.x += 4;
+#else
 #define BLOCKB_READ8(_blockb, _B, _coordB) \
         int2 _coordBTemp = _coordB; \
         _coordBTemp.y += get_local_id(0); \
         _blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
         _blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;
+#endif
 
 #define MATB_PARAMETER __read_only image2d_t B
 
@@ -401,12 +647,21 @@ GEMM_NT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0
 #undef BLOCKB_READ8
 #undef MATB_PARAMETER
 
+#if TYPE == TYPE_HALF
+#define BLOCKB_READ8(_blockb, _B, _coordB) \
+        int2 _coordBTemp = _coordB; \
+        _coordBTemp.y += get_local_id(0); \
+        const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \
+        _blockb = as_Dtype16(as_ushort16(vload8(0, B_read))); \
+        _coordB.x += TILE_K * 2;
+#else
 #define BLOCKB_READ8(_blockb, _B, _coordB) \
         int2 _coordBTemp = _coordB; \
         _coordBTemp.y += get_local_id(0); \
         const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \
         _blockb = vload8(0, B_read); \
         _coordB.x += TILE_K;
+#endif
 
 #define MATB_PARAMETER __global Dtype *B, int offB, int ldb
 
@@ -417,6 +672,45 @@ GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0
 #undef BLOCKB_READ8
 #undef MATB_PARAMETER
 
+#if TYPE == TYPE_HALF
+#define BLOCKB_READ8(_blockb, _B, _coordB) \
+        int2 _coordBTemp = _coordB; \
+        _coordBTemp.y += get_local_id(0); \
+        Dtype4 temp; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s0 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s1 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s2 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s3 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s4 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s5 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s6 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s7 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s8 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.s9 = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.sa = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.sb = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+         _blockb.sc = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.sd = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.se = temp.s0; \
+        temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
+        _blockb.sf = temp.s0; \
+        _coordB.x += 16;
+#else
 #define BLOCKB_READ8(_blockb, _B, _coordB) \
         int2 _coordBTemp = _coordB; \
         _coordBTemp.y += get_local_id(0); \
@@ -438,6 +732,7 @@ GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0
         temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
         _blockb.s7 = temp.s0; \
         _coordB.x += 8;
+#endif
 
 #define MATB_PARAMETER __read_only image2d_t B
 
@@ -483,6 +778,47 @@ GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0
             _result = mad( (Dtype8)_blockB.s7, acol7, _result );      \
         }
 
+#if TYPE == TYPE_HALF
+#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
+__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
+__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
+__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
+    __read_only image2d_t A, \
+    MATB_PARAMETER, \
+    MATC_PARAMETER, \
+    KERNEL_ARG_DTYPE alpha_in, \
+    KERNEL_ARG_DTYPE beta_in, \
+    int padded_k, \
+    int k, \
+    int isFirstColBlock) \
+{ \
+    const Dtype alpha = (Dtype)alpha_in; \
+    const Dtype beta = (Dtype)beta_in; \
+    const int group_x = get_group_id(0); \
+    const int group_y = get_group_id(1); \
+    Dtype8 blockAxB00 = 0; \
+    Dtype8 blockAxB01 = 0; \
+    Dtype8 blockAxB02 = 0; \
+    Dtype8 blockAxB03 = 0; \
+    int2    coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \
+    int2    coordB = (int2)( 0, ( group_x * TILE_N )); \
+    const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
+    do \
+    { \
+        Dtype8 blockB00;             \
+        BLOCKB_READ8(blockB00, B, coordB); \
+        int2    coordATemp = coordA; \
+        Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordATemp.x += 16 * SIZE_OF_ELEMENT;\
+        Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) );    coordA.y += TILE_K;\
+        MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \
+        MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \
+        MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \
+        MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \
+    } \
+    while( coordB.x < padded_k / VECSIZE ); \
+    GEMM_OUTPUT(ALPHA1, BETA_NOT0);\
+}
+#else
 #define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
 __attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
 __attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
@@ -524,6 +860,7 @@ __kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, D
     while( coordB.x < padded_k / VECSIZE ); \
     GEMM_OUTPUT(ALPHA1, BETA_NOT0);\
 }
+#endif
 
 #define BLOCKB_READ8(_blockb, _B, _coordB) \
         int2 _coordBTemp = _coordB; \
@@ -540,12 +877,21 @@ GEMM_TT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0
 #undef BLOCKB_READ8
 #undef MATB_PARAMETER
 
+#if TYPE == TYPE_HALF
+#define BLOCKB_READ8(_blockb, _B, _coordB) \
+        int2 _coordBTemp = _coordB; \
+        _coordBTemp.y += get_local_id(0); \
+        const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \
+        _blockb = as_Dtype8(as_ushort8(vload4(0, B_read))); \
+        _coordB.x += TILE_K;
+#else
 #define BLOCKB_READ8(_blockb, _B, _coordB) \
         int2 _coordBTemp = _coordB; \
         _coordBTemp.y += get_local_id(0); \
         const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \
         _blockb = vload8(0, B_read); \
         _coordB.x += TILE_K;
+#endif
 
 #define MATB_PARAMETER __global Dtype *B, int offB, int ldb
 
@@ -598,7 +944,7 @@ GEMM_TT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0
 #undef READ_IMAGE
 #undef SIZE_OF_ELEMENT
 
-__kernel void TEMPLATE(gemm_buffer_copy_image_transpose,Dtype)(
+__kernel void TEMPLATE(gemm_buffer_copy_image_transpose, Dtype)(
     __global Dtype* A,
     __write_only image2d_t ImA,
     int offA,
@@ -611,10 +957,14 @@ __kernel void TEMPLATE(gemm_buffer_copy_image_transpose,Dtype)(
     int2 coord_dst = (int2)(gidx, gidy);
     __global Dtype* A_off = A + offA;
     Dtype srcA = A_off[gidy * ldA + gidx];
+#if TYPE == TYPE_HALF
+    write_imageh(ImA, coord_dst, (Dtype4)srcA);
+#else
     write_imagef(ImA, coord_dst, (Dtype4)srcA);
+#endif
 }
 
-__kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose,Dtype)(
+__kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)(
     __global Dtype* A,
     __write_only image2d_t ImA,
     int offA,
@@ -625,6 +975,14 @@ __kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose,Dtype)(
     const int gidx = get_global_id(0);
     const int gidy = get_global_id(1);
     int2 coord_dst = (int2)(gidx, gidy);
+#if TYPE == TYPE_HALF
+    if (gidx >= width || gidy >= height) {
+      write_imageh(ImA, coord_dst, 0);
+      return;
+    }
+    __global Dtype* A_off = A + offA;
+    write_imageh(ImA, coord_dst, A_off[gidy * ldA + gidx]);
+#else
     if (gidx >= width || gidy >= height) {
       write_imageui(ImA, coord_dst, (uint4)0);
       return;
@@ -632,4 +990,5 @@ __kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose,Dtype)(
     __global Dtype* A_off = A + offA;
     uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * ldA + gidx]));
     write_imageui(ImA, coord_dst, srcA);
+#endif
 }
index b8f4eff..2be4f9f 100644 (file)
 //
 //M*/
 
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
+
 #define CONCAT(A,B) A##_##B
 #define TEMPLATE(name,type) CONCAT(name,type)
-#define Dtype float
+#define KERNEL_ARG_DTYPE float
 
-__kernel void TEMPLATE(axpy,Dtype)(const int n, const Dtype alpha, __global const Dtype* x,
+__kernel void TEMPLATE(axpy,Dtype)(const int n, const KERNEL_ARG_DTYPE alpha, __global const Dtype* x,
                                    const int offx, __global Dtype* y,
                                    const int offy) {
   for (int index = get_global_id(0); index < n; index += get_global_size(0)) {
     Dtype src = x[offx + index];
     Dtype dst = y[offy + index];
-    y[offy + index] = alpha * src + dst;
+    y[offy + index] = convert_Dtype(alpha) * src + dst;
   }
 }
index 0dabd62..849c490 100644 (file)
 //
 //M*/
 
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
+
 #define CONCAT(A,B) A##_##B
 #define TEMPLATE(name,type) CONCAT(name,type)
-#define Dtype float
+#define KERNEL_ARG_DTYPE float
 
 __kernel void TEMPLATE(matvec_mul4,Dtype)(
-          __global const float * A,
+          __global const Dtype * A,
           int offA,
           unsigned int A_col_size,
           unsigned int trail_item,
-          __global const float * v,
+          __global const Dtype * v,
           int offv,
-          float alpha,
-          float beta,
-          __global float4 * result,
+          KERNEL_ARG_DTYPE alpha,
+          KERNEL_ARG_DTYPE beta,
+          __global Dtype4* result,
           int offr,
-          __local float4 * work)
+          __local Dtype4* work)
 {
   unsigned int row_gid = get_group_id(0);
   unsigned int lid = get_local_id(0);
-  const __global float *src0_read = A + row_gid * 4 * A_col_size + offA;
-  const __global float *src1_read = v + offv;
-  result = (__global float4*)((__global float*)result + offr);
-  float4 dot0 = (float4)(0.f);
-  float4 dot1 = (float4)(0.f);
-  float4 dot2 = (float4)(0.f);
-  float4 dot3 = (float4)(0.f);
+  const __global Dtype *src0_read = A + row_gid * 4 * A_col_size + offA;
+  const __global Dtype *src1_read = v + offv;
+  result = (__global Dtype4*)((__global Dtype*)result + offr);
+  Dtype4 dot0 = (Dtype4)(0.f);
+  Dtype4 dot1 = (Dtype4)(0.f);
+  Dtype4 dot2 = (Dtype4)(0.f);
+  Dtype4 dot3 = (Dtype4)(0.f);
 
   unsigned int i = lid;
   while( i < A_col_size / 4) {
-    const float4 a0 = vload4(i, src0_read);
-    const float4 a1 = vload4(i, src0_read + A_col_size);
-    const float4 a2 = vload4(i, src0_read + 2 * A_col_size);
-    const float4 a3 = vload4(i, src0_read + 3 * A_col_size);
+    const Dtype4 a0 = vload4(i, src0_read);
+    const Dtype4 a1 = vload4(i, src0_read + A_col_size);
+    const Dtype4 a2 = vload4(i, src0_read + 2 * A_col_size);
+    const Dtype4 a3 = vload4(i, src0_read + 3 * A_col_size);
 
-    const float4 b0 = vload4(i, src1_read);
+    const Dtype4 b0 = vload4(i, src1_read);
 
     dot0 += a0 * b0;
     dot1 += a1 * b0;
@@ -92,15 +96,15 @@ __kernel void TEMPLATE(matvec_mul4,Dtype)(
   {
     if(trail_item != 0)
     {
-      const __global float *src0_trail = src0_read + i * 4;
-      const __global float *src1_trail = src1_read + i * 4;
+      const __global Dtype *src0_trail = src0_read + i * 4;
+      const __global Dtype *src1_trail = src1_read + i * 4;
       for(unsigned int i = 0; i < trail_item; ++i) {
-        const float at0 = src0_trail[i];
-        const float at1 = src0_trail[i + A_col_size];
-        const float at2 = src0_trail[i + 2 * A_col_size];
-        const float at3 = src0_trail[i + 3 * A_col_size];
+        const Dtype at0 = src0_trail[i];
+        const Dtype at1 = src0_trail[i + A_col_size];
+        const Dtype at2 = src0_trail[i + 2 * A_col_size];
+        const Dtype at3 = src0_trail[i + 3 * A_col_size];
 
-        const float bt = src1_trail[i];
+        const Dtype bt = src1_trail[i];
 
         work[lid].s0 += at0 * bt;
         work[lid].s1 += at1 * bt;
@@ -118,40 +122,40 @@ __kernel void TEMPLATE(matvec_mul4,Dtype)(
   }
   if(lid == 0) {
     if(beta == (Dtype)0)
-      result[row_gid] = alpha * work[0];
+      result[row_gid] = convert_Dtype(alpha) * work[0];
     else
-      result[row_gid] = alpha * work[0] + beta * result[row_gid];
+      result[row_gid] = convert_Dtype(alpha) * work[0] + convert_Dtype(beta) * result[row_gid];
   }
 }
 
 /* This kernel used for the trailing rows when row_of_A %4 !=0 */
 __kernel void TEMPLATE(matvec_mul1,Dtype)(
-          __global const float * A,
+          __global const Dtype * A,
           int offA,
           unsigned int A_col_size,
           unsigned int row_offset,
           unsigned int trail_item,
-          __global const float * v,
+          __global const Dtype * v,
           int offv,
-          float alpha,
-          float beta,
-          __global float * result,
+          KERNEL_ARG_DTYPE alpha,
+          KERNEL_ARG_DTYPE beta,
+          __global Dtype * result,
           int offr,
-          __local float * work)
+          __local Dtype * work)
 {
   unsigned int row_gid = get_group_id(0);
   unsigned int lid = get_local_id(0);
 
-  const __global float *src0_read = A + (row_offset + row_gid) * A_col_size + offA;
-  const __global float *src1_read = v + + offv;
+  const __global Dtype *src0_read = A + (row_offset + row_gid) * A_col_size + offA;
+  const __global Dtype *src1_read = v + + offv;
   result = result + offr;
-  float4 dot0 = (float4)(0.f);
+  Dtype4 dot0 = (Dtype4)(0.f);
 
   unsigned int i = lid;
   while( i < A_col_size / 4)
   {
-    const float4 a0 = vload4(i, src0_read);
-    const float4 b0 = vload4(i, src1_read);
+    const Dtype4 a0 = vload4(i, src0_read);
+    const Dtype4 b0 = vload4(i, src1_read);
 
     dot0 += a0 * b0;
     i += get_local_size(0);
@@ -163,11 +167,11 @@ __kernel void TEMPLATE(matvec_mul1,Dtype)(
   {
     if(trail_item != 0)
     {
-      const __global float *src0_trail = src0_read + i * 4;
-      const __global float *src1_trail = src1_read + i * 4;
+      const __global Dtype *src0_trail = src0_read + i * 4;
+      const __global Dtype *src1_trail = src1_read + i * 4;
       for(unsigned int i = 0; i < trail_item; ++i) {
-        const float at0 = src0_trail[i];
-        const float bt = src1_trail[i];
+        const Dtype at0 = src0_trail[i];
+        const Dtype bt = src1_trail[i];
 
         work[lid] += at0 * bt;
       }
@@ -182,10 +186,10 @@ __kernel void TEMPLATE(matvec_mul1,Dtype)(
 
   if(lid == 0) {
     if(beta == (Dtype)0) {
-      result[row_gid+row_offset] = alpha * work[0];
+      result[row_gid+row_offset] = convert_Dtype(alpha) * work[0];
     } else {
-      result[row_gid+row_offset] *= beta;
-      result[row_gid+row_offset] += alpha * work[0];
+      result[row_gid+row_offset] *= convert_Dtype(beta);
+      result[row_gid+row_offset] += convert_Dtype(alpha) * work[0];
     }
   }
 }
index 9f8ab57..49a8ebb 100644 (file)
 //
 //M*/
 
-#define Dtype float
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
+
+#define Dtype  float
 #define Dtype4 float4
 #define Dtype8 float8
 
@@ -135,17 +139,17 @@ __kernel void MVN(__global const Dtype* src,
     store(dst_vec, dst, index);
 }
 
-__kernel void MEAN_FUSE(__global const Dtype * A,
+__kernel void MEAN_FUSE(__global const T * A,
                         unsigned int A_col_size,
                         float alpha,
-                        __global Dtype4 * result,
-                        __global Dtype * B,
+                        __global T4 * mean,
+                        __global Dtype * tmp,
                         __local Dtype4 * work)
 {
     unsigned int row_gid = get_group_id(0);
     unsigned int lid = get_local_id(0);
-    const __global Dtype *src0_read = A + row_gid * 4 * A_col_size;
-    __global Dtype *dst0_read = B + row_gid * 4 * A_col_size;
+    const __global T *src0_read = A + row_gid * 4 * A_col_size;
+    __global Dtype *dst0_read = tmp + row_gid * 4 * A_col_size;
     Dtype4 dot0, dot1, dot2, dot3;
     dot0 = dot1 = dot2 = dot3 = (Dtype4)(0.f);
 
@@ -153,15 +157,15 @@ __kernel void MEAN_FUSE(__global const Dtype * A,
     const Dtype4 b0 = (Dtype4)1.f;
     while( i < A_col_size / 4)
     {
-        const Dtype4 a0 = vload4(i, src0_read);
-        const Dtype4 a1 = vload4(i, src0_read + A_col_size);
-        const Dtype4 a2 = vload4(i, src0_read + 2 * A_col_size);
-        const Dtype4 a3 = vload4(i, src0_read + 3 * A_col_size);
+        const T4 a0 = vload4(i, src0_read);
+        const T4 a1 = vload4(i, src0_read + A_col_size);
+        const T4 a2 = vload4(i, src0_read + 2 * A_col_size);
+        const T4 a3 = vload4(i, src0_read + 3 * A_col_size);
 
-        dot0 += a0;
-        dot1 += a1;
-        dot2 += a2;
-        dot3 += a3;
+        dot0 += convert_float4(a0);
+        dot1 += convert_float4(a1);
+        dot2 += convert_float4(a2);
+        dot3 += convert_float4(a3);
 
         i += get_local_size(0);
     }
@@ -181,22 +185,22 @@ __kernel void MEAN_FUSE(__global const Dtype * A,
 
     if(lid == 0)
     {
-        result[row_gid] = alpha * work[0];
+        mean[row_gid] = convert_T(alpha * work[0]);
     }
 
     Dtype4 sum = work[0] * alpha;
     i = lid;
     while( i < A_col_size / 4)
     {
-        const Dtype4 a0 = vload4(i, src0_read);
-        const Dtype4 a1 = vload4(i, src0_read + A_col_size);
-        const Dtype4 a2 = vload4(i, src0_read + 2 * A_col_size);
-        const Dtype4 a3 = vload4(i, src0_read + 3 * A_col_size);
+        const T4 a0 = vload4(i, src0_read);
+        const T4 a1 = vload4(i, src0_read + A_col_size);
+        const T4 a2 = vload4(i, src0_read + 2 * A_col_size);
+        const T4 a3 = vload4(i, src0_read + 3 * A_col_size);
 
-        dot0 = native_powr(a0 - (Dtype4)sum.x, 2);
-        dot1 = native_powr(a1 - (Dtype4)sum.y, 2);
-        dot2 = native_powr(a2 - (Dtype4)sum.z, 2);
-        dot3 = native_powr(a3 - (Dtype4)sum.w, 2);
+        dot0 = native_powr(convert_float4(a0) - (Dtype4)sum.x, 2);
+        dot1 = native_powr(convert_float4(a1) - (Dtype4)sum.y, 2);
+        dot2 = native_powr(convert_float4(a2) - (Dtype4)sum.z, 2);
+        dot3 = native_powr(convert_float4(a3) - (Dtype4)sum.w, 2);
 
         vstore4(dot0, i, dst0_read);
         vstore4(dot1, i, dst0_read + A_col_size);
@@ -208,22 +212,22 @@ __kernel void MEAN_FUSE(__global const Dtype * A,
 }
 
 __kernel void MVN_FUSE(__global const Dtype * tmp,
-                       __global const Dtype * A,
-                       __global const Dtype4 * mean,
+                       __global const T * A,
+                       __global const T4 * mean,
                        unsigned int A_col_size,
                        const float alpha_val,
                        const float eps,
                        const float relu_slope,
                        __global const Dtype4 * bnorm_weight,
                        __global const Dtype4 * bnorm_bias,
-                       __global Dtype * B,
+                       __global T * B,
                        __local Dtype4 * work)
 {
     unsigned int row_gid = get_group_id(0);
     unsigned int lid = get_local_id(0);
     const __global Dtype *src0_read = tmp + row_gid * 4 * A_col_size;
-    const __global Dtype *src1_read = A + row_gid * 4 * A_col_size;
-    __global Dtype *dst0_read = B + row_gid * 4 * A_col_size;
+    const __global T *src1_read = A + row_gid * 4 * A_col_size;
+    __global T *dst0_read = B + row_gid * 4 * A_col_size;
     Dtype4 dot0, dot1, dot2, dot3;
     dot0 = dot1 = dot2 = dot3 = (Dtype4)(0.f);
 
@@ -257,7 +261,7 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
     }
     barrier(CLK_LOCAL_MEM_FENCE);
 
-    Dtype4 mean_val = mean[row_gid];
+    Dtype4 mean_val = convert_float4(mean[row_gid]);
     Dtype4 dev_val = sqrt(work[0] * alpha_val) + (Dtype4)eps;
     Dtype4 alpha = (Dtype4)1.f / dev_val;
 
@@ -271,15 +275,15 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
     i = lid;
     while( i < A_col_size / 4)
     {
-        const Dtype4 a0 = vload4(i, src1_read);
-        const Dtype4 a1 = vload4(i, src1_read + A_col_size);
-        const Dtype4 a2 = vload4(i, src1_read + 2 * A_col_size);
-        const Dtype4 a3 = vload4(i, src1_read + 3 * A_col_size);
+        const T4 a0 = vload4(i, src1_read);
+        const T4 a1 = vload4(i, src1_read + A_col_size);
+        const T4 a2 = vload4(i, src1_read + 2 * A_col_size);
+        const T4 a3 = vload4(i, src1_read + 3 * A_col_size);
 
-        dot0 = (a0 - (Dtype4)mean_val.x) * alpha.x;
-        dot1 = (a1 - (Dtype4)mean_val.y) * alpha.y;
-        dot2 = (a2 - (Dtype4)mean_val.z) * alpha.z;
-        dot3 = (a3 - (Dtype4)mean_val.w) * alpha.w;
+        dot0 = (convert_float4(a0) - (Dtype4)mean_val.x) * alpha.x;
+        dot1 = (convert_float4(a1) - (Dtype4)mean_val.y) * alpha.y;
+        dot2 = (convert_float4(a2) - (Dtype4)mean_val.z) * alpha.z;
+        dot3 = (convert_float4(a3) - (Dtype4)mean_val.w) * alpha.w;
 
         dot0 = dot0 * w.x + (Dtype4)b.x;
         dot1 = dot1 * w.y + (Dtype4)b.y;
@@ -300,10 +304,10 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
         dot3 = select(new3, dot3, dot3 > (Dtype4)0.f);
 #endif
 
-        vstore4(dot0, i, dst0_read);
-        vstore4(dot1, i, dst0_read + A_col_size);
-        vstore4(dot2, i, dst0_read + 2 * A_col_size);
-        vstore4(dot3, i, dst0_read + 3 * A_col_size);
+        vstore4(convert_T(dot0), i, dst0_read);
+        vstore4(convert_T(dot1), i, dst0_read + A_col_size);
+        vstore4(convert_T(dot2), i, dst0_read + 2 * A_col_size);
+        vstore4(convert_T(dot3), i, dst0_read + 3 * A_col_size);
 
         i += get_local_size(0);
     }
index 58477ce..36d9d2a 100644 (file)
 
 #define CONCAT(A,B) A##_##B
 #define TEMPLATE(name,type) CONCAT(name,type)
-#define Dtype float
+#define KERNEL_ARG_DTYPE float
+
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
 
 __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int nthreads, __global const Dtype* in,
                              const int num, const int channels,
                              const int height, const int width, const int size,
-                             const Dtype alpha_over_size, const Dtype k,
+                             const KERNEL_ARG_DTYPE alpha_over_size, const KERNEL_ARG_DTYPE k,
                              __global Dtype* const out,
-                             const Dtype negative_beta) {
+                             const KERNEL_ARG_DTYPE negative_beta) {
   for (int index = get_global_id(0); index < nthreads;
       index += get_global_size(0)) {
     // find out the local offset
@@ -60,11 +64,11 @@ __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int nthreads, __global con
     const int step = height * width;
     __global const Dtype* in_off = in + offset;
     __global Dtype* out_off = out + offset;
-    Dtype scale_val;
+    KERNEL_ARG_DTYPE scale_val;
     int head = 0;
     const int pre_pad = (size - 1) / 2;
     const int post_pad = size - pre_pad - 1;
-    Dtype accum_scale = 0;
+    KERNEL_ARG_DTYPE accum_scale = 0;
     // fill the scale at [n, :, h, w]
     // accumulate values
     while (head < post_pad && head < channels) {
@@ -79,7 +83,7 @@ __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int nthreads, __global con
             * in_off[(head - size) * step];
       }
       scale_val = k + accum_scale * alpha_over_size;
-      out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta);
+      out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta);
       ++head;
     }
     // subtract only
@@ -89,7 +93,7 @@ __kernel void TEMPLATE(lrn_full_no_scale,Dtype)(const int nthreads, __global con
             * in_off[(head - size) * step];
       }
       scale_val = k + accum_scale * alpha_over_size;
-      out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((float)scale_val, (float)negative_beta);
+      out_off[(head - post_pad) * step] = in_off[(head - post_pad) * step] * (Dtype)native_powr((Dtype)scale_val, (Dtype)negative_beta);
       ++head;
     }
   }
index 13e4319..e9d1d26 100644 (file)
 
 #define CONCAT(A,B) A##_##B
 #define TEMPLATE(name,type) CONCAT(name,type)
-#define Dtype float
+
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
 
 #if defined KERNEL_MAX_POOL
 
index 37ba17c..5f96a4e 100644 (file)
@@ -40,9 +40,9 @@
 //
 //M*/
 
-#define Dtype float
-#define Dtype4 float4
-#define Dtype8 float8
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
 
 __kernel void slice(__global const Dtype* src,
                     const int src_plane_size,
index 54cf489..6b525e2 100644 (file)
  * POSSIBILITY OF SUCH DAMAGE.
  **************************************************************************************/
 
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
+
 __kernel void kernel_channel_max(const int num, const int channels,
     const int spatial_dim, __global const T* data, __global T* out) {
   int index = get_global_id(0);
@@ -40,12 +44,12 @@ __kernel void kernel_channel_max(const int num, const int channels,
 
 __kernel void kernel_channel_subtract(const int count,
     const int num, const int channels,
-    const int spatial_dim, __global const T* channel_max, __global T* data) {
+    const int spatial_dim, __global const T* channel_max, __global const T* src, __global T* data) {
   int index = get_global_id(0);
   if(index < count) {
     int n = index / channels / spatial_dim;
     int s = index % spatial_dim;
-    data[index] -= channel_max[n * spatial_dim + s];
+    data[index] = exp(src[index] - channel_max[n * spatial_dim + s]);
   }
 }
 
index 28a43ae..8ea52cf 100644 (file)
 
 #define CONCAT(A,B) A##_##B
 #define TEMPLATE(name,type) CONCAT(name,type)
-#define Dtype float
 
 #if defined(cl_intel_subgroups)
 #pragma OPENCL EXTENSION  cl_intel_subgroups : enable
 #endif
 
+#if defined(cl_khr_fp16)
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
+
 __kernel void TEMPLATE(softmax_forward_slm,Dtype)(const int num, const int channels,
                                    const int spatial_dim,
                                    __global Dtype* scale,
@@ -60,12 +63,12 @@ __kernel void TEMPLATE(softmax_forward_slm,Dtype)(const int num, const int chann
   int n = get_global_id(1);
   for (int index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index +=
       get_global_size(0), ++s) {
-    float maxval = -FLT_MAX;
+    Dtype maxval = -DTYPE_MAX;
     for (int c = get_global_id(0); c < channels; c += get_global_size(0)) {
       Dtype tmp = data[(n * channels + c) * spatial_dim + s];
       maxval = max((Dtype)tmp, (Dtype)maxval);
     }
-    maxval = sub_group_reduce_max(maxval * 100000);
+    maxval = sub_group_reduce_max(maxval);
     //if (get_sub_group_local_id() == 0)
     group_tmp[get_sub_group_id() * spatial_dim + s] = maxval;
   }
@@ -77,7 +80,7 @@ __kernel void TEMPLATE(softmax_forward_slm,Dtype)(const int num, const int chann
     int s = index / get_max_sub_group_size();
     Dtype maxval = sub_group_reduce_max(group_tmp[get_sub_group_local_id() * spatial_dim + s]);
     //if (get_sub_group_local_id() == 0)
-    scale_tmp[s] = maxval / 100000;
+    scale_tmp[s] = maxval;
   }
 
   barrier(CLK_LOCAL_MEM_FENCE);
@@ -95,7 +98,7 @@ __kernel void TEMPLATE(softmax_forward_slm,Dtype)(const int num, const int chann
     for (int c = get_global_id(0); c < channels; c += get_global_size(0)) {
       sum += out_tmp[c * spatial_dim + s];
     }
-    sum = sub_group_reduce_add(sum * 100000);
+    sum = sub_group_reduce_add(sum);
     group_tmp[get_sub_group_id() * spatial_dim + s] = sum;
   }
   barrier(CLK_LOCAL_MEM_FENCE);
@@ -105,7 +108,7 @@ __kernel void TEMPLATE(softmax_forward_slm,Dtype)(const int num, const int chann
     int s = index / get_max_sub_group_size();
     Dtype sum = sub_group_reduce_add(group_tmp[get_sub_group_local_id() * spatial_dim + s]);
     //if (get_sub_group_local_id() == 0)
-    scale_tmp[s] = sum / 100000;
+    scale_tmp[s] = sum;
   }
   barrier(CLK_LOCAL_MEM_FENCE);
 
@@ -130,12 +133,12 @@ __kernel void TEMPLATE(softmax_forward,Dtype)(const int num, const int channels,
   __global Dtype *group_tmp = scale + spatial_dim * num + n * get_max_sub_group_size() * spatial_dim;
   for (int index = get_global_id(0), s = 0; index < spatial_dim * get_local_size(0); index +=
       get_global_size(0), ++s) {
-    float maxval = -FLT_MAX;
+    Dtype maxval = -DTYPE_MAX;
     for (int c = get_global_id(0); c < channels; c += get_global_size(0)) {
       Dtype tmp = data[(n * channels + c) * spatial_dim + s];
       maxval = max((Dtype)tmp, (Dtype)maxval);
     }
-    maxval = sub_group_reduce_max(maxval * 100000);
+    maxval = sub_group_reduce_max(maxval);
     //if (get_sub_group_local_id() == 0)
     group_tmp[get_sub_group_id() * spatial_dim + s] = maxval;
   }
@@ -146,7 +149,7 @@ __kernel void TEMPLATE(softmax_forward,Dtype)(const int num, const int channels,
     int s = index / get_max_sub_group_size();
     Dtype maxval = sub_group_reduce_max(group_tmp[get_sub_group_local_id() * spatial_dim + s]);
     //if (get_sub_group_local_id() == 0)
-    scale[n * spatial_dim + s] = maxval / 100000;
+    scale[n * spatial_dim + s] = maxval;
   }
 
   barrier(CLK_GLOBAL_MEM_FENCE);
@@ -164,7 +167,7 @@ __kernel void TEMPLATE(softmax_forward,Dtype)(const int num, const int channels,
     for (int c = get_global_id(0); c < channels; c += get_global_size(0)) {
       sum += out[n * channels * spatial_dim + c * spatial_dim + s];
     }
-    sum = sub_group_reduce_add(sum * 100000);
+    sum = sub_group_reduce_add(sum);
     group_tmp[get_sub_group_id() * spatial_dim + s] = sum;
   }
   barrier(CLK_GLOBAL_MEM_FENCE);
@@ -174,7 +177,7 @@ __kernel void TEMPLATE(softmax_forward,Dtype)(const int num, const int channels,
     int s = index / get_max_sub_group_size();
     Dtype sum = sub_group_reduce_add(group_tmp[get_sub_group_local_id() * spatial_dim + s]);
     //if (get_sub_group_local_id() == 0)
-    scale[n * spatial_dim + s] = sum / 100000;
+    scale[n * spatial_dim + s] = sum;
   }
   barrier(CLK_GLOBAL_MEM_FENCE);