void propagateScaleFactorsImpl(
const SmallVector<float>& inputScales,
- ScalePropagationStep step) override {
- IE_ASSERT(_inputEdges.size() == 1);
- IE_ASSERT(_outputEdges.size() == 1);
-
+ ScalePropagationStep step,
+ StageDataInfo<float>& scaleInfo) override {
if (step == ScalePropagationStep::Propagate) {
- _scaleInfo.setOutput(_outputEdges[0], inputScales[0]);
+ scaleInfo.setOutput(outputEdge(0), inputScales[0]);
} else {
// Reshape can only propagate scaling.
- _scaleInfo.setInput(_inputEdges[0], 1.0f);
- _scaleInfo.setOutput(_outputEdges[0], 1.0f);
+ scaleInfo.setInput(inputEdge(0), 1.0f);
+ scaleInfo.setOutput(outputEdge(0), 1.0f);
}
}
- void propagateDataOrderImpl() const override {
- IE_ASSERT(_inputEdges.size() == 1);
- IE_ASSERT(_outputEdges.size() == 1);
-
- auto input = _inputEdges[0]->input();
- auto output = _outputEdges[0]->output();
+ void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
+ auto input = inputEdge(0)->input();
+ auto output = outputEdge(0)->output();
// Only default order is supported
- _orderInfo.setInput(_inputEdges[0], DimsOrder::fromNumDims(input->desc().numDims()));
- _orderInfo.setOutput(_outputEdges[0], DimsOrder::fromNumDims(output->desc().numDims()));
+ orderInfo.setInput(inputEdge(0), DimsOrder::fromNumDims(input->desc().numDims()));
+ orderInfo.setOutput(outputEdge(0), DimsOrder::fromNumDims(output->desc().numDims()));
}
- void getDataStridesRequirementsImpl() const override {
- IE_ASSERT(_inputEdges.size() == 1);
- IE_ASSERT(_outputEdges.size() == 1);
-
- _stridesInfo.setInput(_inputEdges[0], StridesRequirement::compact());
- _stridesInfo.setOutput(_outputEdges[0], StridesRequirement::compact());
+ void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
+ stridesInfo.setInput(inputEdge(0), StridesRequirement::compact());
+ stridesInfo.setOutput(outputEdge(0), StridesRequirement::compact());
}
void finalizeDataLayoutImpl() override {
}
- void getBatchSupportInfoImpl() const override {
+ void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
}
- void finalCheckImpl() const override {
- IE_ASSERT(_inputEdges.size() == 1);
- IE_ASSERT(_outputEdges.size() == 1);
-
- auto input = _inputEdges[0]->input();
- auto output = _outputEdges[0]->output();
-
- IE_ASSERT(input->desc().totalDimSize() == output->desc().totalDimSize());
+ void initialCheckImpl() const override {
+ const auto& firstInputPrecision = input(0)->desc().type();
+ assertInputsOutputsTypes(this, {{firstInputPrecision}}, {{firstInputPrecision}});
+ IE_ASSERT(input(0)->desc().totalDimSize() == output(0)->desc().totalDimSize());
}
void serializeParamsImpl(BlobSerializer&) const override {