*/
/*!
- * Copyright (c) 2017 by Contributors
* \file compute_expr.h
- * \brief Utility integer expression with quick eager simplification.
- * This is weaker than Simplify but can be done Eagerly.
+ * \brief Utility to invoke certan compute operations.
*/
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
* \return The result.
*/
template<typename OP>
-inline Expr ComputeExpr(Expr lhs, Expr rhs) {
+inline Expr Compute(Expr lhs, Expr rhs) {
return OP::make(lhs, rhs);
}
}
template<>
-inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
+inline Expr Compute<ir::Add>(Expr a, Expr b) {
return a + b;
}
template<>
-inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
+inline Expr Compute<ir::Sub>(Expr a, Expr b) {
return a - b;
}
template<>
-inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
+inline Expr Compute<ir::Mul>(Expr a, Expr b) {
return a * b;
}
template<>
-inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
+inline Expr Compute<ir::Div>(Expr a, Expr b) {
return a / b;
}
template<>
-inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
+inline Expr Compute<ir::Mod>(Expr a, Expr b) {
return a % b;
}
template<>
-inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
+inline Expr Compute<ir::Max>(Expr a, Expr b) {
return max(a, b);
}
template<>
-inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
+inline Expr Compute<ir::Min>(Expr a, Expr b) {
return min(a, b);
}
}
Expr res = values[0];
for (size_t i = 1; i < values.size(); ++i) {
- res = ComputeExpr<Op>(res, values[i]);
+ res = Compute<Op>(res, values[i]);
}
return res;
}
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
-#include "compute_expr.h"
namespace tvm {
namespace arith {
Expr AddCombine(Expr a, Expr b) {
if (!a.defined()) return b;
if (!b.defined()) return a;
- return ComputeExpr<Add>(a, b);
+ return a + b;
}
Expr SubCombine(Expr a, Expr b) {
// Check b first in case they are both undefined
if (!b.defined()) return a;
if (!a.defined()) return -b;
- return ComputeExpr<Sub>(a, b);
+ return a - b;
}
Expr MulCombine(Expr a, Expr b) {
if (!a.defined()) return a;
if (!b.defined()) return b;
- return ComputeExpr<Mul>(a, b);
+ return a * b;
}
};
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <vector>
#include <string>
#include "codegen_cuda.h"
-#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
std::function<void(int i, llvm::Value* v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) {
for (int i = 0; i < ramp->type.lanes(); ++i) {
- Expr offset = arith::ComputeExpr<Add>(
- ramp->base,
- arith::ComputeExpr<Mul>(ramp->stride, i));
+ Expr offset = ramp->base + (ramp->stride * i);
f(i, MakeValue(offset));
}
} else {
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <string>
-#include "../../arithmetic/compute_expr.h"
#include "codegen_spirv.h"
+#include "../../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
spirv::Value v = base;
if (i != 0) {
spirv::Value offset = MakeValue(
- arith::ComputeExpr<Mul>(make_const(op->stride.type(), i), op->stride));
+ make_const(op->stride.type(), i) * op->stride);
v = builder_->Add(v, offset);
}
values.push_back(v);
std::function<void(int i, spirv::Value v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) {
for (int i = 0; i < ramp->type.lanes(); ++i) {
- Expr offset = arith::ComputeExpr<Add>(
- ramp->base,
- arith::ComputeExpr<Mul>(ramp->stride, i));
+ Expr offset = ramp->base + ramp->stride * i;
f(i, MakeValue(offset));
}
} else {
extent = make_const(self->DefaultIndexType(), 1);
} else if (self->strides.size() == self->shape.size()) {
int highest_dim = 0;
- extent = arith::ComputeExpr<ir::Mul>(
- self->strides[highest_dim], self->shape[highest_dim]) - offset;
+ extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
} else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2017 by Contributors
* \file arg_binder.cc
* \brief Helper utility to match and bind arguments.
*/
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
+#include <tvm/expr_operator.h>
#include "ir_util.h"
#include "../arithmetic/compute_expr.h"
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
- it->second.stride = arith::ComputeReduce<Mul>
- (op->extents, Expr()) * op->type.lanes();
+ it->second.stride = arith::ComputeReduce<Mul>(
+ op->extents, Expr()) * op->type.lanes();
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
<< "It is better to split with multiple of 2";
CHECK(is_zero(old_loop->min));
Expr zero = old_loop->min;
- Expr new_ext = arith::ComputeExpr<Sub>(
- old_loop->extent, make_const(old_loop->loop_var.type(), 1));
+ Expr new_ext =
+ old_loop->extent - make_const(old_loop->loop_var.type(), 1);
Expr factor = make_const(new_ext.type(), split_loop_);
- Expr outer_ext = arith::ComputeExpr<Div>(new_ext, factor);
- Expr tail_base = arith::ComputeExpr<Mul>(outer_ext, factor);
+ Expr outer_ext = new_ext / factor;
+ Expr tail_base = outer_ext * factor;
Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type());
std::unordered_map<const Variable*, Expr> vmap;
std::vector<Stmt> loop_seq;
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2017 by Contributors
* \file inject_virtual_thread.cc
*/
#include <tvm/ir.h>
explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
bool check_write)
: touched_var_(touched), check_write_(check_write) {}
+
void Visit(const NodeRef& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
visit_touched_var_ = true;
Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]);
- Expr stride = arith::ComputeExpr<Div>(
- it->second, make_const(offset.type(), dtype.lanes()));
+ Expr stride =
+ it->second / make_const(offset.type(), dtype.lanes());
offset = stride * var_ + offset;
return Call::make(
op->type, op->name,
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2018 by Contributors
- *
* Lower warp memory to use local memory
* and shuffle intrinsics.
*
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include "ir_util.h"
#include "arg_binder.h"
-#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
stride = ir::Simplify(stride);
}
rstrides.push_back(stride);
- stride = arith::ComputeExpr<Mul>(stride, shape[dim]);
+ stride = stride * shape[dim];
}
strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
}
int first_dim = 0;
ret = Allocate::make(
e.buffer->data, storage_type,
- {arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])},
+ {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
make_const(Bool(e.buffer->dtype.lanes()), true), body);
} else {
shape = e.buffer->shape;
if (be.bounds.size() != 0) {
CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
- begins.push_back(
- arith::ComputeExpr<Sub>(tuple->args[2 * i], be.bounds[i]->min));
+ begins.push_back(tuple->args[2 * i] - be.bounds[i]->min);
extents.push_back(tuple->args[2 * i + 1]);
}
} else {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2017 by Contributors
* Loop unrolling as in Halide pipeline.
* \file unroll_loop.cc
*/
}
Stmt Unroll(const For* op) {
- using arith::ComputeExpr;
int value = GetExtent(op);
// For loop must have a constant integer extent
CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
Stmt unrolled;
for (int i = 0; i < value; ++i) {
Var lv(op->loop_var.node_);
- vmap.Set(lv,
- ComputeExpr<Add>(
- op->min, make_const(op->loop_var.type(), i)));
+ vmap.Set(lv, op->min + make_const(op->loop_var.type(), i));
Stmt step = Substitute(body, vmap);
if (unrolled.defined()) {
unrolled = Block::make(unrolled, step);
*/
/*!
- * Copyright (c) 2017 by Contributors
* \file vectorize_loop.cc
*/
// Loop vectorizer as in Halide pipeline.
const Ramp* a_ramp = a.as<Ramp>();
if (a.type().lanes() == 1 && b_ramp) {
return Ramp::make(
- arith::ComputeExpr<T>(a, b_ramp->base),
- arith::ComputeExpr<T>(make_zero(b_ramp->stride.type()), b_ramp->stride),
+ arith::Compute<T>(a, b_ramp->base),
+ arith::Compute<T>(make_zero(b_ramp->stride.type()), b_ramp->stride),
b_ramp->lanes);
}
if (b.type().lanes() == 1 && a_ramp) {
return Ramp::make(
- arith::ComputeExpr<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
+ arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
*/
/*!
- * Copyright (c) 2017 by Contributors
* \file message_passing.cc
* \brief The message passing domain.
*/
namespace schedule {
using namespace ir;
-using namespace arith;
void Update(std::unordered_map<IterVar, Range>* p_state,
const IterVar& iv,
Range r,
- Analyzer* analyzer) {
+ arith::Analyzer* analyzer) {
auto it = p_state->find(iv);
if (it == p_state->end()) {
(*p_state)[iv] = r;
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
- state[s->outer] = ComputeExpr<Div>(value, factor);
- state[s->inner] = ComputeExpr<Mod>(value, factor);
+ state[s->outer] = value / factor;
+ state[s->inner] = value % factor;
// add min if they exist
if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min;
CHECK(is_zero(r->min));
Expr parent = state.at(s->parent);
Expr factor = r->extent;
- state[s->outer] = ComputeExpr<Div>(parent, factor);
- state[s->inner] = ComputeExpr<Mod>(parent, factor);
+ state[s->outer] = parent / factor;
+ state[s->inner] = parent % factor;
} else if (const FuseNode* s = rel.as<FuseNode>()) {
if (!state.count(s->inner) && !state.count(s->outer)) {
CHECK(allow_missing);
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
- *parent = EvalSet(
+ *parent = arith::EvalSet(
s->outer->var * factor + s->inner->var + parent_min,
{{s->outer, outer}, {s->inner, inner}});
}
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
- *parent = EvalSet(s->rebased->var + parent_min,
- {{s->rebased, rebased}});
+ *parent = arith::EvalSet(s->rebased->var + parent_min,
+ {{s->rebased, rebased}});
}
void PassUpDomain(const Stage& stage,
const std::unordered_map<IterVar, Expr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) {
- Analyzer analyzer;
+ arith::Analyzer analyzer;
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
if (bound_state.at(iv)) {
Range dom = dom_map.at(iv);
- Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
+ Expr value = value_map.at(iv) - dom->min;
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
Range dom = dom_map.at(iv);
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
- Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
+ Expr value = value_map.at(iv) - iv->dom->min;
IntSet s = EvalSet(value, iset_dmap);
Expr vmin = s.min();
Expr vmax = s.max();