[IE][VPU][GT]: Use topological order in shape allocation (#1281)
authorAndrew Bakalin <andrew.bakalin@intel.com>
Fri, 17 Jul 2020 08:47:37 +0000 (11:47 +0300)
committerGitHub <noreply@github.com>
Fri, 17 Jul 2020 08:47:37 +0000 (11:47 +0300)
* Some pass creates datas duplicate with a different order from time to time (because of unordered_set usage). It leads to a different order in model->datas() list and affects the shape allocation process which relies on this order.
* Make shape allocation be relied on topological order of datas which is stable and doesn't depend on order datas creation during different passes.

inference-engine/src/vpu/graph_transformer/include/vpu/middleend/allocator/allocator.hpp
inference-engine/src/vpu/graph_transformer/include/vpu/model/data.hpp
inference-engine/src/vpu/graph_transformer/src/middleend/allocator/allocator.cpp
inference-engine/src/vpu/graph_transformer/src/middleend/passes/allocate_resources.cpp
inference-engine/src/vpu/graph_transformer/src/model/data.cpp

index 08fca0b..359ea80 100644 (file)
@@ -77,7 +77,7 @@ public:
      * Allocates memory for single data node
      */
     bool allocateData(const Data& data);
-    ShapeLocation allocateShape(Data& data);
+    ShapeLocation allocateShape(const Data& data);
     void freeData(const Data& data, DeallocationMode mode = DeallocationMode::JustFree);
 
     void selfCheck();
index 8c33d52..32609aa 100644 (file)
@@ -82,10 +82,19 @@ struct ShapeLocation final {
     int dimsOffset;
     Location stridesLocation;
     int stridesOffset;
+
+    bool operator==(const ShapeLocation& shapeLocation) const {
+        return std::tie(dimsLocation, dimsOffset, stridesLocation, stridesOffset) ==
+               std::tie(shapeLocation.dimsLocation, shapeLocation.dimsOffset, shapeLocation.stridesLocation, shapeLocation.stridesOffset);
+    }
+
+    bool operator!=(const ShapeLocation& shapeLocation) const {
+        return !(*this == shapeLocation);
+    }
 };
 
 static constexpr ShapeLocation defaultShapeLocation = {
-        Location::None, 0, Location::None, 0
+    Location::None, 0, Location::None, 0
 };
 
 //
@@ -251,6 +260,8 @@ public:
 
     void setShapeAllocationInfo(const ShapeLocation& shapeLocation);
 
+    bool isShapeAllocated() const;
+
     //
     // Backend utilities
     //
index 6084400..4ac2a76 100644 (file)
@@ -296,7 +296,7 @@ bool Allocator::allocateData(const Data& data) {
     return chunk->memType == memoryType;
 }
 
-ShapeLocation Allocator::allocateShape(Data& data) {
+ShapeLocation Allocator::allocateShape(const Data& data) {
     ShapeLocation shapeLocation;
 
     const auto dimsByteSize = data->desc().dimsByteSize();
index 91e0c1f..f852313 100644 (file)
@@ -206,9 +206,23 @@ AllocationResult runAllocator(const Model& model, EnableShapeAllocation enableSh
     //
 
     if (enableShapeAllocation == EnableShapeAllocation::YES) {
-        for (auto data : model->datas()) {
-            const auto shapeLocation = allocator.allocateShape(data);
-            data->setShapeAllocationInfo(shapeLocation);
+        for (const auto& stage : model->getStages()) {
+            const auto& allocateShape = [&allocator](const Data& data) {
+                if (!data->isShapeAllocated()) {
+                    const auto shapeLocation = allocator.allocateShape(data);
+                    data->setShapeAllocationInfo(shapeLocation);
+                }
+            };
+
+            for (const auto& input : stage->inputs()) {
+                allocateShape(input);
+            }
+            for (const auto& output : stage->outputs()) {
+                allocateShape(output);
+            }
+            for (const auto& tempBuffer : stage->tempBuffers()) {
+                allocateShape(tempBuffer);
+            }
         }
     }
 
index 72d9968..9f86a3f 100644 (file)
@@ -176,6 +176,10 @@ void DataNode::setShapeAllocationInfo(const ShapeLocation& shapeLocation) {
     _shapeLocation = shapeLocation;
 }
 
+bool DataNode::isShapeAllocated() const {
+    return _shapeLocation != defaultShapeLocation;
+}
+
 void DataNode::serializeBuffer(
         BlobSerializer& serializer) {
     serializeDescImpl(serializer, _desc, this->shapeLocation());