const float *rdata = rhs.getData();
for (size_t i = 0; i < len; ++i) {
+ /** not checking sign change is intentional to avoid float calculation
+ * errors around 0 */
if (std::isnan(data[i]) || std::isnan(rdata[i]) ||
std::fabs(data[i] - rdata[i]) > epsilon)
return false;
return sum(axes, ret, alpha);
}
+void Tensor::merge_axis(unsigned int axis1, unsigned int axis2) {
+ if (axis2 != axis1 + 1)
+ throw std::invalid_argument("Axis to be merged must be continuous.");
+
+ dim.setTensorDim(axis2, dim.getTensorDim(axis1) * dim.getTensorDim(axis2));
+ dim.setTensorDim(axis1, 1);
+}
+
Tensor &Tensor::sum(const std::vector<unsigned int> &axes, Tensor &output,
float alpha) const {
if (axes.empty())
if (axes.size() == 1) {
this->sum(axes[0], output, alpha);
} else {
- Tensor ret = this->sum(axes[0], alpha);
+ /** club axes together */
+ Tensor new_reshaped = *this;
+ std::vector<unsigned int> new_axes = {axes[0]};
+ for (unsigned int i = 1; i < axes.size(); ++i) {
+ if (axes[i] == axes[i - 1] + 1) {
+ new_reshaped.merge_axis(axes[i - 1], axes[i]);
+ new_axes.back() = axes[i];
+ } else {
+ new_axes.push_back(axes[i]);
+ }
+ }
- for (unsigned int i = 1; i < axes.size() - 1; ++i)
+ Tensor ret = new_reshaped.sum(new_axes[0]);
+ for (unsigned int i = 1; i < new_axes.size() - 1; ++i)
ret = ret.sum(axes[i]);
-
- ret.sum(axes.back(), output);
+ ret.sum(new_axes.back(), output, alpha);
}
return output;