}
// assume JIT not supporting complex and qint yet
- TORCH_INTERNAL_ASSERT((typeConstraints & (kQintTypes | kComplexTypes)) == 0);
+ TORCH_INTERNAL_ASSERT(
+ (typeConstraints & (kQintTypes | kComplexTypes)) == 0,
+ buildErrorMessage(
+ "Qint and Complex types are not supported in the fuser."));
return false;
}
namespace jit {
namespace tensorexpr {
+std::string buildErrorMessage(const std::string& s) {
+ // TODO: Update this generic error message to include details regarding
+ // turning off the fuser.
+ static const std::string generic_error_message = "";
+ return s + " " + generic_error_message;
+}
+
static int te_cuda_pointwise_loop_levels = -1;
static int te_cuda_pointwise_block_count = -1;
static int te_cuda_pointwise_block_size = -1;
for (auto const& input : node->inputs()) {
if (auto tt = input->type()->cast<TensorType>()) {
if (auto inputDevice = tt->device()) {
- TORCH_INTERNAL_ASSERT(!device || *device == *inputDevice);
+ TORCH_INTERNAL_ASSERT(
+ !device || *device == *inputDevice,
+ buildErrorMessage(
+ "Different devices specified for inputs to the fuser."));
device = inputDevice;
}
}
}
}
- TORCH_INTERNAL_ASSERT(device);
+ TORCH_INTERNAL_ASSERT(
+ device,
+ buildErrorMessage("Could not find device in fuser graph inputs."));
return device;
}
void annotateInputShapes(
const std::shared_ptr<Graph>& graph,
const std::vector<c10::optional<at::Tensor>>& example_inputs) {
- TORCH_INTERNAL_ASSERT(graph->inputs().size() == example_inputs.size());
+ TORCH_INTERNAL_ASSERT(
+ graph->inputs().size() == example_inputs.size(),
+ buildErrorMessage("Given inputs do not match the fuser graph inputs."));
for (size_t idx = 0; idx < example_inputs.size(); idx++) {
if (auto t = example_inputs[idx]) {
auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
throw std::runtime_error("Empty input list is passed to aten::cat");
}
- TORCH_INTERNAL_ASSERT(n->input(1)->node()->kind() == prim::Constant);
+ TORCH_INTERNAL_ASSERT(
+ n->input(1)->node()->kind() == prim::Constant,
+ buildErrorMessage(
+ "aten::cat op's dim input is not constant in fuser."));
int64_t dim = n->input(1)->node()->i(attr::value);
auto shape = sizesForValue(inputs[0]);
auto norm_dim = normalizeAndCheckIndex(dim, shape.size());
blockSize = default_uint8_blocksize;
}
std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
- TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty");
+ TORCH_INTERNAL_ASSERT(
+ !loops.empty(),
+ buildErrorMessage(
+ "No loops found for the buffer " + buf->name_hint() +
+ " in the fuser."));
ForPtr flattened = nullptr;
LoopNest::flatten(loops, &flattened);
assert(flattened);