#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
-#include <algorithm>
#include <unordered_map>
#include <vector>
std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) {
auto call = Downcast<Call>(input);
- auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
- if (it != src->outs.end()) {
+ auto it = src->nodes.find(call->args[0]);
+ if (it != src->nodes.end()) {
dest->outs.remove(*it);
ins_to_remove.push_back(input);
}
void VisitExpr_(const CallNode* call) final {
if (call->op == compiler_end_op) {
auto region = regions_->GetRegion(GetRef<Call>(call));
+ if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
// set the region target
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
region_targets_[region->GetID()] = compiler_attrs->compiler;
}
}
// get the mergeable regions now all the parents have been visited
- std::vector<AnnotatedRegion> mergeable_regions;
+ std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
auto parent_region = regions_->GetRegion(begin->args[0]);
if (!parent_region.defined()) continue;
- mergeable_regions.push_back(parent_region);
+ mergeable_regions.insert(parent_region);
}
auto& region_restrictions = region_restrictions_[region->GetID()];
for (const auto& parent_region : mergeable_regions) {