Merging upstream
authorSami Kama <skama@nvidia.com>
Fri, 2 Mar 2018 05:39:22 +0000 (21:39 -0800)
committerSami Kama <skama@nvidia.com>
Fri, 2 Mar 2018 05:39:22 +0000 (21:39 -0800)
tensorflow/contrib/tensorrt/BUILD
tensorflow/contrib/tensorrt/convert/convert_graph.h
tensorflow/contrib/tensorrt/convert/convert_nodes.cc
tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc
tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.cc [deleted file]
tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h [deleted file]
tensorflow/contrib/tensorrt/resources/TRTResourceManager.cc [deleted file]
tensorflow/contrib/tensorrt/resources/TRTResourceManager.h [deleted file]
tensorflow/contrib/tensorrt/resources/TRTResources.h [deleted file]
tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h

index 1010a89..79ed24b 100644 (file)
@@ -3,46 +3,46 @@
 #   and provide TensorRT operators and converter package.
 #   APIs are meant to change over time.
 
-package(default_visibility=["//tensorflow:__subpackages__"])
+package(default_visibility = ["//tensorflow:__subpackages__"])
 
 licenses(["notice"])  # Apache 2.0
 
 exports_files(["LICENSE"])
 
 load(
-  "//tensorflow:tensorflow.bzl",
-  "tf_cc_test",
-  "tf_copts",
-  "tf_cuda_library",
-  "tf_custom_op_library",
-  "tf_custom_op_library_additional_deps",
-  "tf_gen_op_libs",
-  "tf_gen_op_wrapper_py",
+    "//tensorflow:tensorflow.bzl",
+    "tf_cc_test",
+    "tf_copts",
+    "tf_cuda_library",
+    "tf_custom_op_library",
+    "tf_custom_op_library_additional_deps",
+    "tf_gen_op_libs",
+    "tf_gen_op_wrapper_py",
 )
 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
 load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
 load(
-  "@local_config_tensorrt//:build_defs.bzl",
-  "if_tensorrt",
+    "@local_config_tensorrt//:build_defs.bzl",
+    "if_tensorrt",
 )
 
 tf_cuda_cc_test(
-  name="tensorrt_test_cc",
-  size="small",
-  srcs=["tensorrt_test.cc"],
-  tags=[
-    "manual",
-    "notap",
-  ],
-  deps=[
-         "//tensorflow/core:lib",
-         "//tensorflow/core:test",
-         "//tensorflow/core:test_main",
-       ] + if_tensorrt([
-    "@local_config_cuda//cuda:cuda_headers",
-    "@local_config_tensorrt//:nv_infer",
-  ]),
+    name = "tensorrt_test_cc",
+    size = "small",
+    srcs = ["tensorrt_test.cc"],
+    tags = [
+        "manual",
+        "notap",
+    ],
+    deps = [
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ] + if_tensorrt([
+        "@local_config_cuda//cuda:cuda_headers",
+        "@local_config_tensorrt//:nv_infer",
+    ]),
 )
 
 tf_custom_op_library(
@@ -61,15 +61,15 @@ tf_custom_op_library(
 )
 
 tf_cuda_library(
-  name="trt_shape_function",
-  srcs=["shape_fn/trt_shfn.cc"],
-  hdrs=["shape_fn/trt_shfn.h"],
-  visibility=["//visibility:public"],
-  deps=[
-         ":trt_logging",
-       ] + if_tensorrt([
-    "@local_config_tensorrt//:nv_infer",
-  ]) + tf_custom_op_library_additional_deps(),
+    name = "trt_shape_function",
+    srcs = ["shape_fn/trt_shfn.cc"],
+    hdrs = ["shape_fn/trt_shfn.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":trt_logging",
+    ] + if_tensorrt([
+        "@local_config_tensorrt//:nv_infer",
+    ]) + tf_custom_op_library_additional_deps(),
 )
 
 cc_library(
@@ -83,6 +83,7 @@ cc_library(
         "kernels/trt_engine_op.h",
     ],
     copts = tf_copts(),
+    visibility = ["//visibility:public"],
     deps = [
         ":trt_logging",
         ":trt_resources",
@@ -92,7 +93,6 @@ cc_library(
     ] + if_tensorrt([
         "@local_config_tensorrt//:nv_infer",
     ]) + tf_custom_op_library_additional_deps(),
-    visibility = ["//visibility:public"],
     # TODO(laigd)
     alwayslink = 1,  # buildozer: disable=alwayslink-with-hdrs
 )
@@ -108,15 +108,15 @@ tf_gen_op_libs(
 )
 
 tf_cuda_library(
-  name="trt_logging",
-  srcs=["log/trt_logger.cc"],
-  hdrs=["log/trt_logger.h"],
-  visibility=["//visibility:public"],
-  deps=[
-         "//tensorflow/core:lib_proto_parsing",
-       ] + if_tensorrt([
-    "@local_config_tensorrt//:nv_infer",
-  ]),
+    name = "trt_logging",
+    srcs = ["log/trt_logger.cc"],
+    hdrs = ["log/trt_logger.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:lib_proto_parsing",
+    ] + if_tensorrt([
+        "@local_config_tensorrt//:nv_infer",
+    ]),
 )
 
 tf_gen_op_wrapper_py(
@@ -130,80 +130,60 @@ tf_gen_op_wrapper_py(
 )
 
 tf_custom_op_py_library(
-  name="trt_engine_op_loader",
-  srcs=["python/ops/trt_engine_op.py"],
-  dso=[
+    name = "trt_engine_op_loader",
+    srcs = ["python/ops/trt_engine_op.py"],
+    dso = [
         ":python/ops/_trt_engine_op.so",
-      ] + if_tensorrt([
-    "@local_config_tensorrt//:nv_infer",
-  ]),
-  srcs_version="PY2AND3",
-  deps=[
-    "//tensorflow/python:framework_for_generated_wrappers",
-    "//tensorflow/python:resources",
-  ],
+    ] + if_tensorrt([
+        "@local_config_tensorrt//:nv_infer",
+    ]),
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:resources",
+    ],
 )
 
 py_library(
-  name="init_py",
-  srcs=[
-    "__init__.py",
-    "python/__init__.py",
-  ],
-  srcs_version="PY2AND3",
-  deps=[
-    ":trt_convert_py",
-    ":trt_ops_py",
-  ],
+    name = "init_py",
+    srcs = [
+        "__init__.py",
+        "python/__init__.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":trt_convert_py",
+        ":trt_ops_py",
+    ],
 )
 
 py_library(
-  name="trt_ops_py",
-  srcs_version="PY2AND3",
-  deps=[
-    ":trt_engine_op",
-    ":trt_engine_op_loader",
-  ],
+    name = "trt_ops_py",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":trt_engine_op",
+        ":trt_engine_op_loader",
+    ],
 )
 
 py_library(
-  name="trt_convert_py",
-  srcs=["python/trt_convert.py"],
-  srcs_version="PY2AND3",
-  deps=[
-    ":wrap_conversion",
-  ],
+    name = "trt_convert_py",
+    srcs = ["python/trt_convert.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":wrap_conversion",
+    ],
 )
 
 tf_py_wrap_cc(
-  name="wrap_conversion",
-  srcs=["trt_conversion.i"],
-  copts=tf_copts(),
-  deps=[
-    ":trt_conversion",
-    "//tensorflow/core:framework_lite",
-    "//util/python:python_headers",
-  ],
-)
-
-tf_cuda_library(
-  name="trt_resources",
-  srcs=[
-    "resources/TRTInt8Calibrator.cc",
-    "resources/TRTResourceManager.cc",
-  ],
-  hdrs=[
-    "resources/TRTInt8Calibrator.h",
-    "resources/TRTResourceManager.h",
-    "resources/TRTResources.h",
-  ],
-  deps=[
-    "@local_config_tensorrt//:nv_infer",
-    "//tensorflow/core:framework_headers_lib",
-    "//tensorflow/core:framework_lite",
-    "//tensorflow/core:lib_proto_parsing",
-
-  ],
+    name = "wrap_conversion",
+    srcs = ["trt_conversion.i"],
+    copts = tf_copts(),
+    deps = [
+        ":trt_conversion",
+        "//tensorflow/core:framework_lite",
+        "//util/python:python_headers",
+    ],
 )
 
 tf_cuda_library(
@@ -262,43 +242,43 @@ tf_cuda_library(
 
 # Library for the segmenting portion of TensorRT operation creation
 cc_library(
-  name="segment",
-  srcs=["segment/segment.cc"],
-  hdrs=[
-    "segment/segment.h",
-    "segment/union_find.h",
-  ],
-  linkstatic=1,
-  deps=[
-    "//tensorflow/core:graph",
-    "//tensorflow/core:lib_proto_parsing",
-    "//tensorflow/core:protos_all_cc",
-    "@protobuf_archive//:protobuf_headers",
-  ],
+    name = "segment",
+    srcs = ["segment/segment.cc"],
+    hdrs = [
+        "segment/segment.h",
+        "segment/union_find.h",
+    ],
+    linkstatic = 1,
+    deps = [
+        "//tensorflow/core:graph",
+        "//tensorflow/core:lib_proto_parsing",
+        "//tensorflow/core:protos_all_cc",
+        "@protobuf_archive//:protobuf_headers",
+    ],
 )
 
 tf_cc_test(
-  name="segment_test",
-  size="small",
-  srcs=["segment/segment_test.cc"],
-  deps=[
-    ":segment",
-    "//tensorflow/c:c_api",
-    "//tensorflow/core:lib",
-    "//tensorflow/core:protos_all_cc",
-    "//tensorflow/core:test",
-    "//tensorflow/core:test_main",
-  ],
+    name = "segment_test",
+    size = "small",
+    srcs = ["segment/segment_test.cc"],
+    deps = [
+        ":segment",
+        "//tensorflow/c:c_api",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
 )
 
 filegroup(
-  name="all_files",
-  srcs=glob(
-    ["**/*"],
-    exclude=[
-      "**/METADATA",
-      "**/OWNERS",
-    ],
-  ),
-  visibility=["//tensorflow:__subpackages__"],
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
 )
index 5d53013..905824c 100644 (file)
@@ -38,7 +38,7 @@ tensorflow::Status ConvertGraphDefToTensorRT(
     const tensorflow::GraphDef& graph_def,
     const std::vector<string>& output_names, size_t max_batch_size,
     size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
-    int precision_mode,int minimum_segment_size);
+    int precision_mode, int minimum_segment_size);
 }  // namespace convert
 }  // namespace tensorrt
 }  // namespace tensorflow
