From 2bf5fd2b5e5c032d0c1803b271c2462e171e5d40 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Sun, 1 Dec 2019 07:41:00 -0800 Subject: [PATCH] [Runtime] Make ADTObject POD container type (#4346) --- include/tvm/runtime/container.h | 279 ++++++++++++++++++++++++++++++++++++++++ include/tvm/runtime/memory.h | 79 +++++++++++- include/tvm/runtime/vm.h | 29 ----- src/runtime/vm/object.cc | 29 ++--- src/runtime/vm/vm.cc | 21 ++- tests/cpp/container_test.cc | 130 ++++++++++++++++++- 6 files changed, 498 insertions(+), 69 deletions(-) create mode 100644 include/tvm/runtime/container.h diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h new file mode 100644 index 0000000..2714ac2 --- /dev/null +++ b/include/tvm/runtime/container.h @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container.h + * \brief Common POD(plain old data) container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_H_ +#define TVM_RUNTIME_CONTAINER_H_ +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Base template for classes with array like memory layout. + * + * It provides general methods to access the memory. The memory + * layout is ArrayType + [ElemType]. The alignment of ArrayType + * and ElemType is handled by the memory allocator. + * + * \tparam ArrayType The array header type, contains object specific metadata. + * \tparam ElemType The type of objects stored in the array right after + * ArrayType. + * + * \code + * // Example usage of the template to define a simple array wrapper + * class ArrayObj : public InplaceArrayBase { + * public: + * // Wrap EmplaceInit to initialize the elements + * template + * void Init(Iterator begin, Iterator end) { + * size_t num_elems = std::distance(begin, end); + * auto it = begin; + * this->size = 0; + * for (size_t i = 0; i < num_elems; ++i) { + * InplaceArrayBase::EmplaceInit(i, *it++); + * this->size++; + * } + * } + * } + * + * void test_function() { + * vector fields; + * auto ptr = make_inplace_array_object(fields.size()); + * ptr->Init(fields.begin(), fields.end()); + * + * // Access the 0th element in the array. + * assert(ptr->operator[](0) == fields[0]); + * } + * + * \endcode + */ +template +class InplaceArrayBase { + public: + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Const reference to ElemType at the index. + */ + const ElemType& operator[](size_t idx) const { + size_t size = Self()->GetSize(); + CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Reference to ElemType at the index. + */ + ElemType& operator[](size_t idx) { + size_t size = Self()->GetSize(); + CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Destroy the Inplace Array Base object + */ + ~InplaceArrayBase() { + if (!(std::is_standard_layout::value && + std::is_trivial::value)) { + size_t size = Self()->GetSize(); + for (size_t i = 0; i < size; ++i) { + ElemType* fp = reinterpret_cast(AddressOf(i)); + fp->ElemType::~ElemType(); + } + } + } + + protected: + /*! + * \brief Construct a value in place with the arguments. + * + * \tparam Args Type parameters of the arguments. + * \param idx Index of the element. + * \param args Arguments to construct the new value. + * + * \note Please make sure ArrayType::GetSize returns 0 before first call of + * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. + */ + template + void EmplaceInit(size_t idx, Args&&... args) { + void* field_ptr = AddressOf(idx); + new (field_ptr) ElemType(std::forward(args)...); + } + + private: + /*! + * \brief Return the self object for the array. + * + * \return Pointer to ArrayType. + */ + inline ArrayType* Self() const { + return static_cast(const_cast(this)); + } + + /*! + * \brief Return the raw pointer to the element at idx. + * + * \param idx The index of the element. + * \return Raw pointer to the element. + */ + void* AddressOf(size_t idx) const { + static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && + sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); + + size_t kDataStart = sizeof(ArrayType); + ArrayType* self = Self(); + char* data_start = reinterpret_cast(self) + kDataStart; + return data_start + idx * sizeof(ElemType); + } +}; + +/*! \brief An object representing a structure or enumeration. */ +class ADTObj : public Object, public InplaceArrayBase { + public: + /*! \brief The tag representing the constructor used. */ + uint32_t tag; + /*! \brief Number of fields in the ADT object. */ + uint32_t size; + // The fields of the structure follows directly in memory. + + static constexpr const uint32_t _type_index = TypeIndex::kVMADT; + static constexpr const char* _type_key = "vm.ADT"; + TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); + + private: + /*! + * \return The number of elements in the array. + */ + size_t GetSize() const { return size; } + + /*! + * \brief Initialize the elements in the array. + * + * \tparam Iterator Iterator type of the array. + * \param begin The begin iterator. + * \param end The end iterator. + */ + template + void Init(Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + this->size = 0; + auto it = begin; + for (size_t i = 0; i < num_elems; ++i) { + InplaceArrayBase::EmplaceInit(i, *it++); + // Only increment size after the initialization succeeds + this->size++; + } + } + + friend class ADT; + friend class InplaceArrayBase; +}; + +/*! \brief reference to algebraic data type objects. */ +class ADT : public ObjectRef { + public: + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param fields The fields of the ADT object. + * \return The constructed ADT object reference. + */ + ADT(uint32_t tag, std::vector fields) + : ADT(tag, fields.begin(), fields.end()){}; + + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param begin The begin iterator to the start of the fields array. + * \param end The end iterator to the end of the fields array. + * \return The constructed ADT object reference. + */ + template + ADT(uint32_t tag, Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + auto ptr = make_inplace_array_object(num_elems); + ptr->tag = tag; + ptr->Init(begin, end); + data_ = std::move(ptr); + } + + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param init The initializer list of fields. + * \return The constructed ADT object reference. + */ + ADT(uint32_t tag, std::initializer_list init) + : ADT(tag, init.begin(), init.end()){}; + + /*! + * \brief Access element at index. + * + * \param idx The array index + * \return const ObjectRef + */ + const ObjectRef& operator[](size_t idx) const { + return operator->()->operator[](idx); + } + + /*! + * \brief Return the ADT tag. + */ + size_t tag() const { return operator->()->tag; } + + /*! + * \brief Return the number of fields. + */ + size_t size() const { return operator->()->size; } + + /*! + * \brief Construct a tuple object. + * + * \tparam Args Type params of tuple feilds. + * \param args Tuple fields. + * \return ADT The tuple object reference. + */ + template + static ADT Tuple(Args&&... args) { + return ADT(0, std::forward(args)...); + } + + TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_H_ diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 07e22d7..63f3e4e 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -23,6 +23,7 @@ #ifndef TVM_RUNTIME_MEMORY_H_ #define TVM_RUNTIME_MEMORY_H_ +#include #include #include #include "object.h" @@ -33,7 +34,7 @@ namespace runtime { * \brief Allocate an object using default allocator. * \param args arguments to the constructor. * \tparam T the node type. - * \return The NodePtr to the allocated object. + * \return The ObjectPtr to the allocated object. */ template inline ObjectPtr make_object(Args&&... args); @@ -67,13 +68,33 @@ class ObjAllocatorBase { inline ObjectPtr make_object(Args&&... args) { using Handler = typename Derived::template Handler; static_assert(std::is_base_of::value, - "make_node can only be used to create NodeBase"); + "make can only be used to create Object"); T* ptr = Handler::New(static_cast(this), std::forward(args)...); ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); } + + /*! + * \tparam ArrayType The type to be allocated. + * \tparam ElemType The type of array element. + * \tparam Args The constructor signature. + * \param num_elems The number of array elements. + * \param args The arguments. + */ + template + inline ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { + using Handler = typename Derived::template ArrayHandler; + static_assert(std::is_base_of::value, + "make_inplace_array can only be used to create Object"); + ArrayType* ptr = Handler::New(static_cast(this), + num_elems, + std::forward(args)...); + ptr->type_index_ = ArrayType::RuntimeTypeIndex(); + ptr->deleter_ = Handler::Deleter(); + return ObjectPtr(ptr); + } }; // Simple allocator that uses new/delete. @@ -123,6 +144,54 @@ class SimpleObjAllocator : delete reinterpret_cast(tptr); } }; + + // Array handler that uses new/delete. + template + class ArrayHandler { + public: + using StorageType = typename std::aligned_union::type; + + template + static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { + // NOTE: the first argument is not needed for ArrayObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + size_t factor = sizeof(ArrayType) / sizeof(ElemType); + num_elems = (num_elems + factor - 1) / factor; + StorageType* data = new StorageType[num_elems+1]; + new (data) ArrayType(std::forward(args)...); + return reinterpret_cast(data); + } + + static Object::FDeleter Deleter() { + return Deleter_; + } + + private: + static void Deleter_(Object* objptr) { + // NOTE: this is important to cast back to ArrayType* + // because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + ArrayType* tptr = static_cast(objptr); + // It is important to do tptr->ArrayType::~ArrayType(), + // so that we explicitly call the specific destructor + // instead of tptr->~ArrayType(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->ArrayType::~ArrayType(); + StorageType* p = reinterpret_cast(tptr); + delete []p; + } + }; }; template @@ -130,6 +199,12 @@ inline ObjectPtr make_object(Args&&... args) { return SimpleObjAllocator().make_object(std::forward(args)...); } +template +inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&... args) { + return SimpleObjAllocator().make_inplace_array( + num_elems, std::forward(args)...); +} + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_MEMORY_H_ diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index f7188e4..59e9ae8 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -55,35 +55,6 @@ class Tensor : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj); }; - -/*! \brief An object representing a structure or enumeration. */ -class ADTObj : public Object { - public: - /*! \brief The tag representing the constructor used. */ - size_t tag; - /*! \brief The fields of the structure. */ - std::vector fields; - - static constexpr const uint32_t _type_index = TypeIndex::kVMADT; - static constexpr const char* _type_key = "vm.ADT"; - TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); -}; - -/*! \brief reference to algebraic data type objects. */ -class ADT : public ObjectRef { - public: - ADT(size_t tag, std::vector fields); - - /*! - * \brief construct a tuple object. - * \param fields The fields of the tuple. - * \return The constructed tuple type. - */ - static ADT Tuple(std::vector fields); - - TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); -}; - /*! \brief An object representing a closure. */ class ClosureObj : public Object { public: diff --git a/src/runtime/vm/object.cc b/src/runtime/vm/object.cc index 12edf51..988ba5d 100644 --- a/src/runtime/vm/object.cc +++ b/src/runtime/vm/object.cc @@ -22,6 +22,7 @@ * \brief VM related objects. */ #include +#include #include #include #include @@ -39,17 +40,6 @@ Tensor::Tensor(NDArray data) { data_ = std::move(ptr); } -ADT::ADT(size_t tag, std::vector fields) { - auto ptr = make_object(); - ptr->tag = tag; - ptr->fields = std::move(fields); - data_ = std::move(ptr); -} - -ADT ADT::Tuple(std::vector fields) { - return ADT(0, fields); -} - Closure::Closure(size_t func_index, std::vector free_vars) { auto ptr = make_object(); ptr->func_index = func_index; @@ -69,17 +59,15 @@ TVM_REGISTER_GLOBAL("_vmobj.GetTensorData") TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; - const auto* cell = obj.as(); - CHECK(cell != nullptr); - *rv = static_cast(cell->tag); + const auto& adt = Downcast(obj); + *rv = static_cast(adt.tag()); }); TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; - const auto* cell = obj.as(); - CHECK(cell != nullptr); - *rv = static_cast(cell->fields.size()); + const auto& adt = Downcast(obj); + *rv = static_cast(adt.size()); }); @@ -87,10 +75,9 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; - const auto* cell = obj.as(); - CHECK(cell != nullptr); - CHECK_LT(idx, cell->fields.size()); - *rv = cell->fields[idx]; + const auto& adt = Downcast(obj); + CHECK_LT(idx, adt.size()); + *rv = adt[idx]; }); TVM_REGISTER_GLOBAL("_vmobj.Tensor") diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 333dd1e..41fe71a 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -755,7 +756,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, size_t arity = 0; for (Index i = 0; i < arg_count; i++) { if (const auto* obj = args[i].as()) { - arity += obj->fields.size(); + arity += obj->size; } else { ++arity; } @@ -767,7 +768,8 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, int idx = 0; for (Index i = 0; i < arg_count; i++) { if (const auto* dt_cell = args[i].as()) { - for (auto obj : dt_cell->fields) { + for (size_t fi = 0; fi < dt_cell->size; ++fi) { + auto obj = (*dt_cell)[fi]; const auto* tensor = obj.as(); CHECK(tensor != nullptr); setter(idx++, tensor->data); @@ -924,23 +926,16 @@ void VirtualMachine::RunLoop() { } case Opcode::GetField: { auto object = ReadRegister(instr.object); - const auto* tuple = object.as(); - CHECK(tuple != nullptr) - << "Object is not data type object, register " << instr.object << ", Object tag " - << object->type_index(); - auto field = tuple->fields[instr.field_index]; + const auto& tuple = Downcast(object); + auto field = tuple[instr.field_index]; WriteRegister(instr.dst, field); pc_++; goto main_loop; } case Opcode::GetTag: { auto object = ReadRegister(instr.get_tag.object); - const auto* data = object.as(); - CHECK(data != nullptr) - << "Object is not data type object, register " - << instr.get_tag.object << ", Object tag " - << object->type_index(); - auto tag = data->tag; + const auto& adt = Downcast(object); + auto tag = adt.tag(); auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); reinterpret_cast(tag_tensor->data)[0] = tag; WriteRegister(instr.dst, Tensor(tag_tensor)); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 005e159..4428642 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -17,11 +17,132 @@ * under the License. */ -#include -#include #include #include #include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::runtime; + +class TestErrorSwitch { + public: + // Need this so that destructor of temporary objects don't interrupt our + // testing. + TestErrorSwitch(const TestErrorSwitch& other) + : should_fail(other.should_fail) { + const_cast(other).should_fail = false; + } + + TestErrorSwitch(bool fail_flag) : should_fail{fail_flag} {} + bool should_fail{false}; + + ~TestErrorSwitch() { + if (should_fail) { + exit(1); + } + } +}; + +class TestArrayObj : public Object, + public InplaceArrayBase { + public: + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "test.TestArrayObj"; + TVM_DECLARE_FINAL_OBJECT_INFO(TestArrayObj, Object); + uint32_t size; + + size_t GetSize() const { return size; } + + template + void Init(Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + this->size = 0; + auto it = begin; + for (size_t i = 0; i < num_elems; ++i) { + InplaceArrayBase::EmplaceInit(i, *it++); + if (i == 1) { + throw std::bad_alloc(); + } + // Only increment size after the initialization succeeds + this->size++; + } + } + + template + void WrongInit(Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + this->size = num_elems; + auto it = begin; + for (size_t i = 0; i < num_elems; ++i) { + InplaceArrayBase::EmplaceInit(i, *it++); + if (i == 1) { + throw std::bad_alloc(); + } + } + } + + friend class InplaceArrayBase; +}; + +TEST(ADT, Constructor) { + std::vector fields; + auto f1 = ADT::Tuple(fields); + auto f2 = ADT::Tuple(fields); + ADT v1{1, {f1, f2}}; + ASSERT_EQ(f1.tag(), 0); + ASSERT_EQ(f2.size(), 0); + ASSERT_EQ(v1.tag(), 1); + ASSERT_EQ(v1.size(), 2); + ASSERT_EQ(Downcast(v1[0]).tag(), 0); + ASSERT_EQ(Downcast(v1[1]).size(), 0); +} + +TEST(InplaceArrayBase, BadExceptionSafety) { + auto wrong_init = []() { + TestErrorSwitch f1{false}; + // WrongInit will set size to 3 so it will call destructor at index 1, which + // will exit with error status. + TestErrorSwitch f2{true}; + TestErrorSwitch f3{false}; + std::vector fields{f1, f2, f3}; + auto ptr = + make_inplace_array_object(fields.size()); + try { + ptr->WrongInit(fields.begin(), fields.end()); + } catch (...) { + } + // Call ~InplaceArrayBase + ptr.reset(); + // never reaches here. + exit(0); + }; + ASSERT_EXIT(wrong_init(), ::testing::ExitedWithCode(1), ""); +} + +TEST(InplaceArrayBase, ExceptionSafety) { + auto correct_init = []() { + TestErrorSwitch f1{false}; + // Init will fail at index 1, so destrucotr at index 1 should not be called + // since it's not initalized. + TestErrorSwitch f2{true}; + std::vector fields{f1, f2}; + auto ptr = + make_inplace_array_object(fields.size()); + try { + ptr->Init(fields.begin(), fields.end()); + } catch (...) { + } + // Call ~InplaceArrayBase + ptr.reset(); + // Skip the destructors of f1, f2, and fields + exit(0); + }; + ASSERT_EXIT(correct_init(), ::testing::ExitedWithCode(0), ""); +} TEST(Array, Expr) { using namespace tvm; @@ -99,11 +220,12 @@ TEST(Map, Iterator) { using namespace tvm; Expr a = 1, b = 2; Map map1{{a, b}}; - std::unordered_map map2(map1.begin(), map1.end()); + std::unordered_map map2(map1.begin(), + map1.end()); CHECK(map2[a].as()->value == 2); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); -- 2.7.4