Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/ATen/native/xpu/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <ATen/native/BatchLinearAlgebra.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/ops/linalg_qr_native.h>
#include <ATen/ops/linalg_qr_cpu_dispatch.h>
#if defined(USE_ONEMKL_XPU)
#include <ATen/native/xpu/mkl/BatchLinearAlgebra.h>
#endif // USE_ONEMKL_XPU
Expand Down Expand Up @@ -64,4 +66,21 @@ void lu_factor_kernel_xpu(

REGISTER_XPU_DISPATCH(lu_factor_stub, &lu_factor_kernel_xpu);

TORCH_IMPL_FUNC(linalg_qr_xpu_out)(const Tensor& A,
std::string_view mode,
const Tensor & Q,
const Tensor & R) {
#if defined(USE_ONEMKL_XPU)
xpu::linalg_qr_kernel(A, mode, Q, R);
#else
auto A_cpu = A.to(at::kCPU);
auto Q_cpu = at::empty_like(Q, at::kCPU);
auto R_cpu = at::empty_like(R, at::kCPU);
at::cpu::linalg_qr_out(Q_cpu, R_cpu, A_cpu, mode);
Q.copy_(Q_cpu);
R.copy_(R_cpu);
#endif // USE_ONEMKL_XPU
}
Comment on lines +69 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion is to register geqrf_kerenl_xpu/orgqr_kernel_xpu to geqrf_stub/orgqr_stub, which allows us to reuse op level code in stock Pytorch and reuse these two kernels in future.



} // namespace at::native
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"linalg_lstsq.out",
"linalg_lu.out",
"linalg_matrix_exp",
"linalg_qr.out",
"linalg_solve_triangular",
"linalg_solve_triangular.out",
"_linalg_svd.U",
Expand Down
104 changes: 104 additions & 0 deletions src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,4 +561,108 @@ void lu_factor_mkl(
pivots.copy_(pivots_);
}

void linalg_qr_kernel(
const at::Tensor& A,
std::string_view mode,
const at::Tensor& Q,
const at::Tensor& R) {

//TORCH_CHECK(A.device().is_xpu(), "a must be an XPU tensor");
//TORCH_CHECK(A.dtype() == at::kFloat, "a must be float");

at::Tensor a_contig = A.contiguous();
at::Tensor result_r = at::clone(a_contig);

auto options = at::TensorOptions().dtype(at::kFloat).device(kXPU);
auto dimensions = A.sizes();

result_r = result_r.transpose(-2, -1).contiguous();

int numel = a_contig.numel();
int range = a_contig.dim();
int64_t n = a_contig.sizes().at(range - 2);
int64_t m = a_contig.sizes().at(range - 1);
int64_t mn = int64_t(m * n);
int64_t b = numel / mn;

int out_q_columns = m > n ? n : m;
if (n > m && mode == "complete") {
out_q_columns = n;
}

std::vector v(dimensions.begin(), dimensions.end());
if (mode != "r") {
v[range - 1] = v[range - 2];
v[range - 2] = out_q_columns;
} else {
v = std::vector<long>({0, 0});
}
auto q_dimensions = at::IntArrayRef(v);

at::Tensor result_q = at::empty(q_dimensions, options);



sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();

int64_t bufsize1 =
oneapi::mkl::lapack::geqrf_scratchpad_size<float>(queue, n, m, n);
int64_t bufsize2 =
oneapi::mkl::lapack::orgqr_scratchpad_size<float>(queue, n, m, m, n);

int64_t bufsize = bufsize2 > bufsize1 ? bufsize2 : bufsize1;
int64_t tau_len = m > n ? n : m;
float* sbuffer = sycl::malloc_device<float>(bufsize, queue);
float* tau_buf = sycl::malloc_device<float>(tau_len, queue);
float* r_buf = result_r.data_ptr<float>();

float* q_buf = NULL;
if (mode != "r") {
q_buf = result_q.data_ptr<float>();
}

for (int batch_item = 0; batch_item < b; batch_item++) {
oneapi::mkl::lapack::geqrf(queue, n, m, r_buf, n, tau_buf, sbuffer, bufsize)
.wait();

if (mode != "r") {
// copy relevant part of R matrix to Q matrix
int copy_columns = out_q_columns > m ? m : out_q_columns;
queue.memcpy(q_buf, r_buf, n * copy_columns * sizeof(float)).wait();

oneapi::mkl::lapack::orgqr(
queue,
n,
out_q_columns,
tau_len,
q_buf,
n,
tau_buf,
sbuffer,
bufsize)
.wait();

q_buf += n * out_q_columns;
}

r_buf += mn;

} // batch

sycl::free(sbuffer, queue);
sycl::free(tau_buf, queue);

if ((mode == "reduced" || mode == "r") && n > m) {
result_r =
result_r
.index(
{"...", at::indexing::Slice(0, n), at::indexing::Slice(0, m)})
.contiguous();
}

Q.set_(result_q.transpose(-2, -1));
R.set_(result_r.transpose(-2, -1).triu_());
queue.wait();
}

} // namespace at::native::xpu
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/mkl/BatchLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,10 @@ TORCH_XPU_API void lu_factor_mkl(
const Tensor& info,
bool pivot);

