return result;
}
-template <typename scalar_t, bool inplace, bool upper>
+template <typename scalar_t, bool upper>
static void apply_triu_tril_single(
- scalar_t* result, scalar_t* self,
+ scalar_t* result, scalar_t* self, bool inplace,
int64_t k, int64_t n, int64_t m,
int64_t res_row_stride, int64_t res_col_stride,
int64_t self_row_stride, int64_t self_col_stride) {
}
}
-template <typename scalar_t, bool inplace, bool upper>
-void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) {
+template <typename scalar_t, bool upper>
+void apply_triu_tril(Tensor& result, const Tensor& self, bool inplace, int64_t k) {
auto n = self.size(-2);
auto m = self.size(-1);
auto self_data = self.data<scalar_t>();
for (b = 0; b < batchsize; b++) {
scalar_t* self_batch = &self_data[b * self_stride];
scalar_t* result_batch = &result_data[b * result_stride];
- apply_triu_tril_single<scalar_t, inplace, upper>(
- result_batch, self_batch, k, n, m,
+ apply_triu_tril_single<scalar_t, upper>(
+ result_batch, self_batch, inplace, k, n, m,
result_row_stride, result_column_stride, self_row_stride, self_column_stride);
}
}
if (self.numel() == 0) {
return self;
}
- if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+ bool inplace = checkTrilTriuBatchContiguous(self);
+ Tensor self_c = inplace ? self : self.contiguous();
+ Tensor result = inplace ? self : at::empty_like(self);
AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
- apply_triu_tril<scalar_t, true, false>(self, self, k);
+ apply_triu_tril<scalar_t, false>(result, self_c, inplace, k);
});
+ if (!inplace) self.copy_(result);
return self;
}
}
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
- apply_triu_tril<scalar_t, false, false>(result, self_c, k);
+ apply_triu_tril<scalar_t, false>(result, self_c, false, k);
});
return result;
}
if (self.numel() == 0) {
return self;
}
- if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+ bool inplace = checkTrilTriuBatchContiguous(self);
+ Tensor self_c = inplace ? self : self.contiguous();
+ Tensor result = inplace ? self : at::empty_like(self);
AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
- apply_triu_tril<scalar_t, true, true>(self, self, k);
+ apply_triu_tril<scalar_t, true>(result, self_c, inplace, k);
});
+ if (!inplace) self.copy_(result);
return self;
}
}
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
- apply_triu_tril<scalar_t, false, true>(result, self_c, k);
+ apply_triu_tril<scalar_t, true>(result, self_c, false, k);
});
return result;
}
}
Tensor& tril_cuda_(Tensor &self, int64_t k) {
- if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
- return tril_cuda_out(self, self, k);
+ bool inplace = checkTrilTriuBatchContiguous(self);
+ Tensor self_c = inplace ? self : self.contiguous();
+ Tensor result = inplace ? self : at::empty_like(self);
+ tril_cuda_out(result, self_c, k);
+ if (!inplace) self.copy_(result);
+ return self;
}
Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
}
Tensor& triu_cuda_(Tensor &self, int64_t k) {
- if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
- return triu_cuda_out(self, self, k);
+ bool inplace = checkTrilTriuBatchContiguous(self);
+ Tensor self_c = inplace ? self : self.contiguous();
+ Tensor result = inplace ? self : at::empty_like(self);
+ triu_cuda_out(result, self_c, k);
+ if (!inplace) self.copy_(result);
+ return self;
}
Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
assert not x_nc.is_contiguous(), "x is intentionally non-contiguous"
exp_nc = torch.where(exp_mask, torch.tensor(0).type_as(x), x_nc)
self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0)
+ x_nc_is_contiguous = x_nc.is_contiguous()
if upper:
self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0)
else:
self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0)
- # any 3-dimensional tensor should be fine
- if len(shape) <= 3 or s == -2:
- self.assertFalse(x_nc.is_contiguous(),
- "x_nc should remain non-contiguous")
- elif s < -3:
- self.assertTrue(x_nc.is_contiguous(),
- "x_nc should become contiguous")
+ self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous,
+ "contiguity of x_nc should not be changed")
# expanded tensors
expanded_size = (x.size(0),) + x.size()