+#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
-#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/final_returns.h>
#include <torch/csrc/jit/script/parser.h>
#include <torch/csrc/jit/script/schema_matching.h>
using AttributeMap = std::unordered_map<std::string, Const>;
using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
+using TypeAndRange = std::pair<TypePtr, const SourceRange*>;
+
+// Holds mappings from a variable name to a refined type for that variable
+// E.g if x is not None is true than we can refine x from type t? to t.
+struct Refinements {
+ // using ordered map for deterministic graph output
+ std::map<std::string, TypeAndRange> mappings_;
+
+ void setRefinement(const std::string& name, TypeAndRange mapping) {
+ mappings_[name] = std::move(mapping);
+ }
+
+ c10::optional<TypeAndRange> getRefinement(const std::string& name) const {
+ const auto& maybe_mapping = mappings_.find(name);
+ if (maybe_mapping == mappings_.end()) {
+ return c10::nullopt;
+ }
+ return maybe_mapping->second;
+ }
+
+ // return the intersection of the values to type mappings between this
+ // types can be unified
+ void intersectRefinements(const Refinements& other) {
+ Refinements ret;
+ for (const auto& name_mapping : mappings_) {
+ const auto& name = name_mapping.first;
+ const auto& mapping = name_mapping.second;
+ if (auto other_mapping = other.getRefinement(name_mapping.first)) {
+ const auto maybe_unified_type =
+ unifyTypes(mapping.first, other_mapping->first);
+ if (maybe_unified_type) {
+ ret.setRefinement(
+ name, TypeAndRange(*maybe_unified_type, mapping.second));
+ }
+ }
+ }
+ mappings_ = std::move(ret.mappings_);
+ }
+
+ // return the union of the values to type mappings in a and b whose
+ // types can be unified
+ void unionRefinements(const Refinements& other) {
+ Refinements ret;
+ for (const auto& name_mapping : mappings_) {
+ const auto& name = name_mapping.first;
+ const auto& mapping = name_mapping.second;
+ TypePtr t_1 = mapping.first;
+ if (auto other_mapping = other.getRefinement(name_mapping.first)) {
+ TypePtr t_2 = other_mapping->first;
+ c10::optional<TypePtr> maybe_unified_type = c10::nullopt;
+ if (t_1->isSubtypeOf(t_2)) {
+ maybe_unified_type = t_1;
+ } else if (t_2->isSubtypeOf(t_1)) {
+ maybe_unified_type = t_2;
+ }
+ if (maybe_unified_type) {
+ ret.setRefinement(
+ name, TypeAndRange(*maybe_unified_type, mapping.second));
+ }
+ } else {
+ ret.setRefinement(name, mapping);
+ }
+ }
+
+ for (auto& name_mapping : other.mappings_) {
+ if (!getRefinement(name_mapping.first)) {
+ ret.setRefinement(name_mapping.first, name_mapping.second);
+ }
+ }
+
+ mappings_ = std::move(ret.mappings_);
+ }
+};
+
+// When a comparison like x is None is made, we associate type refinements
+// with its true value and its false value. If a boolean that has refinements
+// associated with it is used in a conditional of an if statememt, the true and
+// false refinements are inserted into the corresponding blocks
+
+struct BoolInfo {
+ BoolInfo(Refinements true_refinements, Refinements false_refinements)
+ : true_refinements_(std::move(true_refinements)),
+ false_refinements_(std::move(false_refinements)){};
+ BoolInfo() = default;
+
+ Refinements true_refinements_;
+ Refinements false_refinements_;
+
+ BoolInfo* mergeOr(const BoolInfo& other) {
+ // if the result of an OR is true, either a & b could have been true,
+ // so we take the intersection of a.true_refinements & b.true_refinements.
+ // if the result is false, both a and b had to be false,
+ // so we take their union.
+ true_refinements_.intersectRefinements(other.true_refinements_);
+ false_refinements_.unionRefinements(other.false_refinements_);
+ return this;
+ }
+
+ BoolInfo* mergeAnd(const BoolInfo& other) {
+ // if the result of an AND is true, both a & b had to be true,
+ // so we take the union of a.true_refinements and b.true_refinements.
+ // if the result is false, either a or b could have been false,
+ // so we take their intersection.
+ true_refinements_.unionRefinements(other.true_refinements_);
+ false_refinements_.intersectRefinements(other.false_refinements_);
+ return this;
+ }
+};
+
static Value* asSimple(const SugaredValuePtr& value) {
if (SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
return sv->getValue();
std::shared_ptr<Environment> emitSingleIfBranch(
Block* b,
- const List<Stmt>& branch) {
+ const List<Stmt>& branch,
+ const Refinements& refinements) {
pushFrame(b);
WithInsertPoint guard(b);
+ insertRefinements(refinements);
emitStatements(branch);
return popFrame();
}
}
Value* emitTernaryIf(const TernaryIf& expr) {
+ const auto& bool_info = findRefinements(expr.cond());
Value* cond_value = emitCond(expr.cond());
- auto true_expr = [&] { return emitExpr(expr.true_expr()); };
- auto false_expr = [&] { return emitExpr(expr.false_expr()); };
+ auto true_expr = [&] {
+ insertRefinements(bool_info.true_refinements_);
+ return emitExpr(expr.true_expr());
+ };
+ auto false_expr = [&] {
+ insertRefinements(bool_info.false_refinements_);
+ return emitExpr(expr.false_expr());
+ };
return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
}
+ // Insert subtyping refinements
+ void insertRefinements(const Refinements& ref) {
+ for (const auto& name_mappings : ref.mappings_) {
+ const std::string& name = name_mappings.first;
+ auto type = name_mappings.second.first;
+ const auto& range = *name_mappings.second.second;
+ Value* v = environment_stack->getVar(name, range);
+ if (type != NoneType::get()) {
+ Value* output = graph->insert(prim::unchecked_unwrap_optional, {v});
+ environment_stack->setVar(range, name, output);
+ }
+ // todo @eellison - revisit inserting Nones when None subtypes Optional
+ }
+ }
+
Value* emitShortCircuitIf(
const SourceRange& loc,
const TreeRef& first_expr,
const TreeRef& second_expr,
bool is_or) {
+ const auto first_bool_info = findRefinements(first_expr);
Value* first_value = emitCond(Expr(first_expr));
- auto get_first_expr = [first_value] { return first_value; };
- auto get_second_expr = [&] { return emitCond(Expr(second_expr)); };
+ const Refinements* first_expr_refinements;
+ const Refinements* second_expr_refinements;
+ // if it's an OR the first expr is emitted in the true branch
+ // and the second expr in the false branch, if it's an AND the opposite
+ if (is_or) {
+ first_expr_refinements = &first_bool_info.true_refinements_;
+ second_expr_refinements = &first_bool_info.false_refinements_;
+ } else {
+ first_expr_refinements = &first_bool_info.false_refinements_;
+ second_expr_refinements = &first_bool_info.true_refinements_;
+ }
+
+ auto get_first_expr = [&] {
+ insertRefinements(*first_expr_refinements);
+ return first_value;
+ };
+
+ auto get_second_expr = [&] {
+ insertRefinements(*second_expr_refinements);
+ return emitCond(Expr(second_expr));
+ };
- // if this is an OR, eval second expression if first expr is False.
+ // if this is an OR, eval second expression if first expr is False
// If this is an AND, eval second expression if first expr is True
if (is_or) {
return emitIfExpr(loc, first_value, get_first_expr, get_second_expr);
void emitIfElseBlocks(Value* cond_value, const If& stmt) {
Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
n->addInput(cond_value);
+ const auto bool_info = findRefinements(stmt.cond());
auto* true_block = n->addBlock();
auto* false_block = n->addBlock();
// Emit both blocks once to get the union of all mutated values
- auto save_true = emitSingleIfBranch(true_block, stmt.trueBranch());
- auto save_false = emitSingleIfBranch(false_block, stmt.falseBranch());
+ auto save_true = emitSingleIfBranch(
+ true_block, stmt.trueBranch(), bool_info.true_refinements_);
+ auto save_false = emitSingleIfBranch(
+ false_block, stmt.falseBranch(), bool_info.false_refinements_);
// In python, every variable assigned in an if statement escapes
// the scope of the if statement (all variables are scoped to the function).
// emit the whole If stmt as usual, finish emitCond first
auto lhs_range = cond_op.lhs().get()->range();
auto rhs_range = cond_op.rhs().get()->range();
+
auto kind = getNodeKind(cond.kind(), cond.get()->trees().size());
Value* cond_value = emitBuiltinCall(
cond.get()->range(),
}
}
+ BoolInfo findRefinements(const TreeRef& tree) {
+ switch (tree->kind()) {
+ case TK_IS:
+ case TK_ISNOT: {
+ const auto& inputs = tree->trees();
+ if (inputs.at(0)->kind() == TK_VAR && inputs.at(1)->kind() == TK_NONE) {
+ const std::string& var_name = Var(inputs[0]).name().name();
+ Refinements true_info, false_info;
+ auto type =
+ environment_stack->getVar(var_name, inputs[0]->range())->type();
+ if (auto opt_type = type->cast<OptionalType>()) {
+ false_info.setRefinement(
+ var_name,
+ TypeAndRange(opt_type->getElementType(), &tree->range()));
+ true_info.setRefinement(
+ var_name, TypeAndRange(NoneType::get(), &tree->range()));
+ }
+ if (tree->kind() == TK_IS) {
+ return BoolInfo(true_info, false_info);
+ } else {
+ return BoolInfo(false_info, true_info);
+ }
+ }
+ } break;
+ case TK_NOT: {
+ const auto& inputs = tree->trees();
+ auto bool_info = findRefinements(inputs[0]);
+ return BoolInfo(
+ bool_info.false_refinements_, bool_info.true_refinements_);
+ }
+ case TK_OR:
+ case TK_AND: {
+ const auto& inputs = tree->trees();
+ auto first = findRefinements(inputs[0]);
+ auto second = findRefinements(inputs[1]);
+ if (tree->kind() == TK_OR) {
+ return *first.mergeOr(second);
+ } else {
+ return *first.mergeAnd(second);
+ }
+ }
+ }
+ return BoolInfo();
+ }
+
Value* emitExpr(const Expr& tree, const TypePtr& type_hint = nullptr) {
return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method);
}
elem_type = values.at(0)->type();
}
for (auto v : values) {
- if (*v->type() != *elem_type) {
+ if (*v->type() != *elem_type) {
throw ErrorReport(tree)
<< "Lists must contain only a single type, expected: "
<< *elem_type << " but found " << *v->type() << " instead";