Skip to content

Commit 357e961

Browse files
committed
Add test of new symmetry detection
1 parent 3d33eba commit 357e961

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=transpose_symmetric_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
2+
3+
func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
4+
%alpha = stablehlo.constant dense<2.0> : tensor<f32>
5+
%beta = stablehlo.constant dense<3.0> : tensor<f32>
6+
%c = stablehlo.constant dense<[[4.0, 3.0], [3.0, 4.0]]> : tensor<2x2xf32>
7+
%0 = enzymexla.lapack.symm %c, %arg0, %arg1, %alpha, %beta {side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
8+
%1 = stablehlo.subtract %0, %c : tensor<2x2xf32>
9+
%2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
10+
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
11+
return %3 : tensor<2x2xf32>
12+
}
13+
14+
// CHECK: func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
15+
// CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<f32>
16+
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<f32>
17+
// CHECK-NEXT: %cst_1 = stablehlo.constant {enzymexla.guaranteed_symmetric = true} dense<{{\[\[}}4.000000e+00, 3.000000e+00], [3.000000e+00, 4.000000e+00{{\]\]}}> : tensor<2x2xf32>
18+
// CHECK-NEXT: %0 = enzymexla.lapack.symm %cst_1, %arg0, %arg1, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
19+
// CHECK-NEXT: %1 = stablehlo.subtract %0, %cst_1 {enzymexla.guaranteed_symmetric = true} : tensor<2x2xf32>
20+
// CHECK-NEXT: %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
21+
// CHECK-NEXT: return %2 : tensor<2x2xf32>
22+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)