+++ /dev/null
-//===- ReferenceImplGen.cpp - MLIR reference implementation generator -----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// ReferenceImplGen uses the description of operations to generate reference
-// implementations for the ops.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/TableGen/GenInfo.h"
-#include "mlir/TableGen/Operator.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/Signals.h"
-#include "llvm/TableGen/Error.h"
-#include "llvm/TableGen/Record.h"
-#include "llvm/TableGen/TableGenBackend.h"
-
-using namespace llvm;
-using namespace mlir;
-
-using mlir::tblgen::Operator;
-
-static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
- raw_ostream &os) {
- emitSourceFileHeader("Reference implementation file", os);
- const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
-
- os << "void printRefImplementation(StringRef opName, mlir::FuncOp *f) {\n"
- << " using namespace ::mlir::edsc;\n"
- << "if (false) {}";
- for (auto *def : defs) {
- Operator op(def);
- auto referenceImplGenerator = def->getValueInit("referenceImplementation");
- if (!referenceImplGenerator)
- continue;
- os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
- << " edsc::ScopedContext scope(f);\n";
-
- for (auto en : llvm::enumerate(op.getOperands())) {
- os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
- "(void)arg_{0};\n",
- en.value().name, en.index());
- // TODO(jpienaar): this is generally incorrect, not all args are memref
- // in the general case.
- os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
- "(void)view_{0};\n",
- en.value().name, en.index());
- }
- unsigned numOperands = op.getNumOperands();
- unsigned numResults = op.getNumResults();
- for (unsigned idx = 0; idx < numResults; ++idx) {
- os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
- "(void)arg_{0};\n",
- op.getResult(idx).name, numOperands + idx);
- // TODO(jpienaar): this is generally incorrect, not all args are memref
- // in the general case.
- os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
- "(void)view_{0};\n",
- op.getResult(idx).name, numOperands + idx);
- }
-
- // Print the EDSC.
- os << referenceImplGenerator->getAsUnquotedString() << "\n";
- os.indent(2) << "f->print(llvm::outs());\n\n";
- os << "}";
- }
- os << " else {\n";
- os.indent(2) << "f->emitError(\"no reference impl. for \" + opName);\n";
- os.indent(2) << "return;\n";
- os << "}\n";
- os << "}\n";
-}
-
-static mlir::GenRegistration
- genRegister("gen-reference-implementations",
- "Generate reference implementations",
- [](const RecordKeeper &records, raw_ostream &os) {
- emitReferenceImplementations(records, os);
- return false;
- });