[MLGO] Make TFLiteUtils throw an error if some features haven't been passed to the...
authorAiden Grossman <agrossman154@yahoo.com>
Wed, 7 Sep 2022 20:36:09 +0000 (20:36 +0000)
committerAiden Grossman <agrossman154@yahoo.com>
Sat, 10 Sep 2022 22:59:03 +0000 (22:59 +0000)
In the Tensorflow C lib utilities, an error gets thrown if some features
haven't gotten passed into the model (due to differences in ordering
which now don't exist with the transition to TFLite). However, this is
not currently the case when using TFLiteUtils. This patch makes some
minor changes to throw an error when not all inputs of the model have
been passed, which when not handled will result in a seg fault within
TFLite.

Reviewed By: mtrofin

Differential Revision: https://reviews.llvm.org/D133451

llvm/lib/Analysis/TFLiteUtils.cpp
llvm/unittests/Analysis/TFUtilsTest.cpp

index 9c43193..41c9847 100644 (file)
@@ -134,6 +134,7 @@ TFModelEvaluatorImpl::TFModelEvaluatorImpl(
   for (size_t I = 0; I < Interpreter->outputs().size(); ++I)
     OutputsMap[Interpreter->GetOutputName(I)] = I;
 
+  size_t NumberFeaturesPassed = 0;
   for (size_t I = 0; I < InputSpecs.size(); ++I) {
     auto &InputSpec = InputSpecs[I];
     auto MapI = InputsMap.find(InputSpec.name() + ":" +
@@ -147,6 +148,14 @@ TFModelEvaluatorImpl::TFModelEvaluatorImpl(
       return;
     std::memset(Input[I]->data.data, 0,
                 InputSpecs[I].getTotalTensorBufferSize());
+    ++NumberFeaturesPassed;
+  }
+
+  if (NumberFeaturesPassed < Interpreter->inputs().size()) {
+    // we haven't passed all the required features to the model, throw an error.
+    errs() << "Required feature(s) have not been passed to the ML model";
+    invalidate();
+    return;
   }
 
   for (size_t I = 0; I < OutputSpecsSize; ++I) {
index fe3b115..c604afd 100644 (file)
@@ -121,3 +121,12 @@ TEST(TFUtilsTest, UnsupportedFeature) {
   for (auto I = 0; I < 2 * 5; ++I)
     EXPECT_FLOAT_EQ(F[I], 3.14 + I);
 }
+
+TEST(TFUtilsTest, MissingFeature) {
+  std::vector<TensorSpec> InputSpecs{};
+  std::vector<TensorSpec> OutputSpecs{
+      TensorSpec::createSpec<float>("StatefulPartitionedCall", {1})};
+
+  TFModelEvaluator Evaluator(getModelPath(), InputSpecs, OutputSpecs);
+  EXPECT_FALSE(Evaluator.isValid());
+}