Adding quantized tensor shape/type info support for caffe2=>glow in caffe2 side ...
authorRui Zhu <zrphercule@fb.com>
Mon, 1 Apr 2019 00:37:17 +0000 (17:37 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 1 Apr 2019 00:42:27 +0000 (17:42 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18621

This diff added caffe2 support for onnxifi quantization.

Reviewed By: yinghai

Differential Revision: D14648767

fbshipit-source-id: 4ddb492cacbba6142305866e6dbb875880acaea3

13 files changed:
caffe2/core/blob.h
caffe2/core/operator.cc
caffe2/core/operator.h
caffe2/core/tensor.cc
caffe2/operators/onnxifi_op.cc
caffe2/operators/quantized/int8_quantize_op.h
caffe2/opt/backend_transformer_base.cc
caffe2/opt/backend_transformer_base.h
caffe2/opt/bound_shape_inferencer.cc
caffe2/opt/bound_shape_inferencer.h
caffe2/opt/onnxifi_transformer.cc
caffe2/opt/shape_info.cc
caffe2/opt/shape_info.h

index 3ab8e77..cb571bb 100644 (file)
 #include <c10/util/typeid.h>
 #include "caffe2/core/logging.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/core/tensor_int8.h"
 
 namespace caffe2 {
 
+inline bool BlobIsInt8TensorCPUType(const Blob& blob) {
+  return blob.meta().Match<int8::Int8TensorCPU>();
+}
+
 inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
   bool is_match = blob.meta().Match<Tensor>();
   if (!is_match) {
index d7b66bc..21d2329 100644 (file)
@@ -7,6 +7,7 @@
 #include "caffe2/core/net.h"
 #include "caffe2/core/operator_gradient.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/core/tensor_int8.h"
 #include "caffe2/core/types.h"
 #include "caffe2/core/workspace.h"
 
@@ -576,6 +577,13 @@ TensorShapes InferBlobShapesAndTypes(
   return tps;
 }
 
+void LoadInt8TensorInfoOfBlob(float* scale, float* offset, const Blob* b) {
+  const int8::Int8TensorCPU* i8tc =
+      static_cast<const int8::Int8TensorCPU*>(b->GetRaw());
+  *scale = i8tc->scale;
+  *offset = i8tc->zero_point;
+}
+
 TensorShape GetTensorShapeOfBlob(const Blob* b) {
   TypeCall type_fun = GetTypeCallFunction(b->meta().id());
   TensorInfoCall tensor_info_fun = GetTensorInfoFunction(b->meta().id());
index 17a0392..7930a02 100644 (file)
@@ -19,6 +19,7 @@
 #include "caffe2/core/operator_gradient.h"
 #include "caffe2/core/operator_schema.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/core/tensor_int8.h"
 #include "caffe2/core/types.h"
 #include "caffe2/core/workspace.h"
 #include "caffe2/proto/caffe2_pb.h"
@@ -1336,6 +1337,9 @@ CAFFE2_API void SetOpEnginePref(
     const std::string& op_type,
     const CaffeMap<DeviceType, EnginePrefType>& op_pref);
 
+CAFFE2_API void
+LoadInt8TensorInfoOfBlob(float* scale, float* offset, const Blob* b);
+
 CAFFE2_API TensorShape GetTensorShapeOfBlob(const Blob* b);
 
 CAFFE2_API TensorShapes InferBlobShapesAndTypes(
index afcae81..e4278b6 100644 (file)
@@ -1,4 +1,5 @@
 #include "caffe2/core/tensor.h"
+#include "caffe2/core/tensor_int8.h"
 
 #include "caffe2/core/blob_stats.h"
 
@@ -56,10 +57,16 @@ TypeMeta GetTensorType(const void* c) {
   const Tensor* tc = static_cast<const Tensor*>(c);
   return tc->dtype();
 }
+TypeMeta GetInt8TensorType(const void* c) {
+  const int8::Int8TensorCPU* i8tc = static_cast<const int8::Int8TensorCPU*>(c);
+  return (i8tc->t).dtype();
+}
 
 // TODO(jerryzh): Remove
 static CaffeMap<TypeIdentifier, TypeCall> type_call_registry_{
-    {TypeMeta::Id<Tensor>(), GetTensorType}};
+    {TypeMeta::Id<Tensor>(), GetTensorType},
+    {TypeMeta::Id<int8::Int8TensorCPU>(), GetInt8TensorType},
+};
 
 TypeCall GetTypeCallFunction(TypeIdentifier id) {
   auto f = type_call_registry_.find(id);
@@ -89,9 +96,16 @@ vector<int64_t> GetTensorInfo(
   return tc->sizes().vec();
 }
 
+vector<int64_t>
+GetInt8TensorInfo(const void* c, size_t* capacity, DeviceOption* device) {
+  const int8::Int8TensorCPU* i8tc = static_cast<const int8::Int8TensorCPU*>(c);
+  return GetTensorInfo(&(i8tc->t), capacity, device);
+}
 // since we only have one tensor, probably need to remove this at some point?
 static CaffeMap<TypeIdentifier, TensorInfoCall> tensor_info_call_registry_{
-    {TypeMeta::Id<Tensor>(), GetTensorInfo}};
+    {TypeMeta::Id<Tensor>(), GetTensorInfo},
+    {TypeMeta::Id<int8::Int8TensorCPU>(), GetInt8TensorInfo},
+};
 
 // TODO: Remove this code in a separate diff, since we only have one
 // GetTensorInfo function now
index e8b582f..d2b3942 100644 (file)
@@ -34,7 +34,26 @@ void SetInputTensorDescriptorTypeAndBuffer(
   }
 }
 
-TypeMeta OnnixfiTypeToDataType(uint64_t onnxifi_type) {
+void SetInputTensorDescriptorTypeAndBuffer(
+    const int8::Int8TensorCPU& cpu_int8tensor,
+    onnxTensorDescriptorV1* desc) {
+  const Tensor& cpu_tensor = cpu_int8tensor.t;
+  if (cpu_tensor.template IsType<uint8_t>()) {
+    desc->dataType = ONNXIFI_DATATYPE_UINT8;
+    desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<uint8_t>());
+  } else if (cpu_tensor.template IsType<int32_t>()) {
+    desc->dataType = ONNXIFI_DATATYPE_INT32;
+    desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor.data<int32_t>());
+  } else {
+    CAFFE_THROW(
+        "Unsupported Int8Tensor type in ONNXIFI: ", cpu_tensor.dtype().name());
+  }
+  desc->is_quantized = true;
+  desc->scale = cpu_int8tensor.scale;
+  desc->bias = cpu_int8tensor.zero_point;
+}
+
+TypeMeta OnnxifiTypeToDataType(uint64_t onnxifi_type) {
   static std::map<uint64_t, TypeMeta> data_type_map {
     {ONNXIFI_DATATYPE_FLOAT32, TypeMeta::Make<float>()},
     {ONNXIFI_DATATYPE_INT32, TypeMeta::Make<int>()},
@@ -45,7 +64,10 @@ TypeMeta OnnixfiTypeToDataType(uint64_t onnxifi_type) {
     {ONNXIFI_DATATYPE_UINT16, TypeMeta::Make<uint16_t>()},
   };
   const auto it = data_type_map.find(onnxifi_type);
-  CAFFE_ENFORCE(it != data_type_map.end(), "Unsupported ONXNIFI data type: ", onnxifi_type);
+  CAFFE_ENFORCE(
+      it != data_type_map.end(),
+      "Unsupported ONNXIFI data type: ",
+      onnxifi_type);
   return it->second;
 }
 
@@ -54,9 +76,23 @@ void SetOutputTensorDescriptorTypeAndBuffer(
     Tensor* cpu_tensor,
     onnxTensorDescriptorV1* desc) {
   desc->dataType = onnxifi_type;
-  desc->buffer = reinterpret_cast<onnxPointer>(cpu_tensor->raw_mutable_data(OnnixfiTypeToDataType(onnxifi_type)));
+  desc->buffer = reinterpret_cast<onnxPointer>(
+      cpu_tensor->raw_mutable_data(OnnxifiTypeToDataType(onnxifi_type)));
 }
 
+void SetOutputTensorDescriptorTypeAndBuffer(
+    uint64_t onnxifi_type,
+    int8::Int8TensorCPU* cpu_int8tensor,
+    onnxTensorDescriptorV1* desc) {
+  desc->dataType = onnxifi_type;
+  Tensor* cpu_tensor = &(cpu_int8tensor->t);
+
+  desc->buffer = reinterpret_cast<onnxPointer>(
+      cpu_tensor->raw_mutable_data(OnnxifiTypeToDataType(onnxifi_type)));
+  desc->is_quantized = true;
+  desc->scale = cpu_int8tensor->scale;
+  desc->bias = cpu_int8tensor->zero_point;
+}
 void BlobToTensorDescriptor(
     const std::string& name,
     Workspace* ws,
@@ -64,26 +100,39 @@ void BlobToTensorDescriptor(
     std::vector<std::vector<uint64_t>>* shapes) {
   const Blob* blob = ws->GetBlob(name);
   CAFFE_ENFORCE(blob, "Blob ", name, " doesn't exist");
-
+  const bool is_int8tensor =
+      blob->meta().id() == TypeMeta::Id<int8::Int8TensorCPU>();
   // Memory type
-  // We only allow weights to be CPU tensor for now
+  // We only allow weights to be CPU tensor or int8tensor for now
   CAFFE_ENFORCE(
-      BlobIsTensorType(*blob, CPU),
+      (BlobIsTensorType(*blob, CPU) || BlobIsInt8TensorCPUType(*blob)),
       "Initialization blob ",
       name,
-      " needs to be TensorCPU");
+      " needs to be TensorCPU or Int8TensorCPU");
   desc->tag = ONNXIFI_TAG_TENSOR_DESCRIPTOR_V1;
   desc->memoryType = ONNXIFI_MEMORY_TYPE_CPU;
 
-  // Data type
-  const auto& cpu_tensor = blob->template Get<TensorCPU>();
-  SetInputTensorDescriptorTypeAndBuffer(cpu_tensor, desc);
-
-  // Set dims
-  const auto shape = cpu_tensor.sizes();
-  desc->dimensions = shape.size();
-  shapes->emplace_back(shape.cbegin(), shape.cend());
-  desc->shape = shapes->back().data();
+  if (is_int8tensor) {
+    // Data type
+    const auto& cpu_int8tensor = blob->template Get<int8::Int8TensorCPU>();
+    const auto& cpu_tensor = cpu_int8tensor.t;
+    SetInputTensorDescriptorTypeAndBuffer(cpu_int8tensor, desc);
+    // Set dims
+    const auto shape = cpu_tensor.sizes();
+    desc->dimensions = shape.size();
+    shapes->emplace_back(shape.cbegin(), shape.cend());
+    desc->shape = shapes->back().data();
+  } else {
+    // Data type
+    const auto& cpu_tensor = blob->template Get<TensorCPU>();
+    SetInputTensorDescriptorTypeAndBuffer(cpu_tensor, desc);
+    // Set dims
+    const auto shape = cpu_tensor.sizes();
+    desc->dimensions = shape.size();
+    shapes->emplace_back(shape.cbegin(), shape.cend());
+    desc->shape = shapes->back().data();
+    desc->is_quantized = 0;
+  }
 }
 } // namespace
 
@@ -148,7 +197,10 @@ bool OnnxifiOp<float, CPUContext>::RunOnDevice() {
     tensor_descriptor.shape = output_shapes_.back().data();
     std::vector<int64_t> tensor_dims_int64;
     std::copy(tensor_dims.cbegin(), tensor_dims.cend(), std::back_inserter(tensor_dims_int64));
-    auto* output_tensor = Output(i, tensor_dims_int64, at::dtype(OnnixfiTypeToDataType(type)).device(CPU));
+    auto* output_tensor = Output(
+        i,
+        tensor_dims_int64,
+        at::dtype(OnnxifiTypeToDataType(type)).device(CPU));
     SetOutputTensorDescriptorTypeAndBuffer(
         type, output_tensor, &tensor_descriptor);
   }
index 61c56ba..32bdd15 100644 (file)
@@ -21,6 +21,7 @@ void Int8Quantize(
     const int32_t Y_offset) {
   const float inv_scale = 1.0f / Y_scale;
   uint32_t i = 0;
+
 #ifdef INT8_NEON_SIMD
   const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
   // magic float and magic int to take care of rounding
index 96212c1..7c33fa6 100644 (file)
@@ -45,6 +45,27 @@ TensorProto BackendTransformerBase::wrapShapeInfoIntoTensorProto(
   return t;
 }
 
+QTensorProto BackendTransformerBase::wrapShapeInfoIntoQTensorProto(
+    const std::string& name,
+    const ShapeInfo& shape_info) const {
+  QTensorProto t;
+  CAFFE_ENFORCE(
+      shape_info.is_quantized == true,
+      "Only quantized shapeinfo can be extracted into QTensor!");
+  t.set_name(name);
+  t.set_data_type(shape_info.shape.data_type());
+  t.set_scale(shape_info.q_info.scale);
+  t.set_bias(shape_info.q_info.offset);
+  // precision and is_signed is not used in onnxifi workflow, but it is required
+  // field
+  t.set_precision(0);
+  t.set_is_signed(0);
+  for (const auto i : shape_info.shape.dims()) {
+    t.add_dims(i);
+  }
+  return t;
+}
+
 std::unordered_map<std::string, TensorShape>
 BackendTransformerBase::ssaRewriteAndMapNames(
     Workspace* ws,
@@ -106,7 +127,11 @@ ShapeInfoMap BackendTransformerBase::inferShapes(
     shape_map.emplace(
         std::piecewise_construct,
         std::forward_as_tuple(kv.first),
-        std::forward_as_tuple(kv.second.dim_type, kv.second.shape));
+        std::forward_as_tuple(
+            kv.second.dim_type,
+            kv.second.shape,
+            kv.second.is_quantized,
+            kv.second.q_info));
   }
   return shape_map;
 }
index a7281a8..e845445 100644 (file)
@@ -53,6 +53,11 @@ class BackendTransformerBase {
       const std::string& name,
       const ShapeInfo& shape_info) const;
 
+  // Wrap Quantized TensorShape into QTensorProto
+  QTensorProto wrapShapeInfoIntoQTensorProto(
+      const std::string& name,
+      const ShapeInfo& shape_info) const;
+
   // Do bound shape inference and collect shape infos
   ShapeInfoMap inferShapes(
       Workspace* ws,
index 717f843..b7c20d5 100644 (file)
@@ -87,10 +87,16 @@ TensorShape& BoundShapeInferencer::CheckAndSetTensorShapeAndType(
     const std::string& name,
     ShapeInfo::DimType t,
     std::vector<int64_t> bound_dims,
-    TensorProto::DataType type) {
+    TensorProto::DataType type,
+    bool is_quantized) {
   auto rt = shape_info_.emplace(name, ShapeInfo());
   ShapeInfo& shape_info = rt.first->second;
   TensorShape& shape = shape_info.shape;
+  if (is_quantized) {
+    shape_info.is_quantized = true;
+    shape_info.q_info.scale = 1;
+    shape_info.q_info.offset = 0;
+  }
   if (!rt.second) {
     // Check shape consistency
     CAFFE_ENFORCE_EQ(shape.dims_size(), bound_dims.size());
@@ -154,12 +160,14 @@ void BoundShapeInferencer::InferLengthsRangeFill(const OperatorDef& op) {
       op.input(0),
       ShapeInfo::DimType::BATCH,
       {spec_.max_batch_size},
-      TensorProto_DataType_INT32);
+      TensorProto_DataType_INT32,
+      false);
   CheckAndSetTensorShapeAndType(
       op.output(0),
       ShapeInfo::DimType::SEQ,
       {spec_.max_seq_size},
-      TensorProto_DataType_INT32);
+      TensorProto_DataType_INT32,
+      false);
   current_dim_type_ = ShapeInfo::DimType::SEQ;
 }
 
@@ -191,7 +199,8 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
         op.input(weight),
         ShapeInfo::DimType::SEQ,
         {spec_.max_seq_size},
-        TensorProto_DataType_FLOAT);
+        TensorProto_DataType_FLOAT,
+        false);
   }
 
   // Bound inputs
