--- /dev/null
+#include "ConcatSpec.h"
+
+#include <cassert>
+
+using namespace nncc::core::ADT::tensor;
+
+nncc::core::ADT::tensor::Shape ConcatSpec::forward(const ShapeList &inputs) const
+{
+ assert(inputs.size() > 0);
+
+ Shape output_shape = inputs.at(0);
+
+ for (uint32_t n = 1; n < inputs.size(); ++n)
+ {
+ // The current implementation assumes that "inputs" is well-formed
+ // TODO Verify whether "inputs" is really well-formed
+ const auto &input_shape = inputs.at(n);
+ output_shape.dim(_axis) += input_shape.dim(_axis);
+ }
+
+ return output_shape;
+}
+
+ConcatSpec concat_spec(uint32_t axis) { return ConcatSpec{axis}; }
--- /dev/null
+#ifndef __CONCAT_SPEC_H__
+#define __CONCAT_SPEC_H__
+
+#include <nncc/core/ADT/tensor/Shape.h>
+
+#include <vector>
+
+using ShapeList = std::vector<nncc::core::ADT::tensor::Shape>;
+
+class ConcatSpec
+{
+public:
+ explicit ConcatSpec(uint32_t axis) : _axis{axis}
+ {
+ // DO NOTHING
+ }
+
+public:
+ // @brief Return the output shape when inputs of given shape are
+ // concatenated along _axis
+ nncc::core::ADT::tensor::Shape forward(const ShapeList &) const;
+
+private:
+ uint32_t _axis;
+};
+
+ConcatSpec concat_spec(uint32_t axis);
+
+#endif // __CONCAT_SPEC_H__
--- /dev/null
+#include "ConcatSpec.h"
+
+#include <gtest/gtest.h>
+
+using nncc::core::ADT::tensor::Shape;
+
+namespace
+{
+class ConcatSpecTest : public ::testing::Test
+{
+ // FOR FUTURE USE
+};
+} // namespace
+
+TEST_F(ConcatSpecTest, ifm_shape)
+{
+ const Shape in_1{1, 1, 4, 4};
+ const Shape in_2{1, 2, 4, 4};
+ const Shape in_3{1, 3, 4, 4};
+ const Shape in_4{1, 4, 4, 4};
+
+ auto expected = Shape{1, 10, 4, 4};
+ auto obtained = concat_spec(1).forward({in_1, in_2, in_3, in_4});
+
+ ASSERT_EQ(expected, obtained);
+}