Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / include / vpu / model / stage.hpp
index 44b48b7..683ec4d 100644 (file)
@@ -48,8 +48,9 @@ VPU_DECLARE_ENUM(StageType,
     Concat,
     Split,
     Reshape,
-    Broadcast,
+    Expand,
     Shrink,
+    StridedSlice,
 
     Empty = -1,
 
@@ -136,6 +137,10 @@ VPU_DECLARE_ENUM(StageType,
     ReduceAnd = 93,
     ReverseSequence = 94,
     Gather = 100,
+    Exp = 101,
+    Floor = 102,
+    TopK = 104,
+    ReduceMin = 105,
 )
 
 //
@@ -172,7 +177,7 @@ VPU_DECLARE_ENUM(StageSHAVEsRequirements,
 );
 
 //
-// StageNode
+// ScalePropagationStep
 //
 
 VPU_DECLARE_ENUM(ScalePropagationStep,
@@ -181,6 +186,29 @@ VPU_DECLARE_ENUM(ScalePropagationStep,
     Propagate
 );
 
+//
+// TopKMode
+//
+
+// Firmware implementations must be aligned with these values
+VPU_DECLARE_ENUM(TopKMode,
+    Max = 0,
+    Min = 1)
+
+//
+// TopKSort
+//
+
+// Firmware implementations must be aligned with these values
+VPU_DECLARE_ENUM(TopKSort,
+    None = 0,
+    Value = 1,
+    Index = 2)
+
+//
+// StageDataInfo
+//
+
 template <typename Val>
 class StageDataInfo final {
 public:
@@ -258,6 +286,10 @@ private:
     SmallVector<Optional<Val>> _outputVals;
 };
 
+//
+// StageNode
+//
+
 class StageNode :
         public EnableHandleFromThis<StageNode>,
         public EnableCustomAttributes {
@@ -297,6 +329,8 @@ class StageNode :
     // Edges wrappers
     //
 
+    VPU_MODEL_ATTRIBUTE(Handle<Model>, model, nullptr)
+
 public:
     struct StageNameCmp final {
         inline bool operator()(const Stage& left, const Stage& right) const {
@@ -445,7 +479,9 @@ public:
     // Bindings with IE
     //
 
-    inline std::string origLayerName() const { return _origLayer != nullptr ? _origLayer->name : std::string(); }
+    inline std::string origLayerName() const {
+        return _origLayer != nullptr ? _origLayer->name : std::string();
+    }
 
     //
     // SHAVEs allocation
@@ -463,25 +499,27 @@ public:
             ScalePropagationStep step);
 
     // Data order propagation from inputs to outputs.
-    const StageDataInfo<DimsOrder>& propagateDataOrder() const;
+    const StageDataInfo<DimsOrder>& propagateDataOrder();
 
     // Get Data strides requirements
-    const StageDataInfo<StridesRequirement>& getDataStridesRequirements() const;
+    const StageDataInfo<StridesRequirement>& getDataStridesRequirements();
 
     // Finalize internal parameter to final Data layout.
     void finalizeDataLayout();
 
     // Information about batch support.
-    const StageDataInfo<BatchSupport>& getBatchSupportInfo() const;
+    const StageDataInfo<BatchSupport>& getBatchSupportInfo();
 
     // Resources requirements.
     StageSHAVEsRequirements getSHAVEsRequirements() const;
 
-    // Final check.
+    void initialCheck() const;
     void finalCheck() const;
 
     // Name postfix for modified stage
-    inline void appendNamePostfix(const std::string& postfix) { _name = _name + postfix; }
+    inline void appendNamePostfix(const std::string& postfix) {
+        _name = _name + postfix;
+    }
 
     //
     // Backend utilities
@@ -498,19 +536,21 @@ protected:
 
     virtual void propagateScaleFactorsImpl(
             const SmallVector<float>& inputScales,
-            ScalePropagationStep step);
+            ScalePropagationStep step,
+            StageDataInfo<float>& scaleInfo);
 
-    virtual void propagateDataOrderImpl() const = 0;
+    virtual void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) = 0;
 
-    virtual void getDataStridesRequirementsImpl() const = 0;
+    virtual void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) = 0;
 
     virtual void finalizeDataLayoutImpl() = 0;
 
-    virtual void getBatchSupportInfoImpl() const = 0;
+    virtual void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) = 0;
 
     virtual StageSHAVEsRequirements getSHAVEsRequirementsImpl() const;
 
-    virtual void finalCheckImpl() const = 0;
+    virtual void initialCheckImpl() const {}
+    virtual void finalCheckImpl() const {}
 
     virtual void serializeParamsImpl(BlobSerializer& serializer) const = 0;
 
@@ -535,15 +575,12 @@ protected:
             _posInModel(this) {
     }
 
-protected:
-    Handle<Model> _model;
-
-    mutable StageDataInfo<float> _scaleInfo;
-    mutable StageDataInfo<DimsOrder> _orderInfo;
-    mutable StageDataInfo<StridesRequirement> _stridesInfo;
-    mutable StageDataInfo<BatchSupport> _batchInfo;
-
 private:
+    StageDataInfo<float> _scaleInfo;
+    StageDataInfo<DimsOrder> _orderInfo;
+    StageDataInfo<StridesRequirement> _stridesInfo;
+    StageDataInfo<BatchSupport> _batchInfo;
+
     StagePtrList::iterator _ptrPosInModel;
     IntrusivePtrListNode<StageNode> _posInModel;
 
@@ -552,4 +589,12 @@ private:
 
 void printTo(std::ostream& os, const Stage& stage);
 
+void assertAllInputsOutputsTypes(const StageNode* stage,
+                                 const DataType& expectedInputsType,
+                                 const DataType& expectedOutputsType);
+
+void assertInputsOutputsTypes(const StageNode* stage,
+                              const std::vector<EnumSet<DataType>>& expectedInputsTypes,
+                              const std::vector<EnumSet<DataType>>& expectedOutputsTypes);
+
 }  // namespace vpu