: parent(prim, prog)
, split(this->get_primitive()->split())
, depthwise_sep_opt(false)
+ , groups(this->get_primitive()->groups)
{
+ support_padding(true);
}
void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; }
bool get_depthwise_sep_opt() const { return depthwise_sep_opt; }
+ void set_groups(uint32_t node_groups) { groups = node_groups; }
+ uint32_t get_groups() const { return groups; }
+
program_node& input() const { return get_dependency(0); }
program_node& weights(size_t idx = 0) const
if (static_cast<int32_t>(idx) > 0)
throw std::range_error("Only one input for fused sum is supported");
- int d_idx = 1 + this->get_split() + idx;
+ size_t d_idx = 1 + this->get_split() + idx;
d_idx += bias_term() ? this->get_split() : 0;
return get_dependency(d_idx);
}
bool has_fused_sum() const
{
- int d_idx = 1 + this->get_split();
+ size_t d_idx = 1 + this->get_split();
d_idx += bias_term() ? this->get_split() : 0;
- return static_cast<int>(dependencies.size()) == (d_idx + 1);
+ return dependencies.size() == (d_idx + 1);
}
private:
int32_t split;
bool depthwise_sep_opt;
+ uint32_t groups;
};
using deconvolution_node = typed_program_node<deconvolution>;
memory_impl& weights_memory(size_t index) const
{
- if (static_cast<int32_t>(index) >= node.get_split())
- throw std::range_error("weights offset too big");
-
- return dep_memory(1 + index);
+ if (node.get_groups() == 1) {
+ if (static_cast<int32_t>(index) >= node.get_split())
+ throw std::range_error("weights offset too big");
+ return dep_memory(1 + index);
+ }
+ else { // all weights are in one buffer
+ return dep_memory(1);
+ }
}
memory_impl& bias_memory(size_t index) const
{
- if (argument.bias.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
- throw std::range_error("no bias data");
-
- if (static_cast<int32_t>(index) > node.get_split())
- throw std::range_error("bias offset too big");
-
- return dep_memory(1 + node.get_split() + index);
+ if (node.get_groups() == 1) {
+ if (argument.bias.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
+ throw std::range_error("no bias data");
+ if (static_cast<int32_t>(index) > node.get_split())
+ throw std::range_error("bias offset too big");
+ return dep_memory(1 + node.get_split() + index);
+ }
+ else { // all bias are in one buffer
+ return dep_memory(2);
+ }
}
bool bias_term() const