@@ -199,12 +208,14 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
       op.input(1 + weight),
       ShapeInfo::DimType::SEQ,
       {spec_.max_seq_size},
-      TensorProto_DataType_INT64);
+      TensorProto_DataType_INT64,
+      false);
   CheckAndSetTensorShapeAndType(
       op.input(2 + weight),
       ShapeInfo::DimType::BATCH,
       {spec_.max_batch_size},
-      TensorProto_DataType_INT32);
+      TensorProto_DataType_INT32,
+      false);
 
   // Infer output
   CAFFE_ENFORCE_EQ(it->second.shape.dims_size(), 2);
@@ -221,7 +232,8 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
       op.output(0),
       ShapeInfo::DimType::BATCH,
       {spec_.max_batch_size, output_dim1},
-      TensorProto_DataType_FLOAT);
+      TensorProto_DataType_FLOAT,
+      false);
 }
 
 void BoundShapeInferencer::InferShape(const OperatorDef& op) {
@@ -336,7 +348,11 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
     current_dim_type_ = ShapeInfo::DimType::BATCH;
     current_max_batch_size_ = spec_.max_batch_size;
     CheckAndSetTensorShapeAndType(
-        op.input(0), ShapeInfo::DimType::BATCH, dims, w_shape.data_type());
+        op.input(0),
+        ShapeInfo::DimType::BATCH,
+        dims,
+        w_shape.data_type(),
+        false);
   } else {
     ShapeInfo& x_shape_info = x_it->second;
     if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) {
@@ -355,7 +371,8 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
       op.output(0),
       ShapeInfo::DimType::BATCH,
       ConvertToVec(output_shapes[0].dims()),
-      output_shapes[0].data_type());
+      output_shapes[0].data_type(),
+      false);
 }
 
 void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
