[flang] Shape analysis for result of MATMUL
authorpeter klausler <pklausler@nvidia.com>
Sat, 20 Jun 2020 01:11:46 +0000 (18:11 -0700)
committerpeter klausler <pklausler@nvidia.com>
Mon, 22 Jun 2020 16:45:30 +0000 (09:45 -0700)
Implement shape analysis for the result of the MATMUL
generic transformational intrinsic function, based on
the shapes of its arguments.  Correct the names of the
arguments to match the standard, too.

Reviewed By: PeteSteinfeld
Differential Revision: https://reviews.llvm.org/D82250

flang/lib/Evaluate/intrinsics.cpp
flang/lib/Evaluate/shape.cpp
flang/lib/Semantics/check-call.cpp

index 0f92dd8..be155f8 100644 (file)
@@ -496,28 +496,28 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
     {"logical", {{"l", AnyLogical}, DefaultingKIND}, KINDLogical},
     {"log_gamma", {{"x", SameReal}}, SameReal},
     {"matmul",
-        {{"array_a", AnyLogical, Rank::vector},
-            {"array_b", AnyLogical, Rank::matrix}},
+        {{"matrix_a", AnyLogical, Rank::vector},
+            {"matrix_b", AnyLogical, Rank::matrix}},
         ResultLogical, Rank::vector, IntrinsicClass::transformationalFunction},
     {"matmul",
-        {{"array_a", AnyLogical, Rank::matrix},
-            {"array_b", AnyLogical, Rank::vector}},
+        {{"matrix_a", AnyLogical, Rank::matrix},
+            {"matrix_b", AnyLogical, Rank::vector}},
         ResultLogical, Rank::vector, IntrinsicClass::transformationalFunction},
     {"matmul",
-        {{"array_a", AnyLogical, Rank::matrix},
-            {"array_b", AnyLogical, Rank::matrix}},
+        {{"matrix_a", AnyLogical, Rank::matrix},
+            {"matrix_b", AnyLogical, Rank::matrix}},
         ResultLogical, Rank::matrix, IntrinsicClass::transformationalFunction},
     {"matmul",
-        {{"array_a", AnyNumeric, Rank::vector},
-            {"array_b", AnyNumeric, Rank::matrix}},
+        {{"matrix_a", AnyNumeric, Rank::vector},
+            {"matrix_b", AnyNumeric, Rank::matrix}},
         ResultNumeric, Rank::vector, IntrinsicClass::transformationalFunction},
     {"matmul",
-        {{"array_a", AnyNumeric, Rank::matrix},
-            {"array_b", AnyNumeric, Rank::vector}},
+        {{"matrix_a", AnyNumeric, Rank::matrix},
+            {"matrix_b", AnyNumeric, Rank::vector}},
         ResultNumeric, Rank::vector, IntrinsicClass::transformationalFunction},
     {"matmul",
-        {{"array_a", AnyNumeric, Rank::matrix},
-            {"array_b", AnyNumeric, Rank::matrix}},
+        {{"matrix_a", AnyNumeric, Rank::matrix},
+            {"matrix_b", AnyNumeric, Rank::matrix}},
         ResultNumeric, Rank::matrix, IntrinsicClass::transformationalFunction},
     {"maskl", {{"i", AnyInt}, DefaultingKIND}, KINDInt},
     {"maskr", {{"i", AnyInt}, DefaultingKIND}, KINDInt},
@@ -1904,7 +1904,6 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::Probe(
   }
 
   if (call.isSubroutineCall) {
-    parser::Messages buffer;
     auto subrRange{subroutines_.equal_range(call.name)};
     for (auto iter{subrRange.first}; iter != subrRange.second; ++iter) {
       if (auto specificCall{
index 507de42..e943b49 100644 (file)
@@ -545,6 +545,23 @@ auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
       if (!call.arguments().empty()) {
         return (*this)(call.arguments()[0]);
       }
+    } else if (intrinsic->name == "matmul") {
+      if (call.arguments().size() == 2) {
+        if (auto ashape{(*this)(call.arguments()[0])}) {
+          if (auto bshape{(*this)(call.arguments()[1])}) {
+            if (ashape->size() == 1 && bshape->size() == 2) {
+              bshape->erase(bshape->begin());
+              return std::move(*bshape); // matmul(vector, matrix)
+            } else if (ashape->size() == 2 && bshape->size() == 1) {
+              ashape->pop_back();
+              return std::move(*ashape); // matmul(matrix, vector)
+            } else if (ashape->size() == 2 && bshape->size() == 2) {
+              (*ashape)[1] = std::move((*bshape)[1]);
+              return std::move(*ashape); // matmul(matrix, matrix)
+            }
+          }
+        }
+      }
     } else if (intrinsic->name == "reshape") {
       if (call.arguments().size() >= 2 && call.arguments().at(1)) {
         // SHAPE(RESHAPE(array,shape)) -> shape
index 282a877..8c3810c 100644 (file)
@@ -607,7 +607,8 @@ static void CheckExplicitInterfaceArg(evaluate::ActualArgument &arg,
                 // ok
               } else {
                 messages.Say(
-                    "Actual argument is not a variable or typed expression"_err_en_US);
+                    "Actual argument '%s' associated with %s is not a variable or typed expression"_err_en_US,
+                    expr->AsFortran(), dummyName);
               }
             } else {
               const Symbol &assumed{DEREF(arg.GetAssumedTypeDummy())};