index f1925d3..1bd60c6 100644 (file)
@@ -25,8 +25,8 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/contrib/tensorrt/log/trt_logger.h"
-#include "tensorflow/contrib/tensorrt/resources/TRTResourceManager.h"
-#include "tensorflow/contrib/tensorrt/resources/TRTResources.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/node_def_builder.h"
@@ -319,7 +319,7 @@ void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
 }
 
 template <typename T>
-void reorder2(nvinfer1::DimsHW shape, T const* idata, nvinfer1::DimsHW istrides,
+void Reorder2(nvinfer1::DimsHW shape, T const* idata, nvinfer1::DimsHW istrides,
               T* odata, nvinfer1::DimsHW ostrides) {
   for (int h = 0; h < shape.h(); ++h) {
     for (int w = 0; w < shape.w(); ++w) {
@@ -330,8 +330,8 @@ void reorder2(nvinfer1::DimsHW shape, T const* idata, nvinfer1::DimsHW istrides,
 }
 
 // TODO(jie): fail to tensorflow!!
-void reorder_ck_to_kc(TRT_ShapedWeights const& iweights,
-                      TRT_ShapedWeights* oweights) {
+void ReorderCKtoKC(TRT_ShapedWeights const& iweights,
+                   TRT_ShapedWeights* oweights) {
   int c = iweights.shape_.d[0];
   int k = iweights.shape_.d[1];
   oweights->shape_.d[0] = k;
@@ -340,14 +340,14 @@ void reorder_ck_to_kc(TRT_ShapedWeights const& iweights,
   nvinfer1::DimsHW ostrides = {c, 1};
   switch (iweights.type_) {
     case tensorflow::DataType::DT_FLOAT: {
-      reorder2({k, c}, static_cast<float const*>(iweights.GetValues()),
+      Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()),
                istrides,
                static_cast<float*>(const_cast<void*>(oweights->GetValues())),
                ostrides);
       break;
     }
     case tensorflow::DataType::DT_HALF: {
-      reorder2(
+      Reorder2(
           {k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
           istrides,
           static_cast<Eigen::half*>(const_cast<void*>(oweights->GetValues())),
@@ -427,7 +427,7 @@ class Converter {
   std::unordered_map<string, OpConverter> op_registry_;
   nvinfer1::INetworkDefinition* trt_network_;
   std::list<std::vector<uint8_t>> temp_bufs_;
-  tensorflow::trt::TRTWeightStore* weight_store_;
+  tensorflow::tensorrt::TRTWeightStore* weight_store_;
   bool fp16_;
   void register_op_converters();
   std::vector<TRT_TensorOrWeights> get_inputs(
@@ -464,11 +464,11 @@ class Converter {
 
  public:
   explicit Converter(nvinfer1::INetworkDefinition* trt_network,
-                     tensorflow::trt::TRTWeightStore* ws, bool fp16)
+                     tensorflow::tensorrt::TRTWeightStore* ws, bool fp16)
       : trt_network_(trt_network), weight_store_(ws), fp16_(fp16) {
     this->register_op_converters();
   }
-  tensorflow::trt::TRTWeightStore* weight_store() { return weight_store_; }
+  tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; }
   TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
                                      nvinfer1::Dims shape) {
     TRT_ShapedWeights weights(type, nullptr, shape);
@@ -813,12 +813,12 @@ tensorflow::Status ConstantFoldBinary(
         "Binary op implicit broadcast not supported: " + node_def.op());
 
   // TODO(jie): constant fold should really fall back to TF.
-  int nb_dims = weights_input_l.shape_.nbDims;
+  int num_dims = weights_input_l.shape_.nbDims;
   nvinfer1::Dims output_shape;
-  output_shape.nbDims = nb_dims;
-  VLOG(2) << "nb_dims: " << nb_dims
+  output_shape.nbDims = num_dims;
+  VLOG(2) << "nb_dims: " << num_dims
           << ", the other: " << weights_input_r.shape_.nbDims;
-  for (int i = 0; i < nb_dims; i++) {
+  for (int i = 0; i < num_dims; i++) {
     if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
       output_shape.d[i] = weights_input_l.shape_.d[i];
     } else if (weights_input_l.shape_.d[i] == 1 ||
@@ -1950,27 +1950,6 @@ tensorflow::Status ConvertFusedBatchNorm(
       }
     }
   }
-  // if (scale_weights.type_ != tensorflow::DataType::DT_FLOAT ||
-  //     offset_weights.type_ != tensorflow::DataType::DT_FLOAT ||
-  //     mean_weights.type_ != tensorflow::DataType::DT_FLOAT ||
-  //     variance_weights.type_ != tensorflow::DataType::DT_FLOAT) {
-  //   return tensorflow::errors::Unimplemented(
-  //       "only float32 weights data type is supported, at " +
-  //       node_def.name());
-  // }
-  // for (size_t i = 0; i < nweight; ++i) {
-  //   float scale = (static_cast<float const*>(scale_weights.GetValues()))[i];
-  //   float offset = (static_cast<float
-  //   const*>(offset_weights.GetValues()))[i]; float mean = (static_cast<float
-  //   const*>(mean_weights.GetValues()))[i]; float variance =
-  //       (static_cast<float const*>(variance_weights.GetValues()))[i];
-  //   float& combined_scale_ref = const_cast<float*>(
-  //       static_cast<float const*>(combined_scale_weights.GetValues()))[i];
-  //   float& combined_offset_ref = const_cast<float*>(
-  //       static_cast<float const*>(combined_offset_weights.GetValues()))[i];
-  //   combined_scale_ref = scale / sqrtf(variance + epsilon);
-  //   combined_offset_ref = offset - mean * combined_scale_ref;
-  // }
   nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
       *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
       combined_offset_weights.GetWeightsForTRT(),
@@ -1996,7 +1975,7 @@ tensorflow::Status ConvertMatMul(Converter& ctx,
 
   TRT_ShapedWeights weights_ck = inputs.at(1).weights();
   TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_ck);
-  reorder_ck_to_kc(weights_ck, &weights);
+  ReorderCKtoKC(weights_ck, &weights);
   TRT_ShapedWeights biases(weights.type_);
 
   int noutput = weights.shape_.d[0];
@@ -2022,7 +2001,6 @@ tensorflow::Status ConvertReshape(
   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
   auto dims = tensor->getDimensions();
   // restore implicit batch dimension
-  int nbDims = dims.nbDims + 1;
 
   TRT_ShapedWeights shape = inputs.at(1).weights();
 
@@ -2171,32 +2149,32 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
   for (auto& i : input_names) {
     VLOG(1) << " " << i << " in graph " << nodeMaps.count(i);
   }
-  auto trt_rm = tensorflow::trt::TRTResourceManager::instance();
+  auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance();
   auto resmgr = trt_rm->getManager("TRTCalibOps");
-  tensorflow::trt::TRTCalibrationResource* calibRes = nullptr;
-  auto status = resmgr->Lookup(res_name, res_name, &calibRes);
-  if (!status.ok() || !calibRes->calibrator) {
+  tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr;
+  auto status = resmgr->Lookup(res_name, res_name, &calib_res);
+  if (!status.ok() || !calib_res->calibrator_) {
     return tensorflow::errors::FailedPrecondition(
         "You must run calibration"
         " and inference conversion in the same proces");
   }
 
-  calibRes->calibrator->setDone();
-  calibRes->thr->join();
-  delete calibRes->thr;
-  if (!calibRes->engine) {
+  calib_res->calibrator_->setDone();
+  calib_res->thr_->join();
+  delete calib_res->thr_;
+  if (!calib_res->engine_) {
     LOG(FATAL) << "Calibration failed!, engine is nullptr";
   }
   auto weight_rmgr = trt_rm->getManager("WeightStore");
-  TF_CHECK_OK(
-      weight_rmgr->Delete<tensorflow::trt::TRTWeightStore>(res_name, res_name));
-  auto engine_plan = calibRes->engine->serialize();
-  calibRes->engine->destroy();
-  calibRes->network->destroy();
-  calibRes->builder->destroy();
-  calibRes->thr = nullptr;
-  calibRes->engine = nullptr;
-  calibRes->builder = nullptr;
+  TF_CHECK_OK(weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>(
+      res_name, res_name));
+  auto engine_plan = calib_res->engine_->serialize();
+  calib_res->engine_->destroy();
+  calib_res->network_->destroy();
+  calib_res->builder_->destroy();
+  calib_res->thr_ = nullptr;
+  calib_res->engine_ = nullptr;
+  calib_res->builder_ = nullptr;
   tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
   std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
   for (const auto in_edge : c_node->in_edges()) {
@@ -2275,23 +2253,23 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
       tensorflow::strings::StrCat(subgraph_name_scope, "my_trt_op", static_id);
   static_id++;
   VLOG(2) << "BUILDING 2";
-  auto trt_rmgr = tensorflow::trt::TRTResourceManager::instance();
+  auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance();
   auto op_rmgr = trt_rmgr->getManager("TRTCalibOps");
-  auto op_res = new tensorflow::trt::TRTCalibrationResource();
+  auto op_res = new tensorflow::tensorrt::TRTCalibrationResource();
   VLOG(1) << "SAMI Creating calibresource " << calib_op_name << " @ " << op_res;
   TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res));
-  op_res->logger = new tensorflow::tensorrt::Logger();
-  op_res->builder = nvinfer1::createInferBuilder(*(op_res->logger));
+  op_res->logger_ = new tensorflow::tensorrt::Logger();
+  op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_));
 
-  if (!op_res->builder) {
+  if (!op_res->builder_) {
     return tensorflow::errors::Internal(
         "failed to create TensorRT builder object");
   }
 
   VLOG(2) << "BUILDING 3";
 
-  op_res->network = op_res->builder->createNetwork();
-  if (!op_res->network) {
+  op_res->network_ = op_res->builder_->createNetwork();
+  if (!op_res->network_) {
     return tensorflow::errors::Internal(
         "failed to create TensorRT network object");
   }
@@ -2300,9 +2278,9 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
 
   // Build the network
   auto weight_rmgr = trt_rmgr->getManager("WeightStore");
-  auto ws = new tensorflow::trt::TRTWeightStore();
+  auto ws = new tensorflow::tensorrt::TRTWeightStore();
   TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws));
-  Converter converter(op_res->network, ws, s.precision_mode == 1);
+  Converter converter(op_res->network_, ws, s.precision_mode == 1);
 
   VLOG(2) << "BUILDING 5";
   std::vector<string> input_names;
@@ -2420,8 +2398,8 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
   VLOG(2) << "finished output";
 
   // Build the engine
-  op_res->builder->setMaxBatchSize(s.max_batch_size);
-  op_res->builder->setMaxWorkspaceSize(s.max_workspace_size_bytes);
+  op_res->builder_->setMaxBatchSize(s.max_batch_size);
+  op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes);
 
   // Build the TRT op
   // TODO(sami,ben,jie): proper naming!
@@ -2505,9 +2483,9 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
   string engine_name =
       tensorflow::strings::StrCat(subgraph_name_scope, "my_trt_op");
   engine_name = tensorflow::strings::StrCat(engine_name, static_id++);
-  auto trt_rmgr = tensorflow::trt::TRTResourceManager::instance();
+  auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance();
   auto weight_rmgr = trt_rmgr->getManager("WeightStore");
-  auto ws = new tensorflow::trt::TRTWeightStore();
+  auto ws = new tensorflow::tensorrt::TRTWeightStore();
   TF_CHECK_OK(weight_rmgr->Create(engine_name, engine_name, ws));
 
   // Build the network
@@ -2680,8 +2658,8 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
     engine_plan_string =
         string(engine_plan_data, engine_plan_data + engine_plan->size());
   }
-  weight_rmgr->Delete<tensorflow::trt::TRTWeightStore>(engine_name,
-                                                       engine_name);
+  weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>(engine_name,
+                                                            engine_name);
   LOG(INFO) << "finished engine " << engine_name;
 
   // Build the TRT op
index 1dcb87e..b78ff18 100644 (file)
@@ -21,6 +21,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_types.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/stream_executor.h"
 
 #if GOOGLE_CUDA
 #if GOOGLE_TENSORRT
@@ -113,7 +114,13 @@ void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) {
     ctx->set_output(i, t);
   }
   VLOG(2) << "Filled map for sending";
-  calib_res->calibrator_->setBatch(input_data);
+  // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+  const cudaStream_t* stream = CHECK_NOTNULL(
+      reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
+                                                ->stream()
+                                                ->implementation()
+                                                ->CudaStreamMemberHack()));
+  calib_res->calibrator_->setBatch(input_data,*stream);
   VLOG(2) << "Passed calibration data";
   // TODO(aaroey): make sure we wait for the completion of calibration on the
   // last batch in future PR.
diff --git a/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.cc b/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.cc
deleted file mode 100644 (file)
index 57677a3..0000000
+++ /dev/null
@@ -1,174 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h"
-
-#include <atomic>
-#include <chrono>
-#include <unordered_map>
-#include "cuda_runtime_api.h"
-
-#include "tensorflow/core/platform/logging.h"
-
-namespace tensorflow {
-namespace trt {
-// set the batch size before constructing the thread to execute engine
-int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }
-
-TRTInt8Calibrator::TRTInt8Calibrator(
-    const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
-    int batch_size, string engineName)
-    : batch_size_(batch_size),
-      done_(false),
-      dev_buffers_(dev_buffers),
-      calib_running_(false),
-      engine_name_(engineName) {
-  cudaPointerAttributes pa;
-  int devid = -1;
-  cudaGetDevice(&devid);
-  VLOG(0) << "Constructing calibrator with batch size " << batch_size
-          << " on device" << devid;
-  for (auto b : dev_buffers_) {
-    if (cudaPointerGetAttributes(&pa, b.second.first) == cudaSuccess) {
-      VLOG(1) << "CALIBRATOR " << engine_name_ << " Device buffer name "
-              << b.first << " size" << b.second.second << " @ "
-              << b.second.first << " onDevice "
-              << ((pa.memoryType == cudaMemoryTypeHost) ? "HOST" : "DEVICE");
-    } else {
-      VLOG(1) << "CALIBRATOR " << engine_name_ << " Device buffer name "
-              << b.first << " size" << b.second.second << " @ "
-              << b.second.first;
-    }
-  }
-}
-
-bool TRTInt8Calibrator::setBatch(
-    const std::unordered_map<string, void*>& data) {
-  VLOG(1) << "SAMI SAMI " << engine_name_ << " Waiting to set new batch";
-  if (done_) return false;
-  while (calib_running_.load(
-      std::memory_order_acquire)) {  // wait while calibration is running
-    tensorflow::mutex_lock l(cond_mtx_);
-    cond_.wait_for(l, std::chrono::milliseconds(50));
-    if (done_) return false;
-  }
-  VLOG(1) << "Set Batch Waiting finished";
-  for (const auto it : data) {
-    auto devptr = dev_buffers_.find(it.first);
-    if (devptr == dev_buffers_.end()) {
-      LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
-                 << "' does not match with the buffer names";
-    }
-    const auto& d = devptr->second;
-    if (VLOG_IS_ON(1)) {
-      cudaPointerAttributes pa;
-      VLOG(1) << "cuda memcopy " << engine_name_ << " buff name= " << it.first
-              << " dst= " << d.first << " size= " << d.second
-              << " inp= " << it.second;
-      if (cudaPointerGetAttributes(&pa, it.second) == cudaSuccess) {
-        VLOG(1) << "CALIBRATOR " << engine_name_ << " Device buffer name "
-                << it.first << " size" << d.second << " @ " << d.first
-                << " onDevice "
-                << ((pa.memoryType == cudaMemoryTypeHost) ? "HOST" : "DEVICE");
-      }
-    }
-
-    auto status =
-        cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice);
-    if (status != cudaSuccess) {
-      LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first
-                 << "' failed with " << status;
-    }
-    if (VLOG_IS_ON(1)) {
-      float f[4];
-      f[0] = 3.;
-      f[1] = 0.14159;
-      f[2] = 3.;
-      f[3] = 0.14159;
-      status =
-          cudaMemcpy(f, d.first, sizeof(float) * 2, cudaMemcpyDeviceToHost);
-      if (status != cudaSuccess) {
-        VLOG(1) << "Memcopy failed!";
-      }
-      status = cudaMemcpy(f + 2, it.second, sizeof(float) * 2,
-                          cudaMemcpyDeviceToHost);
-      int devid = -1;
-      cudaGetDevice(&devid);
-      VLOG(1) << "SAMI ORDER SETTING " << engine_name_
-              << " Data in perm storage [0]=" << f[0] << " [1]=" << f[1]
-              << " current device=" << devid << " data in tensor=" << f[2]
-              << " " << f[3];
-    }
-  }
-  calib_running_.store(true, std::memory_order_release);  // release builder
-  cond_.notify_all();
-  return true;
-}
-
-bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
-                                 int nbBindings) {
-  calib_running_.store(false, std::memory_order_release);  // wait for new batch
-  VLOG(1) << "SAMI SAMI Calibrator is waiting for new batch";
-  cond_.notify_all();
-  while (!calib_running_.load(
-      std::memory_order_acquire)) {  // wait until new batch arrives
-    tensorflow::mutex_lock l(cond_mtx_);
-    cond_.wait_for(l, std::chrono::milliseconds(50));
-    if (done_) return false;
-  }
-  if (done_) {
-    return false;
-  }
-
-  for (int i = 0; i < nbBindings; i++) {
-    auto it = dev_buffers_.find(names[i]);
-    if (it == dev_buffers_.end()) {
-      LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
-                 << names[i] << "' at position " << i;
-    }
-
-    bindings[i] = it->second.first;
-    if (VLOG_IS_ON(1)) {
-      VLOG(1) << "Setting buffer " << i << " named=" << names[i] << " @ "
-              << it->second.first;
-      float f[2];
-      f[0] = 3.;
-      f[1] = 0.14159;
-      auto status =
-          cudaMemcpy(f, bindings[i], sizeof(float) * 2, cudaMemcpyDeviceToHost);
-      if (status != cudaSuccess) {
-        VLOG(0) << "Memcopy failed!";
-      }
-      int devid = -1;
-      cudaGetDevice(&devid);
-      VLOG(1) << "ORDER GETTING, " << engine_name_
-              << " Data in perm storage [0]=" << f[0] << " [1]=" << f[1]
-              << " on device=" << devid
-              << " Succeed=" << (status == cudaSuccess ? "True" : "False");
-    }
-  }
-  return true;
-}
-const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
-  return nullptr;
-}
-void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
-                                              std::size_t length) {}
-TRTInt8Calibrator::~TRTInt8Calibrator() {
-  VLOG(1) << "Destroying calibrator for " << engine_name_;
-}
-
-}  // namespace trt
-}  // namespace tensorflow
\ No newline at end of file
diff --git a/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h b/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h
deleted file mode 100644 (file)
index 62c2bf9..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTINT8CALIBRATOR_H_
-#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTINT8CALIBRATOR_H_
-
-#include <atomic>
-#include <string>
-#include <unordered_map>
-#include <utility>
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorrt/include/NvInfer.h"
-namespace tensorflow {
-namespace trt {
-
-struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
- public:
-  TRTInt8Calibrator(
-      const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
-      int batch_size, string engineName);
-  int getBatchSize() const;
-  bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
-  bool setBatch(const std::unordered_map<string, void*>& data);
-  void setDone() { done_ = true; }
-  const void* readCalibrationCache(std::size_t& length) override;
-  void writeCalibrationCache(const void* ptr, std::size_t length) override;
-  ~TRTInt8Calibrator();
-
- private:
-  int batch_size_;
-  tensorflow::mutex cond_mtx_;
-  tensorflow::condition_variable cond_;
-  bool done_;
-  const std::unordered_map<string, std::pair<void*, size_t>> dev_buffers_;
-  std::atomic_bool calib_running_;
-  string engine_name_;
-};
-}  // namespace trt
-}  // namespace tensorflow
-#endif  // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTINT8CALIBRATOR_H_
diff --git a/tensorflow/contrib/tensorrt/resources/TRTResourceManager.cc b/tensorflow/contrib/tensorrt/resources/TRTResourceManager.cc
deleted file mode 100644 (file)
index 3eea23b..0000000
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/tensorrt/resources/TRTResourceManager.h"
-#include "tensorflow/core/platform/default/logging.h"
-
-std::shared_ptr<tensorflow::ResourceMgr>
-tensorflow::trt::TRTResourceManager::getManager(const std::string& mgr_name) {
-  // mutex is held for lookup only. Most instantiations where mutex will be held
-  // longer will be during op creation and should be ok.
-  tensorflow::mutex_lock lock(map_mutex_);
-  auto s = managers_.find(mgr_name);
-  if (s == managers_.end()) {
-    auto it = managers_.emplace(
-        mgr_name, std::make_shared<tensorflow::ResourceMgr>(mgr_name));
-    VLOG(0) << "Returning a new manager " << mgr_name;
-    return it.first->second;
-  }
-  VLOG(1) << "Returning old manager " << mgr_name;
-  return s->second;
-}
diff --git a/tensorflow/contrib/tensorrt/resources/TRTResourceManager.h b/tensorflow/contrib/tensorrt/resources/TRTResourceManager.h
deleted file mode 100644 (file)
index d482c7d..0000000
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCEMANAGER_H_
-
-#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCE_TRTRESOURCEMANAGER_H_
-#include <memory>
-
-#include <string>
-#include <unordered_map>
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/platform/mutex.h"
-
-namespace tensorflow {
-namespace trt {
-class TRTResourceManager {
-  TRTResourceManager() = default;
-
- public:
-  static std::shared_ptr<TRTResourceManager> instance() {
-    static std::shared_ptr<TRTResourceManager> instance_(
-        new TRTResourceManager);
-    return instance_;
-  }
-  // returns a manager for given op, if it doesn't exists it creates one
-  std::shared_ptr<tensorflow::ResourceMgr> getManager(const string& op_name);
-
- private:
-  std::unordered_map<string, std::shared_ptr<tensorflow::ResourceMgr>>
-      managers_;
-  tensorflow::mutex map_mutex_;
-};
-}  // namespace trt
-}  // namespace tensorflow
-#endif  // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCEMANAGER_H_
diff --git a/tensorflow/contrib/tensorrt/resources/TRTResources.h b/tensorflow/contrib/tensorrt/resources/TRTResources.h
deleted file mode 100644 (file)
index 20ccf0f..0000000
+++ /dev/null
@@ -1,91 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_
-
-#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_
-
-#include <list>
-#include <sstream>
-#include <string>
-#include <thread>
-#include <vector>
-#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
-#include "tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorrt/include/NvInfer.h"
-
-namespace tensorflow {
-namespace trt {
-struct TRTCalibrationResource : public tensorflow::ResourceBase {
-  TRTCalibrationResource()
-      : calibrator(nullptr),
-        builder(nullptr),
-        network(nullptr),
-        engine(nullptr),
-        logger(nullptr),
-        thr(nullptr) {}
-  string DebugString() override {
-    std::stringstream oss;
-#define VALID_OR_NULL(ptr) \
-  (!ptr ? "nullptr" : std::hex << (void)ptr << std::dec << std::endl)
-    oss << " Calibrator = " << std::hex << calibrator << std::dec << std::endl
-        << " Builder    = " << std::hex << builder << std::dec << std::endl
-        << " Network    = " << std::hex << network << std::dec << std::endl
-        << " Engine     = " << std::hex << engine << std::dec << std::endl
-        << " Logger     = " << std::hex << logger << std::dec << std::endl
-        << " Thread     = " << std::hex << thr << std::dec << std::endl;
-    return oss.str();
-#undef VALID_OR_NULL
-  }
-  ~TRTCalibrationResource() {
-    VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
-  }
-  TRTInt8Calibrator* calibrator;
-  nvinfer1::IBuilder* builder;
-  nvinfer1::INetworkDefinition* network;
-  nvinfer1::ICudaEngine* engine;
-  tensorflow::tensorrt::Logger* logger;
-  // TODO(sami): Use threadpool threads!
-  std::thread* thr;
-};
-
-struct TRTWeightStore : public tensorflow::ResourceBase {
-  TRTWeightStore() {}
-  std::list<std::vector<uint8_t>> store_;
-  string DebugString() override {
-    std::stringstream oss;
-    size_t lenBytes = 0;
-    for (const auto& v : store_) {
-      lenBytes += v.size() * sizeof(uint8_t);
-    }
-    oss << " Number of entries     = " << store_.size() << std::endl
-        << " Total number of bytes = "
-        << store_.size() * sizeof(std::vector<uint8_t>) + lenBytes << std::endl;
-    return oss.str();
-  }
-  virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); }
-};
-
-struct TRTEngineResource : public tensorflow::ResourceBase {
-  TRTEngineResource() : runtime(nullptr), ctx(nullptr){};
-  string DebugString() override { return string(""); }
-  nvinfer1::IRuntime* runtime;
-  nvinfer1::IExecutionContext* ctx;
-};
-
-}  // namespace trt
-}  // namespace tensorflow
-#endif  // TENSORFLOW_CONTRIB_TENSORRT_RESOURCEMGR_TRTRESOURCES_H_
index 3d5cc76..f157720 100644 (file)
@@ -38,22 +38,24 @@ TRTInt8Calibrator::TRTInt8Calibrator(
       done_(false),
       dev_buffers_(dev_buffers),
       calib_running_(false),
+      batch_is_set_(false),
       engine_name_(engine_name) {}
 
-bool TRTInt8Calibrator::setBatch(
-    const std::unordered_map<string, void*>& data) {
+bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
+                                 const cudaStream_t stream) {
   // TODO(aaroey): make sure that in future PR:
   // 1. the mutex_lock is outside of the loop
   // 2. wait() is used instead of wait_for()
   // 3. done_ is to be protected by the mutex
   // 4. the first batch is not missed
   if (done_) return false;
-  while (calib_running_.load(
-      std::memory_order_acquire)) {  // wait while calibration is running
-    tensorflow::mutex_lock l(cond_mtx_);
-    cond_.wait_for(l, std::chrono::milliseconds(50));
+  tensorflow::mutex_lock l(cond_mtx_);
+  while ((calib_running_ || batch_is_set_) &&
+         !done_) {  // wait while calibration is running
+    cond_.wait(l);
     if (done_) return false;
   }
+  CHECK(!calib_running_ && !batch_is_set_);
   VLOG(1) << "Set Batch Waiting finished";
   for (const auto it : data) {
     auto devptr = dev_buffers_.find(it.first);
@@ -65,32 +67,32 @@ bool TRTInt8Calibrator::setBatch(
 
     // TODO(aaroey): we should not use sync copy on default stream. Make sure
     // stream->ThenMemcpy() is used in future PRs.
-    auto status =
-        cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice);
+    auto status = cudaMemcpyAsync(d.first, it.second, d.second,
+                                  cudaMemcpyDeviceToDevice, stream);
     if (status != cudaSuccess) {
       LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first
                  << "' failed with " << status;
     }
   }
-  calib_running_.store(true, std::memory_order_release);  // release builder
+  cudaStreamSynchronize(
+      stream);  // we have to wait for the stream before returning!
+  batch_is_set_ = true;
   cond_.notify_all();
   return true;
 }
 
 bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
                                  int num_bindings) {
-  calib_running_.store(false, std::memory_order_release);  // wait for new batch
+  tensorflow::mutex_lock l(cond_mtx_);
+  calib_running_ = false;
   cond_.notify_all();
-  while (!calib_running_.load(
-      std::memory_order_acquire)) {  // wait until new batch arrives
-    tensorflow::mutex_lock l(cond_mtx_);
-    cond_.wait_for(l, std::chrono::milliseconds(50));
-    if (done_) return false;
+  while ((!batch_is_set_ && !done_)) {  // wait until new batch arrives
+    cond_.wait(l);
   }
   if (done_) {
     return false;
   }
-
+  CHECK(!calib_running_ && batch_is_set_);
   for (int i = 0; i < num_bindings; i++) {
     auto it = dev_buffers_.find(names[i]);
     if (it == dev_buffers_.end()) {
@@ -100,13 +102,19 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
 
     bindings[i] = it->second.first;
   }
+  batch_is_set_ = false;
+  calib_running_ = true;
   return true;
 }
 
 const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
   return nullptr;
 }
-
+void TRTInt8Calibrator::setDone() {
+  tensorflow::mutex_lock l(cond_mtx_);
+  done_ = true;
+  cond_.notify_all();
+}
 void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
                                               std::size_t length) {}
 TRTInt8Calibrator::~TRTInt8Calibrator() {
index 8830f7e..cab9c7e 100644 (file)
@@ -24,6 +24,7 @@ limitations under the License.
 
 #if GOOGLE_CUDA
 #if GOOGLE_TENSORRT
+#include "cuda_runtime_api.h"
 #include "tensorrt/include/NvInfer.h"
 namespace tensorflow {
 namespace tensorrt {
@@ -39,8 +40,8 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
   int getBatchSize() const override;
   bool getBatch(void* bindings[], const char* names[],
                 int num_bindings) override;
-  bool setBatch(const std::unordered_map<string, void*>& data);
-  void setDone() { done_ = true; }
+  bool setBatch(const std::unordered_map<string, void*>& data,const cudaStream_t stream);
+  void setDone();
   const void* readCalibrationCache(std::size_t& length) override;
   void writeCalibrationCache(const void* ptr, std::size_t length) override;
   ~TRTInt8Calibrator();
@@ -55,7 +56,8 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
   const std::unordered_map<string, std::pair<void*, size_t>>
       dev_buffers_;  // map to keep tensorrt input buffers and sizes keyed with
                      // buffer names
-  std::atomic_bool calib_running_;
+  bool calib_running_;
+  bool batch_is_set_;
   string engine_name_;
 };
 }  // namespace tensorrt