Simplify and enforce diagnostic ArrayDataType strings.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 20 Feb 2018 21:00:26 +0000 (13:00 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Feb 2018 21:13:46 +0000 (13:13 -0800)
PiperOrigin-RevId: 186348846

tensorflow/contrib/lite/toco/dump_graphviz.cc
tensorflow/contrib/lite/toco/tooling_util.cc
tensorflow/contrib/lite/toco/tooling_util.h

index c726eb6..2184e8f 100644 (file)
@@ -142,14 +142,8 @@ NodeProperties GetPropertiesForArray(const Model& model,
 
   // Append array shape to the label.
   auto& array = model.GetArray(array_name);
-
-  if (array.data_type == ArrayDataType::kFloat) {
-    AppendF(&node_properties.label, "\\nType: float");
-  } else if (array.data_type == ArrayDataType::kInt32) {
-    AppendF(&node_properties.label, "\\nType: int32");
-  } else if (array.data_type == ArrayDataType::kUint8) {
-    AppendF(&node_properties.label, "\\nType: uint8");
-  }
+  AppendF(&node_properties.label, "\\nType: %s",
+          ArrayDataTypeName(array.data_type));
 
   if (array.has_shape()) {
     auto& array_shape = array.shape();
index dcb409c..eec35b7 100644 (file)
@@ -62,6 +62,35 @@ string LogName(const Operator& op) {
   }
 }
 
+string ArrayDataTypeName(ArrayDataType data_type) {
+  switch (data_type) {
+    case ArrayDataType::kFloat:
+      return "Float";
+    case ArrayDataType::kInt8:
+      return "Int8";
+    case ArrayDataType::kUint8:
+      return "Uint8";
+    case ArrayDataType::kInt16:
+      return "Int16";
+    case ArrayDataType::kUint16:
+      return "Uint16";
+    case ArrayDataType::kInt32:
+      return "Int32";
+    case ArrayDataType::kUint32:
+      return "Uint32";
+    case ArrayDataType::kInt64:
+      return "Int64";
+    case ArrayDataType::kUint64:
+      return "Uint64";
+    case ArrayDataType::kString:
+      return "String";
+    case ArrayDataType::kNone:
+      return "None";
+    default:
+      LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type);
+  }
+}
+
 bool IsInputArray(const Model& model, const string& name) {
   for (const auto& input_array : model.flags.input_arrays()) {
     if (input_array.name() == name) {
@@ -363,48 +392,9 @@ void LogSummary(int log_level, const Model& model) {
 void LogArray(int log_level, const Model& model, const string& name) {
   const auto& array = model.GetArray(name);
   VLOG(log_level) << "Array: " << name;
-  switch (array.data_type) {
-    case ArrayDataType::kNone:
-      VLOG(log_level) << "  Data type:";
-      break;
-    case ArrayDataType::kFloat:
-      VLOG(log_level) << "  Data type: kFloat";
-      break;
-    case ArrayDataType::kInt32:
-      VLOG(log_level) << "  Data type: kInt32";
-      break;
-    case ArrayDataType::kUint8:
-      VLOG(log_level) << "  Data type: kUint8";
-      break;
-    case ArrayDataType::kString:
-      VLOG(log_level) << "  Data type: kString";
-      break;
-    default:
-      VLOG(log_level) << "  Data type: other (numerical value: "
-                      << static_cast<int>(array.data_type) << ")";
-      break;
-  }
-  switch (array.final_data_type) {
-    case ArrayDataType::kNone:
-      VLOG(log_level) << "  Final type:";
-      break;
-    case ArrayDataType::kFloat:
-      VLOG(log_level) << "  Final type: kFloat";
-      break;
-    case ArrayDataType::kInt32:
-      VLOG(log_level) << "  Final type: kInt32";
-      break;
-    case ArrayDataType::kUint8:
-      VLOG(log_level) << "  Final type: kUint8";
-      break;
-    case ArrayDataType::kString:
-      VLOG(log_level) << "  Final type: kString";
-      break;
-    default:
-      VLOG(log_level) << "  Final type: other (numerical value: "
-                      << static_cast<int>(array.data_type) << ")";
-      break;
-  }
+  VLOG(log_level) << "  Data type: " << ArrayDataTypeName(array.data_type);
+  VLOG(log_level) << "  Final type: "
+                  << ArrayDataTypeName(array.final_data_type);
   if (array.buffer) {
     VLOG(log_level) << "  Constant Buffer";
   }
index 0aaa0f6..11208ed 100644 (file)
@@ -54,6 +54,8 @@ absl::string_view FindLongestCommonPrefix(absl::string_view a,
                                           absl::string_view b);
 string LogName(const Operator& op);
 
+string ArrayDataTypeName(ArrayDataType data_type);
+
 bool IsInputArray(const Model& model, const string& name);
 bool IsArrayConsumed(const Model& model, const string& name);
 int CountTrueOutputs(const Model& model, const Operator& op);