This PR adds an optional feature in Tensor::cat to pass the output tensor to the function.
This change allows the user-given tensor to store the result of the concatenation without creating a new tensor.
**Changes proposed in this PR:**
- Add optional argument output (the output tensor) to the cat function.
- Add negative test cases for tensor concatenation.
**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped
Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
return ret;
}
-Tensor FloatTensor::concat(const std::vector<Tensor> &tensors, int axis) {
- if (axis == -1) {
- axis = 3;
- }
+Tensor FloatTensor::concat(const std::vector<Tensor> &tensors, int axis,
+ Tensor &output) {
+ bool is_format_nchw = (tensors.front().getDim().getFormat() == Tformat::NCHW);
- auto ref_dim = tensors.front().getDim();
- bool is_format_nchw = (ref_dim.getFormat() == Tformat::NCHW);
- ref_dim.setTensorDim(axis, 1);
- NNTR_THROW_IF(!std::all_of(tensors.begin(), tensors.end(),
- [&ref_dim, axis](const Tensor &t) {
- auto cur_dim = t.getDim();
- cur_dim.setTensorDim(axis, 1);
- return ref_dim == cur_dim;
- }),
- std::invalid_argument)
- << " all tensor must have the same dimension except for the axis, ref_dim: "
- << ref_dim << " axis : " << axis;
-
- auto axis_dim = std::accumulate(tensors.begin(), tensors.end(), 0u,
- [axis](unsigned cur, const Tensor &t) {
- return cur += t.getDim().getTensorDim(axis);
- });
auto iter_value =
[is_format_nchw](std::array<unsigned, 4> &loc,
const std::array<unsigned, 4> &start_loc, Tensor &t,
return value;
};
- auto ret_dim = ref_dim;
- ret_dim.setTensorDim(axis, axis_dim);
-
- Tensor ret = Tensor(ret_dim);
-
std::array<unsigned, 4> loc = {0, 0, 0, 0};
for (auto &t : tensors) {
std::array<unsigned, 4> start_loc = loc;
std::array<unsigned, 4> tensor_dim_arr;
- if (is_format_nchw) {
- tensor_dim_arr[0] = t.getDim().getTensorDim(0);
- tensor_dim_arr[1] = t.getDim().getTensorDim(1);
- tensor_dim_arr[2] = t.getDim().getTensorDim(2);
- tensor_dim_arr[3] = t.getDim().getTensorDim(3);
- } else {
- tensor_dim_arr[0] = t.getDim().getTensorDim(0);
- tensor_dim_arr[1] = t.getDim().getTensorDim(2);
- tensor_dim_arr[2] = t.getDim().getTensorDim(3);
- tensor_dim_arr[3] = t.getDim().getTensorDim(1);
- }
+ TensorDim curr_dim = t.getDim();
+
+ tensor_dim_arr[0] = curr_dim.getTensorDim(0);
+ tensor_dim_arr[1] =
+ is_format_nchw ? curr_dim.getTensorDim(1) : curr_dim.getTensorDim(2);
+ tensor_dim_arr[2] =
+ is_format_nchw ? curr_dim.getTensorDim(2) : curr_dim.getTensorDim(3);
+ tensor_dim_arr[3] =
+ is_format_nchw ? curr_dim.getTensorDim(3) : curr_dim.getTensorDim(1);
for (size_t i = 0u, sz = t.size(); i < sz; ++i) {
- iter_value(loc, start_loc, ret, tensor_dim_arr) = t.getValue<float>(i);
+ iter_value(loc, start_loc, output, tensor_dim_arr) = t.getValue<float>(i);
}
if (is_format_nchw) {
- loc[axis] += t.getDim().getTensorDim(axis);
+ loc[axis] += curr_dim.getTensorDim(axis);
} else {
if (axis == 0) {
- loc[0] += t.getDim().getTensorDim(axis);
+ loc[0] += curr_dim.getTensorDim(axis);
} else if (axis == 1) {
- loc[3] += t.getDim().getTensorDim(axis);
+ loc[3] += curr_dim.getTensorDim(axis);
} else if (axis == 2 || axis == 3) {
- loc[axis - 1] += t.getDim().getTensorDim(axis);
+ loc[axis - 1] += curr_dim.getTensorDim(axis);
}
}
}
- return ret;
+ return output;
}
void FloatTensor::print(std::ostream &out) const {
std::vector<Tensor> split(std::vector<size_t> sizes, int axis) override;
/**
- * @copydoc Tensor::cat(const std::vector<Tensor> &tensors, int axis)
+ * @copydoc Tensor::concat()
*/
- Tensor concat(const std::vector<Tensor> &tensors, int axis) override;
+ Tensor concat(const std::vector<Tensor> &tensors, int axis,
+ Tensor &output) override;
/**
* @copydoc Tensor::copy(const Tensor &from)
return ret;
}
-Tensor HalfTensor::concat(const std::vector<Tensor> &tensors, int axis) {
- if (axis == -1) {
- axis = 3;
- }
- auto ref_dim = tensors.front().getDim();
- bool is_format_nchw = (ref_dim.getFormat() == Tformat::NCHW);
- ref_dim.setTensorDim(axis, 1);
- NNTR_THROW_IF(!std::all_of(tensors.begin(), tensors.end(),
- [&ref_dim, axis](const Tensor &t) {
- auto cur_dim = t.getDim();
- cur_dim.setTensorDim(axis, 1);
- return ref_dim == cur_dim;
- }),
- std::invalid_argument)
- << " all tensor must have the same dimension except for the axis, ref_dim: "
- << ref_dim << " axis : " << axis;
+Tensor HalfTensor::concat(const std::vector<Tensor> &tensors, int axis,
+ Tensor &output) {
+ bool is_format_nchw = (tensors.front().getDim().getFormat() == Tformat::NCHW);
- auto axis_dim = std::accumulate(tensors.begin(), tensors.end(), 0u,
- [axis](unsigned cur, const Tensor &t) {
- return cur += t.getDim().getTensorDim(axis);
- });
auto iter_value =
[is_format_nchw](std::array<unsigned, 4> &loc,
const std::array<unsigned, 4> &start_loc, Tensor &t,
return value;
};
- auto ret_dim = ref_dim;
- ret_dim.setTensorDim(axis, axis_dim);
-
- Tensor output = Tensor(ret_dim);
-
std::array<unsigned, 4> loc = {0, 0, 0, 0};
for (auto &t : tensors) {
std::array<unsigned, 4> start_loc = loc;
std::array<unsigned, 4> tensor_dim_arr;
- if (is_format_nchw) {
- tensor_dim_arr[0] = t.getDim().getTensorDim(0);
- tensor_dim_arr[1] = t.getDim().getTensorDim(1);
- tensor_dim_arr[2] = t.getDim().getTensorDim(2);
- tensor_dim_arr[3] = t.getDim().getTensorDim(3);
- } else {
- tensor_dim_arr[0] = t.getDim().getTensorDim(0);
- tensor_dim_arr[1] = t.getDim().getTensorDim(2);
- tensor_dim_arr[2] = t.getDim().getTensorDim(3);
- tensor_dim_arr[3] = t.getDim().getTensorDim(1);
- }
+ TensorDim curr_dim = t.getDim();
+
+ tensor_dim_arr[0] = curr_dim.getTensorDim(0);
+ tensor_dim_arr[1] =
+ is_format_nchw ? curr_dim.getTensorDim(1) : curr_dim.getTensorDim(2);
+ tensor_dim_arr[2] =
+ is_format_nchw ? curr_dim.getTensorDim(2) : curr_dim.getTensorDim(3);
+ tensor_dim_arr[3] =
+ is_format_nchw ? curr_dim.getTensorDim(3) : curr_dim.getTensorDim(1);
for (size_t i = 0u, sz = t.size(); i < sz; ++i) {
iter_value(loc, start_loc, output, tensor_dim_arr) = t.getValue<_FP16>(i);
}
if (is_format_nchw) {
- loc[axis] += t.getDim().getTensorDim(axis);
+ loc[axis] += curr_dim.getTensorDim(axis);
} else {
if (axis == 0) {
- loc[0] += t.getDim().getTensorDim(axis);
+ loc[0] += curr_dim.getTensorDim(axis);
} else if (axis == 1) {
- loc[3] += t.getDim().getTensorDim(axis);
+ loc[3] += curr_dim.getTensorDim(axis);
} else if (axis == 2 || axis == 3) {
- loc[axis - 1] += t.getDim().getTensorDim(axis);
+ loc[axis - 1] += curr_dim.getTensorDim(axis);
}
}
}
std::vector<Tensor> split(std::vector<size_t> sizes, int axis) override;
/**
- * @copydoc Tensor::cat(const std::vector<Tensor> &tensors, int axis)
+ * @copydoc Tensor::concat()
*/
- Tensor concat(const std::vector<Tensor> &tensors, int axis) override;
+ Tensor concat(const std::vector<Tensor> &tensors, int axis,
+ Tensor &output) override;
/**
* @copydoc Tensor::copy(const Tensor &from)
#ifdef ENABLE_FP16
#include <half_tensor.h>
#endif
-
namespace nntrainer {
Tensor::Tensor(std::string name_, Tformat fm, Tdatatype d_type) {
return itensor->split(sizes, axis);
}
-Tensor Tensor::concat(const std::vector<Tensor> &tensors, int axis) {
- NNTR_THROW_IF(!(-1 <= axis && axis < 4), std::invalid_argument)
- << "cannot split axis of axis: " << axis;
+Tensor Tensor::concat(const std::vector<Tensor> &tensors, int axis,
+ Tensor &output) {
+ return itensor->concat(tensors, axis, output);
+}
+
+Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
+ if (axis == -1) {
+ axis = 3;
+ }
- NNTR_THROW_IF(tensors.empty(), std::invalid_argument)
- << "given tensor vector is empty";
+ // Create an output tensor to store the concatenation result
+ TensorDim out_dim = Tensor::calculateConcatOutputDim(tensors, axis);
+ Tensor output = Tensor(out_dim);
- return itensor->concat(tensors, axis);
+ return output.concat(tensors, axis, output);
}
-Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
- Tensor input = tensors[0];
- return input.concat(tensors, axis);
+Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis,
+ Tensor &output) {
+ if (axis == -1) {
+ axis = 3;
+ }
+
+ // Check if the given output tensor dimension is valid
+ TensorDim out_dim = Tensor::calculateConcatOutputDim(tensors, axis);
+
+ NNTR_THROW_IF(out_dim != output.getDim(), std::invalid_argument)
+ << "invalid output dim for concatenation " << output.getDim()
+ << "expected output dim " << out_dim;
+
+ return output.concat(tensors, axis, output);
}
void Tensor::print(std::ostream &out) const {
itensor->setTensorVar(d, buf, offset);
}
+TensorDim Tensor::calculateConcatOutputDim(const std::vector<Tensor> &tensors,
+ int axis) {
+ // Check axis, in which the tensors are concatenated, is valid.
+ NNTR_THROW_IF(!(-1 <= axis && axis < 4), std::invalid_argument)
+ << "cannot concatenate tensors along an axis: " << axis;
+
+ // Check if the number of input tensors is valid.
+ NNTR_THROW_IF(tensors.size() <= 1, std::invalid_argument)
+ << "received an invalid tensor vector. size must be greater than 1.";
+
+ auto out_dim = tensors.front().getDim();
+
+ // Check if all tensor data types are the same.
+ for (auto &t : tensors) {
+ NNTR_THROW_IF(t.getDataType() != out_dim.getDataType(),
+ std::invalid_argument)
+ << "cannot concatenate tensors with different data types.";
+ }
+
+ // Compute the dimensions of an output tensor.
+ out_dim.setTensorDim(axis, 1);
+ NNTR_THROW_IF(!std::all_of(tensors.begin(), tensors.end(),
+ [&out_dim, axis](const Tensor &t) {
+ auto cur_dim = t.getDim();
+ cur_dim.setTensorDim(axis, 1);
+ return out_dim == cur_dim;
+ }),
+ std::invalid_argument)
+ << " all tensor must have the same dimension except for the axis, out_dim: "
+ << out_dim << " axis : " << axis;
+
+ auto axis_dim = std::accumulate(tensors.begin(), tensors.end(), 0u,
+ [axis](unsigned cur, const Tensor &t) {
+ return cur += t.getDim().getTensorDim(axis);
+ });
+
+ out_dim.setTensorDim(axis, axis_dim);
+ return out_dim;
+}
+
std::ostream &operator<<(std::ostream &out, Tensor const &input) {
input.print(out);
return out;
*
* @param tensors tensors to be concatenated to the first tensor
* @param axis axis
+ * @param output output tensor to store the result
* @return Tensor concatenated tensor
+ *
+ * @note This function should not be used directly. Please use cat() instead.
*/
- Tensor concat(const std::vector<Tensor> &tensors, int axis = 0);
+ Tensor concat(const std::vector<Tensor> &tensors, int axis, Tensor &output);
/**
* @brief concatenate tensors along axis
*/
static Tensor cat(const std::vector<Tensor> &tensors, int axis = 0);
+ /**
+ * @brief concatenate tensors along axis
+ *
+ * @param tensors tensors to be concatenated to the first tensor
+ * @param axis axis
+ * @param output output tensor to store the result
+ * @return Tensor concatenated tensor
+ */
+ static Tensor cat(const std::vector<Tensor> &tensors, int axis,
+ Tensor &output);
+
/**
* @brief Print element
* @param[in] out out stream
* @param[in] offset offset to be used
*/
void setTensorVar(TensorDim d, void *buf, size_t offset);
+
+ /**
+ * @brief Calculate the output tensor dimension of the concatenating a list of
+ * tensors as an input.
+ *
+ * @param[in] tensors tensors to be concatenated to the first tensor
+ * @param[in] axis axis
+ */
+ static TensorDim calculateConcatOutputDim(const std::vector<Tensor> &tensors,
+ int axis);
};
/**
return ret;
}
-Tensor TensorBase::concat(const std::vector<Tensor> &tensors, int axis) {
+Tensor TensorBase::concat(const std::vector<Tensor> &tensors, int axis,
+ Tensor &output) {
throw std::invalid_argument(
"Tensor::concat() is currently not supported in tensor data type " +
getStringDataType());
virtual std::vector<Tensor> split(std::vector<size_t> sizes, int axis);
/**
- * @copydoc Tensor::concat(const std::vector<Tensor> &tensors, int axis)
+ * @copydoc Tensor::concat()
*/
- virtual Tensor concat(const std::vector<Tensor> &tensors, int axis);
+ virtual Tensor concat(const std::vector<Tensor> &tensors, int axis,
+ Tensor &output);
/**
* @copydoc Tensor::print(std::ostream &out)
18, 54, 55, 56, 36, 37, 19, 57, 58, 59, 38, 39, 20, 60, 61, 62, 40, 41,
21, 63, 64, 65, 42, 43, 22, 66, 67, 68, 44, 45, 23, 69, 70, 71, 46, 47};
nntrainer::Tensor answer(ml::train::TensorDim{3, 2, 4, 6}, answer_data);
- EXPECT_EQ(nntrainer::Tensor::cat(inputs, 3), answer);
+ EXPECT_EQ(nntrainer::Tensor::cat(inputs, -1), answer);
}
}
}
}
+// concatenate an empty list of tensors
+TEST(nntrainer_Tensor, cat_03_n) {
+ std::vector<nntrainer::Tensor> inputs;
+ EXPECT_THROW(nntrainer::Tensor::cat(inputs, 0), std::invalid_argument);
+}
+
+// concatenate a single tensor
+TEST(nntrainer_Tensor, cat_04_n) {
+ std::vector<nntrainer::Tensor> inputs;
+ inputs.reserve(1);
+ inputs.emplace_back(nntrainer::Tensor(2, 1, 1, 2));
+ EXPECT_THROW(nntrainer::Tensor::cat(inputs, 0), std::invalid_argument);
+}
+
+// concatenate tensors with different data types
+TEST(nntrainer_Tensor, cat_05_n) {
+ std::vector<nntrainer::Tensor> inputs;
+ inputs.reserve(2);
+ inputs.emplace_back(nntrainer::Tensor(2, 1, 1, 2));
+ inputs.emplace_back(nntrainer::Tensor(
+ 2, 1, 1, 2, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8}));
+ EXPECT_THROW(nntrainer::Tensor::cat(inputs, 0), std::invalid_argument);
+}
+
+// incorrect output tensor dimension
+TEST(nntrainer_Tensor, cat_06_n) {
+ std::vector<nntrainer::Tensor> inputs;
+ inputs.reserve(2);
+ inputs.emplace_back(nntrainer::Tensor(3, 2, 4, 1));
+ inputs.emplace_back(nntrainer::Tensor(3, 2, 4, 3));
+ nntrainer::Tensor output(3, 2, 4, 5);
+ EXPECT_THROW(nntrainer::Tensor::cat(inputs, 3, output),
+ std::invalid_argument);
+}
+
+// tensors not having the same shape except for the axis
+TEST(nntrainer_Tensor, cat_07_n) {
+ std::vector<nntrainer::Tensor> inputs;
+ inputs.reserve(2);
+ inputs.emplace_back(nntrainer::Tensor(3, 2, 4, 1));
+ inputs.emplace_back(nntrainer::Tensor(3, 1, 4, 3));
+ EXPECT_THROW(nntrainer::Tensor::cat(inputs, 1), std::invalid_argument);
+ EXPECT_THROW(nntrainer::Tensor::cat(inputs, 3), std::invalid_argument);
+}
+
TEST(nntrainer_Tensor, zoneout_mask_01_n) {
const float zoneout_rate = 0.3f;
nntrainer::Tensor t(10, 10, 10, 10);