self.run_pass('constant_propagation', constant_prop.graph)
self.assertExpected(canonical(constant_prop.graph))
+ def test_constant_prop_none(self):
+ @torch.jit.script
+ def typed_none():
+ # type: () -> Optional[int]
+ return None
+
+ @torch.jit.script
+ def constant_prop():
+ a = typed_none()
+ b = typed_none()
+ if (a is None and b is None):
+ a = 2
+ else:
+ a = 1
+ return a
+
+ self.run_pass('constant_propagation', constant_prop.graph)
+ graph_str = str(constant_prop.graph)
+ self.assertTrue(graph_str.count("prim::None") == 0)
+
def test_trace_records_names(self):
def foo(bar, baz):
baz = bar + 3
auto op = getOperation(n);
Stack stack;
for (auto input : n->inputs()) {
- stack.push_back(*(toIValue(input)));
+ if (input->node()->kind() == prim::None) {
+ stack.emplace_back(IValue());
+ } else {
+ stack.push_back(*(toIValue(input)));
+ }
}
op(stack);
auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
}
void propagateNode(Node* n) {
- auto outputs = runNode(n);
+ std::vector<IValue> outputs;
+ try {
+ outputs = runNode(n);
+ } catch (const c10::Error& e) {
+ // catch AT_ASSERT errors. This op may not be run reached,
+ // so catch the error here & leave the op in the graph
+ return;
+ }
auto graph = n->owningGraph();
WithInsertPoint guard(n);
for (size_t i = 0; i < outputs.size(); ++i) {
void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) {
bool constant_inputs =
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
- return v->node()->kind() == prim::Constant;
+ return v->node()->kind() == prim::Constant || v->node()->kind() == prim::None;
});
bool supported_node = !n->kind().is_onnx() &&
skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
[](const Node* node) -> Operation {
return [=](Stack& stack) {
auto val = pop(stack);
- JIT_ASSERTM(!val.isNone(), "Unwrapping null optional");
+ AT_CHECK(!val.isNone(), "Unwrapping null optional");
push(stack, val);
return 0;
};
if full:
mask = target > 1
loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask]
- if reduction is 'none':
+ if reduction == 'none':
ret = loss
- if reduction is 'mean':
+ if reduction == 'mean':
ret = torch.mean(loss)
else:
ret = torch.sum(loss)