49c34211a7e466b0473238612fe0109cdddcd958
[platform/core/ml/nnfw.git] / compute / cker / include / cker / eigen / EigenSupport.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_EIGEN_EIGEN_SUPPORT_H__
19 #define __NNFW_CKER_EIGEN_EIGEN_SUPPORT_H__
20
21 //#if defined(CKER_OPTIMIZED_EIGEN)
22
23 #include <Eigen/Core>
24 #include <thread>
25 #include "cker/eigen/eigen_spatial_convolutions.h"
26
27 #ifdef EIGEN_USE_THREADS
28 #include <unsupported/Eigen/CXX11/ThreadPool>
29 #endif
30
31 namespace nnfw
32 {
33 namespace cker
34 {
35 namespace eigen_support
36 {
37
38 // Shorthands for the types we need when interfacing with the EigenTensor
39 // library.
40 typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor, Eigen::DenseIndex>,
41                          Eigen::Aligned>
42     EigenMatrix;
43 typedef Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor, Eigen::DenseIndex>,
44                          Eigen::Aligned>
45     ConstEigenMatrix;
46
47 typedef Eigen::TensorMap<Eigen::Tensor<float, 4, Eigen::RowMajor, Eigen::DenseIndex>,
48                          Eigen::Aligned>
49     EigenTensor;
50 typedef Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor, Eigen::DenseIndex>,
51                          Eigen::Aligned>
52     ConstEigenTensor;
53
54 // Utility functions we need for the EigenTensor API.
55 template <typename Device, typename T> struct MatMulConvFunctor
56 {
57   // Computes on device "d": out = in0 * in1, where * is matrix
58   // multiplication.
59   void operator()(const Device &d, EigenMatrix out, ConstEigenMatrix in0, ConstEigenMatrix in1,
60                   const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> &dim_pair)
61   {
62     out.device(d) = in0.contract(in1, dim_pair);
63   }
64 };
65
66 // We have a single global threadpool for all convolution operations. This means
67 // that inferences started from different threads may block each other, but
68 // since the underlying resource of CPU cores should be consumed by the
69 // operations anyway, it shouldn't affect overall performance.
70 class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface
71 {
72 public:
73   // Takes ownership of 'pool'
74   explicit EigenThreadPoolWrapper(Eigen::ThreadPool *pool) : pool_(pool) {}
75   ~EigenThreadPoolWrapper() override {}
76
77   void Schedule(std::function<void()> fn) override { pool_->Schedule(std::move(fn)); }
78   int NumThreads() const override { return pool_->NumThreads(); }
79   int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
80
81 private:
82   std::unique_ptr<Eigen::ThreadPool> pool_;
83 };
84
85 struct EigenContext
86 {
87   constexpr static int default_num_threadpool_threads = 4;
88   std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper;
89   std::unique_ptr<Eigen::ThreadPoolDevice> device;
90
91   EigenContext()
92   {
93     int num_threads = std::thread::hardware_concurrency();
94     if (num_threads == 0)
95     {
96       num_threads = default_num_threadpool_threads;
97     }
98     device.reset(); // destroy before we invalidate the thread pool
99     thread_pool_wrapper.reset(new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads)));
100     device.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper.get(), num_threads));
101   }
102
103   static inline EigenContext &GetEigenContext()
104   {
105     static EigenContext instance;
106     return instance;
107   }
108 };
109
110 inline const Eigen::ThreadPoolDevice *GetThreadPoolDevice()
111 {
112   auto &ctx = EigenContext::GetEigenContext();
113   return ctx.device.get();
114 }
115
116 } // namespace eigen_support
117 } // namespace cker
118 } // namespace nnfw
119
120 //#endif // defined(CKER_OPTIMIZED_EIGEN)
121
122 #endif // __NNFW_CKER_EIGEN_EIGEN_SUPPORT_H__