From b4d59a8f82437b6da897c2c5fe773db7127efdd8 Mon Sep 17 00:00:00 2001 From: Jie Date: Thu, 8 Feb 2018 13:12:31 -0800 Subject: [PATCH] [cleanup] remove TRT_ShapedWeights owned_values_ fallback to converter temporary memory allocation to avoid redundant memcpy --- .../contrib/tensorrt/convert/convert_nodes.cc | 28 ++++------------------ 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 5efef61..5c22c62 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -114,28 +114,18 @@ static std::vector> CreateSamePadding( class TRT_ShapedWeights { public: TRT_ShapedWeights(tensorflow::DataType type, const void* values, - nvinfer1::Dims shape, - const std::vector* owned_values = nullptr) - : shape_(shape), - type_(type), - values_(values), - owned_values_(owned_values ? *owned_values : std::vector({})), - empty_weight_flag_(false) { + nvinfer1::Dims shape) + : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) { // Note: this->shape.type[] is not used } explicit TRT_ShapedWeights(tensorflow::DataType type) - : shape_(), - type_(type), - values_(nullptr), - owned_values_(), - empty_weight_flag_(true) {} + : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {} TRT_ShapedWeights(const TRT_ShapedWeights& rhs) : shape_(rhs.shape_), type_(rhs.type_), values_(rhs.values_), - owned_values_(rhs.owned_values_), empty_weight_flag_(rhs.empty_weight_flag_) {} int64_t count() const { @@ -153,16 +143,9 @@ class TRT_ShapedWeights { return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)}; } - const void* GetValues() const { - if (values_) return values_; - if (owned_values_.size()) return owned_values_.data(); - return nullptr; - } + const void* GetValues() const { return values_; } - void SetValues(const void* values) { - values_ = values; - owned_values_.clear(); - } + void SetValues(const void* values) { values_ = values; } size_t size_bytes() const { int type_size = tensorflow::DataTypeSize(this->type_); @@ -177,7 +160,6 @@ class TRT_ShapedWeights { private: const void* values_; - std::vector owned_values_; bool empty_weight_flag_; }; -- 2.7.4