1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef SIMPLE_SUM_HPP
18 #define SIMPLE_SUM_HPP
20 #include "cpu_sum.hpp"
26 template <data_type_t data_type>
27 struct simple_sum_t: public cpu_primitive_t {
28 using cpu_memory_pd_t = cpu_memory_t::pd_t;
30 struct pd_t: public cpu_sum_pd_t {
31 pd_t(const memory_desc_t *output_d, int n, const float *scales,
32 const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
33 : cpu_sum_pd_t(output_d, n, scales, input_pds, attr) {}
35 DECLARE_CPU_SUM_PD_T("simple:any", simple_sum_t);
37 virtual status_t init() override {
39 && cpu_sum_pd_t::init() == success
40 && src_pds_.size() <= max_num_arrs;
41 if (!ok) return unimplemented;
43 const memory_desc_wrapper o_d(&dst_pd_);
45 && o_d.data_type() == data_type
48 const auto n = src_pds_.size();
49 for (size_t i = 0; i < n; ++i) {
50 const memory_desc_wrapper i_d(&src_pds_[i]);
52 && utils::everyone_is(data_type, i_d.data_type())
53 && i_d.format() == o_d.format()
57 return ok ? success : unimplemented;
61 simple_sum_t(const pd_t *apd, const input_vector &inputs,
62 const output_vector &outputs)
63 : cpu_primitive_t(apd, inputs, outputs) {}
65 virtual void execute(event_t *e) const {
67 e->set_state(event_t::ready);
70 enum {max_num_arrs = 16 };
71 typedef typename prec_traits<data_type>::type data_t;
75 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
84 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s