Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / blob_factory.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <utility>
8 #include <memory>
9 #include "inference_engine.hpp"
10
11 template <InferenceEngine::Precision::ePrecision precision>
12 class BlobFactory {
13  public:
14     using BlobType = typename InferenceEngine::PrecisionTrait<precision>::value_type;
15     static InferenceEngine::Blob::Ptr make(InferenceEngine::Layout l, InferenceEngine::SizeVector dims) {
16         return InferenceEngine::make_shared_blob<BlobType>(precision, l, dims);
17     }
18     static InferenceEngine::Blob::Ptr make(InferenceEngine::Layout l, InferenceEngine::SizeVector dims, void* ptr) {
19         return InferenceEngine::make_shared_blob<BlobType>(precision, l, dims, reinterpret_cast<BlobType*>(ptr));
20     }
21     static InferenceEngine::Blob::Ptr make(const InferenceEngine::TensorDesc& desc) {
22         return InferenceEngine::make_shared_blob<BlobType>(desc);
23     }
24     static InferenceEngine::Blob::Ptr make(const InferenceEngine::TensorDesc& desc, void* ptr) {
25         return InferenceEngine::make_shared_blob<BlobType>(desc, reinterpret_cast<BlobType*>(ptr));
26     }
27     static InferenceEngine::Blob::Ptr make(const InferenceEngine::TensorDesc& desc, const std::shared_ptr<InferenceEngine::IAllocator>& alloc) {
28         return InferenceEngine::make_shared_blob<BlobType>(desc, alloc);
29     }
30 };
31
32 template <InferenceEngine::Precision::ePrecision precision, class ... Args> InferenceEngine::Blob::Ptr make_shared_blob2(Args && ... args) {
33     return BlobFactory<precision>::make(std::forward<Args>(args) ...);
34 }
35
36 // TODO: customize make_shared_blob2
37 #define USE_FACTORY(precision)\
38     case InferenceEngine::Precision::precision  : return make_shared_blob2<InferenceEngine::Precision::precision>(std::forward<Args>(args) ...);
39
40 INFERENCE_ENGINE_API_CPP(InferenceEngine::Blob::Ptr) make_blob_with_precision(const InferenceEngine::TensorDesc& desc);
41 INFERENCE_ENGINE_API_CPP(InferenceEngine::Blob::Ptr) make_blob_with_precision(const InferenceEngine::TensorDesc& desc, void* ptr);
42 INFERENCE_ENGINE_API_CPP(InferenceEngine::Blob::Ptr) make_blob_with_precision(const InferenceEngine::TensorDesc& desc,
43                                                                               const std::shared_ptr<InferenceEngine::IAllocator>& alloc);
44 INFERENCE_ENGINE_API_CPP(InferenceEngine::Blob::Ptr) make_plain_blob(InferenceEngine::Precision prec, const InferenceEngine::SizeVector dims);
45
46 INFERENCE_ENGINE_API_CPP(InferenceEngine::Layout) plain_layout(InferenceEngine::SizeVector dims);
47
48 template <class ... Args>
49 InferenceEngine::Blob::Ptr make_blob_with_precision(InferenceEngine::Precision precision, Args &&... args) {
50     switch (precision) {
51         USE_FACTORY(FP32);
52         USE_FACTORY(FP16);
53         USE_FACTORY(Q78);
54         USE_FACTORY(I16);
55         USE_FACTORY(U8);
56         USE_FACTORY(I8);
57         USE_FACTORY(U16);
58         USE_FACTORY(I32);
59         USE_FACTORY(BIN);
60         default:
61             THROW_IE_EXCEPTION << "cannot locate blob for precision: " << precision;
62     }
63 }
64
65 #undef USE_FACTORY
66
67 /**
68  * Create blob with custom precision
69  * @tparam T - type off underlined elements
70  * @tparam Args
71  * @param args
72  * @return
73  */
74 template <class T, class ... Args>
75 InferenceEngine::Blob::Ptr make_custom_blob(Args &&... args) {
76     return InferenceEngine::make_shared_blob<T>(InferenceEngine::Precision::fromType<T>(), std::forward<Args>(args) ...);
77 }
78
79 /**
80  * @brief Creates a TBlob<> object from a Data node
81  * @param Data reference to a smart pointer of the Data node
82  * @return Smart pointer to TBlob<> with the relevant C type to the precision of the data node
83  */
84 INFERENCE_ENGINE_API_CPP(InferenceEngine::Blob::Ptr) CreateBlobFromData(const InferenceEngine::DataPtr &data);