1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
7 * @brief a header file for MKL-DNN Generic Primitive API
8 * @file mkldnn_generic_primitive.hpp
12 #include "mkldnn_extension_types.hpp"
13 #include "details/ie_irelease.hpp"
16 namespace InferenceEngine {
17 namespace MKLDNNPlugin {
20 * @deprecated use new extensibility API
21 * @brief The MKLDNNGenericFormats stores weights, biases, inputs and outputs of the primitive
23 class MKLDNNGenericFormats {
26 * @brief A default constructor
27 * @param ins - vector of inputs
28 * @param outs - vector of outputs
29 * @param weights - weights, format_undef by default
30 * @param biases - biases, format_undef by default
32 MKLDNNGenericFormats(const std::vector<MemoryFormat> &ins, const std::vector<MemoryFormat> &outs,
33 const MemoryFormat weights = MemoryFormat::format_undef,
34 const MemoryFormat biases = MemoryFormat::format_undef) : inputs(ins), outputs(outs) {
35 this->weights = weights;
36 this->biases = biases;
40 * @brief Get input formats
41 * @return vector of input formats
43 const std::vector<MemoryFormat>& GetInputs() const noexcept {
48 * @brief Get output formats
49 * @return vector of output formats
51 const std::vector<MemoryFormat>& GetOutputs() const noexcept {
56 * @brief Get weights format
57 * @return weights format
59 const MemoryFormat& GetWeights() const noexcept {
64 * @brief Get biases format
65 * @return biases format
67 const MemoryFormat& GetBiases() const noexcept {
72 std::vector<MemoryFormat> inputs;
73 std::vector<MemoryFormat> outputs;
79 * @deprecated use new extensibility API
80 * @brief The IMKLDNNGenericPrimitive is the main Generic Primitive interface
82 class IMKLDNNGenericPrimitive : public InferenceEngine::details::IRelease {
84 void Release() noexcept override {
89 * @brief Sets inputs nd outputs
90 * @param inputs - vector of input primitives
91 * @param outputs - vector of output primitives
93 void SetMemory(const std::vector<MKLDNNPrimitiveMemory>& inputs,
94 const std::vector<MKLDNNPrimitiveMemory>& outputs) noexcept {
95 this->inputs = inputs;
96 this->outputs = outputs;
100 * @brief Gets supported formats
101 * @return vector of supported formats
103 virtual std::vector<MKLDNNGenericFormats> GetSupportedFormats() noexcept = 0;
106 * @brief Entry point of actual execution of primitive.
107 * Error reporting mechanism missed, static check should be done in constructor
109 virtual void Execute() noexcept = 0;
113 * @brief Vector of input primitives
115 std::vector<MKLDNNPrimitiveMemory> inputs;
117 * @brief Vector of output primitives
119 std::vector<MKLDNNPrimitiveMemory> outputs;
122 } // namespace MKLDNNPlugin
123 } // namespace InferenceEngine