TORCH_XPU_API void linalg_qr_kernel(
const at::Tensor& A,
std::string_view mode,
const at::Tensor& Q,
const at::Tensor& R);

} // namespace at::native::xpu
53 changes: 53 additions & 0 deletions test/xpu/test_linalg_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,57 @@ def __tunableop_ctx(self):
pass


@parametrize("batch", [1, 3])
@parametrize("m", [0, 1, 12])
@parametrize("n", [0, 1, 17])
@dtypes(torch.float32)
def qr_mode_r(self, device, dtype, batch, m, n):
if batch > 1:
A_cpu = torch.randn(batch, m, n, dtype=dtype, device="cpu")
else:
A_cpu = torch.randn(m, n, dtype=dtype, device="cpu")
A_xpu = A_cpu.to(device)

R_cpu = torch.linalg.qr(A_cpu, mode="r").R
R_xpu = torch.linalg.qr(A_xpu, mode="r").R
self.assertEqual(R_xpu, R_cpu, atol=1e-5, rtol=1e-5)

# Verify that R is upper triangular
lower_triangle = torch.tril(R_xpu, diagonal=-1)
self.assertEqual(lower_triangle.sum(), 0.0, atol=0.0, rtol=0.0)


@parametrize("batch", [1, 3])
@parametrize("m", [0, 1, 12])
@parametrize("n", [0, 1, 17])
@parametrize("mode", ["reduced", "complete"])
@dtypes(torch.float32)
def qr_modes_reduced_complete(self, device, dtype, batch, m, n, mode):
if batch > 1:
A_cpu = torch.randn(batch, m, n, dtype=dtype, device="cpu")
else:
A_cpu = torch.randn(m, n, dtype=dtype, device="cpu")
A_xpu = A_cpu.to(device)

Q_cpu, R_cpu = torch.linalg.qr(A_cpu, mode=mode)
Q_xpu, R_xpu = torch.linalg.qr(A_xpu, mode=mode)

self.assertEqual(Q_xpu, Q_cpu, atol=1e-5, rtol=1e-5)
self.assertEqual(R_xpu, R_cpu, atol=1e-5, rtol=1e-5)

# Verify Q is orthogonal: Q^T @ Q should be identity
QTQ_xpu = torch.matmul(Q_xpu.mT, Q_xpu)
k = min(m, n) if mode == "reduced" else m
identity = torch.eye(k, dtype=dtype, device=device)
if batch > 1:
identity = identity.expand(batch, k, k)
self.assertEqual(QTQ_xpu, identity, atol=1e-5, rtol=1e-5)

# Verify that R is upper triangular
lower_triangle = torch.tril(R_xpu, diagonal=-1)
self.assertEqual(lower_triangle.sum(), 0.0, atol=0.0, rtol=0.0)


with XPUPatchForImport(False):
from test_linalg import TestLinalg

Expand All @@ -493,6 +544,8 @@ def __tunableop_ctx(self):
TestLinalg.test_ck_blas_library = ck_blas_library
TestLinalg.test_addmm_relu_tunableop_rocm = addmm_relu_tunableop_rocm
TestLinalg._tunableop_ctx = __tunableop_ctx
TestLinalg.test_qr_mode_r = qr_mode_r
TestLinalg.test_qr_modes_reduced_complete = qr_modes_reduced_complete

TestLinalg._default_dtype_check_enabled = True
instantiate_device_type_tests(TestLinalg, globals(), only_for=("xpu"), allow_xpu=True)
Expand Down
11 changes: 11 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9443,6 +9443,17 @@
- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor
python_module: linalg

- func: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
python_module: linalg
variants: function
structured_delegate: linalg_qr.out

- func: linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)
python_module: linalg
structured: True
dispatch:
XPU: linalg_qr_xpu_out

Comment on lines +9446 to +9456
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
python_module: linalg
structured_delegate: linalg_inv_ex.inverse
Expand Down
Loading