void propagateScaleFactorsImpl(
const SmallVector<float>& inputScales,
- ScalePropagationStep step) override {
- IE_ASSERT(_inputEdges.size() == 1);
- IE_ASSERT(_outputEdges.size() == 1);
-
+ ScalePropagationStep step,
+ StageDataInfo<float>& scaleInfo) override {
auto power = attrs().get<float>("power");
auto& scale = attrs().get<float>("scale");
auto& bias = attrs().get<float>("bias");
if (power != 1.0f) {
- _scaleInfo.setInput(_inputEdges[0], 1.0f);
- _scaleInfo.setOutput(_outputEdges[0], 1.0f);
+ scaleInfo.setInput(inputEdge(0), 1.0f);
+ scaleInfo.setOutput(outputEdge(0), 1.0f);
} else {
auto inputScale = inputScales[0];
- _scaleInfo.setOutput(_outputEdges[0], inputScale);
+ scaleInfo.setOutput(outputEdge(0), inputScale);
if (step == ScalePropagationStep::ScaleInput) {
scale *= inputScale;