Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_1x1_conv_utils.hpp
index d360a14..a3ed769 100644 (file)
 #ifndef JIT_UNI_1x1_CONV_UTILS_HPP
 #define JIT_UNI_1x1_CONV_UTILS_HPP
 
+#include "memory_tracking.hpp"
 #include "mkldnn_thread.hpp"
-#include "utils.hpp"
 #include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
 
 #include "jit_generator.hpp"
 
@@ -29,6 +31,12 @@ namespace cpu {
 
 using namespace mkldnn::impl::utils;
 
+struct reduce_to_unit_stride_t {
+    convolution_desc_t conv_d_;
+    bool reduce_src_;
+    size_t space_per_thread_;
+};
+
 /* 1x1-kernel does not support non-unit strides so far, so the idea is:
  *  - for fwd or bwd_weights: to copy src to a scratch memory (with strides
  *    equal to 1) and then call the kernel
@@ -38,7 +46,7 @@ using namespace mkldnn::impl::utils;
 template <typename conv_pd_t>
 inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
         const memory_desc_t *&src_d, const memory_desc_t *dst_d) {
-    const bool is_bwd_data = self->cdesc()->prop_kind
+    const bool is_bwd_data = self->desc()->prop_kind
         == prop_kind::backward_data;
 
     const int ndims = src_d->ndims;
@@ -83,6 +91,22 @@ inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
     }
 }
 
+template <typename conv_pd_t>
+inline void rtus_prepare_space_info(conv_pd_t *self,
+        memory_tracking::registrar_t &scratchpad) {
+    const auto &jcp = self->jcp_;
+
+    const int max_threads = mkldnn_get_max_threads();
+    const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind,
+            jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking);
+    size_t typesize = types::data_type_size(
+            conv_prop_agnostic_src_d(self->desc())->data_type);
+
+    self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block;
+    scratchpad.book(memory_tracking::names::key_conv_rtus_space,
+            typesize * max_threads * self->rtus_.space_per_thread_);
+}
+
 template <cpu_isa_t isa>
 struct rtus_driver_t: public jit_generator {
 
@@ -246,62 +270,44 @@ struct rtus_driver_t: public jit_generator {
 
 template <cpu_isa_t isa, typename conv_t>
 inline void init_rtus_driver(conv_t *self) {
-    const auto &conf = self->conf_;
-    const auto &cd = *conf.cdesc();
-    const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data;
-    const int ndims = conf.ndims();
-
+    const auto &conf = *self->pd();
     if (!conf.rtus_.reduce_src_) return;
 
-    const int max_threads = mkldnn_get_max_threads();
-    size_t factor = 0;
-    switch (cd.prop_kind) {
-    case prop_kind::forward_training: case prop_kind::forward_inference:
-        factor = conf.jcp_.nb_reduce; break;
-    case prop_kind::backward_data:
-        factor = conf.jcp_.nb_load_blocking_max; break;
-    case prop_kind::backward_weights:
-        factor = conf.jcp_.nb_bcast_blocking; break;
-    default: assert(!"unsupported prop_kind");
-    }
-
-    size_t typesize = sizeof(decltype(*self->scratch_));
-
-    self->ws_per_thread_ = factor * conf.jcp_.is * conf.jcp_.ic_block;
-    self->scratch_ = (decltype(self->scratch_))malloc(
-            max_threads * self->ws_per_thread_ * typesize, 64);
-
+    const auto &cd = *conf.desc();
+    const int ndims = conf.ndims();
     const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0];
     const int stride_w = cd.strides[ndims - 3];
 
+    const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data;
     const auto &src_d = is_bwd_data ? *conf.diff_src_pd()->desc()
                                     : *conf.src_pd()->desc();
     assert((isa == avx2 && utils::one_of(src_d.format, memory_format::nCw8c,
         memory_format::nChw8c)) || (isa == avx512_common && utils::one_of(
             src_d.format, memory_format::nCw16c, memory_format::nChw16c)));
 
-    const int ih = (ndims == 3) ? 1 : src_d.dims[2];
+    const int ih = ndims == 3 ? 1 : src_d.dims[2];
     const int iw = src_d.dims[ndims - 1];
 
     const int src_step_h = stride_h * iw;
     const int src_step_icb = ih * iw;
     const int ws_step_icb = conf.jcp_.is;
     const bool src_to_ws = !is_bwd_data;
+    const size_t typesize = types::data_type_size(
+            conv_prop_agnostic_src_d(self->pd()->desc())->data_type);
+
     self->rtus_driver_ = new rtus_driver_t<isa>(iw, stride_w, src_step_h,
             src_step_icb, ws_step_icb, src_to_ws, typesize);
 }
 
-inline float loss_ratio(int amount, int divider)
-{
-    return float(rnd_up(amount, divider) - amount) / rnd_up(amount, divider);
-}
-
 inline int best_divider(int value, int min_divider, int max_divider,
-                        bool find_max, int step = 1)
+        bool find_max, int step = 1)
 {
     max_divider = nstl::max(1, nstl::min(max_divider, value));
     min_divider = nstl::max(1, nstl::min(min_divider, max_divider));
 
+    auto loss_ratio = [](int total, int chunk)
+    { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); };
+
     float min_loss = FLT_MAX;
     int x_divider = max_divider;
     for (int divider = max_divider; divider >= min_divider; divider -= step) {