self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
- def test_explicit_bool_cast(self):
- with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
+ def test_conditional_casting(self):
+ def test_bool_cast_tensor(x):
+ if x:
+ return 1
+ else:
+ return 0
+
+ for make_one_dim in [True, False]:
+ for inp_val in [0.1, 0.0, -0.0, -0.1, -1, 0, 1]:
+ inp_val = [inp_val] if make_one_dim else inp_val
+ self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),))
+
+ self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception,
+ "bool value of Tensor with more than one value")
+
+ def test_cast_int(x):
+ # type: (int) -> int
+ if x:
+ return 1
+ else:
+ return 0
+ self.checkScript(test_cast_int, (1,))
+ self.checkScript(test_cast_int, (0,))
+ self.checkScript(test_cast_int, (-1,))
+
+ def test_cast_float(x):
+ # type: (float) -> int
+ if x:
+ return 1
+ else:
+ return 0
+ self.checkScript(test_cast_float, (1.,))
+ self.checkScript(test_cast_float, (0.,))
+ self.checkScript(test_cast_float, (-1.,))
+
+ with self.assertRaisesRegex(RuntimeError, "expected a bool, int, float, or Tensor"):
@torch.jit.script
- def test_bool_cast(a):
- if a:
- return a + 2
- return a + 1
+ def test_bad_conditional(x):
+ if (1, 2):
+ return
+ else:
+ return 0
def test_while_nonexistent_value(self):
with self.assertRaisesRegex(RuntimeError, "undefined value x"):
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/script/logging.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
+#include <torch/csrc/jit/script/logging.h>
#include <ATen/ExpandUtils.h>
#include <ATen/WrapDimUtils.h>
[](Stack& stack) {
at::Tensor a;
pop(stack, a);
- push(stack, a.item<int64_t>() != 0);
+ push(stack, a.is_nonzero());
return 0;
}),
Operator(
userObj->setSlot(slot, std::move(v));
return 0;
};
- })
- });
-
-RegisterOperators logging_operators({
- Operator("prim::AddStatValue(str key, int val) -> ()", [](Stack& stack) {
- auto val = pop(stack).toInt();
- auto key = pop(stack).toString();
-
- auto schema = parseSchema("prim::AddStatValue(str key, int val) -> ()");
- // TODO: remove this custom tracing code once the custom op bugfix lands
- if (jit::tracer::isTracing()) {
- const auto& graph = tracer::getTracingState()->graph;
- Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0);
- tracer::recordSourceLocation(node);
- node->addInput(insertConstant(*graph, key));
- tracer::addInputs(node, "val", val);
- graph->insertNode(node);
- }
- torch::jit::logging::getLogger()->addStatValue(*key, val);
- return 0;
- }),
- Operator("prim::TimePoint() -> int", [](Stack& stack) {
- auto schema = parseSchema("prim::TimePoint() -> int");
- Node* node = nullptr;
- // TODO: remove this custom tracing code once the custom op bugfix lands
- if (jit::tracer::isTracing()) {
- const auto& graph = tracer::getTracingState()->graph;
- Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0);
- tracer::recordSourceLocation(node);
- graph->insertNode(node);
- }
- auto output = autograd::profiler::getTime();
- push(stack, output);
- if (jit::tracer::isTracing()) {
- jit::tracer::addOutput(node, output);
- }
- return 0;
- })
-});
+ })});
+RegisterOperators logging_operators(
+ {Operator(
+ "prim::AddStatValue(str key, int val) -> ()",
+ [](Stack& stack) {
+ auto val = pop(stack).toInt();
+ auto key = pop(stack).toString();
+
+ auto schema =
+ parseSchema("prim::AddStatValue(str key, int val) -> ()");
+ // TODO: remove this custom tracing code once the custom op bugfix
+ // lands
+ if (jit::tracer::isTracing()) {
+ const auto& graph = tracer::getTracingState()->graph;
+ Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0);
+ tracer::recordSourceLocation(node);
+ node->addInput(insertConstant(*graph, key));
+ tracer::addInputs(node, "val", val);
+ graph->insertNode(node);
+ }
+ torch::jit::logging::getLogger()->addStatValue(*key, val);
+ return 0;
+ }),
+ Operator("prim::TimePoint() -> int", [](Stack& stack) {
+ auto schema = parseSchema("prim::TimePoint() -> int");
+ Node* node = nullptr;
+ // TODO: remove this custom tracing code once the custom op bugfix lands
+ if (jit::tracer::isTracing()) {
+ const auto& graph = tracer::getTracingState()->graph;
+ Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0);
+ tracer::recordSourceLocation(node);
+ graph->insertNode(node);
+ }
+ auto output = autograd::profiler::getTime();
+ push(stack, output);
+ if (jit::tracer::isTracing()) {
+ jit::tracer::addOutput(node, output);
+ }
+ return 0;
+ })});
// define implementations for primitive number ops
#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
return 0;
}
-template<typename T>
+template <typename T>
int hashValue(Stack& stack) {
auto value = pop(stack);
auto hash = std::hash<T>()(value.to<T>());
#undef CREATE_MUTABLE_LIST_OPS
#define CREATE_LIST_OPS(decl_type, c_type) \
- Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
+ Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
Operator( \
"aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \
"[]", \
push(stack, t);
return 0;
}),
-#define CREATE_DICT_OPS(key_type) \
- Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
- Operator( \
- "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \
- dictKeys), \
- Operator( \
- "aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues), \
- Operator( \
- "prim::DictIndex(Dict(" key_type ", t) self, " key_type \
- " key) -> t(*)", \
- dictIndex), \
- Operator( \
- "aten::get(Dict(" key_type ", t) self, " key_type \
- " key) -> t(*)?", \
- dictGet), \
- Operator( \
- "aten::get(Dict(" key_type ", t) self, " key_type \
- " key, t default_value) -> t(*)", \
- dictGetDefault), \
- Operator( \
- "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \
- " idx, t v) -> ()", \
+#define CREATE_DICT_OPS(key_type) \
+ Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
+ Operator( \
+ "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \
+ dictKeys), \
+ Operator( \
+ "aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues), \
+ Operator( \
+ "prim::DictIndex(Dict(" key_type ", t) self, " key_type \
+ " key) -> t(*)", \
+ dictIndex), \
+ Operator( \
+ "aten::get(Dict(" key_type ", t) self, " key_type " key) -> t(*)?", \
+ dictGet), \
+ Operator( \
+ "aten::get(Dict(" key_type ", t) self, " key_type \
+ " key, t default_value) -> t(*)", \
+ dictGetDefault), \
+ Operator( \
+ "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \
+ " idx, t v) -> ()", \
dictSetItem)
CREATE_DICT_OPS("str"),
CREATE_DICT_OPS("float"),
#undef CREATE_DICT_OPS
-
Operator("aten::hash(str t) -> int", hashValue<std::string>),
Operator("aten::hash(int t) -> int", hashValue<int>),
Operator("aten::hash(float t) -> int", hashValue<double>),
{"abs", std::make_shared<BuiltinFunction>(prim::abs, at::nullopt)},
{"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
{"ord", std::make_shared<BuiltinFunction>(aten::ord, at::nullopt)},
- {"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
{"rangelist",
std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
};
Value* emitCond(const Expr& cond) {
Value* v = emitExpr(cond);
if (!v->type()->isSubtypeOf(BoolType::get())) {
- ErrorReport error(cond);
- error << "expected a boolean expression for condition but found "
+ Value* cast_v = emitBuiltinCall(
+ cond.get()->range(),
+ *v->owningGraph(),
+ prim::Bool,
+ c10::nullopt,
+ {v},
+ {},
+ /*required*/ false);
+ if (cast_v == nullptr) {
+ ErrorReport error(cond);
+ error
+ << "expected a bool, int, float, or Tensor expression for condition but found "
<< v->type()->str();
- if (v->type()->isSubtypeOf(TensorType::get())) {
- error << ", to use a tensor in a boolean"
- << " expression, explicitly cast it with `bool()`";
+ throw error;
+ } else {
+ v = cast_v;
}
- throw error;
}
return v;
}
{},
true);
} else {
- throw ErrorReport(loc)
- << "Indexing only supported on List, Dict, "
- "Tensor, Tuple, and str but got type '"
- << gatherable->type()->python_str() << "'";
+ throw ErrorReport(loc) << "Indexing only supported on List, Dict, "
+ "Tensor, Tuple, and str but got type '"
+ << gatherable->type()->python_str() << "'";
}
}
};