Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_binarization_pd.hpp
index 05d1059..b10a4e5 100644 (file)
@@ -39,7 +39,8 @@ struct cpu_binarization_fwd_pd_t: public binarization_fwd_pd_t {
         : binarization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
         , src_pd_(engine_, &desc_.src_desc)
         , dst_pd_(engine_, &desc_.dst_desc)
-        , weights_pd_(engine_, &desc_.weights_desc) {}
+        , weights_pd_(engine_, &desc_.weights_desc)
+        , output_mask_pd_(engine_, &desc_.output_mask_desc) {}
     virtual ~cpu_binarization_fwd_pd_t() {}
 
     virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
@@ -48,11 +49,12 @@ struct cpu_binarization_fwd_pd_t: public binarization_fwd_pd_t {
     { return index == 0 ? &dst_pd_ : nullptr; }
     virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
         if (index == 0) return &weights_pd_;
+        if (index == 1) return &output_mask_pd_;
         return nullptr;
     }
 
 protected:
-    cpu_memory_pd_t src_pd_, dst_pd_, weights_pd_;
+    cpu_memory_pd_t src_pd_, dst_pd_, weights_pd_, output_mask_pd_;
 
     inline memory_format_t src_format()
     {
@@ -73,6 +75,8 @@ protected:
             CHECK(dst_pd_.set_format(src_pd_.desc()->format));
         if (weights_pd_.desc()->format == any)
             CHECK(weights_pd_.set_format(wei_format()));
+        if (output_mask_pd_.desc()->format == any)
+            CHECK(output_mask_pd_.set_format(wei_format()));
         return status::success;
     }