size_t ElementCount = 0;
};
+/// Construct a TensorSpec from a JSON dictionary of the form:
+/// { "name": <string>,
+/// "port": <int>,
+/// "type": <string. Use LLVM's types, e.g. float, double, int64_t>,
+/// "shape": <array of ints> }
+/// For the "type" field, see the C++ primitive types used in
+/// TFUTILS_SUPPORTED_TYPES.
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
const json::Value &Value);
std::unique_ptr<TFModelEvaluatorImpl> Impl;
};
-/// List of supported types, as a triple:
-/// C++ type
-/// short name (for strings, for instance)
-/// capitalized short name (for enums, for instance)
+/// List of supported types, as a pair:
+/// - C++ type
+/// - enum name (implementation-specific)
#define TFUTILS_SUPPORTED_TYPES(M) \
- M(float, float, FLOAT) \
- M(double, double, DOUBLE) \
- M(int8_t, int8, INT8) \
- M(uint8_t, uint8, UINT8) \
- M(int16_t, int16, INT16) \
- M(uint16_t, uint16, UINT16) \
- M(int32_t, int32, INT32) \
- M(uint32_t, uint32, UINT32) \
- M(int64_t, int64, INT64) \
- M(uint64_t, uint64, UINT64)
-
-#define TFUTILS_GETDATATYPE_DEF(T, S, C) \
+ M(float, TF_FLOAT) \
+ M(double, TF_DOUBLE) \
+ M(int8_t, TF_INT8) \
+ M(uint8_t, TF_UINT8) \
+ M(int16_t, TF_INT16) \
+ M(uint16_t, TF_UINT16) \
+ M(int32_t, TF_INT32) \
+ M(uint32_t, TF_UINT32) \
+ M(int64_t, TF_INT64) \
+ M(uint64_t, TF_UINT64)
+
+#define TFUTILS_GETDATATYPE_DEF(T, E) \
template <> int TensorSpec::getDataType<T>();
TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF)
if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
return EmitError("'shape' property not present or not an int array");
-#define PARSE_TYPE(T, S, E) \
- if (TensorType == #S) \
+#define PARSE_TYPE(T, E) \
+ if (TensorType == #T) \
return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
#undef PARSE_TYPE
return TF_TensorData(Impl->getOutput()[Index]);
}
-#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \
- template <> int TensorSpec::getDataType<T>() { return TF_##E; }
+#define TFUTILS_GETDATATYPE_IMPL(T, E) \
+ template <> int TensorSpec::getDataType<T>() { return E; }
TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)