1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
8 #include <unordered_set>
10 #include <ie_parallel.hpp>
12 #include <vpu/model/data.hpp>
13 #include <vpu/model/stage.hpp>
14 #include <vpu/utils/numeric.hpp>
18 namespace ie = InferenceEngine;
25 void kchw_to_hwck(const T* src, T* dst, const DataDesc& desc) {
26 IE_ASSERT(desc.numDims() >= 3);
28 auto W = desc.dim(Dim::W);
29 auto H = desc.dim(Dim::H);
30 auto C = desc.dim(Dim::C);
32 ie::parallel_for3d(W, H, C, [=](int w, int h, int c) {
33 auto inInd = w + W * h + W * H * c;
34 auto outInd = c + C * h + C * H * w;
35 dst[outInd] = src[inInd];
40 void kchw_to_khwc(const T* src, T* dst, const DataDesc& desc) {
41 IE_ASSERT(desc.numDims() >= 3);
43 auto W = desc.dim(Dim::W);
44 auto H = desc.dim(Dim::H);
45 auto C = desc.dim(Dim::C);
47 ie::parallel_for3d(W, H, C, [=](int w, int h, int c) {
48 auto inInd = w + W * h + W * H * c;
49 auto outInd = h + H * w + H * W * c;
50 dst[outInd] = src[inInd];
55 void kchw_to_hwkc(const T* src, T* dst, const DataDesc& desc) {
56 IE_ASSERT(desc.numDims() >= 3);
58 auto W = desc.dim(Dim::W);
59 auto H = desc.dim(Dim::H);
60 auto C = desc.dim(Dim::C);
62 ie::parallel_for3d(W, H, C, [=](int w, int h, int c) {
63 auto inInd = w + W * h + W * H * c;
64 auto outInd = h + H * c + C * H * w;
65 dst[outInd] = src[inInd];
70 void deconv_to_conv(const T* src, T* dst, const DataDesc& desc) {
71 IE_ASSERT(desc.numDims() >= 4);
73 auto KX = desc.dim(Dim::W);
74 auto KY = desc.dim(Dim::H);
75 auto IC = desc.dim(Dim::C);
76 auto OC = desc.dim(Dim::N);
78 ie::parallel_for4d(OC, IC, KY, KX, [=](int oc, int ic, int ky, int kx) {
79 auto inInd = kx + ky * KX + oc * KX * KY + ic * KX * KY * OC;
80 auto outInd = (KX - kx - 1) + (KY - ky - 1) * KX + ic * KX * KY + oc * KX * KY * IC;
81 dst[outInd] = src[inInd];
86 // DefaultSwWeightsContent
89 class DefaultSwWeightsContent final : public CalculatedDataContent {
91 explicit DefaultSwWeightsContent(const DataContent::Ptr& origContent);
94 void fillTempBuf(const SmallVector<DataContent::Ptr, 2>& baseContents, void* tempBuf) const override;
102 const Stage& curStage,
103 const std::unordered_set<StageType, EnumClassHash>& supportedTypes);