[flang] Implement MATMUL in the runtime
authorpeter klausler <pklausler@nvidia.com>
Mon, 17 May 2021 21:06:44 +0000 (14:06 -0700)
committerpeter klausler <pklausler@nvidia.com>
Tue, 18 May 2021 17:59:52 +0000 (10:59 -0700)
Define an API for the transformational intrinsic function MATMUL,
implement it, and add some basic unit tests.  The large number of
possible argument type combinations are covered by a set of
generalized templates that are instantiated for each valid
pair of possible argument types.

Places where BLAS-2/3 routines could be called for acceleration
are marked with TODOs.  Handling for other special cases (e.g.,
known-shape 3x3 matrices and vectors) are deferred.

Some minor tweaks were made to the recent related implementation
of DOT_PRODUCT to reflect lessons learned.

Differential Revision: https://reviews.llvm.org/D102652

flang/runtime/CMakeLists.txt
flang/runtime/dot-product.cpp
flang/runtime/matmul.cpp [new file with mode: 0644]
flang/runtime/matmul.h [new file with mode: 0644]
flang/runtime/reduction.h
flang/unittests/RuntimeGTest/CMakeLists.txt
flang/unittests/RuntimeGTest/Matmul.cpp [new file with mode: 0644]

index 84d13f1..a484c94 100644 (file)
@@ -53,6 +53,7 @@ add_flang_library(FortranRuntime
   io-error.cpp
   io-stmt.cpp
   main.cpp
+  matmul.cpp
   memory.cpp
   misc-intrinsic.cpp
   namelist.cpp
index 1c83d8d..075d987 100644 (file)
 
 namespace Fortran::runtime {
 
-template <typename ACCUMULATOR>
-static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y,
-    Terminator &terminator) -> typename ACCUMULATOR::Result {
+template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+class Accumulator {
+public:
+  using Result = RESULT;
+  Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+  void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
+    if constexpr (XCAT == TypeCategory::Complex) {
+      sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
+          static_cast<Result>(*y_.Element<YT>(&yAt));
+    } else if constexpr (XCAT == TypeCategory::Logical) {
+      sum_ = sum_ ||
+          (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
+    } else {
+      sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
+          static_cast<Result>(*y_.Element<YT>(&yAt));
+    }
+  }
+  Result GetResult() const { return sum_; }
+
+private:
+  const Descriptor &x_, &y_;
+  Result sum_{};
+};
+
+template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+static inline RESULT DoDotProduct(
+    const Descriptor &x, const Descriptor &y, Terminator &terminator) {
   RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
   SubscriptValue n{x.GetDimension(0).Extent()};
   if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
@@ -25,18 +49,27 @@ static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y,
         "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
         static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
   }
+  if constexpr (std::is_same_v<XT, YT>) {
+    if constexpr (std::is_same_v<XT, float>) {
+      // TODO: call BLAS-1 SDOT or SDSDOT
+    } else if constexpr (std::is_same_v<XT, double>) {
+      // TODO: call BLAS-1 DDOT
+    } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+      // TODO: call BLAS-1 CDOTC
+    } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+      // TODO: call BLAS-1 ZDOTC
+    }
+  }
   SubscriptValue xAt{x.GetDimension(0).LowerBound()};
   SubscriptValue yAt{y.GetDimension(0).LowerBound()};
-  ACCUMULATOR accumulator{x, y};
+  Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y};
   for (SubscriptValue j{0}; j < n; ++j) {
     accumulator.Accumulate(xAt++, yAt++);
   }
   return accumulator.GetResult();
 }
 
