dnn(ocl4dnn): support log softmax in ocl4dnn
[platform/upstream/opencv.git] / modules / dnn / src / ocl4dnn / src / ocl4dnn_softmax.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2017, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of the copyright holders may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "../../precomp.hpp"
43 #include <vector>
44 #include "common.hpp"
45 #include "ocl4dnn.hpp"
46 #include "opencl_kernels_dnn.hpp"
47
48 #ifdef HAVE_OPENCL
49 namespace cv { namespace dnn { namespace ocl4dnn {
50 template<typename Dtype>
51 OCL4DNNSoftmax<Dtype>::OCL4DNNSoftmax(OCL4DNNSoftmaxConfig config)
52 {
53     softmax_axis_ = config.axis;
54     channels_ = config.channels;
55     log_softmax_ = config.logsoftmax;
56
57     inner_num_ = 1;
58     outer_num_ = 1;
59     count_ = 1;
60     int32_t scale_sz = 1;
61     for (int32_t i = softmax_axis_ + 1; i < config.in_shape.size(); i++)
62         inner_num_ *= config.in_shape[i];
63     use_slm_ = (config.in_shape[softmax_axis_] * inner_num_ + inner_num_ * 17) <= 8192;
64     for (int32_t i = 0; i < softmax_axis_; i++)
65         outer_num_ *= config.in_shape[i];
66     count_ = inner_num_ + outer_num_;
67
68     std::vector<int32_t> scale_dims = config.in_shape;
69     scale_dims[softmax_axis_] = use_slm_ ? 1 : 17;
70     for (int32_t i = 0; i < scale_dims.size(); i++)
71         scale_sz *= scale_dims[i];
72
73     scale_data_.create(1, scale_sz, CV_32FC1);
74 }
75
76 template<typename Dtype>
77 OCL4DNNSoftmax<Dtype>::~OCL4DNNSoftmax()
78 {
79     scale_data_.release();
80 }
81
82 template<typename Dtype>
83 bool OCL4DNNSoftmax<Dtype>::Forward(const UMat& bottom, UMat& top)
84 {
85     bool ret = false;
86     ocl::Queue queue = ocl::Queue::getDefault();
87     bool intel_subgroup = ocl::Device::getDefault().intelSubgroupsSupport();
88     if (intel_subgroup && inner_num_ < 128)
89     {
90         String opts = clOptionSupport("-cl-no-subgroup-ifp") ? " -cl-no-subgroup-ifp " : "";
91         String kname;
92         ocl::Kernel oclk_softmax_forward_kernel;
93
94         if (log_softmax_) opts += " -DLOG_SOFTMAX ";
95         if (use_slm_)
96             kname = CL_KERNEL_SELECT("softmax_forward_slm");
97         else
98             kname = CL_KERNEL_SELECT("softmax_forward");
99
100         if (!oclk_softmax_forward_kernel.create(kname.c_str(), ocl::dnn::softmax_loss_oclsrc, opts))
101             return false;
102
103         size_t global_size[] = { 256, (size_t)outer_num_, 1 };
104         size_t local_size[] = { 256, 1, 1 };
105         cl_uint argIdx = 0;
106
107         if (use_slm_)
108         {
109             oclk_softmax_forward_kernel.set(argIdx++, outer_num_);
110             oclk_softmax_forward_kernel.set(argIdx++, channels_);
111             oclk_softmax_forward_kernel.set(argIdx++, inner_num_);
112             oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrWriteOnly(scale_data_));
113             oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom));
114             oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top));
115             oclk_softmax_forward_kernel.set(argIdx++, NULL, channels_ * inner_num_* sizeof(Dtype));
116             oclk_softmax_forward_kernel.set(argIdx++, NULL, inner_num_* sizeof(Dtype));
117             oclk_softmax_forward_kernel.set(argIdx++, NULL, 16 * inner_num_* sizeof(Dtype));
118         }
119         else
120         {
121             oclk_softmax_forward_kernel.set(argIdx++, outer_num_);
122             oclk_softmax_forward_kernel.set(argIdx++, channels_);
123             oclk_softmax_forward_kernel.set(argIdx++, inner_num_);
124             oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrWriteOnly(scale_data_));
125             oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom));
126             oclk_softmax_forward_kernel.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top));
127         }
128         ret = oclk_softmax_forward_kernel.run(3, global_size, local_size, false);
129     }
130     return ret;
131 }
132
133 template class OCL4DNNSoftmax<float>;
134 } // namespace ocl4dnn
135 }
136 }
137 #endif // HAVE_OPENCL