1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef REF_CONCAT_HPP
18 #define REF_CONCAT_HPP
20 #include "cpu_concat.hpp"
21 #include "reorder_pd.hpp"
27 struct ref_concat_t: public cpu_primitive_t {
28 using cpu_memory_pd_t = cpu_memory_t::pd_t;
30 struct pd_t: public cpu_concat_pd_t {
31 pd_t(const memory_desc_t *output_d, int n, int concat_dim,
32 const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
33 : cpu_concat_pd_t(output_d, n, concat_dim, input_pds, attr) {}
35 : cpu_concat_pd_t(rhs)
37 for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) {
38 reorder_pds_.push_back(
39 (const reorder_pd_t *)rhs.reorder_pds_[i]->clone());
44 for (size_t i = 0; i < reorder_pds_.size(); ++i) {
45 delete reorder_pds_[i];
49 static status_t create(concat_pd_t **concat_pd,
50 const memory_desc_t *output_d, int n, int concat_dim,
51 const memory_pd_t **input_pds, const primitive_attr_t *attr) {
52 auto _pd = new pd_t(output_d, n, concat_dim,
53 (const cpu_memory_pd_t **)input_pds, attr);
54 if (_pd == nullptr) return out_of_memory;
55 if (_pd->init() != success) { delete _pd; return unimplemented; }
56 return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd);
58 virtual status_t create_primitive(primitive_t **primitive,
59 const primitive_at_t *inputs,
60 const primitive_t **outputs) const override {
61 double ms = get_msec();
63 nstl::vector<primitive_t *> reorders;
65 for (int i = 0; i < n; ++i) {
66 CHECK(reorder_pds_[i]->create_primitive(&reorders[i],
67 &inputs[i], outputs));
69 primitive_t::input_vector ins(inputs, inputs + n_);
70 primitive_t::output_vector outs(outputs, outputs + 1);
71 auto ret = safe_ptr_assign<primitive_t>(*primitive,
72 new ref_concat_t(this, ins, outs, reorders));
74 if (mkldnn_verbose()->level >= 2) {
75 printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms);
80 virtual pd_t *clone() const override { return new pd_t(*this); }
81 virtual const char *name() const override { return "ref:any"; }
83 virtual status_t init() override {
84 assert(engine()->kind() == engine_kind::cpu);
86 bool ok = cpu_concat_pd_t::init() == success;
87 if (!ok) return unimplemented;
89 for (int i = 0; i < n_; ++i) {
90 auto r_impls = engine_->get_reorder_implementation_list();
91 for (auto r = r_impls; *r; ++r) {
92 const primitive_attr_t dummy_attr; /* alpha == 1. */
94 if ((*r)(&r_pd, &src_pds_[i], &src_image_pds_[i],
95 &dummy_attr) == status::success) {
97 reorder_pds_.push_back(r_pd);
102 return (size_t)n_ == reorder_pds_.size() ? success : unimplemented;
105 nstl::vector<const reorder_pd_t *> reorder_pds_;
108 ref_concat_t(const pd_t *apd, const input_vector &inputs,
109 const output_vector &outputs, nstl::vector<primitive_t *> reorders)
110 : cpu_primitive_t(apd, inputs, outputs),
111 reorders_(reorders) {}
114 const auto n = reorders_.size();
115 for (size_t i = 0; i < n; ++i)
119 virtual void execute(event_t *e) const {
120 for (size_t i = 0; i < reorders_.size(); ++i) {
122 reorders_[i]->execute(&ei);
124 e->set_state(event_t::ready);
128 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
129 nstl::vector<primitive_t *> reorders_;