Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_batch_normalization.hpp
index 7dbc47a..857e3a0 100644 (file)
 #include <assert.h>
 
 #include "c_types_map.hpp"
-#include "cpu_batch_normalization_pd.hpp"
-#include "cpu_engine.hpp"
 #include "type_helpers.hpp"
 #include "utils.hpp"
 
+#include "cpu_batch_normalization_pd.hpp"
 #include "jit_generator.hpp"
 
 namespace mkldnn {
@@ -46,58 +45,21 @@ struct jit_uni_batch_normalization_fwd_t: public cpu_primitive_t {
                 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
                 jit_uni_batch_normalization_fwd_t<isa>);
 
-        virtual status_t init() override {
-            using namespace prop_kind;
-            using namespace data_type;
-            using namespace memory_format;
-            assert(engine()->kind() == engine_kind::cpu);
-            auto desired_fmt = (ndims() == 4)
-                ? isa == avx512_common ? nChw16c : nChw8c
-                : isa == avx512_common ? nCdhw16c : nCdhw8c;
-            bool ok = true
-                && mayiuse(isa)
-                && is_fwd()
-                && !has_zero_dim_memory()
-                && utils::one_of(ndims(), 4, 5)
-                && desc()->data_desc.data_type == f32
-                && IMPLICATION(use_scaleshift(),
-                        desc()->data_scaleshift_desc.data_type == f32)
-                && desc()->data_desc.format == desired_fmt
-                && (attr()->has_default_values() || this->with_relu_post_op());
-            if (!ok) return status::unimplemented;
-
-            if (is_training() && fuse_bn_relu()) {
-                if (isa < avx2) return status::unimplemented;
-                bn_init_default_ws(this, this->workspace_pd_, 1);
-            }
-            if (memory_desc_wrapper(&data_pd_).blocking_desc()
-                .padding_dims[1] != this->C() && isa < avx2)
-                return status::unimplemented;
-
-            if (stats_is_src() || is_training()) {
-                memory_desc_t stats_d;
-                dims_t stats_dims = { C() };
-                mkldnn_memory_desc_init(&stats_d, 1, stats_dims,
-                        data_type::f32, x);
-                mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
-                variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
-            }
-
-            return success;
-        }
+        virtual status_t init() override;
     };
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    jit_uni_batch_normalization_fwd_t(const pd_t *pd,
+    jit_uni_batch_normalization_fwd_t(const pd_t *apd,
             const input_vector &inputs, const output_vector &outputs);
     ~jit_uni_batch_normalization_fwd_t();
-    virtual void execute(event_t *e);
+
+    virtual void execute(event_t *e) const;
 
 private:
-    uni_bnorm_driver_t<isa> *bnorm_driver_;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 
-    pd_t conf_;
+    uni_bnorm_driver_t<isa> *bnorm_driver_;
 };
 
 template <cpu_isa_t isa>
@@ -113,63 +75,21 @@ struct jit_uni_batch_normalization_bwd_t: public cpu_primitive_t {
                 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
                 jit_uni_batch_normalization_bwd_t<isa>);
 
-        virtual status_t init() override {
-            using namespace prop_kind;
-            using namespace data_type;
-            using namespace utils;
-            using namespace memory_format;
-            assert(engine()->kind() == engine_kind::cpu);
-            auto desired_fmt = (ndims() == 4)
-                ? utils::one_of(isa, sse42, avx2) ? nChw8c : nChw16c
-                : utils::one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c;
-            bool ok = true
-                && mayiuse(isa)
-                && is_bwd()
-                && !has_zero_dim_memory()
-                && utils::one_of(ndims(), 4, 5)
-                && everyone_is(f32, desc()->data_desc.data_type,
-                        desc()->diff_data_desc.data_type)
-                && IMPLICATION(use_scaleshift(),
-                        desc()->data_scaleshift_desc.data_type == f32)
-                && everyone_is(desired_fmt, desc()->diff_data_desc.format,
-                        desc()->data_desc.format)
-                && attr()->has_default_values();
-            if (!ok) return status::unimplemented;
-            if (memory_desc_wrapper(&data_pd_).blocking_desc()
-                .padding_dims[1] != this->C() && isa < avx2)
-                return status::unimplemented;
-
-            if (fuse_bn_relu()) {
-                if (isa < avx2) return status::unimplemented;
-                bn_init_default_ws(this, this->workspace_pd_, 1);
-                const size_t this_ws_sz
-                    = memory_desc_wrapper(this->workspace_pd()).size();
-
-                bool ws_ok = true
-                    && hint_fwd_pd_->workspace_pd()
-                    && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
-                            == this_ws_sz;
-                if (!ws_ok)
-                    return status::unimplemented;
-            }
-
-            /* TODO: extra checks required */
-
-            return success;
-        }
+        virtual status_t init() override;
     };
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    jit_uni_batch_normalization_bwd_t(const pd_t *pd,
+    jit_uni_batch_normalization_bwd_t(const pd_t *apd,
             const input_vector &inputs, const output_vector &outputs);
     ~jit_uni_batch_normalization_bwd_t();
-    virtual void execute(event_t *e);
+
+    virtual void execute(event_t *e) const;
 
 private:
-    uni_bnorm_driver_t<isa> *bnorm_driver_;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 
-    pd_t conf_;
+    uni_bnorm_driver_t<isa> *bnorm_driver_;
 };
 
 }