[MLInliner] Simplify TFUTILS_SUPPORTED_TYPES
authorMircea Trofin <mtrofin@google.com>
Tue, 25 Aug 2020 16:58:49 +0000 (09:58 -0700)
committerMircea Trofin <mtrofin@google.com>
Tue, 25 Aug 2020 21:19:39 +0000 (14:19 -0700)
We only need the C++ type and the corresponding TF Enum. The other
parameter was used for the output spec json file, but we can just
standardize on the C++ type name there.

Differential Revision: https://reviews.llvm.org/D86549

llvm/include/llvm/Analysis/Utils/TFUtils.h
llvm/lib/Analysis/TFUtils.cpp
llvm/lib/Analysis/models/inliner/output_spec.json
llvm/test/Transforms/Inline/ML/Inputs/test_output_spec.json
llvm/unittests/Analysis/TFUtilsTest.cpp

index a6cfb16..bba275b 100644 (file)
@@ -90,6 +90,13 @@ private:
   size_t ElementCount = 0;
 };
 
+/// Construct a TensorSpec from a JSON dictionary of the form:
+/// { "name": <string>,
+///   "port": <int>,
+///   "type": <string. Use LLVM's types, e.g. float, double, int64_t>,
+///   "shape": <array of ints> }
+/// For the "type" field, see the C++ primitive types used in
+/// TFUTILS_SUPPORTED_TYPES.
 Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
                                            const json::Value &Value);
 
@@ -155,23 +162,22 @@ private:
   std::unique_ptr<TFModelEvaluatorImpl> Impl;
 };
 
-/// List of supported types, as a triple:
-/// C++ type
-/// short name (for strings, for instance)
-/// capitalized short name (for enums, for instance)
+/// List of supported types, as a pair:
+/// - C++ type
+/// - enum name (implementation-specific)
 #define TFUTILS_SUPPORTED_TYPES(M)                                             \
-  M(float, float, FLOAT)                                                       \
-  M(double, double, DOUBLE)                                                    \
-  M(int8_t, int8, INT8)                                                        \
-  M(uint8_t, uint8, UINT8)                                                     \
-  M(int16_t, int16, INT16)                                                     \
-  M(uint16_t, uint16, UINT16)                                                  \
-  M(int32_t, int32, INT32)                                                     \
-  M(uint32_t, uint32, UINT32)                                                  \
-  M(int64_t, int64, INT64)                                                     \
-  M(uint64_t, uint64, UINT64)
-
-#define TFUTILS_GETDATATYPE_DEF(T, S, C)                                       \
+  M(float, TF_FLOAT)                                                           \
+  M(double, TF_DOUBLE)                                                         \
+  M(int8_t, TF_INT8)                                                           \
+  M(uint8_t, TF_UINT8)                                                         \
+  M(int16_t, TF_INT16)                                                         \
+  M(uint16_t, TF_UINT16)                                                       \
+  M(int32_t, TF_INT32)                                                         \
+  M(uint32_t, TF_UINT32)                                                       \
+  M(int64_t, TF_INT64)                                                         \
+  M(uint64_t, TF_UINT64)
+
+#define TFUTILS_GETDATATYPE_DEF(T, E)                                          \
   template <> int TensorSpec::getDataType<T>();
 
 TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF)
index 99b6330..648a3a4 100644 (file)
@@ -122,8 +122,8 @@ Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
   if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
     return EmitError("'shape' property not present or not an int array");
 
-#define PARSE_TYPE(T, S, E)                                                    \
-  if (TensorType == #S)                                                        \
+#define PARSE_TYPE(T, E)                                                       \
+  if (TensorType == #T)                                                        \
     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
   TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
 #undef PARSE_TYPE
@@ -307,8 +307,8 @@ TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
   return TF_TensorData(Impl->getOutput()[Index]);
 }
 
-#define TFUTILS_GETDATATYPE_IMPL(T, S, E)                                      \
-  template <> int TensorSpec::getDataType<T>() { return TF_##E; }
+#define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
+  template <> int TensorSpec::getDataType<T>() { return E; }
 
 TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
 
index d9e2060..5f9d13d 100644 (file)
@@ -4,7 +4,7 @@
         "tensor_spec": {
             "name": "StatefulPartitionedCall",
             "port": 0,
-            "type": "int64",
+            "type": "int64_t",
             "shape": [
                 1
             ]
index bd6a19c..2a70e3a 100644 (file)
@@ -4,7 +4,7 @@
         "tensor_spec": {
             "name": "StatefulPartitionedCall",
             "port": 0,
-            "type": "int64",
+            "type": "int64_t",
             "shape": [
                 1
             ]
@@ -15,7 +15,7 @@
         "tensor_spec": {
             "name": "StatefulPartitionedCall",
             "port": 0,
-            "type": "int64",
+            "type": "int64_t",
             "shape": [
                 1
             ]
index 5d8425d..19ca1f2 100644 (file)
@@ -103,7 +103,7 @@ TEST(TFUtilsTest, JSONParsing) {
   auto Value = json::parse(
       R"({"name": "tensor_name", 
         "port": 2, 
-        "type": "int32", 
+        "type": "int32_t", 
         "shape":[1,4]
         })");
   EXPECT_TRUE(!!Value);