struct TfLiteContext*, TfLiteRegistration registration,
const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
+ // Number of threads that are recommended to subsystems like gemmlowp and
+ // eigen.
+ int recommended_num_threads;
+
// TODO(ahentz): we should create a more general mechanism for this sort of
// library-global objects.
void* gemm_context;
+ void* eigen_context;
} TfLiteContext;
typedef struct _TfLiteRegistration {
context_.AddTensors = AddTensors;
context_.tensors = nullptr;
context_.tensors_size = 0;
+ context_.eigen_context = nullptr;
context_.gemm_context = nullptr;
+ context_.recommended_num_threads = 0;
// Invalid to call these these except from TfLiteDelegate
SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
}
void Interpreter::SetNumThreads(int num_threads) {
- // TODO(ahentz): this forces us to link against gemmlowp even when the ops
- // don't use it. We should implement some dynamic mechanism for this sort of
- // library-specific initialization.
- tflite::gemm_support::SetMaxNumThreads(&context_, num_threads);
+ context_.recommended_num_threads = num_threads;
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
)
cc_library(
+ name = "eigen_support",
+ srcs = [
+ "eigen_support.cc",
+ ],
+ hdrs = [
+ "eigen_support.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":op_macros",
+ "//tensorflow/contrib/lite:context",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "gemm_support",
srcs = [
"gemm_support.cc",
}),
deps = [
":activation_functor",
+ ":eigen_support",
":kernel_util",
":op_macros",
"//tensorflow/contrib/lite:builtin_op_data",
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h"
// to carry information from Prepare() to Eval().
auto* data = new OpData;
gemm_support::IncrementUsageCounter(context);
+ eigen_support::IncrementUsageCounter(context);
- // TODO(ahentz): This is the gemmlowp context, which really only applies to
- // quantized kernels. However, Interpreter::SetNumThreads() should also be
- // setting the number of kernel on Eigen, so this works OK as a proxy for
- // now.
- int num_threads = gemm_support::GetFromContext(context)->max_num_threads();
- data->run_multithreaded_kernel = num_threads != 1;
+ data->run_multithreaded_kernel = context->recommended_num_threads != 1;
return data;
}
void Free(TfLiteContext* context, void* buffer) {
+ eigen_support::DecrementUsageCounter(context);
gemm_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer);
}
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/eigen_support.h"
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace eigen_support {
+
+struct RefCountedEigenContext {
+ int num_references = 0;
+};
+
+void IncrementUsageCounter(TfLiteContext* context) {
+ auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ if (ptr == nullptr) {
+ Eigen::setNbThreads(context->recommended_num_threads);
+
+ ptr = new RefCountedEigenContext;
+ ptr->num_references = 0;
+ context->eigen_context = ptr;
+ }
+ ptr->num_references++;
+}
+
+void DecrementUsageCounter(TfLiteContext* context) {
+ auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to DecrementUsageCounter() not preceded by "
+ "IncrementUsageCounter()");
+ }
+ if (--ptr->num_references == 0) {
+ delete ptr;
+ }
+}
+
+} // namespace eigen_support
+} // namespace tflite
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace eigen_support {
+
+// Let the framework know that the op will be using Eigen. If necessary a set of
+// temporary Eigen objects might be created and placed in 'context'.
+void IncrementUsageCounter(TfLiteContext* context);
+
+// Let the framework know that the op stopped using Eigen. If there are no more
+// usages all temporary Eigen objects will be deleted.
+void DecrementUsageCounter(TfLiteContext* context);
+
+} // namespace eigen_support
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
if (ptr == nullptr) {
ptr = new RefCountedGemmContext;
ptr->gemm_context_ = new gemmlowp::GemmContext();
+ ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads);
ptr->num_references_ = 0;
context->gemm_context = ptr;
}
return ptr->gemm_context_;
}
-void SetMaxNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- GetFromContext(context)->set_max_num_threads(num_threads);
- DecrementUsageCounter(context);
-}
-
} // namespace gemm_support
} // namespace tflite
// 'context'. If there are no more usages the GemmContext will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
-// Set the maximum number threads available for gemmlowp operations.
-void SetMaxNumThreads(TfLiteContext* context, int num_threads);
-
} // namespace gemm_support
} // namespace tflite