-template <TypeCategory RCAT, int RKIND,
-    template <typename, TypeCategory, typename, typename> class ACCUM>
-struct DotProduct {
+template <TypeCategory RCAT, int RKIND> struct DotProduct {
   using Result = CppTypeFor<RCAT, RKIND>;
   template <TypeCategory XCAT, int XKIND> struct DP1 {
     template <TypeCategory YCAT, int YKIND> struct DP2 {
@@ -46,9 +79,8 @@ struct DotProduct {
                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
           if constexpr (resultType->first == RCAT &&
               resultType->second <= RKIND) {
-            using Accum = ACCUM<Result, XCAT, CppTypeFor<XCAT, XKIND>,
-                CppTypeFor<YCAT, YKIND>>;
-            return DoDotProduct<Accum>(x, y, terminator);
+            return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>,
+                CppTypeFor<YCAT, YKIND>>(x, y, terminator);
           }
         }
         terminator.Crash(
@@ -73,127 +105,76 @@ struct DotProduct {
   }
 };
 
-template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
-class NumericAccumulator {
-public:
-  using Result = RESULT;
-  NumericAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
-  void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
-    if constexpr (XCAT == TypeCategory::Complex) {
-      sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
-          static_cast<Result>(*y_.Element<YT>(&yAt));
-    } else {
-      sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
-          static_cast<Result>(*y_.Element<YT>(&yAt));
-    }
-  }
-  Result GetResult() const { return sum_; }
-
-private:
-  const Descriptor &x_, &y_;
-  Result sum_{0};
-};
-
-template <typename, TypeCategory, typename XT, typename YT>
-class LogicalAccumulator {
-public:
-  using Result = bool;
-  LogicalAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
-  void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
-    result_ = result_ ||
-        (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
-  }
-  bool GetResult() const { return result_; }
-
-private:
-  const Descriptor &x_, &y_;
-  bool result_{false};
-};
-
 extern "C" {
 std::int8_t RTNAME(DotProductInteger1)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
 }
 std::int16_t RTNAME(DotProductInteger2)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
 }
 std::int32_t RTNAME(DotProductInteger4)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
 }
 std::int64_t RTNAME(DotProductInteger8)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
 }
 #ifdef __SIZEOF_INT128__
 common::int128_t RTNAME(DotProductInteger16)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 16, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
 }
 #endif
 
 // TODO: REAL/COMPLEX(2 & 3)
 float RTNAME(DotProductReal4)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
 }
 double RTNAME(DotProductReal8)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
 }
 #if LONG_DOUBLE == 80
 long double RTNAME(DotProductReal10)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Real, 10, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
 }
 #elif LONG_DOUBLE == 128
 long double RTNAME(DotProductReal16)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Real, 16, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
 }
 #endif
 
 void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  auto z{DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
-      x, y, source, line)};
+  auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)};
   result = std::complex<float>{
       static_cast<float>(z.real()), static_cast<float>(z.imag())};
 }
 void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  result = DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
 }
 #if LONG_DOUBLE == 80
 void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  result = DotProduct<TypeCategory::Complex, 10, NumericAccumulator>{}(
-      x, y, source, line);
+  result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
 }
 #elif LONG_DOUBLE == 128
 void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  result = DotProduct<TypeCategory::Complex, 16, NumericAccumulator>{}(
-      x, y, source, line);
+  result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
 }
 #endif
 
 bool RTNAME(DotProductLogical)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Logical, 1, LogicalAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
 }
 } // extern "C"
 } // namespace Fortran::runtime
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
new file mode 100644 (file)
index 0000000..3d10ca0
--- /dev/null
@@ -0,0 +1,220 @@
+//===-- runtime/matmul.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// Implements all forms of MATMUL (Fortran 2018 16.9.124)
+//
+// There are two main entry points; one establishes a descriptor for the
+// result and allocates it, and the other expects a result descriptor that
+// points to existing storage.
+//
+// This implementation must handle all combinations of numeric types and
+// kinds (100 - 165 cases depending on the target), plus all combinations
+// of logical kinds (16).  A single template undergoes many instantiations
+// to cover all of the valid possibilities.
+//
+// Places where BLAS routines could be called are marked as TODO items.
+
+#include "matmul.h"
+#include "cpp-type.h"
+#include "descriptor.h"
+#include "terminator.h"
+#include "tools.h"
+
+namespace Fortran::runtime {
+
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+class Accumulator {
+public:
+  // Accumulate floating-point results in (at least) double precision
+  using Result = CppTypeFor<RCAT,
+      RCAT == TypeCategory::Real || RCAT == TypeCategory::Complex
+          ? std::max(RKIND, static_cast<int>(sizeof(double)))
+          : RKIND>;
+  Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+  void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) {
+    if constexpr (RCAT == TypeCategory::Logical) {
+      sum_ = sum_ ||
+          (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt));
+    } else {
+      sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) *
+          static_cast<Result>(*y_.Element<YT>(yAt));
+    }
+  }
+  Result GetResult() const { return sum_; }
+
+private:
+  const Descriptor &x_, &y_;
+  Result sum_{};
+};
+
+// Implements an instance of MATMUL for given argument types.
+template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
+    typename YT>
+static inline void DoMatmul(
+    std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
+    const Descriptor &x, const Descriptor &y, Terminator &terminator) {
+  int xRank{x.rank()};
+  int yRank{y.rank()};
+  int resRank{xRank + yRank - 2};
+  if (xRank * yRank != 2 * resRank) {
+    terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
+  }
+  SubscriptValue extent[2]{
+      xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(),
+      resRank == 2 ? y.GetDimension(1).Extent() : 0};
+  if constexpr (IS_ALLOCATING) {
+    result.Establish(
+        RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
+    for (int j{0}; j < resRank; ++j) {
+      result.GetDimension(j).SetBounds(1, extent[j]);
+    }
+    if (int stat{result.Allocate()}) {
+      terminator.Crash(
+          "MATMUL: could not allocate memory for result; STAT=%d", stat);
+    }
+  } else {
+    RUNTIME_CHECK(terminator, resRank == result.rank());
+    RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND}));
+    RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
+    RUNTIME_CHECK(terminator,
+        resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
+  }
+  using WriteResult =
+      CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
+          RKIND>;
+  SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
+  if (n != y.GetDimension(0).Extent()) {
+    terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)",
+        static_cast<std::intmax_t>(n),
+        static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
+  }
+  SubscriptValue xAt[2], yAt[2], resAt[2];
+  x.GetLowerBounds(xAt);
+  y.GetLowerBounds(yAt);
+  result.GetLowerBounds(resAt);
+  if (resRank == 2) { // M*M -> M
+    if constexpr (std::is_same_v<XT, YT>) {
+      if constexpr (std::is_same_v<XT, float>) {
+        // TODO: call BLAS-3 SGEMM
+      } else if constexpr (std::is_same_v<XT, double>) {
+        // TODO: call BLAS-3 DGEMM
+      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+        // TODO: call BLAS-3 CGEMM
+      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+        // TODO: call BLAS-3 ZGEMM
+      }
+    }
+    SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
+    for (SubscriptValue i{0}; i < extent[0]; ++i) {
+      for (SubscriptValue j{0}; j < extent[1]; ++j) {
+        Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+        yAt[1] = y1 + j;
+        for (SubscriptValue k{0}; k < n; ++k) {
+          xAt[1] = x1 + k;
+          yAt[0] = y0 + k;
+          accumulator.Accumulate(xAt, yAt);
+        }
+        resAt[1] = res1 + j;
+        *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+      }
+      ++resAt[0];
+      ++xAt[0];
+    }
+  } else {
+    if constexpr (std::is_same_v<XT, YT>) {
+      if constexpr (std::is_same_v<XT, float>) {
+        // TODO: call BLAS-2 SGEMV
+      } else if constexpr (std::is_same_v<XT, double>) {
+        // TODO: call BLAS-2 DGEMV
+      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+        // TODO: call BLAS-2 CGEMV
+      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+        // TODO: call BLAS-2 ZGEMV
+      }
+    }
+    if (xRank == 2) { // M*V -> V
+      SubscriptValue x1{xAt[1]}, y0{yAt[0]};
+      for (SubscriptValue j{0}; j < extent[0]; ++j) {
+        Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+        for (SubscriptValue k{0}; k < n; ++k) {
+          xAt[1] = x1 + k;
+          yAt[0] = y0 + k;
+          accumulator.Accumulate(xAt, yAt);
+        }
+        *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+        ++resAt[0];
+        ++xAt[0];
+      }
+    } else { // V*M -> V
+      SubscriptValue x0{xAt[0]}, y0{yAt[0]};
+      for (SubscriptValue j{0}; j < extent[0]; ++j) {
+        Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+        for (SubscriptValue k{0}; k < n; ++k) {
+          xAt[0] = x0 + k;
+          yAt[0] = y0 + k;
+          accumulator.Accumulate(xAt, yAt);
+        }
+        *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+        ++resAt[0];
+        ++yAt[1];
+      }
+    }
+  }
+}
+
+// Maps the dynamic type information from the arguments' descriptors
+// to the right instantiation of DoMatmul() for valid combinations of
+// types.
+template <bool IS_ALLOCATING> struct Matmul {
+  using ResultDescriptor =
+      std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+  template <TypeCategory XCAT, int XKIND> struct MM1 {
+    template <TypeCategory YCAT, int YKIND> struct MM2 {
+      void operator()(ResultDescriptor &result, const Descriptor &x,
+          const Descriptor &y, Terminator &terminator) const {
+        if constexpr (constexpr auto resultType{
+                          GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+          if constexpr (common::IsNumericTypeCategory(resultType->first) ||
+              resultType->first == TypeCategory::Logical) {
+            return DoMatmul<IS_ALLOCATING, resultType->first,
+                resultType->second, CppTypeFor<XCAT, XKIND>,
+                CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
+          }
+        }
+        terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
+            static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
+      }
+    };
+    void operator()(ResultDescriptor &result, const Descriptor &x,
+        const Descriptor &y, Terminator &terminator, TypeCategory yCat,
+        int yKind) const {
+      ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
+    }
+  };
+  void operator()(ResultDescriptor &result, const Descriptor &x,
+      const Descriptor &y, const char *sourceFile, int line) const {
+    Terminator terminator{sourceFile, line};
+    auto xCatKind{x.type().GetCategoryAndKind()};
+    auto yCatKind{y.type().GetCategoryAndKind()};
+    RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+    ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
+        x, y, terminator, yCatKind->first, yCatKind->second);
+  }
+};
+
+extern "C" {
+void RTNAME(Matmul)(Descriptor &result, const Descriptor &x,
+    const Descriptor &y, const char *sourceFile, int line) {
+  Matmul<true>{}(result, x, y, sourceFile, line);
+}
+void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x,
+    const Descriptor &y, const char *sourceFile, int line) {
+  Matmul<false>{}(result, x, y, sourceFile, line);
+}
+} // extern "C"
+} // namespace Fortran::runtime
diff --git a/flang/runtime/matmul.h b/flang/runtime/matmul.h
new file mode 100644 (file)
index 0000000..8334d66
--- /dev/null
@@ -0,0 +1,29 @@
+//===-- runtime/matmul.h ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// API for the transformational intrinsic function MATMUL.
+
+#ifndef FORTRAN_RUNTIME_MATMUL_H_
+#define FORTRAN_RUNTIME_MATMUL_H_
+#include "entry-names.h"
+namespace Fortran::runtime {
+class Descriptor;
+extern "C" {
+
+// The most general MATMUL.  All type and shape information is taken from the
+// arguments' descriptors, and the result is dynamically allocated.
+void RTNAME(Matmul)(Descriptor &, const Descriptor &, const Descriptor &,
+    const char *sourceFile = nullptr, int line = 0);
+
+// A non-allocating variant; the result's descriptor must be established
+// and have a valid base address.
+void RTNAME(MatmulDirect)(const Descriptor &, const Descriptor &,
+    const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+} // extern "C"
+} // namespace Fortran::runtime
+#endif // FORTRAN_RUNTIME_MATMUL_H_
index cec3084..379fcb8 100644 (file)
@@ -7,9 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 // Defines the API for the reduction transformational intrinsic functions.
-// (Except the complex-valued DOT_PRODUCT and the complex-valued total reduction
-// forms of SUM & PRODUCT; the API for those is in complex-reduction.h so that
-// C's _Complex can be used for their return types.)
 
 #ifndef FORTRAN_RUNTIME_REDUCTION_H_
 #define FORTRAN_RUNTIME_REDUCTION_H_
@@ -36,10 +33,10 @@ extern "C" {
 // results in a caller-supplied descriptor, which is assumed to
 // be large enough.
 //
-// Complex-valued SUM and PRODUCT reductions have their API
-// entry points defined in complex-reduction.h; these are C wrappers
-// around C++ implementations so as to keep usage of C's _Complex
-// types out of C++ code.
+// Complex-valued SUM and PRODUCT reductions and complex-valued
+// DOT_PRODUCT have their API entry points defined in complex-reduction.h;
+// these here are C wrappers around C++ implementations so as to keep
+// usage of C's _Complex types out of C++ code.
 
 // SUM()
 
index cad827a..3d45cf6 100644 (file)
@@ -2,6 +2,7 @@ add_flang_unittest(FlangRuntimeTests
   CharacterTest.cpp
   CrashHandlerFixture.cpp
   Format.cpp
+  Matmul.cpp
   MiscIntrinsic.cpp
   Namelist.cpp
   Numeric.cpp
diff --git a/flang/unittests/RuntimeGTest/Matmul.cpp b/flang/unittests/RuntimeGTest/Matmul.cpp
new file mode 100644 (file)
index 0000000..ae9e7a8
--- /dev/null
@@ -0,0 +1,98 @@
+//===-- flang/unittests/RuntimeGTest/Matmul.cpp---- -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../../runtime/matmul.h"
+#include "gtest/gtest.h"
+#include "tools.h"
+#include "../../runtime/allocatable.h"
+#include "../../runtime/cpp-type.h"
+#include "../../runtime/descriptor.h"
+#include "../../runtime/type-code.h"
+
+using namespace Fortran::runtime;
+using Fortran::common::TypeCategory;
+
+TEST(Matmul, Basic) {
+  // X 0 2 4   Y 6  9   V -1 -2
+  //   1 3 5     7 10
+  //             8 11
+  auto x{MakeArray<TypeCategory::Integer, 4>(
+      std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})};
+  auto y{MakeArray<TypeCategory::Integer, 2>(
+      std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
+  auto v{MakeArray<TypeCategory::Integer, 8>(
+      std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
+  StaticDescriptor<2> statDesc;
+  Descriptor &result{statDesc.descriptor()};
+
+  RTNAME(Matmul)(result, *x, *y, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+
+  std::memset(
+      result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
+  result.GetDimension(0).SetLowerBound(0);
+  result.GetDimension(1).SetLowerBound(2);
+  RTNAME(MatmulDirect)(result, *x, *y, __FILE__, __LINE__);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
+  RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
+  result.Destroy();
+
+  RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+  result.Destroy();
+
+  // X F F T  Y F T
+  //   F T T    F T
+  //            F F
+  auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{2, 3},
+      std::vector<std::uint8_t>{false, false, false, true, true, false})};
+  auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2},
+      std::vector<std::uint16_t>{false, false, false, true, true, false})};
+  RTNAME(Matmul)(result, *xLog, *yLog, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
+  EXPECT_TRUE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+}