X-Git-Url: http://review.tizen.org/git/?a=blobdiff_plain;f=inference-engine%2Fsrc%2Fvpu%2Fgraph_transformer%2Fsrc%2Fmodel%2Fdata_desc.cpp;h=9001946c35a9e76b29c6f328afe5943d2414ee65;hb=0923303e0201c5b59386ab146d0e30b2ef79272d;hp=d97efe269d2c5b972e2fa11feade0701abb22872;hpb=ba6e22b1b5ee4cbefcc30e8d9493cddb0bb3dfdf;p=platform%2Fupstream%2Fdldt.git diff --git a/inference-engine/src/vpu/graph_transformer/src/model/data_desc.cpp b/inference-engine/src/vpu/graph_transformer/src/model/data_desc.cpp index d97efe2..9001946 100644 --- a/inference-engine/src/vpu/graph_transformer/src/model/data_desc.cpp +++ b/inference-engine/src/vpu/graph_transformer/src/model/data_desc.cpp @@ -44,6 +44,8 @@ DimsOrder DimsOrder::HCW = DimsOrder::fromCode(0x231); DimsOrder DimsOrder::NCHW = DimsOrder::fromCode(0x4321); DimsOrder DimsOrder::NHWC = DimsOrder::fromCode(0x4213); DimsOrder DimsOrder::NHCW = DimsOrder::fromCode(0x4231); +DimsOrder DimsOrder::NCDHW = DimsOrder::fromCode(0x43521); +DimsOrder DimsOrder::NDHWC = DimsOrder::fromCode(0x45213); namespace { @@ -109,12 +111,18 @@ DimsOrder DimsOrder::fromNumDims(int numDims) { return DimsOrder::C; } else if (numDims == 2) { return DimsOrder::NC; + } else if (numDims == 3) { + return DimsOrder::CHW; + } else if (numDims == 4) { + return DimsOrder::NCHW; + } else if (numDims == 5) { + return DimsOrder::NCDHW; } else { return DimsOrder::fromCode(maskOrder(FULL_ORDER_DEFAULT, numDims)); } } -DimsOrder DimsOrder::fromPermutation(const SmallVector& perm) { +DimsOrder DimsOrder::fromPermutation(const DimVector& perm) { StorageOrder64 code = 0; for (int sh = 0, i = 0; i < perm.size(); i++, sh += 4) { @@ -124,6 +132,20 @@ DimsOrder DimsOrder::fromPermutation(const SmallVector& perm) return DimsOrder::fromCode(code); } +DimsOrder DimsOrder::fromLayout(ie::Layout const& layout) { + switch (layout) { + case ie::Layout::C : return DimsOrder::C; + case ie::Layout::NC : return DimsOrder::NC; + case ie::Layout::CHW : return DimsOrder::CHW; + case ie::Layout::NCHW : return DimsOrder::NCHW; + case ie::Layout::NHWC : return DimsOrder::NHWC; + case ie::Layout::NCDHW : return DimsOrder::NCDHW; + case ie::Layout::NDHWC : return DimsOrder::NDHWC; + default: + VPU_THROW_EXCEPTION << "Unsupported layout " << layout; + } +} + int DimsOrder::numDims() const { int out = 0; @@ -262,7 +284,8 @@ void printTo(std::ostream& os, DimsOrder order) { {1, 'W'}, {2, 'H'}, {3, 'C'}, - {4, 'N'} + {4, 'N'}, + {5, 'D'} }); auto code = order.code(); @@ -289,6 +312,17 @@ void printTo(std::ostream& os, DimsOrder order) { } // +// Dim +// + +int dimToIeInd(vpu::Dim const& dim, int numDims) { + IE_ASSERT(1 <= numDims && numDims <= 8); + auto dimsOrder = DimsOrder::fromNumDims(numDims); + int dimInd = dimsOrder.dimInd(dim); + return (numDims - 1) - dimInd; +} + +// // DataDesc // @@ -297,22 +331,7 @@ DataDesc::DataDesc(const ie::TensorDesc& ieDesc) { // Parse precision // - switch (ieDesc.getPrecision()) { - case ie::Precision::U8: - _type = DataType::U8; - break; - case ie::Precision::I8: - _type = DataType::I8; - break; - case ie::Precision::FP16: - _type = DataType::FP16; - break; - case ie::Precision::FP32: - _type = DataType::FP32; - break; - default: - VPU_THROW_EXCEPTION << "Unsupported precision " << ieDesc.getPrecision().name(); - } + _type = fromIEPrecision(ieDesc.getPrecision()); // // Parse dimensions and layout @@ -342,10 +361,14 @@ int DataDesc::elemSize() const { switch (_type) { case DataType::U8: return sizeof(uint8_t); + case DataType::I8: + return sizeof(int8_t); case DataType::FP16: return sizeof(fp16_t); case DataType::FP32: return sizeof(float); + case DataType::S32: + return sizeof(int32_t); default: VPU_THROW_EXCEPTION << "Unknown data type " << _type; } @@ -372,6 +395,62 @@ void DataDesc::reorder(DimsOrder dimsOrder) { _dimsOrder = dimsOrder; } +ie::TensorDesc DataDesc::toTensorDesc() const { + ie::TensorDesc desc; + + switch (this->type()) { + case DataType::FP16: + desc.setPrecision(ie::Precision::FP16); + break; + case DataType::FP32: + desc.setPrecision(ie::Precision::FP32); + break; + case DataType::I8: + desc.setPrecision(ie::Precision::I8); + break; + case DataType::U8: + desc.setPrecision(ie::Precision::U8); + break; + case DataType::S32: + desc.setPrecision(ie::Precision::I32); + break; + default: + desc.setPrecision(ie::Precision::UNSPECIFIED); + } + + ie::SizeVector dims{}; + + DataDesc descCopy = *this; + descCopy.reorder(DimsOrder::fromNumDims(this->numDims())); + auto perm = descCopy.dimsOrder().toPermutation(); + std::reverse(perm.begin(), perm.end()); + for (auto &p : perm) { + dims.push_back(descCopy.dim(p)); + } + + desc.setDims(dims); + + if (DimsOrder::C == this->dimsOrder()) { + desc.setLayout(ie::Layout::C); + } else if (DimsOrder::NC == this->dimsOrder()) { + desc.setLayout(ie::Layout::NC); + } else if (DimsOrder::CHW == this->dimsOrder()) { + desc.setLayout(ie::Layout::CHW); + } else if (DimsOrder::NCHW == this->dimsOrder()) { + desc.setLayout(ie::Layout::NCHW); + } else if (DimsOrder::NHWC == this->dimsOrder()) { + desc.setLayout(ie::Layout::NHWC); + } else if (DimsOrder::NCDHW == this->dimsOrder()) { + desc.setLayout(ie::Layout::NCDHW); + } else if (DimsOrder::NDHWC == this->dimsOrder()) { + desc.setLayout(ie::Layout::NDHWC); + } else { + desc.setLayout(ie::Layout::BLOCKED); + } + + return desc; +} + void printTo(std::ostream& os, const DataDesc& desc) { os << "[" << std::endl; @@ -409,6 +488,35 @@ StridesRequirement StridesRequirement::compact() { return reqs; } +StridesRequirement StridesRequirement::fixed(const std::vector& strides, const DataDesc& desc) { + StridesRequirement reqs; + + const auto dims = desc.dims(); + const auto dimsOrder = desc.dimsOrder(); + const auto dimOrderVec = dimsOrder.toPermutation(); + auto setStride = [&] (Dim d, int val) { + IE_ASSERT(dimsOrder.hasDim(d)); + + auto perm = dimsOrder.toPermutation(); + auto idx = dimsOrder.dimInd(d); + + auto minStrideVal = idx == 0 ? desc.elemSize() : reqs._fixedStrides[perm[idx - 1]] * dims[perm[idx - 1]]; + IE_ASSERT(val >= minStrideVal); + + reqs._fixedStrides.set(d, val); + }; + + for (const auto& dim : dimOrderVec) { + const auto idx = dimToIeInd(dim, dims.size()); + setStride(dim, strides[idx]); + } + + for (int i = 0; i < MAX_DIMS_64; ++i) { + reqs.add(i, DimStride::Fixed); + } + return reqs; +} + void printTo(std::ostream& os, const StridesRequirement& reqs) { os << "[" << std::endl; @@ -457,12 +565,16 @@ DimValues calcStrides(const DataDesc& desc, const StridesRequirement& reqs) { auto perm = desc.dimsOrder().toPermutation(); IE_ASSERT(!perm.empty()); - strides.set(perm[0], desc.elemSize()); - strides.set(perm[0], applyStrideRequirement(strides[perm[0]], 0, reqs)); + strides = reqs.fixedStrides(); + + if (strides.empty()) { + strides.set(perm[0], desc.elemSize()); + strides.set(perm[0], applyStrideRequirement(strides[perm[0]], 0, reqs)); - for (int i = 1; i < perm.size(); i++) { - strides.set(perm[i], strides[perm[i - 1]] * desc.dim(perm[i - 1])); - strides.set(perm[i], applyStrideRequirement(strides[perm[i]], i, reqs)); + for (std::size_t i = 1; i < perm.size(); i++) { + strides.set(perm[i], strides[perm[i - 1]] * desc.dim(perm[i - 1])); + strides.set(perm[i], applyStrideRequirement(strides[perm[i]], i, reqs)); + } } return strides; @@ -472,7 +584,8 @@ bool checkStride( const DimValues& strides, const DataDesc& desc, int ind, - DimStride req) { + const StridesRequirement& reqs) { + const auto req = reqs.get(ind); if (req == DimStride::Any) { return true; } @@ -496,6 +609,10 @@ bool checkStride( if (strideVal % STRIDE_ALIGNMENT != 0) { return false; } + } else if (req == DimStride::Fixed) { + if (strideVal != reqs.getFixedStride(perm[ind])) { + return false; + } } else { VPU_THROW_EXCEPTION << "Unsupported stride requirement : " << req; } @@ -511,7 +628,7 @@ bool checkStrides( IE_ASSERT(!perm.empty()); for (int i = 0; i < perm.size(); i++) { - if (!checkStride(strides, desc, i, reqs.get(i))) { + if (!checkStride(strides, desc, i, reqs)) { return false; } } @@ -524,4 +641,15 @@ int calcTotalByteSize(const DataDesc& desc, const DimValues& strides) { return strides[perm.back()] * desc.dim(perm.back()); } +DataType fromIEPrecision(const InferenceEngine::Precision& precision) { + switch (precision) { + case InferenceEngine::Precision::U8: return DataType::U8; + case InferenceEngine::Precision::I8: return DataType::I8; + case InferenceEngine::Precision::I32: return DataType::S32; + case InferenceEngine::Precision::FP16: return DataType::FP16; + case InferenceEngine::Precision::FP32: return DataType::FP32; + default: VPU_THROW_EXCEPTION << precision << " isn't supported"; + } +} + } // namespace vpu