throw std::invalid_argument(
"Strided multiplication does not support broadcasting");
- /** @todo optimize this with a tensor iterator */
- for (unsigned int b = 0; b < batch(); ++b) {
- for (unsigned int c = 0; c < channel(); ++c) {
- for (unsigned int h = 0; h < height(); ++h) {
- for (unsigned int w = 0; w < width(); ++w) {
- output.setValue(b, c, h, w,
- getValue(b, c, h, w) * m.getValue(b, c, h, w));
+ if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1) {
+ for (unsigned int b = 0; b < batch(); ++b) {
+ for (unsigned int c = 0; c < channel(); ++c) {
+ for (unsigned int h = 0; h < height(); ++h) {
+ for (unsigned int w = 0; w < width(); ++w) {
+ output.setValue(b, c, h, w,
+ getValue(b, c, h, w) * m.getValue(b, c, h, w));
+ }
+ }
+ }
+ }
+ } else {
+ /** @todo optimize this with combining these loops where stride is 1 */
+ for (unsigned int b = 0; b < batch(); ++b) {
+ for (unsigned int c = 0; c < channel(); ++c) {
+ for (unsigned int h = 0; h < height(); ++h) {
+ float *out_data = output.getAddress(b, c, h, 0);
+ const float *m_data = m.getAddress(b, c, h, 0);
+ const float *in_data = getAddress(b, c, h, 0);
+ std::transform(in_data, in_data + width(), m_data, out_data,
+ std::multiplies<float>());
}
}
}
const float *data = getData();
float *rdata = output.getData();
std::transform(data, data + size(), rdata, f);
+ } else if (strides[3] == 1 && output.strides[3] == 1) {
+ /** @todo optimize this with combining these loops where stride is 1 */
+ for (unsigned int b = 0; b < batch(); ++b) {
+ for (unsigned int c = 0; c < channel(); ++c) {
+ for (unsigned int h = 0; h < height(); ++h) {
+ float *out_data = output.getAddress(b, c, h, 0);
+ const float *in_data = getAddress(b, c, h, 0);
+ std::transform(in_data, in_data + width(), out_data, f);
+ }
+ }
+ }
} else {
- /** @todo optimize this with a tensor iterator */
for (unsigned int b = 0; b < batch(); ++b) {
for (unsigned int c = 0; c < channel(); ++c) {
for (unsigned int h = 0; h < height(); ++h) {