#include "loco/ADT/ObjectPool.h"
+#include <initializer_list>
#include <set>
#include <string>
#include <memory>
// Any "TensorShaped" class has the following methods
// - const TensorShape *shape(void) const;
// - void shape(std::unique_ptr<TensorShape> &&);
+ // - void shape(std::initializer_list<Dimension> &&);
//
// TODO Rename NodeMixin::TensorShape as NodeMixin::NDShape
TensorShaped,
public:
const TensorShape *shape(void) const { return _shape.get(); }
void shape(std::unique_ptr<TensorShape> &&shape) { _shape = std::move(shape); }
+ void shape(std::initializer_list<Dimension> dims);
private:
std::unique_ptr<TensorShape> _shape = nullptr;
#include <cassert>
+namespace
+{
+
+std::unique_ptr<loco::TensorShape> make_tensor_shape(std::initializer_list<loco::Dimension> dims)
+{
+ auto tensor_shape = stdex::make_unique<loco::TensorShape>();
+
+ tensor_shape->rank(dims.size());
+ {
+ uint32_t axis = 0;
+ for (auto it = dims.begin(); it != dims.end(); ++it)
+ {
+ tensor_shape->dim(axis++) = *it;
+ }
+ assert(axis == dims.size());
+ }
+
+ return std::move(tensor_shape);
+}
+
+} // namespace
+
namespace loco
{
+void Mixin<Trait::TensorShaped>::shape(std::initializer_list<Dimension> dims)
+{
+ shape(make_tensor_shape(dims));
+}
+
void GraphInput::node(Pull *pull)
{
_pull = pull;
ASSERT_EQ(mixin.dtype(), loco::DataType::FLOAT32);
}
+TEST(TensorShapedMixinTest, setter_and_getter)
+{
+ loco::Mixin<loco::Trait::TensorShaped> mixin;
+
+ mixin.shape({1, 2, 3, 4});
+ ASSERT_NE(mixin.shape(), nullptr);
+ ASSERT_EQ(mixin.shape()->rank(), 4);
+ ASSERT_EQ(mixin.shape()->dim(0), 1);
+ ASSERT_EQ(mixin.shape()->dim(1), 2);
+ ASSERT_EQ(mixin.shape()->dim(2), 3);
+ ASSERT_EQ(mixin.shape()->dim(3), 4);
+}
+
TEST(GraphTest, create_and_destroy_node)
{
auto g = loco::make_graph();