1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <vpu/frontend/frontend.hpp>
8 #include <unordered_set>
15 VPU_DECLARE_ENUM(ResampleType,
22 class ResampleStage final : public StageNode {
24 StagePtr cloneImpl() const override {
25 return std::make_shared<ResampleStage>(*this);
28 void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
29 auto input = inputEdge(0)->input();
31 orderInfo.setOutput(outputEdge(0), input->desc().dimsOrder());
34 void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
37 void finalizeDataLayoutImpl() override {
40 void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
41 batchInfo.setInput(inputEdge(0), BatchSupport::Split);
42 batchInfo.setOutput(outputEdge(0), BatchSupport::Split);
45 void initialCheckImpl() const override {
46 assertInputsOutputsTypes(this, {{DataType::FP16}}, {{DataType::FP16}});
49 void serializeParamsImpl(BlobSerializer& serializer) const override {
50 auto antialias = attrs().get<bool>("antialias");
51 auto factor = attrs().get<float>("factor");
52 auto sampleType = attrs().get<ResampleType>("type");
54 serializer.append(static_cast<int32_t>(antialias));
55 serializer.append(static_cast<float>(factor));
56 serializer.append(static_cast<uint32_t>(sampleType));
59 void serializeDataImpl(BlobSerializer& serializer) const override {
60 auto input = inputEdge(0)->input();
61 auto output = outputEdge(0)->output();
63 input->serializeOldBuffer(handle_from_this(), serializer);
64 output->serializeOldBuffer(handle_from_this(), serializer);
70 void FrontEnd::parseResample(
71 const Model::Ptr& model,
72 const ie::CNNLayerPtr& layer,
73 const DataVector& inputs,
74 const DataVector& outputs) {
75 IE_ASSERT(inputs.size() == 1);
76 IE_ASSERT(outputs.size() == 1);
78 ie::details::CaselessEq<std::string> cmp;
80 auto stage = model->addNewStage<ResampleStage>(
87 stage->attrs().set<bool>("antialias", layer->GetParamAsInt("antialias", 0));
88 stage->attrs().set<float>("factor", layer->GetParamAsInt("factor", -1.0f));
90 auto method = layer->GetParamAsString("type", "caffe.ResampleParameter.NEAREST");
91 if (cmp(method, "caffe.ResampleParameter.NEAREST")) {
92 stage->attrs().set<ResampleType>("type", ResampleType::Nearest);
94 stage->attrs().set<ResampleType>("type", ResampleType::Bilinear);