#define __COCO_IR_KERNEL_OBJECT_H__
#include "coco/IR/Object.h"
+#include "coco/IR/ElemID.h"
#include <nncc/core/ADT/kernel/Shape.h>
+#include <vector>
+
namespace coco
{
private:
nncc::core::ADT::kernel::Shape const _shape;
+
+public:
+ ElemID &at(uint32_t n, uint32_t ch, uint32_t row, uint32_t col);
+ const ElemID &at(uint32_t n, uint32_t ch, uint32_t row, uint32_t col) const;
+
+private:
+ std::vector<ElemID> _map;
};
} // namespace coco
#include "coco/IR/KernelObject.h"
+#include <nncc/core/ADT/kernel/NCHWLayout.h>
+
+namespace
+{
+
+static nncc::core::ADT::kernel::NCHWLayout l{};
+
+} // namespace
+
namespace coco
{
KernelObject::KernelObject(const nncc::core::ADT::kernel::Shape &shape) : _shape{shape}
{
- // DO NOTHING
+ _map.resize(nncc::core::ADT::kernel::num_elements(shape));
+}
+
+ElemID &KernelObject::at(uint32_t n, uint32_t ch, uint32_t row, uint32_t col)
+{
+ return _map.at(l.offset(_shape, n, ch, row, col));
+}
+
+const ElemID &KernelObject::at(uint32_t n, uint32_t ch, uint32_t row, uint32_t col) const
+{
+ return _map.at(l.offset(_shape, n, ch, row, col));
}
} // namespace coco
ASSERT_EQ(o.shape().height(), shape.height());
ASSERT_EQ(o.shape().width(), shape.width());
}
+
+TEST(IR_KERNEL_OBJECT, at)
+{
+ const uint32_t N = 1;
+ const uint32_t C = 1;
+ const uint32_t H = 3;
+ const uint32_t W = 3;
+
+ const nncc::core::ADT::kernel::Shape shape{N, C, H, W};
+ coco::KernelObject o{shape};
+
+ coco::KernelObject *mutable_ptr = &o;
+ const coco::KernelObject *immutable_ptr = &o;
+
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ for (uint32_t ch = 0; ch < C; ++ch)
+ {
+ for (uint32_t row = 0; row < H; ++row)
+ {
+ for (uint32_t col = 0; col < W; ++col)
+ {
+ mutable_ptr->at(n, ch, row, col) = coco::ElemID{16};
+ }
+ }
+ }
+ }
+
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ for (uint32_t ch = 0; ch < C; ++ch)
+ {
+ for (uint32_t row = 0; row < H; ++row)
+ {
+ for (uint32_t col = 0; col < W; ++col)
+ {
+ ASSERT_EQ(immutable_ptr->at(n, ch, row, col).value(), 16);
+ }
+ }
+ }
+ }
+}