1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "cpu_sum.hpp"
21 #include "reorder_pd.hpp"
27 struct ref_sum_t: public cpu_primitive_t {
28 using cpu_memory_pd_t = cpu_memory_t::pd_t;
30 struct pd_t: public cpu_sum_pd_t {
31 pd_t(const memory_desc_t *output_d, int n, const float *scales,
32 const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
33 : cpu_sum_pd_t(output_d, n, scales, input_pds, attr) {}
34 pd_t(const pd_t &rhs): cpu_sum_pd_t(rhs) {
35 for (size_t i = 0; i < rhs.scales_.size(); ++i) {
36 scales_.push_back(rhs.scales_[i]);
38 for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) {
39 reorder_pds_.push_back(
40 (const reorder_pd_t *)rhs.reorder_pds_[i]->clone());
45 for (size_t i = 0; i < reorder_pds_.size(); ++i) {
46 delete reorder_pds_[i];
50 static status_t create(sum_pd_t **sum_pd,
51 const memory_desc_t *output_d, int n, const float *scales,
52 const memory_pd_t **input_pds, const primitive_attr_t *attr) {
53 auto _pd = new pd_t(output_d, n, scales,
54 (const cpu_memory_pd_t **)input_pds, attr);
55 if (_pd == nullptr) return out_of_memory;
56 if (_pd->init() != success) { delete _pd; return unimplemented; }
57 return safe_ptr_assign<sum_pd_t>(*sum_pd, _pd);
60 virtual status_t create_primitive(primitive_t **primitive,
61 const primitive_at_t *inputs, const primitive_t **outputs)
63 double ms = get_msec();
64 nstl::vector<primitive_t *> reorders;
66 for (int i = 0; i < n_; ++i)
67 CHECK(reorder_pds_[i]->create_primitive(&reorders[i],
68 &inputs[i], outputs));
70 primitive_t::input_vector ins(inputs, inputs + n_);
71 primitive_t::output_vector outs(outputs, outputs + 1);
72 auto ret = safe_ptr_assign<primitive_t>(*primitive,
73 new ref_sum_t(this, ins, outs, reorders));
75 if (mkldnn_verbose()->level >= 2) {
76 printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms);
81 virtual pd_t *clone() const override { return new pd_t(*this); }
82 virtual const char *name() const override { return "ref:any"; }
84 virtual status_t init() override {
85 bool ok = cpu_sum_pd_t::init() == success;
86 if (!ok) return unimplemented;
88 for (int i = 0; i < n_; ++i) {
89 auto r_impls = engine_->get_reorder_implementation_list();
90 for (auto r = r_impls; *r; ++r) {
91 primitive_attr_t dummy_attr;
92 dummy_attr.output_scales_.set(scales_[i]);
95 dummy_attr.post_ops_.append_sum(1.0);
97 if ((*r)(&r_pd, &src_pds_[i], &dst_pd_, &dummy_attr)
100 reorder_pds_.push_back(r_pd);
105 ok = utils::everyone_is(reorder_pds_.size(), scales_.size());
106 return ok ? success : unimplemented;
109 nstl::vector<const reorder_pd_t *> reorder_pds_;
112 ref_sum_t(const pd_t *apd, const input_vector &inputs,
113 const output_vector &outputs, nstl::vector<primitive_t *> reorders)
114 : cpu_primitive_t(apd, inputs, outputs),
115 reorders_(reorders) {}
118 const auto n = reorders_.size();
119 for (size_t i = 0; i < n; ++i)
123 virtual void execute(event_t *e) const {
124 const auto n = reorders_.size();
125 for (size_t i = 0; i < n; ++i) {
127 reorders_[i]->execute(&ei);
129 e->set_state(event_t::ready);
133 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
134 nstl::vector<primitive_t *> reorders_;