#define CPU_RNN_PD_HPP
#include "c_types_map.hpp"
-#include "cpu_engine.hpp"
-#include "cpu_memory.hpp"
-#include "cpu_primitive.hpp"
+#include "../cpu_engine.hpp"
+#include "../cpu_memory.hpp"
+#include "../cpu_primitive.hpp"
#include "nstl.hpp"
#include "rnn_pd.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
+#include "rnn_utils.hpp"
namespace mkldnn {
namespace impl {
using namespace memory_format;
if (src_layer_pd_.desc()->format == any)
CHECK(src_layer_pd_.set_format(tnc));
- if (weights_layer_pd_.desc()->format == any)
- CHECK(weights_layer_pd_.set_format(ldigo));
- if (weights_iter_pd_.desc()->format == any)
- CHECK(weights_iter_pd_.set_format(ldigo));
if (dst_layer_pd_.desc()->format == any)
CHECK(dst_layer_pd_.set_format(tnc));
return status::success;
}
+
+ status_t check_layout_consistency() {
+ using namespace memory_format;
+ using namespace utils;
+ using namespace data_type;
+ bool ok = true;
+ ok = ok && src_layer_pd_.desc()->format == tnc
+ && dst_layer_pd_.desc()->format == tnc;
+ ok = ok && IMPLICATION(!src_iter_pd_.is_zero(),
+ src_iter_pd_.desc()->format == ldsnc)
+ && IMPLICATION(!dst_iter_pd_.is_zero(),
+ dst_iter_pd_.desc()->format == ldsnc);
+
+ ok = ok && one_of(weights_layer_pd_.desc()->format, ldigo, rnn_packed)
+ && one_of(weights_iter_pd_.desc()->format, ldigo, rnn_packed);
+ ok = ok && IMPLICATION(weights_iter_pd_.desc()->format == rnn_packed,
+ weights_iter_pd_.desc()
+ ->layout_desc.rnn_packed_desc.format
+ == mkldnn_ldigo_p);
+ ok = ok && IMPLICATION(weights_layer_pd_.desc()->format == rnn_packed,
+ weights_layer_pd_.desc()
+ ->layout_desc.rnn_packed_desc.format
+ == mkldnn_ldigo_p);
+
+ ok = ok && IMPLICATION(!bias_pd_.is_zero(),
+ bias_pd_.desc()->format == ldgo);
+
+ /* Int8 is supported only for packed weights */
+ data_type_t weights_iter_dt = weights_iter_pd_.desc()->data_type;
+ data_type_t weights_layer_dt = weights_layer_pd_.desc()->data_type;
+ ok = ok && IMPLICATION(weights_iter_dt == s8,
+ weights_iter_pd_.desc()->format == rnn_packed);
+ ok = ok && IMPLICATION(weights_layer_dt == s8,
+ weights_layer_pd_.desc()->format == rnn_packed);
+
+ return ok ? status::success : status::unimplemented;
+ }
};
struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
using cpu_memory_pd_t = cpu_memory_t::pd_t;
cpu_rnn_bwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
- const primitive_attr_t *attr, const rnn_bwd_pd_t *hint_bwd_pd)
- : rnn_bwd_pd_t(engine, adesc, attr, hint_bwd_pd)
+ const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd)
+ : rnn_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
, src_layer_pd_(engine, &desc_.src_layer_desc)
, src_iter_pd_(engine, &desc_.src_iter_desc)
, weights_layer_pd_(engine, &desc_.weights_layer_desc)
CHECK(src_layer_pd_.set_format(tnc));
if (diff_src_layer_pd_.desc()->format == any)
CHECK(diff_src_layer_pd_.set_format(tnc));
- if (weights_layer_pd_.desc()->format == any)
- CHECK(weights_layer_pd_.set_format(ldgoi));
- if (diff_weights_layer_pd_.desc()->format == any)
- CHECK(diff_weights_layer_pd_.set_format(ldigo));
- if (weights_iter_pd_.desc()->format == any)
- CHECK(weights_iter_pd_.set_format(ldgoi));
- if (diff_weights_iter_pd_.desc()->format == any)
- CHECK(diff_weights_iter_pd_.set_format(ldigo));
+ if (diff_weights_layer_pd_.desc()->format == any) {
+ memory_desc_t md = *(diff_weights_layer_pd_.desc());
+ md.format = ldigo;
+ CHECK(memory_desc_wrapper::compute_blocking(md));
+ CHECK(rnn_utils::set_good_strides(md));
+ cpu_memory_t::pd_t new_pd(engine_, &md);
+ diff_weights_layer_pd_ = new_pd;
+ }
+ if (diff_weights_iter_pd_.desc()->format == any) {
+ memory_desc_t md = *(diff_weights_iter_pd_.desc());
+ md.format = ldigo;
+ CHECK(memory_desc_wrapper::compute_blocking(md));
+ CHECK(rnn_utils::set_good_strides(md));
+ cpu_memory_t::pd_t new_pd(engine_, &md);
+ diff_weights_iter_pd_ = new_pd;
+ }
if (dst_layer_pd_.desc()->format == any)
CHECK(dst_layer_pd_.set_format(tnc));
if (diff_dst_layer_pd_.desc()->format == any)
return status::success;
}
+
+ status_t check_layout_consistency() {
+ using namespace memory_format;
+ using namespace utils;
+ bool ok = true;
+ ok = ok && src_layer_pd_.desc()->format == tnc
+ && dst_layer_pd_.desc()->format == tnc;
+ ok = ok && IMPLICATION(!src_iter_pd_.is_zero(),
+ src_iter_pd_.desc()->format == ldsnc)
+ && IMPLICATION(!dst_iter_pd_.is_zero(),
+ dst_iter_pd_.desc()->format == ldsnc);
+
+ ok = ok && one_of(weights_layer_pd_.desc()->format, ldgoi, rnn_packed)
+ && one_of(weights_iter_pd_.desc()->format, ldgoi, rnn_packed);
+ ok = ok && IMPLICATION(weights_iter_pd_.desc()->format == rnn_packed,
+ weights_iter_pd_.desc()
+ ->layout_desc.rnn_packed_desc.format
+ == mkldnn_ldgoi_p);
+ ok = ok && IMPLICATION(weights_layer_pd_.desc()->format == rnn_packed,
+ weights_layer_pd_.desc()
+ ->layout_desc.rnn_packed_desc.format
+ == mkldnn_ldgoi_p);
+
+ ok = ok && IMPLICATION(!bias_pd_.is_zero(),
+ bias_pd_.desc()->format == ldgo);
+
+ ok = ok && diff_src_layer_pd_.desc()->format == tnc
+ && diff_dst_layer_pd_.desc()->format == tnc;
+ ok = ok && IMPLICATION(!diff_states_pd_.is_zero(),
+ diff_states_pd_.desc()->format == ldsnc)
+ && IMPLICATION(!diff_dst_iter_pd_.is_zero(),
+ diff_dst_iter_pd_.desc()->format == ldsnc);
+ ok = ok && diff_weights_layer_pd_.desc()->format == ldigo
+ && diff_weights_iter_pd_.desc()->format == ldigo;
+ ok = ok && IMPLICATION(!diff_bias_pd_.is_zero(),
+ diff_bias_pd_.desc()->format == ldgo);
+
+ return ok ? status::success : status::unimplemented;
+ }
};
}
}