1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef SIMPLE_SUM_HPP
18 #define SIMPLE_SUM_HPP
20 #include "cpu_sum.hpp"
21 #include "cpu_isa_traits.hpp"
22 #include "bfloat16_utils.hpp"
29 struct sum_bf16_params_t {
30 size_t ws_cvt_elements_per_thread_;
31 size_t ws_acc_elements_per_thread_;
32 size_t ws_elements_per_thread_;
33 size_t acc_loop_step_;
37 template <data_type_t src_data_type, data_type_t dst_data_type>
38 struct simple_sum_t: public cpu_primitive_t {
39 using cpu_memory_pd_t = cpu_memory_t::pd_t;
41 struct pd_t: public cpu_sum_pd_t {
42 pd_t(const memory_desc_t *output_d, int n, const float *scales,
43 const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
44 : cpu_sum_pd_t(output_d, n, scales, input_pds, attr) {}
46 DECLARE_CPU_SUM_PD_T("simple:any", simple_sum_t);
48 virtual status_t init() override {
50 && cpu_sum_pd_t::init() == success
51 && src_pds_.size() <= max_num_arrs;
52 if (!ok) return unimplemented;
54 const memory_desc_wrapper o_d(&dst_pd_);
56 && o_d.data_type() == dst_data_type
58 if (!ok) return unimplemented;
60 const auto n = src_pds_.size();
61 for (size_t i = 0; i < n; ++i) {
62 const memory_desc_wrapper i_d(&src_pds_[i]);
64 && utils::everyone_is(src_data_type, i_d.data_type())
65 && i_d.format() == o_d.format()
67 if (!ok) return unimplemented;
76 sum_bf16_params_t bf16_p_;
77 size_t block_size_, nelems_, blocks_number_, tail_;
81 const size_t cacheline_size_ = 64; // bytes
82 const size_t half_L1_size_ = 16 * 1024; // bytes
84 void compute_blocking() {
85 block_size_ = (src_data_type == data_type::bf16
86 ? 16 * cacheline_size_
88 / sizeof(src_data_type);
89 nelems_ = memory_desc_wrapper(dst_pd()).nelems();
90 blocks_number_ = nelems_ / block_size_;
91 tail_ = nelems_ % block_size_;
94 void init_scratchpad() {
95 if (src_data_type == data_type::bf16) {
96 bool is_dst_bf16_ = dst_data_type == data_type::bf16;
97 bf16_p_.ws_cvt_elements_per_thread_ =
98 cacheline_size_ / sizeof(acc_data_t);
100 bf16_p_.ws_acc_elements_per_thread_ =
102 ? bf16_p_.ws_cvt_elements_per_thread_
105 bf16_p_.acc_loop_step_ = is_dst_bf16_
106 ? bf16_p_.ws_cvt_elements_per_thread_
109 bf16_p_.ws_elements_per_thread_ = bf16_p_.ws_cvt_elements_per_thread_
110 + bf16_p_.ws_acc_elements_per_thread_;
111 size_t bf16cvt_buf_sz_ = sizeof(acc_data_t) * bf16_p_.ws_elements_per_thread_
112 * mkldnn_get_max_threads();
113 auto scratchpad = scratchpad_registry().registrar();
114 scratchpad.book(memory_tracking::names::key_sum_bf16cvt, bf16cvt_buf_sz_);
119 simple_sum_t(const pd_t *apd, const input_vector &inputs,
120 const output_vector &outputs)
121 : cpu_primitive_t(apd, inputs, outputs) {
124 virtual void execute(event_t *e) const {
126 e->set_state(event_t::ready);
129 enum {max_num_arrs = 16 };
130 typedef typename prec_traits<src_data_type>::type src_data_t;
131 typedef typename prec_traits<dst_data_type>::type dst_data_t;
132 typedef typename prec_traits<data_type::f32>::type acc_data_t;
135 void execute() const;
136 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
145 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s