Make SetNumThreads apply to the eigen threads. (This creates a dependency on eigen!)
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 9 Mar 2018 18:39:50 +0000 (10:39 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Mar 2018 18:44:15 +0000 (10:44 -0800)
PiperOrigin-RevId: 188504172

tensorflow/contrib/lite/context.h
tensorflow/contrib/lite/interpreter.cc
tensorflow/contrib/lite/kernels/BUILD
tensorflow/contrib/lite/kernels/conv.cc
tensorflow/contrib/lite/kernels/eigen_support.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/eigen_support.h [new file with mode: 0644]
tensorflow/contrib/lite/kernels/gemm_support.cc
tensorflow/contrib/lite/kernels/gemm_support.h

index 23946dd..6491d8c 100644 (file)
@@ -324,9 +324,14 @@ typedef struct TfLiteContext {
       struct TfLiteContext*, TfLiteRegistration registration,
       const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
 
+  // Number of threads that are recommended to subsystems like gemmlowp and
+  // eigen.
+  int recommended_num_threads;
+
   // TODO(ahentz): we should create a more general mechanism for this sort of
   // library-global objects.
   void* gemm_context;
+  void* eigen_context;
 } TfLiteContext;
 
 typedef struct _TfLiteRegistration {
index 4710488..819782a 100644 (file)
@@ -92,7 +92,9 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
   context_.AddTensors = AddTensors;
   context_.tensors = nullptr;
   context_.tensors_size = 0;
+  context_.eigen_context = nullptr;
   context_.gemm_context = nullptr;
+  context_.recommended_num_threads = 0;
 
   // Invalid to call these these except from TfLiteDelegate
   SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
@@ -691,10 +693,7 @@ void Interpreter::UseNNAPI(bool enable) {
 }
 
 void Interpreter::SetNumThreads(int num_threads) {
-  // TODO(ahentz): this forces us to link against gemmlowp even when the ops
-  // don't use it. We should implement some dynamic mechanism for this sort of
-  // library-specific initialization.
-  tflite::gemm_support::SetMaxNumThreads(&context_, num_threads);
+  context_.recommended_num_threads = num_threads;
 }
 
 TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
index c6c11b0..9c63269 100644 (file)
@@ -41,6 +41,22 @@ cc_library(
 )
 
 cc_library(
+    name = "eigen_support",
+    srcs = [
+        "eigen_support.cc",
+    ],
+    hdrs = [
+        "eigen_support.h",
+    ],
+    copts = tflite_copts(),
+    deps = [
+        ":op_macros",
+        "//tensorflow/contrib/lite:context",
+        "//third_party/eigen3",
+    ],
+)
+
+cc_library(
     name = "gemm_support",
     srcs = [
         "gemm_support.cc",
@@ -175,6 +191,7 @@ cc_library(
     }),
     deps = [
         ":activation_functor",
+        ":eigen_support",
         ":kernel_util",
         ":op_macros",
         "//tensorflow/contrib/lite:builtin_op_data",
index 6821a22..b91ba1a 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 
 #include "tensorflow/contrib/lite/builtin_op_data.h"
 #include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/eigen_support.h"
 #include "tensorflow/contrib/lite/kernels/gemm_support.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h"
@@ -87,18 +88,15 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   // to carry information from Prepare() to Eval().
   auto* data = new OpData;
   gemm_support::IncrementUsageCounter(context);
+  eigen_support::IncrementUsageCounter(context);
 
-  // TODO(ahentz): This is the gemmlowp context, which really only applies to
-  // quantized kernels. However, Interpreter::SetNumThreads() should also be
-  // setting the number of kernel on Eigen, so this works OK as a proxy for
-  // now.
-  int num_threads = gemm_support::GetFromContext(context)->max_num_threads();
-  data->run_multithreaded_kernel = num_threads != 1;
+  data->run_multithreaded_kernel = context->recommended_num_threads != 1;
 
   return data;
 }
 
 void Free(TfLiteContext* context, void* buffer) {
+  eigen_support::DecrementUsageCounter(context);
   gemm_support::DecrementUsageCounter(context);
   delete reinterpret_cast<OpData*>(buffer);
 }
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
new file mode 100644 (file)
index 0000000..1435a45
--- /dev/null
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/eigen_support.h"
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace eigen_support {
+
+struct RefCountedEigenContext {
+  int num_references = 0;
+};
+
+void IncrementUsageCounter(TfLiteContext* context) {
+  auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+  if (ptr == nullptr) {
+    Eigen::setNbThreads(context->recommended_num_threads);
+
+    ptr = new RefCountedEigenContext;
+    ptr->num_references = 0;
+    context->eigen_context = ptr;
+  }
+  ptr->num_references++;
+}
+
+void DecrementUsageCounter(TfLiteContext* context) {
+  auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+  if (ptr == nullptr) {
+    TF_LITE_FATAL(
+        "Call to DecrementUsageCounter() not preceded by "
+        "IncrementUsageCounter()");
+  }
+  if (--ptr->num_references == 0) {
+    delete ptr;
+  }
+}
+
+}  // namespace eigen_support
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
new file mode 100644 (file)
index 0000000..d47e691
--- /dev/null
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace eigen_support {
+
+// Let the framework know that the op will be using Eigen. If necessary a set of
+// temporary Eigen objects might be created and placed in 'context'.
+void IncrementUsageCounter(TfLiteContext* context);
+
+// Let the framework know that the op stopped using Eigen. If there are no more
+// usages all temporary Eigen objects will be deleted.
+void DecrementUsageCounter(TfLiteContext* context);
+
+}  // namespace eigen_support
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
index eb2b0aa..df8a9c8 100644 (file)
@@ -29,6 +29,7 @@ void IncrementUsageCounter(TfLiteContext* context) {
   if (ptr == nullptr) {
     ptr = new RefCountedGemmContext;
     ptr->gemm_context_ = new gemmlowp::GemmContext();
+    ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads);
     ptr->num_references_ = 0;
     context->gemm_context = ptr;
   }
@@ -58,11 +59,5 @@ gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
   return ptr->gemm_context_;
 }
 
-void SetMaxNumThreads(TfLiteContext* context, int num_threads) {
-  IncrementUsageCounter(context);
-  GetFromContext(context)->set_max_num_threads(num_threads);
-  DecrementUsageCounter(context);
-}
-
 }  // namespace gemm_support
 }  // namespace tflite
index 466781c..37af772 100644 (file)
@@ -45,9 +45,6 @@ void IncrementUsageCounter(TfLiteContext* context);
 // 'context'. If there are no more usages the GemmContext will be deleted.
 void DecrementUsageCounter(TfLiteContext* context);
 
-// Set the maximum number threads available for gemmlowp operations.
-void SetMaxNumThreads(TfLiteContext* context, int num_threads);
-
 }  // namespace gemm_support
 }  // namespace tflite