void propagateScaleFactorsImpl(
const SmallVector<float>&,
- ScalePropagationStep) override {
+ ScalePropagationStep,
+ StageDataInfo<float>&) override {
VPU_THROW_EXCEPTION << "Must never be called";
}
- void propagateDataOrderImpl() const override {
- IE_ASSERT(_inputEdges.size() == 3);
- IE_ASSERT(_outputEdges.size() == 1);
-
- auto input = _inputEdges[0]->input();
- auto weights = _inputEdges[1]->input();
- auto output = _outputEdges[0]->output();
+ void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
+ auto input = inputEdge(0)->input();
+ auto weights = inputEdge(1)->input();
+ auto output = outputEdge(0)->output();
auto finalOrder = input->desc().dimsOrder();
if (finalOrder.dimInd(Dim::C) == 1) {
if (_type == StageType::Conv ||
_type == StageType::Im2ColConvolution) {
if (finalOrder != input->desc().dimsOrder()) {
- _orderInfo.setInput(_inputEdges[0], finalOrder);
+ orderInfo.setInput(inputEdge(0), finalOrder);
}
- _orderInfo.setOutput(_outputEdges[0], finalOrder);
+ orderInfo.setOutput(outputEdge(0), finalOrder);
} else if (_type == StageType::DepthConv) {
if (finalOrder != input->desc().dimsOrder()) {
- _orderInfo.setInput(_inputEdges[0], finalOrder);
+ orderInfo.setInput(inputEdge(0), finalOrder);
}
- _orderInfo.setOutput(_outputEdges[0], finalOrder);
+ orderInfo.setOutput(outputEdge(0), finalOrder);
} else {
- _orderInfo.setInput(_inputEdges[0], finalOrder.createMovedDim(Dim::C, 0));
- _orderInfo.setOutput(_outputEdges[0], finalOrder.createMovedDim(Dim::C, 0));
+ orderInfo.setInput(inputEdge(0), finalOrder.createMovedDim(Dim::C, 0));
+ orderInfo.setOutput(outputEdge(0), finalOrder.createMovedDim(Dim::C, 0));
}
}
- void getDataStridesRequirementsImpl() const override {
- IE_ASSERT(_inputEdges.size() == 3);
- IE_ASSERT(_outputEdges.size() == 1);
-
+ void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
if (_type != StageType::DepthConv) {
- _stridesInfo.setInput(_inputEdges[0], StridesRequirement::compact());
- _stridesInfo.setOutput(_outputEdges[0], StridesRequirement::compact());
+ stridesInfo.setInput(inputEdge(0), StridesRequirement::compact());
+ stridesInfo.setOutput(outputEdge(0), StridesRequirement::compact());
}
}
void finalizeDataLayoutImpl() override {
- IE_ASSERT(_inputEdges.size() == 3);
- IE_ASSERT(_outputEdges.size() == 1);
-
- auto input = _inputEdges[0]->input();
- auto weights = _inputEdges[1]->input();
- auto output = _outputEdges[0]->output();
+ auto input = inputEdge(0)->input();
+ auto weights = inputEdge(1)->input();
+ auto output = outputEdge(0)->output();
auto kernelSizeX = attrs().get<int>("kernelSizeX");
auto kernelSizeY = attrs().get<int>("kernelSizeY");
IE_ASSERT(swWeights != nullptr);
- _model->replaceStageInput(_inputEdges[1], swWeights);
+ _model->replaceStageInput(inputEdge(1), swWeights);
}
- void getBatchSupportInfoImpl() const override {
- IE_ASSERT(_inputEdges.size() == 3);
- IE_ASSERT(_outputEdges.size() == 1);
-
- _batchInfo.setInput(_inputEdges[0], BatchSupport::Split);
- _batchInfo.setOutput(_outputEdges[0], BatchSupport::Split);
+ void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
+ batchInfo.setInput(inputEdge(0), BatchSupport::Split);
+ batchInfo.setOutput(outputEdge(0), BatchSupport::Split);
}
void finalCheckImpl() const override {
+ assertInputsOutputsTypes(this,
+ {{DataType::FP16}, {DataType::FP16}, {DataType::FP16}},
+ {{DataType::FP16}});
}
void serializeParamsImpl(BlobSerializer& serializer) const override {
}
void serializeDataImpl(BlobSerializer& serializer) const override {
- IE_ASSERT(_inputEdges.size() == 3);
- IE_ASSERT(_outputEdges.size() == 1);
-
- auto input = _inputEdges[0]->input();
- auto weights = _inputEdges[1]->input();
- auto biases = _inputEdges[2]->input();
- auto output = _outputEdges[0]->output();
+ auto input = inputEdge(0)->input();
+ auto weights = inputEdge(1)->input();
+ auto biases = inputEdge(2)->input();
+ auto output = outputEdge(0)->output();
input->serializeOldBuffer(handle_from_this(), serializer);
output->serializeOldBuffer(handle_from_this(), serializer);
weights->serializeOldBuffer(handle_from_this(), serializer);
- if (!_tempBufferEdges.empty()) {
- _tempBufferEdges[0]->tempBuffer()->serializeOldBuffer(handle_from_this(), serializer);
+ if (numTempBuffers() == 1) {
+ tempBuffer(0)->serializeOldBuffer(handle_from_this(), serializer);
}
// TODO: remove this