_(prim, NoneGenerator) \
_(prim, MMTreeReduce) \
_(prim, MMBatchSide) \
+ _(aten, warn) \
_(aten, floordiv) \
_(aten, __round_to_zero_floordiv)\
_(prim, fork) \
%4 : bool = prim::TensorToBool(%3)
= prim::If(%4)
block0() {
- = prim::Print(%1)
+ = aten::warn(%1, %2)
-> ()
}
block1() {
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
auto node = *it;
- if (node->isNondeterministic() || node->kind() == prim::PythonOp ||
- node->kind() == prim::Print || aliasDb.hasWriters(node) ||
- aliasDb.hasWildcard(node)) {
+ if (node->kind() == prim::PythonOp || node->kind() == prim::Print ||
+ node->kind() == aten::warn || node->isNondeterministic() ||
+ aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
// Do NOT have enough information to do CSE on these nodes.
continue;
}
if (it != memo_.end())
return it->second;
bool has_side_effects = node->kind() == prim::Print ||
+ node->kind() == aten::warn ||
node->kind() == prim::RaiseException ||
node->kind() == prim::PythonOp ||
std::any_of(node->blocks().begin(),
};
}),
Operator(
+ FunctionSchema("aten::warn", {Argument("message", StringType::get()), Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)}, {}),
+ [](const Node* node) {
+ return [](Stack& stack) {
+ drop(stack, 1);
+ AT_WARN(pop(stack).toStringRef());
+ return 0;
+ };
+ }),
+
+ Operator(
"prim::RaiseException(str msg) -> ()",
[](const Node* node) -> Operation {
return [](Stack& stack) {
return torch.reciprocal(b) * a
)SCRIPT");
-auto python_builtins_source = R"SCRIPT(
-def warn(string: str):
- print(string)
-)SCRIPT";
-
-auto python_builtins_source_overloads = R"SCRIPT(
-def warn(string: str, stacklevel: int):
- print(string)
-)SCRIPT";
-
auto _ntuple_ops = CodeTemplate(
R"SCRIPT(
def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
env.s("Scalar", scalar);
loadSource(scalar_operators_source.format(env));
}
- loadSource(python_builtins_source);
- loadSource(python_builtins_source_overloads);
using str_pair = std::pair<std::string, std::string>;
const std::vector<str_pair> name_len = {