tensor_name = self._interpreter.TensorName(tensor_index)
tensor_size = self._interpreter.TensorSize(tensor_index)
tensor_type = self._interpreter.TensorType(tensor_index)
+ tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
if not tensor_name or not tensor_type:
raise ValueError('Could not get tensor details')
'index': tensor_index,
'shape': tensor_size,
'dtype': tensor_type,
+ 'quantization': tensor_quantization,
}
return details
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+ self.assertEqual((0.0, 0), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0.0, 0), output_details[0]['quantization'])
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+ self.assertEqual((1.0, 0), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.uint8, output_details[0]['dtype'])
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+ self.assertEqual((1.0, 0), output_details[0]['quantization'])
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
return PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
}
+PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
+ PyObject* result = PyTuple_New(2);
+ PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
+ PyTuple_SET_ITEM(result, 1, PyInt_FromLong(param.zero_point));
+ return result;
+}
+
} // namespace
InterpreterWrapper::InterpreterWrapper(
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
+PyObject* InterpreterWrapper::TensorQuantization(int i) const {
+ if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
+ Py_INCREF(Py_None);
+ return Py_None;
+ }
+
+ const TfLiteTensor* tensor = interpreter_->tensor(i);
+ return PyTupleFromQuantizationParam(tensor->params);
+}
+
bool InterpreterWrapper::SetTensor(int i, PyObject* value) {
if (!interpreter_) {
LOG(ERROR) << "Invalid interpreter.";