Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_x8s8s32x_convolution.hpp
index 3bc0cc4..e7943ac 100644 (file)
 #define GEMM_X8S8S32X_CONVOLUTION_HPP
 
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
 #include "cpu_convolution_pd.hpp"
 #include "cpu_engine.hpp"
 #include "jit_primitive_conf.hpp"
+#include "jit_generator.hpp"
 #include "gemm_convolution_utils.hpp"
 
-#include "gemm/os_blas.hpp"
+#include "gemm/gemm.hpp"
 
 namespace mkldnn {
 namespace impl {
 namespace cpu {
 
-template <bool with_relu, data_type_t src_type, data_type_t dst_type>
+template <data_type_t src_type, data_type_t dst_type>
 struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t {
-    struct pd_t: public _cpu_convolution_fwd_pd_t<with_relu> {
-        pd_t(engine_t *engine, const typename pd_t::base_desc_t *adesc,
+    struct pd_t: public cpu_convolution_fwd_pd_t {
+        pd_t(engine_t *engine, const convolution_desc_t *adesc,
                 const primitive_attr_t *attr,
                 const typename pd_t::base_class *hint_fwd_pd)
-            : _cpu_convolution_fwd_pd_t<with_relu>(engine, adesc, attr,
-                    hint_fwd_pd), jcp_() {}
+            : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+            , jcp_() {}
 
-        DECLARE_COMMON_PD_T("gemm:blas",
-                _gemm_x8s8s32x_convolution_fwd_t<with_relu, src_type, dst_type>);
+        DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR,
+                _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>);
 
         virtual status_t init() override {
             using namespace data_type;
@@ -48,30 +51,33 @@ struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t {
             assert(this->engine()->kind() == engine_kind::cpu);
 
             bool ok = true
-#if !USE_MKL_IGEMM
-                && false
-#endif
                 && this->set_default_params() == status::success
-                && utils::one_of(this->cdesc_().prop_kind,
+                && utils::one_of(this->desc()->prop_kind,
                         prop_kind::forward_training,
                         prop_kind::forward_inference)
-                && this->cdesc_().alg_kind == alg_kind::convolution_direct
+                && utils::one_of(this->desc()->alg_kind,
+                        alg_kind::convolution_auto,
+                        alg_kind::convolution_direct)
                 && !this->has_zero_dim_memory()
-                && this->cdesc_().src_desc.data_type == src_type
-                && this->cdesc_().dst_desc.data_type == dst_type
-                && this->cdesc_().weights_desc.data_type == s8
+                && this->desc()->src_desc.data_type == src_type
+                && this->desc()->dst_desc.data_type == dst_type
+                && this->desc()->weights_desc.data_type == s8
                 && IMPLICATION(this->with_bias(), utils::one_of(
-                            this->cdesc_().bias_desc.data_type, f32, s32, s8,
+                            this->desc()->bias_desc.data_type, f32, s32, s8,
                             u8))
-                && this->cdesc_().accum_data_type == data_type::s32
+                && this->desc()->accum_data_type == data_type::s32
                 && utils::everyone_is(nhwc, this->src_pd_.desc()->format,
                         this->dst_pd_.desc()->format)
                 && this->weights_pd_.desc()->format == (this->with_groups()
                         ? ((src_type == data_type::s8) ? hwigo_s8s8 : hwigo)
                         : ((src_type == data_type::s8) ? hwio_s8s8 : hwio))
                 && this->is_gemm_conv_format();
+            if (!ok) return status::unimplemented;
 
-            return ok ? status::success : status::unimplemented;
+            auto scratchpad = scratchpad_registry().registrar();
+            return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+                    *this->desc(), this->src_pd(), this->weights_pd(0),
+                    this->dst_pd(), mkldnn_get_max_threads());
         }
 
         jit_gemm_conv_conf_t jcp_;
@@ -79,94 +85,127 @@ struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t {
     protected:
         virtual status_t set_default_params() override {
             using namespace memory_format;
-            bool is_sign_input =
-                    (this->cdesc_().src_desc.data_type == data_type::s8);
+            const bool is_sign_input =
+                this->desc()->src_desc.data_type == data_type::s8;
+
             if (this->src_pd_.desc()->format == any)
                 CHECK(this->src_pd_.set_format(nhwc));
             if (this->dst_pd_.desc()->format == any)
                 CHECK(this->dst_pd_.set_format(nhwc));
             if (this->weights_pd_.desc()->format == any)
                 CHECK(this->weights_pd_.set_format(this->with_groups()
-                            ? ((is_sign_input) ? hwigo_s8s8 : hwigo)
-                            : ((is_sign_input) ? hwio_s8s8 : hwio)));
+                            ? (is_sign_input ? hwigo_s8s8 : hwigo)
+                            : (is_sign_input ? hwio_s8s8 : hwio)));
             if (this->bias_pd_.desc()->format == any)
                 CHECK(this->bias_pd_.set_format(x));
+            if (this->desc()->alg_kind == alg_kind::convolution_auto)
+                CHECK(this->set_alg_kind(alg_kind::convolution_direct));
             return status::success;
         }
 
         virtual bool is_gemm_conv_format() const {
             using namespace mkldnn::impl::primitive_kind;
-            bool ok = true;
             auto const &po = this->attr()->post_ops_;
+            auto is_relu = [&](int idx) {
+                return po.entry_[idx].is_relu(true, false); };
+
             switch (po.len_) {
-            case 0: break;
-            case 1: ok = ok
-                    && (po.entry_[0].is_relu() || po.contain(sum, 0));
-                break;
-            case 2: ok = ok
-                    && (po.contain(sum, 0) && po.entry_[1].is_relu());
-                break;
-            default: ok = false;
+            case 0: return true;
+            case 1: return is_relu(0) || po.contain(sum, 0);
+            case 2: return po.contain(sum, 0) && is_relu(1);
+            default: return false;
             }
-            return ok;
+            return false;
         }
     };
 
-    _gemm_x8s8s32x_convolution_fwd_t(const pd_t *pd, const input_vector &inputs,
+    _gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd, const input_vector &inputs,
            const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-        , scratchpad_(nullptr)
-    {
-        jit_gemm_convolution_utils::init_conf(conf_.jcp_,
-            *conf_.cdesc(), conf_.src_pd(), conf_.weights_pd(0),
-            conf_.dst_pd(), mkldnn_get_max_threads(), with_relu, conf_.negative_slope());
-
-        size_t col_size = (size_t)conf_.jcp_.im2col_sz * sizeof(src_data_t);
-        size_t acc_size = (size_t)conf_.jcp_.os * conf_.jcp_.oc
-                            * sizeof(acc_data_t);
-        size_t size = col_size + acc_size;
-
-        jit_gemm_convolution_utils::prepare_scratchpad(this->conf_.jcp_,
-                &this->scratchpad_, size, this->conf_.jcp_.nthr);
+        : cpu_primitive_t(apd, inputs, outputs, true) {
+        pp_ker_ = new pp_ker_t(apd);
     }
-
     ~_gemm_x8s8s32x_convolution_fwd_t() {
-        delete this->scratchpad_;
-    };
+        delete pp_ker_;
+    }
 
     typedef typename prec_traits<src_type>::type src_data_t;
     typedef typename prec_traits<data_type::s8>::type wei_data_t;
     typedef typename prec_traits<dst_type>::type dst_data_t;
     typedef typename prec_traits<data_type::s32>::type acc_data_t;
 
-    virtual void execute(event_t *e) {
+    virtual void execute(event_t *e) const {
         execute_forward();
         e->set_state(event_t::ready);
     }
 
 private:
-    void execute_forward();
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+    void execute_forward() const;
+    // XXX: this is throwaway code that will become unnecessary when we have a
+    // sufficiently advanced igemm jit generator that supports quantization,
+    // relu, and whatnot
+    class pp_ker_t : jit_generator {
+    public:
+        DECLARE_CPU_JIT_AUX_FUNCTIONS(
+        _gemm_x8s8s32x_convolution_fwd_t::pp_kernel);
+        pp_ker_t(const pd_t *pd);
+
+        void operator()(dst_data_t *dst, const acc_data_t *acc,
+            const char *bias, const float *scales,
+            float nslope, float sum_scale, float signed_scale,
+            int g, size_t start, size_t end);
+    private:
+        void generate();
+
+        struct ker_args {
+            dst_data_t *dst;
+            const acc_data_t *acc;
+            const char *bias;
+            const float *scales;
+            float nslope;
+            float sum_scale;
+            float signed_scale;
+            size_t len;
+            size_t oc_offset;
+        };
+        void(*ker_)(const ker_args *args);
+
+        const jit_gemm_conv_conf_t jcp_;
+        size_t OC_;
+        size_t OS_;
+        data_type_t bias_data_type_;
+        size_t bias_data_type_size_;
+        size_t scale_idx_mult_;
+        round_mode_t rmode_;
+        bool do_bias_;
+        bool do_relu_;
+        bool do_sum_;
+        bool do_signed_scaling_;
+        size_t dst_os_stride_;
+        size_t vlen_;
+    };
+
+
     void execute_forward_thr(const int ithr, const int nthr,
             const src_data_t *src_base, const wei_data_t *wei_base,
             const char *bia_base, dst_data_t *dst_base,
-            char *scratchpad);
-    pd_t conf_;
-    scratchpad_t *scratchpad_;
+            const memory_tracking::grantor_t &scratchpad) const;
+
     int nthr_;
+    pp_ker_t *pp_ker_;
+
 };
 
 template <data_type_t dst_type>
 struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t {
     struct pd_t: public cpu_convolution_bwd_data_pd_t{
         pd_t(engine_t *engine,
-                const convolution_desc_t *adesc,
-                const primitive_attr_t *attr,
+                const convolution_desc_t *adesc, const primitive_attr_t *attr,
                 const convolution_fwd_pd_t *hint_fwd_pd)
             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
-            , jcp_()
-        {}
+            , jcp_() {}
 
-        DECLARE_COMMON_PD_T("gemm:blas",
+        DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR,
                 _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>);
 
         virtual status_t init() override {
@@ -176,12 +215,10 @@ struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t {
             assert(this->engine()->kind() == engine_kind::cpu);
 
             bool ok = true
-#if !USE_MKL_IGEMM
-                && false
-#endif
                 && this->set_default_params() == status::success
                 && this->desc()->prop_kind == prop_kind::backward_data
-                && this->desc()->alg_kind == alg_kind::convolution_direct
+                && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
+                           alg_kind::convolution_direct)
                 && !this->has_zero_dim_memory()
                 && this->desc()->diff_src_desc.data_type == dst_type
                 && this->desc()->diff_dst_desc.data_type == u8
@@ -195,8 +232,12 @@ struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t {
                 && this->weights_pd_.desc()->format == (this->with_groups()
                         ? hwigo : hwio)
                 && attr()->post_ops_.has_default_values();
+            if (!ok) return status::unimplemented;
 
-            return ok ? status::success : status::unimplemented;
+            auto scratchpad = scratchpad_registry().registrar();
+            return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+                    *this->desc(), this->diff_src_pd(), this->weights_pd(0),
+                    this->diff_dst_pd(), mkldnn_get_max_threads());
         }
 
         virtual bool support_bias() const override { return true; }
@@ -206,59 +247,44 @@ struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t {
     protected:
         virtual status_t set_default_params() override {
             using namespace memory_format;
+
             if (this->diff_src_pd_.desc()->format == any)
                 CHECK(this->diff_src_pd_.set_format(nhwc));
             if (this->diff_dst_pd_.desc()->format == any)
                 CHECK(this->diff_dst_pd_.set_format(nhwc));
             if (this->weights_pd_.desc()->format == any)
-                CHECK(this->weights_pd_.set_format(this->with_groups()
-                            ? hwigo : hwio));
+                CHECK(this->weights_pd_.set_format(
+                            this->with_groups() ? hwigo : hwio));
             if (bias_pd_.desc()->format == any)
                 CHECK(bias_pd_.set_format(x));
+            if (this->desc()->alg_kind == alg_kind::convolution_auto)
+                CHECK(this->set_alg_kind(alg_kind::convolution_direct));
              return status::success;
         }
     };
 
-    _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *pd, const input_vector &inputs,
+    _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *apd, const input_vector &inputs,
            const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-        , scratchpad_(nullptr)
-    {
-        jit_gemm_convolution_utils::init_conf(conf_.jcp_,
-            *conf_.desc(), conf_.diff_src_pd(), conf_.weights_pd(0),
-            conf_.diff_dst_pd(), mkldnn_get_max_threads());
-
-        size_t col_size = (size_t)conf_.jcp_.im2col_sz * sizeof(acc_data_t);
-        size_t acc_size = (size_t)conf_.jcp_.is * conf_.jcp_.ic
-                            * sizeof(acc_data_t);
-        size_t size = col_size + acc_size;
-
-        jit_gemm_convolution_utils::prepare_scratchpad(this->conf_.jcp_,
-                &this->scratchpad_, size, this->conf_.jcp_.nthr);
-    }
-
-    ~_gemm_u8s8s32x_convolution_bwd_data_t() {
-        delete this->scratchpad_;
-    };
+        : cpu_primitive_t(apd, inputs, outputs, true) {}
+    ~_gemm_u8s8s32x_convolution_bwd_data_t() {}
 
     typedef typename prec_traits<data_type::u8>::type diff_dst_data_t;
     typedef typename prec_traits<data_type::s8>::type wei_data_t;
     typedef typename prec_traits<dst_type>::type diff_src_data_t;
     typedef typename prec_traits<data_type::s32>::type acc_data_t;
 
-    virtual void execute(event_t *e) {
+    virtual void execute(event_t *e) const {
         execute_backward_data();
         e->set_state(event_t::ready);
     }
 
 private:
-    void execute_backward_data();
+    void execute_backward_data() const;
     void execute_backward_data_thr(const int ithr, const int nthr,
             const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
             const char *bia_base, diff_src_data_t *diff_src_base,
-            char *scratchpad);
-    pd_t conf_;
-    scratchpad_t *scratchpad_;
+            const memory_tracking::grantor_t &scratchpad) const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
 
 }