@@ -378,7 +395,35 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
   std::vector<TensorShape> output_shapes;
     output_shapes = schema->InferTensor(op, input_shapes);
   int i = 0;
+  bool is_quantized =
+      !(op.type().compare(0, 4, "Int8")) && (op.type() != "Int8Dequantize");
+  TensorProto::DataType infered_data_type = TensorProto::UNDEFINED;
+  if (is_quantized) {
+    const static std::map<string, int> type_info_from_input = {
+        {"Int8Quantize", -1}, // Force this op's output to be uint8
+        {"Int8ConvRelu", 1},
+        {"Int8MaxPool", 0},
+        {"Int8AveragePool", 0},
+        {"Int8FC", 1},
+        {"Int8Conv", 1},
+        {"Int8SumRelu", 0}};
+    CAFFE_ENFORCE(
+        type_info_from_input.find(op.type()) != type_info_from_input.end(),
+        "Undefined quantized output data type, add it into type_info_from_input");
+    int target = type_info_from_input.find(op.type())->second;
+    if (target == -1) {
+      infered_data_type = TensorProto::UINT8;
+    } else {
+      CAFFE_ENFORCE(target < input_shapes.size());
+      infered_data_type = input_shapes[target].data_type();
+    }
+  } else if (op.type() == "Int8Dequantize") {
+    infered_data_type = TensorProto::FLOAT;
+  }
   for (const auto& shape : output_shapes) {
+    if (infered_data_type == TensorProto::UNDEFINED) {
+      infered_data_type = shape.data_type();
+    }
     if (shape.unknown_shape()) {
       ++i;
       continue;
@@ -387,7 +432,8 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
         op.output(i++),
         current_dim_type_,
         ConvertToVec(shape.dims()),
-        shape.data_type());
+        infered_data_type,
+        is_quantized);
   }
   } catch (const caffe2::EnforceNotMet& e) {
     LOG(ERROR) << "Enforce not met while inferring shapes for " << op.type()
index 333eca3..ee1c670 100644 (file)
@@ -61,7 +61,8 @@ class CAFFE2_API BoundShapeInferencer {
       const std::string& name,
       ShapeInfo::DimType t,
       std::vector<int64_t> bound_dims,
-      TensorProto::DataType type);
+      TensorProto::DataType type,
+      bool is_quantized);
 
   void InferGivenTensorFill(const OperatorDef& op);
   void InferSparseLengthsSum(const OperatorDef& op);
