NNTR_THROW_IF(!contiguous, std::invalid_argument)
<< getName() << " is not contiguous. Cannot dot product.";
- if (m.dim.rank() > 2) {
- throw exception::not_supported("Error: support only for rank of dot "
- "matrix <= 2");
- }
+ // Comment out with intension to support the calculation wrt. batch and height
+ // direction. It supposes to have this->dim as [ BxCxH,W ] and m.dim is
+ // [BxCxH,W] as well if (m.dim.rank() > 2) {
+ // throw exception::not_supported("Error: support only for rank of dot "
+ // "matrix <= 2");
+ // }
// Comment out with intension to support the calculation wrt. batch and height
// direction of this tensor. It is OK as long as m is 2D
TEST(nntrainer_Tensor, dot_01_n) {
nntrainer::Tensor input(2, 3, 4, 5);
nntrainer::Tensor m(1, 3, 4, 5);
- EXPECT_THROW(nntrainer::Tensor result = input.dot(m),
- nntrainer::exception::not_supported);
+ EXPECT_THROW(nntrainer::Tensor result = input.dot(m), std::runtime_error);
}
TEST(nntrainer_Tensor, dot_02_n) {
nntrainer::Tensor input(2, 3, 4, 5);
nntrainer::Tensor m(1, 3, 4, 5);
EXPECT_THROW(nntrainer::Tensor result = input.dot(m, true),
- nntrainer::exception::not_supported);
+ std::runtime_error);
}
-TEST(nntrainer_Tensor, dot_03_n) {
+TEST(nntrainer_Tensor, dot_02_p) {
+ nntrainer::Tensor input(2, 3, 4, 5);
+ nntrainer::Tensor m(1, 3, 4, 5);
+ EXPECT_NO_THROW(nntrainer::Tensor result = input.dot(m, false, true));
+}
+
+TEST(nntrainer_Tensor, dot_03_p) {
nntrainer::Tensor input(1, 3, 4, 5);
nntrainer::Tensor m(1, 3, 4, 5);
- EXPECT_THROW(nntrainer::Tensor result = input.dot(m, true),
- nntrainer::exception::not_supported);
+ EXPECT_NO_THROW(nntrainer::Tensor result = input.dot(m, true));
}
TEST(nntrainer_Tensor, dot_04_n) {