2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_EIGEN_EIGEN_SUPPORT_H__
19 #define __NNFW_CKER_EIGEN_EIGEN_SUPPORT_H__
21 //#if defined(CKER_OPTIMIZED_EIGEN)
25 #include "cker/eigen/eigen_spatial_convolutions.h"
27 #ifdef EIGEN_USE_THREADS
28 #include <unsupported/Eigen/CXX11/ThreadPool>
35 namespace eigen_support
38 // Shorthands for the types we need when interfacing with the EigenTensor
40 typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor, Eigen::DenseIndex>,
43 typedef Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor, Eigen::DenseIndex>,
47 typedef Eigen::TensorMap<Eigen::Tensor<float, 4, Eigen::RowMajor, Eigen::DenseIndex>,
50 typedef Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor, Eigen::DenseIndex>,
54 // Utility functions we need for the EigenTensor API.
55 template <typename Device, typename T> struct MatMulConvFunctor
57 // Computes on device "d": out = in0 * in1, where * is matrix
59 void operator()(const Device &d, EigenMatrix out, ConstEigenMatrix in0, ConstEigenMatrix in1,
60 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> &dim_pair)
62 out.device(d) = in0.contract(in1, dim_pair);
66 // We have a single global threadpool for all convolution operations. This means
67 // that inferences started from different threads may block each other, but
68 // since the underlying resource of CPU cores should be consumed by the
69 // operations anyway, it shouldn't affect overall performance.
70 class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface
73 // Takes ownership of 'pool'
74 explicit EigenThreadPoolWrapper(Eigen::ThreadPool *pool) : pool_(pool) {}
75 ~EigenThreadPoolWrapper() override {}
77 void Schedule(std::function<void()> fn) override { pool_->Schedule(std::move(fn)); }
78 int NumThreads() const override { return pool_->NumThreads(); }
79 int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
82 std::unique_ptr<Eigen::ThreadPool> pool_;
87 constexpr static int default_num_threadpool_threads = 4;
88 std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper;
89 std::unique_ptr<Eigen::ThreadPoolDevice> device;
93 int num_threads = std::thread::hardware_concurrency();
96 num_threads = default_num_threadpool_threads;
98 device.reset(); // destroy before we invalidate the thread pool
99 thread_pool_wrapper.reset(new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads)));
100 device.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper.get(), num_threads));
103 static inline EigenContext &GetEigenContext()
105 static EigenContext instance;
110 inline const Eigen::ThreadPoolDevice *GetThreadPoolDevice()
112 auto &ctx = EigenContext::GetEigenContext();
113 return ctx.device.get();
116 } // namespace eigen_support
120 //#endif // defined(CKER_OPTIMIZED_EIGEN)
122 #endif // __NNFW_CKER_EIGEN_EIGEN_SUPPORT_H__