Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_softmax.hpp
index c82f5b2..8023785 100644 (file)
 * limitations under the License.
 *******************************************************************************/
 
-#ifndef CPU_REF_SOFTMAX_FWD_HPP
-#define CPU_REF_SOFTMAX_FWD_HPP
+#ifndef CPU_REF_SOFTMAX_HPP
+#define CPU_REF_SOFTMAX_HPP
 
 #include <assert.h>
 
 #include "c_types_map.hpp"
-#include "cpu_softmax_pd.hpp"
-#include "cpu_engine.hpp"
+#include "memory_tracking.hpp"
 #include "type_helpers.hpp"
 #include "utils.hpp"
 
+#include "cpu_softmax_pd.hpp"
+
 namespace mkldnn {
 namespace impl {
 namespace cpu {
@@ -49,63 +50,68 @@ struct ref_softmax_fwd_t: public cpu_primitive_t {
                 && attr()->has_default_values();
             if (!ok) return status::unimplemented;
 
+            init_scratchpad();
+
             return status::success;
         }
+
+    private:
+        void init_scratchpad() {
+            const int inner_size = utils::array_product(
+                    desc()->data_desc.dims + desc()->softmax_axis + 1,
+                    desc()->data_desc.ndims - desc()->softmax_axis - 1);
+
+            if (inner_size > 1) {
+                auto scratchpad = scratchpad_registry().registrar();
+                scratchpad.book(memory_tracking::names::key_softmax_reduction,
+                        sizeof(data_t) * 2 * inner_size);
+            }
+        }
     };
 
-    ref_softmax_fwd_t(const pd_t *pd, const input_vector &inputs,
+    ref_softmax_fwd_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), ws_(nullptr) {
-        auto ndims = conf_.desc()->data_desc.ndims;
-        auto dims = conf_.desc()->data_desc.dims;
-        auto axis = conf_.desc()->softmax_axis;
+        : cpu_primitive_t(apd, inputs, outputs)
+    {
+        auto ndims = pd()->desc()->data_desc.ndims;
+        auto dims = pd()->desc()->data_desc.dims;
+        auto axis = pd()->desc()->softmax_axis;
 
         outer_size_ = utils::array_product(dims, axis);
         channels_ = dims[axis];
         inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1);
-        val_max_ = val_denom_ = 0;
-
-        if (inner_size_ > 1) {
-            ws_ = new data_t[2*inner_size_];
-            max_ = &ws_[0];
-            denom_ = &ws_[inner_size_];
-        } else {
-            max_ = &val_max_;
-            denom_ = &val_denom_;
-        }
 
-        const memory_desc_wrapper data_d(conf_.src_pd());
+        const memory_desc_wrapper data_d(pd()->src_pd());
         use_dense_ = inner_size_ == 1 && data_d.is_dense()
             && data_d.blocking_desc().block_dims[axis] == 1
             && data_d.blocking_desc().strides[0][axis] == 1;
     }
-    ~ref_softmax_fwd_t() { if (ws_) delete [] ws_; }
+    ~ref_softmax_fwd_t() {}
+
     typedef typename prec_traits<data_type>::type data_t;
 
-    virtual void execute(event_t *e) {
+    virtual void execute(event_t *e) const {
         if (use_dense_) execute_forward_dense();
         else execute_forward_generic();
         e->set_state(event_t::ready);
     }
 
 private:
-    void execute_forward_dense();
-    void execute_forward_generic();
+    void execute_forward_dense() const;
+    void execute_forward_generic() const;
 
-    void _max(int n, const data_t *x, data_t *max_data);
-    void _sub(int n, data_t alpha, const data_t *x, data_t *y);
-    void _exp(int n, const data_t *a, data_t *r);
-    void _exp_parallel(int n, const data_t *a, data_t *r);
-    void _sum(int n, const data_t *x, data_t *sum_data);
-    void _scal(int n, data_t alpha, data_t *x);
-    void _scal_parallel(int n, data_t alpha, data_t *x);
+    void _max(int n, const data_t *x, data_t *max_data) const;
+    void _sub(int n, data_t alpha, const data_t *x, data_t *y) const;
+    void _exp(int n, const data_t *a, data_t *r) const;
+    void _exp_parallel(int n, const data_t *a, data_t *r) const;
+    void _sum(int n, const data_t *x, data_t *sum_data) const;
+    void _scal(int n, data_t alpha, data_t *x) const;
+    void _scal_parallel(int n, data_t alpha, data_t *x) const;
 
-    pd_t conf_;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 
     bool use_dense_;
     int outer_size_, channels_, inner_size_;
-    data_t val_max_, val_denom_;
-    data_t *ws_, *max_, *denom_;
 };
 
 template <impl::data_type_t data_type>
@@ -132,20 +138,20 @@ struct ref_softmax_bwd_t: public cpu_primitive_t {
         }
     };
 
-    ref_softmax_bwd_t(const pd_t *pd, const input_vector &inputs,
+    ref_softmax_bwd_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {
-        auto dims = conf_.desc()->diff_desc.dims;
-        auto axis = conf_.desc()->softmax_axis;
-        auto ndims = conf_.desc()->diff_desc.ndims;
+        : cpu_primitive_t(apd, inputs, outputs) {
+        auto dims = pd()->desc()->diff_desc.dims;
+        auto axis = pd()->desc()->softmax_axis;
+        auto ndims = pd()->desc()->diff_desc.ndims;
 
         outer_size_ = utils::array_product(dims, axis);
         channels_ = dims[axis];
         inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1);
 
         // Diff desc as well as data desc whould be checked
-        const memory_desc_wrapper data_d(conf_.dst_pd());
-        const memory_desc_wrapper diff_d(conf_.diff_dst_pd());
+        const memory_desc_wrapper data_d(pd()->dst_pd());
+        const memory_desc_wrapper diff_d(pd()->diff_dst_pd());
         use_dense_ = true
             && inner_size_ == 1
             && diff_d == data_d
@@ -154,23 +160,22 @@ struct ref_softmax_bwd_t: public cpu_primitive_t {
             && diff_d.blocking_desc().strides[0][axis] == 1;
     }
     ~ref_softmax_bwd_t() {}
+
     typedef typename prec_traits<data_type>::type data_t;
 
-    virtual void execute(event_t *e) {
+    virtual void execute(event_t *e) const {
         if (use_dense_) execute_backward_dense();
         else execute_backward_generic();
         e->set_state(event_t::ready);
     }
 
 private:
-    void execute_backward_dense();
-    void execute_backward_generic();
+    void execute_backward_dense() const;
+    void execute_backward_generic() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 
-    pd_t conf_;
     bool use_dense_;
     int outer_size_, channels_, inner_size_;
-    data_t val_max_, val_denom_;
-    data_t *max_, *denom_;
 };