#ifdef LLVM_HAVE_TF_API
#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/JSON.h"
#include <memory>
#include <vector>
int typeIndex() const { return TypeIndex; }
const std::vector<int64_t> &shape() const { return Shape; }
+ bool operator==(const TensorSpec &Other) const {
+ return Name == Other.Name && Port == Other.Port &&
+ TypeIndex == Other.TypeIndex && Shape == Other.Shape;
+ }
+
+ bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
+
private:
TensorSpec(const std::string &Name, int Port, int TypeIndex,
const std::vector<int64_t> &Shape)
std::vector<int64_t> Shape;
};
+Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
+ const json::Value &Value);
+
class TFModelEvaluator final {
public:
/// The result of a model evaluation. Handles the lifetime of the output
std::unique_ptr<TFModelEvaluatorImpl> Impl;
};
-template <> int TensorSpec::getDataType<float>();
-template <> int TensorSpec::getDataType<double>();
-template <> int TensorSpec::getDataType<int8_t>();
-template <> int TensorSpec::getDataType<uint8_t>();
-template <> int TensorSpec::getDataType<int16_t>();
-template <> int TensorSpec::getDataType<uint16_t>();
-template <> int TensorSpec::getDataType<int32_t>();
-template <> int TensorSpec::getDataType<uint32_t>();
-template <> int TensorSpec::getDataType<int64_t>();
-template <> int TensorSpec::getDataType<uint64_t>();
-
+/// List of supported types, as a triple:
+/// C++ type
+/// short name (for strings, for instance)
+/// capitalized short name (for enums, for instance)
+#define TFUTILS_SUPPORTED_TYPES(M) \
+ M(float, float, FLOAT) \
+ M(double, double, DOUBLE) \
+ M(int8_t, int8, INT8) \
+ M(uint8_t, uint8, UINT8) \
+ M(int16_t, int16, INT16) \
+ M(uint16_t, uint16, UINT16) \
+ M(int32_t, int32, INT32) \
+ M(uint32_t, uint32, UINT32) \
+ M(int64_t, int64, INT64) \
+ M(uint64_t, uint64, UINT64)
+
+#define TFUTILS_GETDATATYPE_DEF(T, S, C) \
+ template <> int TensorSpec::getDataType<T>();
+
+TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF)
+
+#undef TFUTILS_GETDATATYPE_DEF
} // namespace llvm
#endif // LLVM_HAVE_TF_API
#include "llvm/Config/config.h"
#if defined(LLVM_HAVE_TF_API)
-#include "llvm/Analysis/Utils/TFUtils.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/Utils/TFUtils.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/JSON.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/raw_ostream.h"
std::vector<TF_Tensor *> Output;
};
+Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
+ const json::Value &Value) {
+ auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
+ std::string S;
+ llvm::raw_string_ostream OS(S);
+ OS << Value;
+ Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
+ return None;
+ };
+ json::ObjectMapper Mapper(Value);
+ if (!Mapper)
+ return EmitError("Value is not a dict");
+
+ std::string TensorName;
+ int TensorPort = -1;
+ std::string TensorType;
+ std::vector<int64_t> TensorShape;
+
+ if (!Mapper.map<std::string>("name", TensorName))
+ return EmitError("'name' property not present or not a string");
+ if (!Mapper.map<std::string>("type", TensorType))
+ return EmitError("'type' property not present or not a string");
+ if (!Mapper.map<int>("port", TensorPort))
+ return EmitError("'port' property not present or not an int");
+ if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
+ return EmitError("'shape' property not present or not an int array");
+
+#define PARSE_TYPE(T, S, E) \
+ if (TensorType == #S) \
+ return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
+ TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
+#undef PARSE_TYPE
+ return None;
+}
+
class TFModelEvaluatorImpl {
public:
TFModelEvaluatorImpl(StringRef SavedModelPath,
return TF_TensorData(Impl->getOutput()[Index]);
}
-template <> int TensorSpec::getDataType<float>() { return TF_FLOAT; }
-
-template <> int TensorSpec::getDataType<double>() { return TF_DOUBLE; }
-
-template <> int TensorSpec::getDataType<int8_t>() { return TF_INT8; }
-
-template <> int TensorSpec::getDataType<uint8_t>() { return TF_UINT8; }
-
-template <> int TensorSpec::getDataType<int16_t>() { return TF_INT16; }
-
-template <> int TensorSpec::getDataType<uint16_t>() { return TF_UINT16; }
-
-template <> int TensorSpec::getDataType<int32_t>() { return TF_INT32; }
-
-template <> int TensorSpec::getDataType<uint32_t>() { return TF_UINT32; }
+#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \
+ template <> int TensorSpec::getDataType<T>() { return TF_##E; }
-template <> int TensorSpec::getDataType<int64_t>() { return TF_INT64; }
+TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
-template <> int TensorSpec::getDataType<uint64_t>() { return TF_UINT64; }
+#undef TFUTILS_GETDATATYPE_IMPL
TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
TFModelEvaluator::~TFModelEvaluator() {}
EXPECT_FALSE(ER.hasValue());
EXPECT_FALSE(Evaluator.isValid());
}
+
+TEST(TFUtilsTest, JSONParsing) {
+ auto Value = json::parse(
+ R"({"name": "tensor_name",
+ "port": 2,
+ "type": "int32",
+ "shape":[1,4]
+ })");
+ EXPECT_TRUE(!!Value);
+ LLVMContext Ctx;
+ Optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, *Value);
+ EXPECT_TRUE(Spec.hasValue());
+ EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name", {1, 4}, 2));
+}
+
+TEST(TFUtilsTest, JSONParsingInvalidTensorType) {
+ auto Value = json::parse(
+ R"(
+ {"name": "tensor_name",
+ "port": 2,
+ "type": "no such type",
+ "shape":[1,4]
+ }
+ )");
+ EXPECT_TRUE(!!Value);
+ LLVMContext Ctx;
+ auto Spec = getTensorSpecFromJSON(Ctx, *Value);
+ EXPECT_FALSE(Spec.hasValue());
+}