index 702c294..797c3f4 100644 (file)
@@ -486,7 +486,9 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
       &initialization_list,
       &total_inputs_vec);
   auto* shape_arg = onnxifi_net.add_arg();
+  auto* qshape_arg = onnxifi_net.add_arg();
   shape_arg->set_name("input_shape_info");
+  qshape_arg->set_name("input_qshape_info");
   onnxifi_net.clear_external_input();
   for (const auto& i : total_inputs_vec) {
     auto input = i;
@@ -495,8 +497,14 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
       input = it->second;
     }
     onnxifi_net.add_external_input(input);
-    shape_arg->mutable_tensors()->Add()->CopyFrom(
-        wrapShapeInfoIntoTensorProto(input, shape_hints.at(i)));
+    auto info = shape_hints.at(i);
+    if (!info.is_quantized) {
+      shape_arg->mutable_tensors()->Add()->CopyFrom(
+          wrapShapeInfoIntoTensorProto(input, shape_hints.at(i)));
+    } else {
+      qshape_arg->mutable_qtensors()->Add()->CopyFrom(
+          wrapShapeInfoIntoQTensorProto(input, shape_hints.at(i)));
+    }
   }
 
   // Compute output shape hints
@@ -812,24 +820,39 @@ bool OnnxifiTransformer::supportOpC2(
 
     // Encode the input/output shapes to an argument
     auto* shape_arg = net.add_arg();
+    auto* qshape_arg = net.add_arg();
     shape_arg->set_name("input_shape_info");
+    qshape_arg->set_name("input_qshape_info");
     for (const auto& i : op.input()) {
       const auto it = shape_hints.find(i);
       if (it == shape_hints.end()) {
         return false;
       }
-      shape_arg->mutable_tensors()->Add()->CopyFrom(
-          wrapShapeInfoIntoTensorProto(i, it->second));
+      if ((it->second).is_quantized == false) {
+        shape_arg->mutable_tensors()->Add()->CopyFrom(
+            wrapShapeInfoIntoTensorProto(i, it->second));
+      } else {
+        qshape_arg->mutable_qtensors()->Add()->CopyFrom(
+            wrapShapeInfoIntoQTensorProto(i, it->second));
+      }
     }
+
+    qshape_arg = net.add_arg();
     shape_arg = net.add_arg();
     shape_arg->set_name("output_shape_info");
+    qshape_arg->set_name("output_qshape_info");
     for (const auto& i : op.output()) {
       const auto it = shape_hints.find(i);
       if (it == shape_hints.end()) {
         return false;
       }
-      shape_arg->mutable_tensors()->Add()->CopyFrom(
-          wrapShapeInfoIntoTensorProto(i, it->second));
+      if ((it->second).is_quantized == false) {
+        shape_arg->mutable_tensors()->Add()->CopyFrom(
+            wrapShapeInfoIntoTensorProto(i, it->second));
+      } else {
+        qshape_arg->mutable_qtensors()->Add()->CopyFrom(
+            wrapShapeInfoIntoQTensorProto(i, it->second));
+      }
     }
 
     std::string c2_model_str;
@@ -1002,11 +1025,19 @@ void OnnxifiTransformer::transform(
   if (opts_.debug) {
     NetDef shape_net(*pred_net);
     auto* shape_arg = shape_net.add_arg();
+    auto* qshape_arg = shape_net.add_arg();
     shape_arg->set_name("shape_info");
+    qshape_arg->set_name("qshape_info");
     for (const auto& kv : shape_hints) {
-      auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second);
-      t.add_int32_data(static_cast<int32_t>(kv.second.dim_type));
-      shape_arg->mutable_tensors()->Add()->CopyFrom(t);
+      if (!kv.second.is_quantized) {
+        auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second);
+        t.add_int32_data(static_cast<int32_t>(kv.second.dim_type));
+        shape_arg->mutable_tensors()->Add()->CopyFrom(t);
+      } else {
+        auto t = wrapShapeInfoIntoQTensorProto(kv.first, kv.second);
+        t.add_data(static_cast<int32_t>(kv.second.dim_type));
+        qshape_arg->mutable_qtensors()->Add()->CopyFrom(t);
+      }
     }
     WriteProtoToTextFile(shape_net, "debug_ssa_net.pb_txt");
   }
