namespace kernel
{
-struct Layout
+class Layout
{
- virtual ~Layout() = default;
+public:
+ using Func = uint32_t (*)(const Shape &, uint32_t n, uint32_t ch, uint32_t row, uint32_t col);
- virtual uint32_t offset(const Shape &, uint32_t n, uint32_t ch, uint32_t row,
- uint32_t col) const = 0;
+public:
+ Layout(const Func &func);
+
+public:
+ uint32_t offset(const Shape &shape, uint32_t n, uint32_t ch, uint32_t row, uint32_t col) const
+ {
+ return _func(shape, n, ch, row, col);
+ }
+
+private:
+ Func _func;
};
} // namespace kernel
--- /dev/null
+#include "nncc/core/ADT/kernel/Layout.h"
+
+#include <gtest/gtest.h>
+
+using nncc::core::ADT::kernel::Shape;
+using nncc::core::ADT::kernel::Layout;
+
+static uint32_t offset_0(const Shape &, uint32_t, uint32_t, uint32_t, uint32_t) { return 0; }
+static uint32_t offset_1(const Shape &, uint32_t, uint32_t, uint32_t, uint32_t) { return 1; }
+
+TEST(ADT_FEATURE_LAYOUT, ctor)
+{
+ Layout l{offset_0};
+
+ ASSERT_EQ(l.offset(Shape{4, 3, 6, 5}, 1, 1, 1, 1), 0);
+}
+
+TEST(ADT_FEATURE_LAYOUT, copy)
+{
+ Layout orig{offset_0};
+ Layout copy{offset_1};
+
+ ASSERT_EQ(copy.offset(Shape{4, 3, 6, 5}, 1, 1, 1, 1), 1);
+
+ copy = orig;
+
+ ASSERT_EQ(copy.offset(Shape{4, 3, 6, 5}, 1, 1, 1, 1), 0);
+}
+
+TEST(ADT_FEATURE_LAYOUT, move)
+{
+ Layout orig{offset_0};
+ Layout move{offset_1};
+
+ ASSERT_EQ(move.offset(Shape{4, 3, 6, 5}, 1, 1, 1, 1), 1);
+
+ move = std::move(orig);
+
+ ASSERT_EQ(move.offset(Shape{4, 3, 6, 5}, 1, 1, 1, 1), 0);
+}
#include "nncc/core/ADT/kernel/NCHWLayout.h"
+using nncc::core::ADT::kernel::Shape;
+
+static uint32_t NCHW_offset(const Shape &shape, uint32_t n, uint32_t ch, uint32_t row, uint32_t col)
+{
+ return (((n * shape.depth() + ch) * shape.height() + row) * shape.width() + col);
+}
+
namespace nncc
{
namespace core
namespace kernel
{
-uint32_t NCHWLayout::offset(const Shape &shape, uint32_t n, uint32_t ch, uint32_t row,
- uint32_t col) const
+NCHWLayout::NCHWLayout() : Layout{NCHW_offset}
{
- return (((n * shape.depth() + ch) * shape.height() + row) * shape.width() + col);
+ // DO NOTHING
}
} // namespace kernel