#include <ruy/context.h>
#include "cker/Types.h"
-namespace
-{
-const int kDefaultNumThreadpoolThreads = 4;
-}
-
namespace nnfw
{
namespace cker
namespace ruy_support
{
-struct RuyContext
-{
-public:
- RuyContext() : ruy_context_(new ruy::Context)
- {
- SetMaxNumThreads(onert::util::getConfigInt(onert::util::config::RUY_THREADS));
-#ifdef USE_RUY_GEMV
- ruy_context_->cache_policy = ruy::kCacheLHSOnNarrowMul;
-#endif
- };
-
- ruy::Context *ruy_context() const { return ruy_context_.get(); }
-
- static inline RuyContext &GetRuyContext()
- {
- static thread_local RuyContext instance;
- return instance;
- }
-
- void SetMaxNumThreads(int max_num_threads)
- {
- const int target_num_threads =
- max_num_threads > -1 ? max_num_threads : kDefaultNumThreadpoolThreads;
- ruy_context_->max_num_threads = target_num_threads;
- }
-
-private:
- const std::unique_ptr<ruy::Context> ruy_context_;
-};
-
-inline ruy::Context *GetRuyContext()
-{
- auto &ctx = RuyContext::GetRuyContext();
- return ctx.ruy_context();
-}
-
template <typename Scalar, typename DataPointer>
void MakeRuyMatrix(const MatrixParams<Scalar> ¶ms, DataPointer data_ptr,
ruy::Matrix<Scalar> *dst)