DimsOrder DimsOrder::NCHW = DimsOrder::fromCode(0x4321);
DimsOrder DimsOrder::NHWC = DimsOrder::fromCode(0x4213);
DimsOrder DimsOrder::NHCW = DimsOrder::fromCode(0x4231);
+DimsOrder DimsOrder::NCDHW = DimsOrder::fromCode(0x43521);
+DimsOrder DimsOrder::NDHWC = DimsOrder::fromCode(0x45213);
namespace {
return DimsOrder::C;
} else if (numDims == 2) {
return DimsOrder::NC;
+ } else if (numDims == 3) {
+ return DimsOrder::CHW;
+ } else if (numDims == 4) {
+ return DimsOrder::NCHW;
+ } else if (numDims == 5) {
+ return DimsOrder::NCDHW;
} else {
return DimsOrder::fromCode(maskOrder(FULL_ORDER_DEFAULT, numDims));
}
}
-DimsOrder DimsOrder::fromPermutation(const SmallVector<Dim, MAX_DIMS_64>& perm) {
+DimsOrder DimsOrder::fromPermutation(const DimVector& perm) {
StorageOrder64 code = 0;
for (int sh = 0, i = 0; i < perm.size(); i++, sh += 4) {
return DimsOrder::fromCode(code);
}
+DimsOrder DimsOrder::fromLayout(ie::Layout const& layout) {
+ switch (layout) {
+ case ie::Layout::C : return DimsOrder::C;
+ case ie::Layout::NC : return DimsOrder::NC;
+ case ie::Layout::CHW : return DimsOrder::CHW;
+ case ie::Layout::NCHW : return DimsOrder::NCHW;
+ case ie::Layout::NHWC : return DimsOrder::NHWC;
+ case ie::Layout::NCDHW : return DimsOrder::NCDHW;
+ case ie::Layout::NDHWC : return DimsOrder::NDHWC;
+ default:
+ VPU_THROW_EXCEPTION << "Unsupported layout " << layout;
+ }
+}
+
int DimsOrder::numDims() const {
int out = 0;
{1, 'W'},
{2, 'H'},
{3, 'C'},
- {4, 'N'}
+ {4, 'N'},
+ {5, 'D'}
});
auto code = order.code();
}
//
+// Dim
+//
+
+int dimToIeInd(vpu::Dim const& dim, int numDims) {
+ IE_ASSERT(1 <= numDims && numDims <= 8);
+ auto dimsOrder = DimsOrder::fromNumDims(numDims);
+ int dimInd = dimsOrder.dimInd(dim);
+ return (numDims - 1) - dimInd;
+}
+
+//
// DataDesc
//
// Parse precision
//
- switch (ieDesc.getPrecision()) {
- case ie::Precision::U8:
- _type = DataType::U8;
- break;
- case ie::Precision::I8:
- _type = DataType::I8;
- break;
- case ie::Precision::FP16:
- _type = DataType::FP16;
- break;
- case ie::Precision::FP32:
- _type = DataType::FP32;
- break;
- default:
- VPU_THROW_EXCEPTION << "Unsupported precision " << ieDesc.getPrecision().name();
- }
+ _type = fromIEPrecision(ieDesc.getPrecision());
//
// Parse dimensions and layout
switch (_type) {
case DataType::U8:
return sizeof(uint8_t);
+ case DataType::I8:
+ return sizeof(int8_t);
case DataType::FP16:
return sizeof(fp16_t);
case DataType::FP32:
return sizeof(float);
+ case DataType::S32:
+ return sizeof(int32_t);
default:
VPU_THROW_EXCEPTION << "Unknown data type " << _type;
}
_dimsOrder = dimsOrder;
}
+ie::TensorDesc DataDesc::toTensorDesc() const {
+ ie::TensorDesc desc;
+
+ switch (this->type()) {
+ case DataType::FP16:
+ desc.setPrecision(ie::Precision::FP16);
+ break;
+ case DataType::FP32:
+ desc.setPrecision(ie::Precision::FP32);
+ break;
+ case DataType::I8:
+ desc.setPrecision(ie::Precision::I8);
+ break;
+ case DataType::U8:
+ desc.setPrecision(ie::Precision::U8);
+ break;
+ case DataType::S32:
+ desc.setPrecision(ie::Precision::I32);
+ break;
+ default:
+ desc.setPrecision(ie::Precision::UNSPECIFIED);
+ }
+
+ ie::SizeVector dims{};
+
+ DataDesc descCopy = *this;
+ descCopy.reorder(DimsOrder::fromNumDims(this->numDims()));
+ auto perm = descCopy.dimsOrder().toPermutation();
+ std::reverse(perm.begin(), perm.end());
+ for (auto &p : perm) {
+ dims.push_back(descCopy.dim(p));
+ }
+
+ desc.setDims(dims);
+
+ if (DimsOrder::C == this->dimsOrder()) {
+ desc.setLayout(ie::Layout::C);
+ } else if (DimsOrder::NC == this->dimsOrder()) {
+ desc.setLayout(ie::Layout::NC);
+ } else if (DimsOrder::CHW == this->dimsOrder()) {
+ desc.setLayout(ie::Layout::CHW);
+ } else if (DimsOrder::NCHW == this->dimsOrder()) {
+ desc.setLayout(ie::Layout::NCHW);
+ } else if (DimsOrder::NHWC == this->dimsOrder()) {
+ desc.setLayout(ie::Layout::NHWC);
+ } else if (DimsOrder::NCDHW == this->dimsOrder()) {
+ desc.setLayout(ie::Layout::NCDHW);
+ } else if (DimsOrder::NDHWC == this->dimsOrder()) {
+ desc.setLayout(ie::Layout::NDHWC);
+ } else {
+ desc.setLayout(ie::Layout::BLOCKED);
+ }
+
+ return desc;
+}
+
void printTo(std::ostream& os, const DataDesc& desc) {
os << "[" << std::endl;
return reqs;
}
+StridesRequirement StridesRequirement::fixed(const std::vector<int>& strides, const DataDesc& desc) {
+ StridesRequirement reqs;
+
+ const auto dims = desc.dims();
+ const auto dimsOrder = desc.dimsOrder();
+ const auto dimOrderVec = dimsOrder.toPermutation();
+ auto setStride = [&] (Dim d, int val) {
+ IE_ASSERT(dimsOrder.hasDim(d));
+
+ auto perm = dimsOrder.toPermutation();
+ auto idx = dimsOrder.dimInd(d);
+
+ auto minStrideVal = idx == 0 ? desc.elemSize() : reqs._fixedStrides[perm[idx - 1]] * dims[perm[idx - 1]];
+ IE_ASSERT(val >= minStrideVal);
+
+ reqs._fixedStrides.set(d, val);
+ };
+
+ for (const auto& dim : dimOrderVec) {
+ const auto idx = dimToIeInd(dim, dims.size());
+ setStride(dim, strides[idx]);
+ }
+
+ for (int i = 0; i < MAX_DIMS_64; ++i) {
+ reqs.add(i, DimStride::Fixed);
+ }
+ return reqs;
+}
+
void printTo(std::ostream& os, const StridesRequirement& reqs) {
os << "[" << std::endl;
auto perm = desc.dimsOrder().toPermutation();
IE_ASSERT(!perm.empty());
- strides.set(perm[0], desc.elemSize());
- strides.set(perm[0], applyStrideRequirement(strides[perm[0]], 0, reqs));
+ strides = reqs.fixedStrides();
+
+ if (strides.empty()) {
+ strides.set(perm[0], desc.elemSize());
+ strides.set(perm[0], applyStrideRequirement(strides[perm[0]], 0, reqs));
- for (int i = 1; i < perm.size(); i++) {
- strides.set(perm[i], strides[perm[i - 1]] * desc.dim(perm[i - 1]));
- strides.set(perm[i], applyStrideRequirement(strides[perm[i]], i, reqs));
+ for (std::size_t i = 1; i < perm.size(); i++) {
+ strides.set(perm[i], strides[perm[i - 1]] * desc.dim(perm[i - 1]));
+ strides.set(perm[i], applyStrideRequirement(strides[perm[i]], i, reqs));
+ }
}
return strides;
const DimValues& strides,
const DataDesc& desc,
int ind,
- DimStride req) {
+ const StridesRequirement& reqs) {
+ const auto req = reqs.get(ind);
if (req == DimStride::Any) {
return true;
}
if (strideVal % STRIDE_ALIGNMENT != 0) {
return false;
}
+ } else if (req == DimStride::Fixed) {
+ if (strideVal != reqs.getFixedStride(perm[ind])) {
+ return false;
+ }
} else {
VPU_THROW_EXCEPTION << "Unsupported stride requirement : " << req;
}
IE_ASSERT(!perm.empty());
for (int i = 0; i < perm.size(); i++) {
- if (!checkStride(strides, desc, i, reqs.get(i))) {
+ if (!checkStride(strides, desc, i, reqs)) {
return false;
}
}
return strides[perm.back()] * desc.dim(perm.back());
}
+DataType fromIEPrecision(const InferenceEngine::Precision& precision) {
+ switch (precision) {
+ case InferenceEngine::Precision::U8: return DataType::U8;
+ case InferenceEngine::Precision::I8: return DataType::I8;
+ case InferenceEngine::Precision::I32: return DataType::S32;
+ case InferenceEngine::Precision::FP16: return DataType::FP16;
+ case InferenceEngine::Precision::FP32: return DataType::FP32;
+ default: VPU_THROW_EXCEPTION << precision << " isn't supported";
+ }
+}
+
} // namespace vpu