Add tensor quantization info to python wrapper
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 22 Mar 2018 02:12:18 +0000 (19:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 02:14:52 +0000 (19:14 -0700)
PiperOrigin-RevId: 190005998

tensorflow/contrib/lite/python/interpreter.py
tensorflow/contrib/lite/python/interpreter_test.py
tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h

index accdd04..b863800 100644 (file)
@@ -71,6 +71,7 @@ class Interpreter(object):
     tensor_name = self._interpreter.TensorName(tensor_index)
     tensor_size = self._interpreter.TensorSize(tensor_index)
     tensor_type = self._interpreter.TensorType(tensor_index)
+    tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
 
     if not tensor_name or not tensor_type:
       raise ValueError('Could not get tensor details')
@@ -80,6 +81,7 @@ class Interpreter(object):
         'index': tensor_index,
         'shape': tensor_size,
         'dtype': tensor_type,
+        'quantization': tensor_quantization,
     }
 
     return details
index e85390c..bf12441 100644 (file)
@@ -39,12 +39,14 @@ class InterpreterTest(test_util.TensorFlowTestCase):
     self.assertEqual('input', input_details[0]['name'])
     self.assertEqual(np.float32, input_details[0]['dtype'])
     self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+    self.assertEqual((0.0, 0), input_details[0]['quantization'])
 
     output_details = interpreter.get_output_details()
     self.assertEqual(1, len(output_details))
     self.assertEqual('output', output_details[0]['name'])
     self.assertEqual(np.float32, output_details[0]['dtype'])
     self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+    self.assertEqual((0.0, 0), output_details[0]['quantization'])
 
     test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
     expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
@@ -67,12 +69,14 @@ class InterpreterTest(test_util.TensorFlowTestCase):
       self.assertEqual('input', input_details[0]['name'])
       self.assertEqual(np.uint8, input_details[0]['dtype'])
       self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+      self.assertEqual((1.0, 0), input_details[0]['quantization'])
 
       output_details = interpreter.get_output_details()
       self.assertEqual(1, len(output_details))
       self.assertEqual('output', output_details[0]['name'])
       self.assertEqual(np.uint8, output_details[0]['dtype'])
       self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+      self.assertEqual((1.0, 0), output_details[0]['quantization'])
 
       test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
       expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
index 14e1190..35ad226 100644 (file)
@@ -109,6 +109,13 @@ PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
   return PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
 }
 
+PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
+  PyObject* result = PyTuple_New(2);
+  PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
+  PyTuple_SET_ITEM(result, 1, PyInt_FromLong(param.zero_point));
+  return result;
+}
+
 }  // namespace
 
 InterpreterWrapper::InterpreterWrapper(
@@ -214,6 +221,16 @@ PyObject* InterpreterWrapper::TensorSize(int i) const {
   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
 }
 
+PyObject* InterpreterWrapper::TensorQuantization(int i) const {
+  if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
+    Py_INCREF(Py_None);
+    return Py_None;
+  }
+
+  const TfLiteTensor* tensor = interpreter_->tensor(i);
+  return PyTupleFromQuantizationParam(tensor->params);
+}
+
 bool InterpreterWrapper::SetTensor(int i, PyObject* value) {
   if (!interpreter_) {
     LOG(ERROR) << "Invalid interpreter.";
index 63bdb30..0972c57 100644 (file)
@@ -54,6 +54,7 @@ class InterpreterWrapper {
   std::string TensorName(int i) const;
   PyObject* TensorType(int i) const;
   PyObject* TensorSize(int i) const;
+  PyObject* TensorQuantization(int i) const;
   bool SetTensor(int i, PyObject* value);
   PyObject* GetTensor(int i) const;