Modify the python interface to toco to provide arithmetic operations used by the...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 11 May 2018 22:03:34 +0000 (15:03 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 22:11:43 +0000 (15:11 -0700)
PiperOrigin-RevId: 196314416

tensorflow/contrib/lite/toco/model.h
tensorflow/contrib/lite/toco/python/toco.i
tensorflow/contrib/lite/toco/python/toco_python_api.cc
tensorflow/contrib/lite/toco/python/toco_python_api.h
tensorflow/contrib/lite/toco/toco_tooling.cc

index aefa9ac..d878ac5 100644 (file)
@@ -1829,6 +1829,8 @@ class Model {
   }
   const ArrayMap& GetArrayMap() const { return arrays; }
 
+  int64 ArithmeticOpsCount() const { return ops_count; }
+
   // Optional arrays are used for optional tensors,
   // these tensors do not have data, but with reserved names as op inputs.
   std::set<string> optional_arrays;
@@ -1845,6 +1847,8 @@ class Model {
   std::size_t transient_data_size = 0;
   // For code-generation only: required alignment of the transient_data buffer
   std::size_t transient_data_alignment = 0;
+  // Arithmatic operations performed in the model.
+  int64 ops_count = 0;
 
  private:
   // The associative array mapping names to Array's.
index 3787cba..0d2fbdd 100644 (file)
@@ -24,9 +24,12 @@ namespace toco {
 // Convert a model represented in `input_contents`. `model_flags_proto`
 // describes model parameters. `toco_flags_proto` describes conversion
 // parameters (see relevant .protos for more information). Returns a string
-// representing the contents of the converted model.
+// representing the contents of the converted model. When extended_return
+// flag is set to true returns a dictionary that contains string representation
+// of the converted model and some statitics like arithmetic ops count.
 PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
                         PyObject* toco_flags_proto_txt_raw,
-                        PyObject* input_contents_txt_raw);
+                        PyObject* input_contents_txt_raw,
+                        bool extended_return = false);
 
 } // namespace toco
\ No newline at end of file
index 153c117..5b1db85 100644 (file)
@@ -37,7 +37,7 @@ namespace toco {
 // sure we input and output bytes rather than unicode strings for Python3.
 PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
                       PyObject* toco_flags_proto_txt_raw,
-                      PyObject* input_contents_txt_raw) {
+                      PyObject* input_contents_txt_raw, bool extended_return) {
   // Use Python C API to validate and convert arguments. In py3 (bytes),
   // in py2 (str).
   auto ConvertArg = [&](PyObject* obj, bool* error) {
@@ -78,6 +78,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
   Export(toco_flags, *model, toco_flags.allow_custom_ops(),
          &output_file_contents_txt);
 
+  if (extended_return) {
+    PyObject* dict = PyDict_New();
+    PyDict_SetItemString(
+        dict, "flatbuffer",
+        TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(),
+                                  output_file_contents_txt.size()));
+    PyDict_SetItemString(dict, "arithmetic_ops",
+                         PyLong_FromLong(model->ArithmeticOpsCount()));
+    return dict;
+  }
   // Convert arguments back to byte (py3) or str (py2)
   return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(),
                                    output_file_contents_txt.size());
index dc37835..9af38e9 100644 (file)
@@ -23,10 +23,13 @@ namespace toco {
 // Convert a model represented in `input_contents`. `model_flags_proto`
 // describes model parameters. `toco_flags_proto` describes conversion
 // parameters (see relevant .protos for more information). Returns a string
-// representing the contents of the converted model.
+// representing the contents of the converted model. When extended_return
+// flag is set to true returns a dictionary that contains string representation
+// of the converted model and some statitics like arithmetic ops count.
 PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
                       PyObject* toco_flags_proto_txt_raw,
-                      PyObject* input_contents_txt_raw);
+                      PyObject* input_contents_txt_raw,
+                      bool extended_return = false);
 
 }  // namespace toco
 
index d894916..b5531ca 100644 (file)
@@ -373,6 +373,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
     LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count
               << " billion (note that a multiply-add is counted as 2 ops).";
   }
+  model->ops_count = ops_count;
 }
 
 void Export(const TocoFlags& toco_flags, const Model& model,