Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / model / data_desc.cpp
index d97efe2..9001946 100644 (file)
@@ -44,6 +44,8 @@ DimsOrder DimsOrder::HCW = DimsOrder::fromCode(0x231);
 DimsOrder DimsOrder::NCHW = DimsOrder::fromCode(0x4321);
 DimsOrder DimsOrder::NHWC = DimsOrder::fromCode(0x4213);
 DimsOrder DimsOrder::NHCW = DimsOrder::fromCode(0x4231);
+DimsOrder DimsOrder::NCDHW = DimsOrder::fromCode(0x43521);
+DimsOrder DimsOrder::NDHWC = DimsOrder::fromCode(0x45213);
 
 namespace {
 
@@ -109,12 +111,18 @@ DimsOrder DimsOrder::fromNumDims(int numDims) {
         return DimsOrder::C;
     } else if (numDims == 2) {
         return DimsOrder::NC;
+    } else if (numDims == 3) {
+        return DimsOrder::CHW;
+    } else if (numDims == 4) {
+        return DimsOrder::NCHW;
+    } else if (numDims == 5) {
+        return DimsOrder::NCDHW;
     } else {
         return DimsOrder::fromCode(maskOrder(FULL_ORDER_DEFAULT, numDims));
     }
 }
 
-DimsOrder DimsOrder::fromPermutation(const SmallVector<Dim, MAX_DIMS_64>& perm) {
+DimsOrder DimsOrder::fromPermutation(const DimVector& perm) {
     StorageOrder64 code = 0;
 
     for (int sh = 0, i = 0; i < perm.size(); i++, sh += 4) {
@@ -124,6 +132,20 @@ DimsOrder DimsOrder::fromPermutation(const SmallVector<Dim, MAX_DIMS_64>& perm)
     return DimsOrder::fromCode(code);
 }
 
+DimsOrder DimsOrder::fromLayout(ie::Layout const& layout) {
+    switch (layout) {
+    case ie::Layout::C     : return DimsOrder::C;
+    case ie::Layout::NC    : return DimsOrder::NC;
+    case ie::Layout::CHW   : return DimsOrder::CHW;
+    case ie::Layout::NCHW  : return DimsOrder::NCHW;
+    case ie::Layout::NHWC  : return DimsOrder::NHWC;
+    case ie::Layout::NCDHW : return DimsOrder::NCDHW;
+    case ie::Layout::NDHWC : return DimsOrder::NDHWC;
+    default:
+        VPU_THROW_EXCEPTION << "Unsupported layout " << layout;
+    }
+}
+
 int DimsOrder::numDims() const {
     int out = 0;
 
@@ -262,7 +284,8 @@ void printTo(std::ostream& os, DimsOrder order) {
         {1, 'W'},
         {2, 'H'},
         {3, 'C'},
-        {4, 'N'}
+        {4, 'N'},
+        {5, 'D'}
     });
 
     auto code = order.code();
@@ -289,6 +312,17 @@ void printTo(std::ostream& os, DimsOrder order) {
 }
 
 //
+// Dim
+//
+
+int dimToIeInd(vpu::Dim const& dim, int numDims) {
+    IE_ASSERT(1 <= numDims && numDims <= 8);
+    auto dimsOrder =  DimsOrder::fromNumDims(numDims);
+    int dimInd = dimsOrder.dimInd(dim);
+    return (numDims - 1) - dimInd;
+}
+
+//
 // DataDesc
 //
 
@@ -297,22 +331,7 @@ DataDesc::DataDesc(const ie::TensorDesc& ieDesc) {
     // Parse precision
     //
 
-    switch (ieDesc.getPrecision()) {
-    case ie::Precision::U8:
-        _type = DataType::U8;
-        break;
-    case ie::Precision::I8:
-         _type = DataType::I8;
-        break;
-    case ie::Precision::FP16:
-        _type = DataType::FP16;
-        break;
-    case ie::Precision::FP32:
-        _type = DataType::FP32;
-        break;
-    default:
-        VPU_THROW_EXCEPTION << "Unsupported precision " << ieDesc.getPrecision().name();
-    }
+    _type = fromIEPrecision(ieDesc.getPrecision());
 
     //
     // Parse dimensions and layout
@@ -342,10 +361,14 @@ int DataDesc::elemSize() const {
     switch (_type) {
     case DataType::U8:
         return sizeof(uint8_t);
+    case DataType::I8:
+        return sizeof(int8_t);
     case DataType::FP16:
         return sizeof(fp16_t);
     case DataType::FP32:
         return sizeof(float);
+    case DataType::S32:
+        return sizeof(int32_t);
     default:
         VPU_THROW_EXCEPTION << "Unknown data type " << _type;
     }
@@ -372,6 +395,62 @@ void DataDesc::reorder(DimsOrder dimsOrder) {
     _dimsOrder = dimsOrder;
 }
 
+ie::TensorDesc DataDesc::toTensorDesc() const {
+    ie::TensorDesc desc;
+
+    switch (this->type()) {
+        case DataType::FP16:
+            desc.setPrecision(ie::Precision::FP16);
+            break;
+        case DataType::FP32:
+            desc.setPrecision(ie::Precision::FP32);
+            break;
+        case DataType::I8:
+            desc.setPrecision(ie::Precision::I8);
+            break;
+        case DataType::U8:
+            desc.setPrecision(ie::Precision::U8);
+            break;
+        case DataType::S32:
+            desc.setPrecision(ie::Precision::I32);
+            break;
+        default:
+            desc.setPrecision(ie::Precision::UNSPECIFIED);
+    }
+
+    ie::SizeVector dims{};
+
+    DataDesc descCopy = *this;
+    descCopy.reorder(DimsOrder::fromNumDims(this->numDims()));
+    auto perm = descCopy.dimsOrder().toPermutation();
+    std::reverse(perm.begin(), perm.end());
+    for (auto &p : perm) {
+        dims.push_back(descCopy.dim(p));
+    }
+
+    desc.setDims(dims);
+
+    if (DimsOrder::C == this->dimsOrder()) {
+        desc.setLayout(ie::Layout::C);
+    } else if (DimsOrder::NC == this->dimsOrder()) {
+        desc.setLayout(ie::Layout::NC);
+    } else if (DimsOrder::CHW == this->dimsOrder()) {
+        desc.setLayout(ie::Layout::CHW);
+    } else if (DimsOrder::NCHW == this->dimsOrder()) {
+        desc.setLayout(ie::Layout::NCHW);
+    } else if (DimsOrder::NHWC == this->dimsOrder()) {
+        desc.setLayout(ie::Layout::NHWC);
+    } else if (DimsOrder::NCDHW == this->dimsOrder()) {
+        desc.setLayout(ie::Layout::NCDHW);
+    } else if (DimsOrder::NDHWC == this->dimsOrder()) {
+        desc.setLayout(ie::Layout::NDHWC);
+    } else {
+        desc.setLayout(ie::Layout::BLOCKED);
+    }
+
+    return desc;
+}
+
 void printTo(std::ostream& os, const DataDesc& desc) {
     os << "[" << std::endl;
 
@@ -409,6 +488,35 @@ StridesRequirement StridesRequirement::compact() {
     return reqs;
 }
 
+StridesRequirement StridesRequirement::fixed(const std::vector<int>& strides, const DataDesc& desc) {
+    StridesRequirement reqs;
+
+    const auto dims = desc.dims();
+    const auto dimsOrder = desc.dimsOrder();
+    const auto dimOrderVec = dimsOrder.toPermutation();
+    auto setStride = [&] (Dim d, int val) {
+        IE_ASSERT(dimsOrder.hasDim(d));
+
+        auto perm = dimsOrder.toPermutation();
+        auto idx = dimsOrder.dimInd(d);
+
+        auto minStrideVal = idx == 0 ? desc.elemSize() : reqs._fixedStrides[perm[idx - 1]] * dims[perm[idx - 1]];
+        IE_ASSERT(val >= minStrideVal);
+
+        reqs._fixedStrides.set(d, val);
+    };
+
+    for (const auto& dim : dimOrderVec) {
+        const auto idx = dimToIeInd(dim, dims.size());
+        setStride(dim, strides[idx]);
+    }
+
+    for (int i = 0; i < MAX_DIMS_64; ++i) {
+        reqs.add(i, DimStride::Fixed);
+    }
+    return reqs;
+}
+
 void printTo(std::ostream& os, const StridesRequirement& reqs) {
     os << "[" << std::endl;
 
@@ -457,12 +565,16 @@ DimValues calcStrides(const DataDesc& desc, const StridesRequirement& reqs) {
     auto perm = desc.dimsOrder().toPermutation();
     IE_ASSERT(!perm.empty());
 
-    strides.set(perm[0], desc.elemSize());
-    strides.set(perm[0], applyStrideRequirement(strides[perm[0]], 0, reqs));
+    strides = reqs.fixedStrides();
+
+    if (strides.empty()) {
+        strides.set(perm[0], desc.elemSize());
+        strides.set(perm[0], applyStrideRequirement(strides[perm[0]], 0, reqs));
 
-    for (int i = 1; i < perm.size(); i++) {
-        strides.set(perm[i], strides[perm[i - 1]] * desc.dim(perm[i - 1]));
-        strides.set(perm[i], applyStrideRequirement(strides[perm[i]], i, reqs));
+        for (std::size_t i = 1; i < perm.size(); i++) {
+            strides.set(perm[i], strides[perm[i - 1]] * desc.dim(perm[i - 1]));
+            strides.set(perm[i], applyStrideRequirement(strides[perm[i]], i, reqs));
+        }
     }
 
     return strides;
@@ -472,7 +584,8 @@ bool checkStride(
         const DimValues& strides,
         const DataDesc& desc,
         int ind,
-        DimStride req) {
+        const StridesRequirement& reqs) {
+    const auto req = reqs.get(ind);
     if (req == DimStride::Any) {
         return true;
     }
@@ -496,6 +609,10 @@ bool checkStride(
         if (strideVal % STRIDE_ALIGNMENT != 0) {
             return false;
         }
+    } else if (req == DimStride::Fixed) {
+        if (strideVal != reqs.getFixedStride(perm[ind])) {
+            return false;
+        }
     } else {
         VPU_THROW_EXCEPTION << "Unsupported stride requirement : " << req;
     }
@@ -511,7 +628,7 @@ bool checkStrides(
     IE_ASSERT(!perm.empty());
 
     for (int i = 0; i < perm.size(); i++) {
-        if (!checkStride(strides, desc, i, reqs.get(i))) {
+        if (!checkStride(strides, desc, i, reqs)) {
             return false;
         }
     }
@@ -524,4 +641,15 @@ int calcTotalByteSize(const DataDesc& desc, const DimValues& strides) {
     return strides[perm.back()] * desc.dim(perm.back());
 }
 
+DataType fromIEPrecision(const InferenceEngine::Precision& precision) {
+    switch (precision) {
+        case InferenceEngine::Precision::U8:   return DataType::U8;
+        case InferenceEngine::Precision::I8:   return DataType::I8;
+        case InferenceEngine::Precision::I32:  return DataType::S32;
+        case InferenceEngine::Precision::FP16: return DataType::FP16;
+        case InferenceEngine::Precision::FP32: return DataType::FP32;
+        default: VPU_THROW_EXCEPTION << precision << " isn't supported";
+    }
+}
+
 }  // namespace vpu