From 19fe2b9db474819c93b19695e552f0e549c3eceb Mon Sep 17 00:00:00 2001 From: Rui Zhu Date: Sun, 31 Mar 2019 17:37:17 -0700 Subject: [PATCH] Adding quantized tensor shape/type info support for caffe2=>glow in caffe2 side (#18621) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18621 This diff added caffe2 support for onnxifi quantization. Reviewed By: yinghai Differential Revision: D14648767 fbshipit-source-id: 4ddb492cacbba6142305866e6dbb875880acaea3 --- caffe2/core/blob.h | 5 ++ caffe2/core/operator.cc | 8 +++ caffe2/core/operator.h | 4 ++ caffe2/core/tensor.cc | 18 +++++- caffe2/operators/onnxifi_op.cc | 86 +++++++++++++++++++++------ caffe2/operators/quantized/int8_quantize_op.h | 1 + caffe2/opt/backend_transformer_base.cc | 27 ++++++++- caffe2/opt/backend_transformer_base.h | 5 ++ caffe2/opt/bound_shape_inferencer.cc | 66 ++++++++++++++++---- caffe2/opt/bound_shape_inferencer.h | 3 +- caffe2/opt/onnxifi_transformer.cc | 49 ++++++++++++--- caffe2/opt/shape_info.cc | 6 ++ caffe2/opt/shape_info.h | 26 +++++++- 13 files changed, 261 insertions(+), 43 deletions(-) diff --git a/caffe2/core/blob.h b/caffe2/core/blob.h index 3ab8e77..cb571bb 100644 --- a/caffe2/core/blob.h +++ b/caffe2/core/blob.h @@ -12,9 +12,14 @@ #include #include "caffe2/core/logging.h" #include "caffe2/core/tensor.h" +#include "caffe2/core/tensor_int8.h" namespace caffe2 { +inline bool BlobIsInt8TensorCPUType(const Blob& blob) { + return blob.meta().Match(); +} + inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) { bool is_match = blob.meta().Match(); if (!is_match) { diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc index d7b66bc..21d2329 100644 --- a/caffe2/core/operator.cc +++ b/caffe2/core/operator.cc @@ -7,6 +7,7 @@ #include "caffe2/core/net.h" #include "caffe2/core/operator_gradient.h" #include "caffe2/core/tensor.h" +#include "caffe2/core/tensor_int8.h" #include "caffe2/core/types.h" #include "caffe2/core/workspace.h" @@ -576,6 +577,13 @@ TensorShapes InferBlobShapesAndTypes( return tps; } +void LoadInt8TensorInfoOfBlob(float* scale, float* offset, const Blob* b) { + const int8::Int8TensorCPU* i8tc = + static_cast(b->GetRaw()); + *scale = i8tc->scale; + *offset = i8tc->zero_point; +} + TensorShape GetTensorShapeOfBlob(const Blob* b) { TypeCall type_fun = GetTypeCallFunction(b->meta().id()); TensorInfoCall tensor_info_fun = GetTensorInfoFunction(b->meta().id()); diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 17a0392..7930a02 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -19,6 +19,7 @@ #include "caffe2/core/operator_gradient.h" #include "caffe2/core/operator_schema.h" #include "caffe2/core/tensor.h" +#include "caffe2/core/tensor_int8.h" #include "caffe2/core/types.h" #include "caffe2/core/workspace.h" #include "caffe2/proto/caffe2_pb.h" @@ -1336,6 +1337,9 @@ CAFFE2_API void SetOpEnginePref( const std::string& op_type, const CaffeMap& op_pref); +CAFFE2_API void +LoadInt8TensorInfoOfBlob(float* scale, float* offset, const Blob* b); + CAFFE2_API TensorShape GetTensorShapeOfBlob(const Blob* b); CAFFE2_API TensorShapes InferBlobShapesAndTypes( diff --git a/caffe2/core/tensor.cc b/caffe2/core/tensor.cc index afcae81..e4278b6 100644 --- a/caffe2/core/tensor.cc +++ b/caffe2/core/tensor.cc @@ -1,4 +1,5 @@ #include "caffe2/core/tensor.h" +#include "caffe2/core/tensor_int8.h" #include "caffe2/core/blob_stats.h" @@ -56,10 +57,16 @@ TypeMeta GetTensorType(const void* c) { const Tensor* tc = static_cast(c); return tc->dtype(); } +TypeMeta GetInt8TensorType(const void* c) { + const int8::Int8TensorCPU* i8tc = static_cast(c); + return (i8tc->t).dtype(); +} // TODO(jerryzh): Remove static CaffeMap type_call_registry_{ - {TypeMeta::Id(), GetTensorType}}; + {TypeMeta::Id(), GetTensorType}, + {TypeMeta::Id(), GetInt8TensorType}, +}; TypeCall GetTypeCallFunction(TypeIdentifier id) { auto f = type_call_registry_.find(id); @@ -89,9 +96,16 @@ vector GetTensorInfo( return tc->sizes().vec(); } +vector +GetInt8TensorInfo(const void* c, size_t* capacity, DeviceOption* device) { + const int8::Int8TensorCPU* i8tc = static_cast(c); + return GetTensorInfo(&(i8tc->t), capacity, device); +} // since we only have one tensor, probably need to remove this at some point? static CaffeMap tensor_info_call_registry_{ - {TypeMeta::Id(), GetTensorInfo}}; + {TypeMeta::Id(), GetTensorInfo}, + {TypeMeta::Id(), GetInt8TensorInfo}, +}; // TODO: Remove this code in a separate diff, since we only have one // GetTensorInfo function now diff --git a/caffe2/operators/onnxifi_op.cc b/caffe2/operators/onnxifi_op.cc index e8b582f..d2b3942 100644 --- a/caffe2/operators/onnxifi_op.cc +++ b/caffe2/operators/onnxifi_op.cc @@ -34,7 +34,26 @@ void SetInputTensorDescriptorTypeAndBuffer( } } -TypeMeta OnnixfiTypeToDataType(uint64_t onnxifi_type) { +void SetInputTensorDescriptorTypeAndBuffer( + const int8::Int8TensorCPU& cpu_int8tensor, + onnxTensorDescriptorV1* desc) { + const Tensor& cpu_tensor = cpu_int8tensor.t; + if (cpu_tensor.template IsType()) { + desc->dataType = ONNXIFI_DATATYPE_UINT8; + desc->buffer = reinterpret_cast(cpu_tensor.data()); + } else if (cpu_tensor.template IsType()) { + desc->dataType = ONNXIFI_DATATYPE_INT32; + desc->buffer = reinterpret_cast(cpu_tensor.data()); + } else { + CAFFE_THROW( + "Unsupported Int8Tensor type in ONNXIFI: ", cpu_tensor.dtype().name()); + } + desc->is_quantized = true; + desc->scale = cpu_int8tensor.scale; + desc->bias = cpu_int8tensor.zero_point; +} + +TypeMeta OnnxifiTypeToDataType(uint64_t onnxifi_type) { static std::map data_type_map { {ONNXIFI_DATATYPE_FLOAT32, TypeMeta::Make()}, {ONNXIFI_DATATYPE_INT32, TypeMeta::Make()}, @@ -45,7 +64,10 @@ TypeMeta OnnixfiTypeToDataType(uint64_t onnxifi_type) { {ONNXIFI_DATATYPE_UINT16, TypeMeta::Make()}, }; const auto it = data_type_map.find(onnxifi_type); - CAFFE_ENFORCE(it != data_type_map.end(), "Unsupported ONXNIFI data type: ", onnxifi_type); + CAFFE_ENFORCE( + it != data_type_map.end(), + "Unsupported ONNXIFI data type: ", + onnxifi_type); return it->second; } @@ -54,9 +76,23 @@ void SetOutputTensorDescriptorTypeAndBuffer( Tensor* cpu_tensor, onnxTensorDescriptorV1* desc) { desc->dataType = onnxifi_type; - desc->buffer = reinterpret_cast(cpu_tensor->raw_mutable_data(OnnixfiTypeToDataType(onnxifi_type))); + desc->buffer = reinterpret_cast( + cpu_tensor->raw_mutable_data(OnnxifiTypeToDataType(onnxifi_type))); } +void SetOutputTensorDescriptorTypeAndBuffer( + uint64_t onnxifi_type, + int8::Int8TensorCPU* cpu_int8tensor, + onnxTensorDescriptorV1* desc) { + desc->dataType = onnxifi_type; + Tensor* cpu_tensor = &(cpu_int8tensor->t); + + desc->buffer = reinterpret_cast( + cpu_tensor->raw_mutable_data(OnnxifiTypeToDataType(onnxifi_type))); + desc->is_quantized = true; + desc->scale = cpu_int8tensor->scale; + desc->bias = cpu_int8tensor->zero_point; +} void BlobToTensorDescriptor( const std::string& name, Workspace* ws, @@ -64,26 +100,39 @@ void BlobToTensorDescriptor( std::vector>* shapes) { const Blob* blob = ws->GetBlob(name); CAFFE_ENFORCE(blob, "Blob ", name, " doesn't exist"); - + const bool is_int8tensor = + blob->meta().id() == TypeMeta::Id(); // Memory type - // We only allow weights to be CPU tensor for now + // We only allow weights to be CPU tensor or int8tensor for now CAFFE_ENFORCE( - BlobIsTensorType(*blob, CPU), + (BlobIsTensorType(*blob, CPU) || BlobIsInt8TensorCPUType(*blob)), "Initialization blob ", name, - " needs to be TensorCPU"); + " needs to be TensorCPU or Int8TensorCPU"); desc->tag = ONNXIFI_TAG_TENSOR_DESCRIPTOR_V1; desc->memoryType = ONNXIFI_MEMORY_TYPE_CPU; - // Data type - const auto& cpu_tensor = blob->template Get(); - SetInputTensorDescriptorTypeAndBuffer(cpu_tensor, desc); - - // Set dims - const auto shape = cpu_tensor.sizes(); - desc->dimensions = shape.size(); - shapes->emplace_back(shape.cbegin(), shape.cend()); - desc->shape = shapes->back().data(); + if (is_int8tensor) { + // Data type + const auto& cpu_int8tensor = blob->template Get(); + const auto& cpu_tensor = cpu_int8tensor.t; + SetInputTensorDescriptorTypeAndBuffer(cpu_int8tensor, desc); + // Set dims + const auto shape = cpu_tensor.sizes(); + desc->dimensions = shape.size(); + shapes->emplace_back(shape.cbegin(), shape.cend()); + desc->shape = shapes->back().data(); + } else { + // Data type + const auto& cpu_tensor = blob->template Get(); + SetInputTensorDescriptorTypeAndBuffer(cpu_tensor, desc); + // Set dims + const auto shape = cpu_tensor.sizes(); + desc->dimensions = shape.size(); + shapes->emplace_back(shape.cbegin(), shape.cend()); + desc->shape = shapes->back().data(); + desc->is_quantized = 0; + } } } // namespace @@ -148,7 +197,10 @@ bool OnnxifiOp::RunOnDevice() { tensor_descriptor.shape = output_shapes_.back().data(); std::vector tensor_dims_int64; std::copy(tensor_dims.cbegin(), tensor_dims.cend(), std::back_inserter(tensor_dims_int64)); - auto* output_tensor = Output(i, tensor_dims_int64, at::dtype(OnnixfiTypeToDataType(type)).device(CPU)); + auto* output_tensor = Output( + i, + tensor_dims_int64, + at::dtype(OnnxifiTypeToDataType(type)).device(CPU)); SetOutputTensorDescriptorTypeAndBuffer( type, output_tensor, &tensor_descriptor); } diff --git a/caffe2/operators/quantized/int8_quantize_op.h b/caffe2/operators/quantized/int8_quantize_op.h index 61c56ba..32bdd15 100644 --- a/caffe2/operators/quantized/int8_quantize_op.h +++ b/caffe2/operators/quantized/int8_quantize_op.h @@ -21,6 +21,7 @@ void Int8Quantize( const int32_t Y_offset) { const float inv_scale = 1.0f / Y_scale; uint32_t i = 0; + #ifdef INT8_NEON_SIMD const float32x4_t vinv_scale = vdupq_n_f32(inv_scale); // magic float and magic int to take care of rounding diff --git a/caffe2/opt/backend_transformer_base.cc b/caffe2/opt/backend_transformer_base.cc index 96212c1..7c33fa6 100644 --- a/caffe2/opt/backend_transformer_base.cc +++ b/caffe2/opt/backend_transformer_base.cc @@ -45,6 +45,27 @@ TensorProto BackendTransformerBase::wrapShapeInfoIntoTensorProto( return t; } +QTensorProto BackendTransformerBase::wrapShapeInfoIntoQTensorProto( + const std::string& name, + const ShapeInfo& shape_info) const { + QTensorProto t; + CAFFE_ENFORCE( + shape_info.is_quantized == true, + "Only quantized shapeinfo can be extracted into QTensor!"); + t.set_name(name); + t.set_data_type(shape_info.shape.data_type()); + t.set_scale(shape_info.q_info.scale); + t.set_bias(shape_info.q_info.offset); + // precision and is_signed is not used in onnxifi workflow, but it is required + // field + t.set_precision(0); + t.set_is_signed(0); + for (const auto i : shape_info.shape.dims()) { + t.add_dims(i); + } + return t; +} + std::unordered_map BackendTransformerBase::ssaRewriteAndMapNames( Workspace* ws, @@ -106,7 +127,11 @@ ShapeInfoMap BackendTransformerBase::inferShapes( shape_map.emplace( std::piecewise_construct, std::forward_as_tuple(kv.first), - std::forward_as_tuple(kv.second.dim_type, kv.second.shape)); + std::forward_as_tuple( + kv.second.dim_type, + kv.second.shape, + kv.second.is_quantized, + kv.second.q_info)); } return shape_map; } diff --git a/caffe2/opt/backend_transformer_base.h b/caffe2/opt/backend_transformer_base.h index a7281a8..e845445 100644 --- a/caffe2/opt/backend_transformer_base.h +++ b/caffe2/opt/backend_transformer_base.h @@ -53,6 +53,11 @@ class BackendTransformerBase { const std::string& name, const ShapeInfo& shape_info) const; + // Wrap Quantized TensorShape into QTensorProto + QTensorProto wrapShapeInfoIntoQTensorProto( + const std::string& name, + const ShapeInfo& shape_info) const; + // Do bound shape inference and collect shape infos ShapeInfoMap inferShapes( Workspace* ws, diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index 717f843..b7c20d5 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -87,10 +87,16 @@ TensorShape& BoundShapeInferencer::CheckAndSetTensorShapeAndType( const std::string& name, ShapeInfo::DimType t, std::vector bound_dims, - TensorProto::DataType type) { + TensorProto::DataType type, + bool is_quantized) { auto rt = shape_info_.emplace(name, ShapeInfo()); ShapeInfo& shape_info = rt.first->second; TensorShape& shape = shape_info.shape; + if (is_quantized) { + shape_info.is_quantized = true; + shape_info.q_info.scale = 1; + shape_info.q_info.offset = 0; + } if (!rt.second) { // Check shape consistency CAFFE_ENFORCE_EQ(shape.dims_size(), bound_dims.size()); @@ -154,12 +160,14 @@ void BoundShapeInferencer::InferLengthsRangeFill(const OperatorDef& op) { op.input(0), ShapeInfo::DimType::BATCH, {spec_.max_batch_size}, - TensorProto_DataType_INT32); + TensorProto_DataType_INT32, + false); CheckAndSetTensorShapeAndType( op.output(0), ShapeInfo::DimType::SEQ, {spec_.max_seq_size}, - TensorProto_DataType_INT32); + TensorProto_DataType_INT32, + false); current_dim_type_ = ShapeInfo::DimType::SEQ; } @@ -191,7 +199,8 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) { op.input(weight), ShapeInfo::DimType::SEQ, {spec_.max_seq_size}, - TensorProto_DataType_FLOAT); + TensorProto_DataType_FLOAT, + false); } // Bound inputs @@ -199,12 +208,14 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) { op.input(1 + weight), ShapeInfo::DimType::SEQ, {spec_.max_seq_size}, - TensorProto_DataType_INT64); + TensorProto_DataType_INT64, + false); CheckAndSetTensorShapeAndType( op.input(2 + weight), ShapeInfo::DimType::BATCH, {spec_.max_batch_size}, - TensorProto_DataType_INT32); + TensorProto_DataType_INT32, + false); // Infer output CAFFE_ENFORCE_EQ(it->second.shape.dims_size(), 2); @@ -221,7 +232,8 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) { op.output(0), ShapeInfo::DimType::BATCH, {spec_.max_batch_size, output_dim1}, - TensorProto_DataType_FLOAT); + TensorProto_DataType_FLOAT, + false); } void BoundShapeInferencer::InferShape(const OperatorDef& op) { @@ -336,7 +348,11 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) { current_dim_type_ = ShapeInfo::DimType::BATCH; current_max_batch_size_ = spec_.max_batch_size; CheckAndSetTensorShapeAndType( - op.input(0), ShapeInfo::DimType::BATCH, dims, w_shape.data_type()); + op.input(0), + ShapeInfo::DimType::BATCH, + dims, + w_shape.data_type(), + false); } else { ShapeInfo& x_shape_info = x_it->second; if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) { @@ -355,7 +371,8 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) { op.output(0), ShapeInfo::DimType::BATCH, ConvertToVec(output_shapes[0].dims()), - output_shapes[0].data_type()); + output_shapes[0].data_type(), + false); } void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { @@ -378,7 +395,35 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { std::vector output_shapes; output_shapes = schema->InferTensor(op, input_shapes); int i = 0; + bool is_quantized = + !(op.type().compare(0, 4, "Int8")) && (op.type() != "Int8Dequantize"); + TensorProto::DataType infered_data_type = TensorProto::UNDEFINED; + if (is_quantized) { + const static std::map type_info_from_input = { + {"Int8Quantize", -1}, // Force this op's output to be uint8 + {"Int8ConvRelu", 1}, + {"Int8MaxPool", 0}, + {"Int8AveragePool", 0}, + {"Int8FC", 1}, + {"Int8Conv", 1}, + {"Int8SumRelu", 0}}; + CAFFE_ENFORCE( + type_info_from_input.find(op.type()) != type_info_from_input.end(), + "Undefined quantized output data type, add it into type_info_from_input"); + int target = type_info_from_input.find(op.type())->second; + if (target == -1) { + infered_data_type = TensorProto::UINT8; + } else { + CAFFE_ENFORCE(target < input_shapes.size()); + infered_data_type = input_shapes[target].data_type(); + } + } else if (op.type() == "Int8Dequantize") { + infered_data_type = TensorProto::FLOAT; + } for (const auto& shape : output_shapes) { + if (infered_data_type == TensorProto::UNDEFINED) { + infered_data_type = shape.data_type(); + } if (shape.unknown_shape()) { ++i; continue; @@ -387,7 +432,8 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { op.output(i++), current_dim_type_, ConvertToVec(shape.dims()), - shape.data_type()); + infered_data_type, + is_quantized); } } catch (const caffe2::EnforceNotMet& e) { LOG(ERROR) << "Enforce not met while inferring shapes for " << op.type() diff --git a/caffe2/opt/bound_shape_inferencer.h b/caffe2/opt/bound_shape_inferencer.h index 333eca3..ee1c670 100644 --- a/caffe2/opt/bound_shape_inferencer.h +++ b/caffe2/opt/bound_shape_inferencer.h @@ -61,7 +61,8 @@ class CAFFE2_API BoundShapeInferencer { const std::string& name, ShapeInfo::DimType t, std::vector bound_dims, - TensorProto::DataType type); + TensorProto::DataType type, + bool is_quantized); void InferGivenTensorFill(const OperatorDef& op); void InferSparseLengthsSum(const OperatorDef& op); diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 702c294..797c3f4 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -486,7 +486,9 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( &initialization_list, &total_inputs_vec); auto* shape_arg = onnxifi_net.add_arg(); + auto* qshape_arg = onnxifi_net.add_arg(); shape_arg->set_name("input_shape_info"); + qshape_arg->set_name("input_qshape_info"); onnxifi_net.clear_external_input(); for (const auto& i : total_inputs_vec) { auto input = i; @@ -495,8 +497,14 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( input = it->second; } onnxifi_net.add_external_input(input); - shape_arg->mutable_tensors()->Add()->CopyFrom( - wrapShapeInfoIntoTensorProto(input, shape_hints.at(i))); + auto info = shape_hints.at(i); + if (!info.is_quantized) { + shape_arg->mutable_tensors()->Add()->CopyFrom( + wrapShapeInfoIntoTensorProto(input, shape_hints.at(i))); + } else { + qshape_arg->mutable_qtensors()->Add()->CopyFrom( + wrapShapeInfoIntoQTensorProto(input, shape_hints.at(i))); + } } // Compute output shape hints @@ -812,24 +820,39 @@ bool OnnxifiTransformer::supportOpC2( // Encode the input/output shapes to an argument auto* shape_arg = net.add_arg(); + auto* qshape_arg = net.add_arg(); shape_arg->set_name("input_shape_info"); + qshape_arg->set_name("input_qshape_info"); for (const auto& i : op.input()) { const auto it = shape_hints.find(i); if (it == shape_hints.end()) { return false; } - shape_arg->mutable_tensors()->Add()->CopyFrom( - wrapShapeInfoIntoTensorProto(i, it->second)); + if ((it->second).is_quantized == false) { + shape_arg->mutable_tensors()->Add()->CopyFrom( + wrapShapeInfoIntoTensorProto(i, it->second)); + } else { + qshape_arg->mutable_qtensors()->Add()->CopyFrom( + wrapShapeInfoIntoQTensorProto(i, it->second)); + } } + + qshape_arg = net.add_arg(); shape_arg = net.add_arg(); shape_arg->set_name("output_shape_info"); + qshape_arg->set_name("output_qshape_info"); for (const auto& i : op.output()) { const auto it = shape_hints.find(i); if (it == shape_hints.end()) { return false; } - shape_arg->mutable_tensors()->Add()->CopyFrom( - wrapShapeInfoIntoTensorProto(i, it->second)); + if ((it->second).is_quantized == false) { + shape_arg->mutable_tensors()->Add()->CopyFrom( + wrapShapeInfoIntoTensorProto(i, it->second)); + } else { + qshape_arg->mutable_qtensors()->Add()->CopyFrom( + wrapShapeInfoIntoQTensorProto(i, it->second)); + } } std::string c2_model_str; @@ -1002,11 +1025,19 @@ void OnnxifiTransformer::transform( if (opts_.debug) { NetDef shape_net(*pred_net); auto* shape_arg = shape_net.add_arg(); + auto* qshape_arg = shape_net.add_arg(); shape_arg->set_name("shape_info"); + qshape_arg->set_name("qshape_info"); for (const auto& kv : shape_hints) { - auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second); - t.add_int32_data(static_cast(kv.second.dim_type)); - shape_arg->mutable_tensors()->Add()->CopyFrom(t); + if (!kv.second.is_quantized) { + auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second); + t.add_int32_data(static_cast(kv.second.dim_type)); + shape_arg->mutable_tensors()->Add()->CopyFrom(t); + } else { + auto t = wrapShapeInfoIntoQTensorProto(kv.first, kv.second); + t.add_data(static_cast(kv.second.dim_type)); + qshape_arg->mutable_qtensors()->Add()->CopyFrom(t); + } } WriteProtoToTextFile(shape_net, "debug_ssa_net.pb_txt"); } diff --git a/caffe2/opt/shape_info.cc b/caffe2/opt/shape_info.cc index 2f09bf7..ebd7b68 100644 --- a/caffe2/opt/shape_info.cc +++ b/caffe2/opt/shape_info.cc @@ -1,4 +1,5 @@ #include "caffe2/opt/shape_info.h" +#include "caffe2/core/tensor_int8.h" #include "caffe2/core/operator.h" @@ -10,6 +11,11 @@ ShapeInfo getShapeInfoFromBlob(const Blob* blob) { shape_info.dim_type = shape_info.shape.unknown_shape() ? ShapeInfo::DimType::UNKNOWN : ShapeInfo::DimType::CONSTANT; + if (blob->meta().id() == TypeMeta::Id()) { + shape_info.is_quantized = true; + LoadInt8TensorInfoOfBlob( + &shape_info.q_info.scale, &shape_info.q_info.offset, blob); + } return shape_info; } diff --git a/caffe2/opt/shape_info.h b/caffe2/opt/shape_info.h index 24672a2..06d4282 100644 --- a/caffe2/opt/shape_info.h +++ b/caffe2/opt/shape_info.h @@ -4,15 +4,35 @@ namespace caffe2 { +struct CAFFE2_API QShapeInfo { + QShapeInfo(float o = 0, float s = 1) : offset(o), scale(s) {} + float offset; + float scale; + // TODO zrphercule + // Add multi offset/scale support here +}; + struct CAFFE2_API ShapeInfo { enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 }; - ShapeInfo() {} - ShapeInfo(DimType t, TensorShape&& s) : dim_type(t), shape(std::move(s)) {} - ShapeInfo(DimType t, const TensorShape& s) : dim_type(t), shape(s) {} + ShapeInfo(bool q = false) : is_quantized(q) {} + ShapeInfo(DimType t, TensorShape&& s, bool q = false) + : dim_type(t), shape(std::move(s)), is_quantized(q) {} + ShapeInfo(DimType t, const TensorShape& s, bool q = false) + : dim_type(t), shape(s), is_quantized(q) {} + + ShapeInfo(bool q, const QShapeInfo& info) : is_quantized(q), q_info(info) {} + ShapeInfo(DimType t, TensorShape&& s, bool q, const QShapeInfo& info) + : dim_type(t), shape(std::move(s)), is_quantized(q), q_info(info) {} + ShapeInfo(DimType t, const TensorShape& s, bool q, const QShapeInfo& info) + : dim_type(t), shape(s), is_quantized(q), q_info(info) {} // type of the shape according its first dim DimType dim_type{DimType::UNKNOWN}; TensorShape shape; + + // quantization related information + bool is_quantized; + QShapeInfo q_info; }; using ShapeInfoMap = std::unordered_map; -- 2.7.4