index 2f09bf7..ebd7b68 100644 (file)
@@ -1,4 +1,5 @@
 #include "caffe2/opt/shape_info.h"
+#include "caffe2/core/tensor_int8.h"
 
 #include "caffe2/core/operator.h"
 
@@ -10,6 +11,11 @@ ShapeInfo getShapeInfoFromBlob(const Blob* blob) {
   shape_info.dim_type = shape_info.shape.unknown_shape()
       ? ShapeInfo::DimType::UNKNOWN
       : ShapeInfo::DimType::CONSTANT;
+  if (blob->meta().id() == TypeMeta::Id<int8::Int8TensorCPU>()) {
+    shape_info.is_quantized = true;
+    LoadInt8TensorInfoOfBlob(
+        &shape_info.q_info.scale, &shape_info.q_info.offset, blob);
+  }
   return shape_info;
 }
 
index 24672a2..06d4282 100644 (file)
@@ -4,15 +4,35 @@
 
 namespace caffe2 {
 
+struct CAFFE2_API QShapeInfo {
+  QShapeInfo(float o = 0, float s = 1) : offset(o), scale(s) {}
+  float offset;
+  float scale;
+  // TODO zrphercule
+  // Add multi offset/scale support here
+};
+
 struct CAFFE2_API ShapeInfo {
   enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 };
-  ShapeInfo() {}
-  ShapeInfo(DimType t, TensorShape&& s) : dim_type(t), shape(std::move(s)) {}
-  ShapeInfo(DimType t, const TensorShape& s) : dim_type(t), shape(s) {}
+  ShapeInfo(bool q = false) : is_quantized(q) {}
+  ShapeInfo(DimType t, TensorShape&& s, bool q = false)
+      : dim_type(t), shape(std::move(s)), is_quantized(q) {}
+  ShapeInfo(DimType t, const TensorShape& s, bool q = false)
+      : dim_type(t), shape(s), is_quantized(q) {}
+
+  ShapeInfo(bool q, const QShapeInfo& info) : is_quantized(q), q_info(info) {}
+  ShapeInfo(DimType t, TensorShape&& s, bool q, const QShapeInfo& info)
+      : dim_type(t), shape(std::move(s)), is_quantized(q), q_info(info) {}
+  ShapeInfo(DimType t, const TensorShape& s, bool q, const QShapeInfo& info)
+      : dim_type(t), shape(s), is_quantized(q), q_info(info) {}
 
   // type of the shape according its first dim
   DimType dim_type{DimType::UNKNOWN};
   TensorShape shape;
+
+  // quantization related information
+  bool is_quantized;
+  QShapeInfo q_info;
 };
 
 using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;