[XLA] Print allowed attributes when the user specifies an invalid attr.
authorJustin Lebar <jlebar@google.com>
Fri, 4 May 2018 22:40:07 +0000 (15:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 5 May 2018 15:28:14 +0000 (08:28 -0700)
PiperOrigin-RevId: 195482974

tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc

index 3a945fb..40dc073 100644 (file)
@@ -30,6 +30,7 @@ namespace {
 
 using tensorflow::StringPiece;
 using tensorflow::gtl::optional;
+using tensorflow::str_util::Join;
 using tensorflow::str_util::Split;
 using tensorflow::str_util::SplitAndParseAsInts;
 using tensorflow::strings::Printf;
@@ -53,7 +54,7 @@ class HloParser {
   std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
 
   // Returns the error information.
-  string GetError() const { return tensorflow::str_util::Join(error_, "\n"); }
+  string GetError() const { return Join(error_, "\n"); }
 
  private:
   // ParseXXX returns false if an error occurred.
@@ -245,7 +246,7 @@ bool HloParser::Error(LocTy loc, StringPiece msg) {
   error_lines.push_back(std::string(lexer_.GetLine(loc)));
   error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
 
-  error_.push_back(tensorflow::str_util::Join(error_lines, "\n"));
+  error_.push_back(Join(error_lines, "\n"));
   VLOG(1) << "Error: " << error_.back();
   return false;
 }
@@ -1488,11 +1489,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
     std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
                                             elems_seen_per_dim.begin() + dim);
     return StrCat("[",
-                  tensorflow::str_util::Join(
-                      elems_seen_until_dim, ",",
-                      [](string* out, const int64& num_elems) {
-                        tensorflow::strings::StrAppend(out, num_elems - 1);
-                      }),
+                  Join(elems_seen_until_dim, ",",
+                       [](string* out, const int64& num_elems) {
+                         tensorflow::strings::StrAppend(out, num_elems - 1);
+                       }),
                   "]");
   };
   do {
@@ -1680,7 +1680,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
         return Error(
             index_loc,
             StrCat("invalid multi-dimension index for shape with rank ", rank,
-                   ": [", tensorflow::str_util::Join(index, ", "), "]"));
+                   ": [", Join(index, ", "), "]"));
       }
     }
     if (!ParseToken(TokKind::kColon,
@@ -1848,7 +1848,19 @@ bool HloParser::ParseAttributeHelper(
   }
   auto attr_it = attrs.find(name);
   if (attr_it == attrs.end()) {
-    return Error(loc, Printf("unexpected attribute %s", name.c_str()));
+    string allowed_attrs;
+    if (attrs.empty()) {
+      allowed_attrs = "No attributes are allowed here.";
+    } else {
+      allowed_attrs = StrCat(
+          "Allowed attributes: ",
+          Join(attrs, ", ",
+               [&](string* out, const std::pair<string, AttrConfig>& kv) {
+                 StrAppend(out, kv.first);
+               }));
+    }
+    return Error(loc, Printf("unexpected attribute \"%s\".  %s", name.c_str(),
+                             allowed_attrs.c_str()));
   }
   AttrTy attr_type = attr_it->second.attr_type;
   void* attr_out_ptr = attr_it->second.result;
index 4e085bc..d38d890 100644 (file)
@@ -1138,7 +1138,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
 
 )";
   ExpectHasSubstr(Parse(original).status().error_message(),
-                  "unexpected attribute calls");
+                  "unexpected attribute \"calls\"");
 }
 
 TEST_F(HloParserTest, MissingAttribute) {