From 2dc6d205acb367fc0c1f3ed5172da00a438d50c1 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 12 Apr 2019 09:52:11 -0700 Subject: [PATCH] [TableGen] Allocate `Operator` object on heap in `RecordOperatorMap` Iterators for a `llvm::DenseMap` can be invalidated when an insertion occurs. In Pattern's `collectBoundArguments()`, we recursively handle all nested DAG nodes and grow the the `RecordOperatorMap`, while retaining a reference. This can cause the reference to be invalid and the program to behave randomly. Allocate each `Operator` object specifically to solve this issue. Also, `llvm::DenseMap` is a great way to map pointers to pointers, or map other small types to each other. This avoids placing the `Operator` object directly into the map. -- PiperOrigin-RevId: 243281486 --- mlir/include/mlir/TableGen/Pattern.h | 10 ++++++++-- mlir/lib/TableGen/Pattern.cpp | 6 +++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index cacf400..b6381f3 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -39,8 +39,14 @@ class Record; namespace mlir { namespace tblgen { -// Mapping from TableGen Record to Operator wrapper object -using RecordOperatorMap = llvm::DenseMap; +// Mapping from TableGen Record to Operator wrapper object. +// +// We allocate each wrapper object in heap to make sure the pointer to it is +// valid throughout the lifetime of this map. This is important because this map +// is shared among multiple patterns to avoid creating the wrapper object for +// the same op again and again. But this map will continuously grow. +using RecordOperatorMap = + llvm::DenseMap>; class Pattern; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 2200346..794b772 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -111,7 +111,11 @@ llvm::StringRef tblgen::DagNode::getOpName() const { Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { llvm::Record *opDef = cast(node->getOperator())->getDef(); - return mapper->try_emplace(opDef, opDef).first->second; + auto it = mapper->find(opDef); + if (it != mapper->end()) + return *it->second; + return *mapper->try_emplace(opDef, llvm::make_unique(opDef)) + .first->second; } unsigned tblgen::DagNode::getNumOps() const { -- 2.7.4