Concat,
Split,
Reshape,
- Broadcast,
+ Expand,
Shrink,
+ StridedSlice,
Empty = -1,
ReduceAnd = 93,
ReverseSequence = 94,
Gather = 100,
+ Exp = 101,
+ Floor = 102,
+ TopK = 104,
+ ReduceMin = 105,
)
//
);
//
-// StageNode
+// ScalePropagationStep
//
VPU_DECLARE_ENUM(ScalePropagationStep,
Propagate
);
+//
+// TopKMode
+//
+
+// Firmware implementations must be aligned with these values
+VPU_DECLARE_ENUM(TopKMode,
+ Max = 0,
+ Min = 1)
+
+//
+// TopKSort
+//
+
+// Firmware implementations must be aligned with these values
+VPU_DECLARE_ENUM(TopKSort,
+ None = 0,
+ Value = 1,
+ Index = 2)
+
+//
+// StageDataInfo
+//
+
template <typename Val>
class StageDataInfo final {
public:
SmallVector<Optional<Val>> _outputVals;
};
+//
+// StageNode
+//
+
class StageNode :
public EnableHandleFromThis<StageNode>,
public EnableCustomAttributes {
// Edges wrappers
//
+ VPU_MODEL_ATTRIBUTE(Handle<Model>, model, nullptr)
+
public:
struct StageNameCmp final {
inline bool operator()(const Stage& left, const Stage& right) const {
// Bindings with IE
//
- inline std::string origLayerName() const { return _origLayer != nullptr ? _origLayer->name : std::string(); }
+ inline std::string origLayerName() const {
+ return _origLayer != nullptr ? _origLayer->name : std::string();
+ }
//
// SHAVEs allocation
ScalePropagationStep step);
// Data order propagation from inputs to outputs.
- const StageDataInfo<DimsOrder>& propagateDataOrder() const;
+ const StageDataInfo<DimsOrder>& propagateDataOrder();
// Get Data strides requirements
- const StageDataInfo<StridesRequirement>& getDataStridesRequirements() const;
+ const StageDataInfo<StridesRequirement>& getDataStridesRequirements();
// Finalize internal parameter to final Data layout.
void finalizeDataLayout();
// Information about batch support.
- const StageDataInfo<BatchSupport>& getBatchSupportInfo() const;
+ const StageDataInfo<BatchSupport>& getBatchSupportInfo();
// Resources requirements.
StageSHAVEsRequirements getSHAVEsRequirements() const;
- // Final check.
+ void initialCheck() const;
void finalCheck() const;
// Name postfix for modified stage
- inline void appendNamePostfix(const std::string& postfix) { _name = _name + postfix; }
+ inline void appendNamePostfix(const std::string& postfix) {
+ _name = _name + postfix;
+ }
//
// Backend utilities
virtual void propagateScaleFactorsImpl(
const SmallVector<float>& inputScales,
- ScalePropagationStep step);
+ ScalePropagationStep step,
+ StageDataInfo<float>& scaleInfo);
- virtual void propagateDataOrderImpl() const = 0;
+ virtual void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) = 0;
- virtual void getDataStridesRequirementsImpl() const = 0;
+ virtual void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) = 0;
virtual void finalizeDataLayoutImpl() = 0;
- virtual void getBatchSupportInfoImpl() const = 0;
+ virtual void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) = 0;
virtual StageSHAVEsRequirements getSHAVEsRequirementsImpl() const;
- virtual void finalCheckImpl() const = 0;
+ virtual void initialCheckImpl() const {}
+ virtual void finalCheckImpl() const {}
virtual void serializeParamsImpl(BlobSerializer& serializer) const = 0;
_posInModel(this) {
}
-protected:
- Handle<Model> _model;
-
- mutable StageDataInfo<float> _scaleInfo;
- mutable StageDataInfo<DimsOrder> _orderInfo;
- mutable StageDataInfo<StridesRequirement> _stridesInfo;
- mutable StageDataInfo<BatchSupport> _batchInfo;
-
private:
+ StageDataInfo<float> _scaleInfo;
+ StageDataInfo<DimsOrder> _orderInfo;
+ StageDataInfo<StridesRequirement> _stridesInfo;
+ StageDataInfo<BatchSupport> _batchInfo;
+
StagePtrList::iterator _ptrPosInModel;
IntrusivePtrListNode<StageNode> _posInModel;
void printTo(std::ostream& os, const Stage& stage);
+void assertAllInputsOutputsTypes(const StageNode* stage,
+ const DataType& expectedInputsType,
+ const DataType& expectedOutputsType);
+
+void assertInputsOutputsTypes(const StageNode* stage,
+ const std::vector<EnumSet<DataType>>& expectedInputsTypes,
+ const std::vector<EnumSet<DataType>>& expectedOutputsTypes);
+
} // namespace vpu