graph_str = """
graph(%a : Tensor, %b : Tensor):
- %c : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a, %b)
+ %c : Tensor = aten::mul(%a, %b)
return (%c)
"""
graph = torch._C.parse_ir(graph_str)
# Inject shape info and try compiling again
example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
torch._C._te.annotate_input_shapes(graph, example_inputs)
-
- # TODO: once we have shape propagation as well we should erase type
- # info for %c from the input IR and run shape propagation here - it
- # should be able to reconstruct that info
+ torch._C._jit_pass_propagate_shapes_on_graph(graph)
# Now compilation should pass
kernel = torch._C._te.TensorExprKernel(graph)
np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
- @unittest.skip("Does not work until shape propagation is implemented")
def test_kernel_shape_prop_module(self):
class TestModule(torch.nn.Module):
def forward(self, x, y):
assert exception_thrown
# Remove 'self' argument and try annotating shapes one more time
- graph = torch._C._te.remove_unused_self_argument(graph)
+ torch._C._te.remove_unused_self_argument(graph)
# Inject shape info and try compiling again
torch._C._te.annotate_input_shapes(graph, example_inputs)
-
- # TODO: once we have shape propagation as well we should erase type
- # info for %c from the input IR and run shape propagation here - it
- # should be able to reconstruct that info
+ torch._C._jit_pass_propagate_shapes_on_graph(graph)
# Now compilation should pass
kernel = torch._C._te.TensorExprKernel(graph)
return false;
}
+void annotateInputShapes(
+ const std::shared_ptr<Graph>& graph,
+ const std::vector<c10::optional<at::Tensor>>& example_inputs) {
+ TORCH_INTERNAL_ASSERT(
+ graph->inputs().size() == example_inputs.size(),
+ buildErrorMessage("Given inputs do not match the fuser graph inputs."));
+ for (size_t idx = 0; idx < example_inputs.size(); idx++) {
+ if (auto t = example_inputs[idx]) {
+ auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
+ graph->inputs().at(idx)->setType(concrete_tensor_type);
+ }
+ }
+}
+
+std::shared_ptr<Graph> removeUnusedSelfArgument(
+ const std::shared_ptr<Graph>& graph) {
+ if (graph->inputs().size() == 0) {
+ return graph;
+ }
+ jit::Value* self_argument = graph->inputs().at(0);
+ if (self_argument->uses().size() != 0 ||
+ !self_argument->type()->is_module()) {
+ return graph;
+ }
+ graph->eraseInput(0);
+ return graph;
+}
+
} // namespace tensorexpr
} // namespace jit
} // namespace torch
bool OptimizeCat(const std::shared_ptr<Graph>& graph);
+TORCH_API void annotateInputShapes(
+ const std::shared_ptr<Graph>& graph,
+ const std::vector<c10::optional<at::Tensor>>& example_inputs);
+TORCH_API std::shared_ptr<Graph> removeUnusedSelfArgument(
+ const std::shared_ptr<Graph>& graph);
+
} // namespace tensorexpr
} // namespace jit
} // namespace torch
return true;
}
-void annotateInputShapes(
- const std::shared_ptr<Graph>& graph,
- const std::vector<c10::optional<at::Tensor>>& example_inputs) {
- TORCH_INTERNAL_ASSERT(
- graph->inputs().size() == example_inputs.size(),
- buildErrorMessage("Given inputs do not match the fuser graph inputs."));
- for (size_t idx = 0; idx < example_inputs.size(); idx++) {
- if (auto t = example_inputs[idx]) {
- auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
- graph->inputs().at(idx)->setType(concrete_tensor_type);
- }
- }
-}
-
-std::shared_ptr<Graph> removeUnusedSelfArgument(
- const std::shared_ptr<Graph>& graph) {
- if (graph->inputs().size() == 0) {
- return graph;
- }
- jit::Value* self_argument = graph->inputs().at(0);
- if (self_argument->uses().size() != 0 ||
- !self_argument->type()->is_module()) {
- return graph;
- }
- std::shared_ptr<Graph> res = graph->copy();
- res->eraseInput(0);
- return res;
-}
-
std::vector<ExprHandle> valueShape(const ArgValue& v) {
if (auto b = c10::get_if<tensorexpr::BufHandle>(&v)) {
return b->dims();
TORCH_API c10::optional<at::Device> pickDeviceType(
const at::ArrayRef<torch::jit::Value*>& inputs);
-TORCH_API void annotateInputShapes(
- const std::shared_ptr<Graph>& graph,
- const std::vector<c10::optional<at::Tensor>>& example_inputs);
-TORCH_API std::shared_ptr<Graph> removeUnusedSelfArgument(
- const std::shared_ptr<Graph>& graph);
-
} // namespace tensorexpr
} // namespace jit
} // namespace torch
#ifdef USE_CUDA
#include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
#endif
+#include <torch/csrc/jit/tensorexpr/graph_opt.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>