return std::make_shared<GRNStage>(*this);
}
- void propagateDataOrderImpl() const override {
- IE_ASSERT(_inputEdges.size() == 1);
- IE_ASSERT(_outputEdges.size() == 1);
+ void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
+ auto input = inputEdge(0)->input();
- auto input = _inputEdges[0]->input();
-
- _orderInfo.setOutput(_outputEdges[0], input->desc().dimsOrder());
+ orderInfo.setOutput(outputEdge(0), input->desc().dimsOrder());
}
- void getDataStridesRequirementsImpl() const override {
+ void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
}
void finalizeDataLayoutImpl() override {
}
- void getBatchSupportInfoImpl() const override {
- IE_ASSERT(_inputEdges.size() == 1);
- IE_ASSERT(_outputEdges.size() == 1);
-
- _batchInfo.setInput(_inputEdges[0], BatchSupport::Split);
- _batchInfo.setOutput(_outputEdges[0], BatchSupport::Split);
+ void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
+ batchInfo.setInput(inputEdge(0), BatchSupport::Split);
+ batchInfo.setOutput(outputEdge(0), BatchSupport::Split);
}
- void finalCheckImpl() const override {
+ void initialCheckImpl() const override {
+ assertInputsOutputsTypes(this, {{DataType::FP16}}, {{DataType::FP16}});
}
void serializeParamsImpl(BlobSerializer& serializer) const override {
}
void serializeDataImpl(BlobSerializer& serializer) const override {
- IE_ASSERT(_inputEdges.size() == 1);
- IE_ASSERT(_outputEdges.size() == 1);
- IE_ASSERT(_tempBufferEdges.empty());
-
- auto input = _inputEdges[0]->input();
- auto output = _outputEdges[0]->output();
+ auto input = inputEdge(0)->input();
+ auto output = outputEdge(0)->output();
input->serializeNewBuffer(serializer);
output->serializeNewBuffer(serializer);