@@ -46,6 +46,30 @@ func.func @pass3() -> tensor<3x3xf32> {
4646// CHECK-NEXT: return %2 : tensor<3x3xf32>
4747// CHECK-NEXT: }
4848
49+ func.func @pass4 (%arg0: tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 > {
50+ %0 = stablehlo.transpose %arg0 , dims = [1 , 0 ] : (tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
51+ %1 = stablehlo.dot_general %0 , %arg0 , contracting_dims = [1 ] x [0 ], precision = [DEFAULT , DEFAULT ] : (tensor <2 x2 xf32 >, tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
52+ %2 = stablehlo.transpose %1 , dims = [1 , 0 ] : (tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
53+ return %2 : tensor <2 x2 xf32 >
54+ }
55+
56+ // CHECK: func.func @pass4(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
57+ // CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
58+ // CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
59+ // CHECK-NEXT: return %1 : tensor<2x2xf32>
60+ // CHECK-NEXT: }
61+
62+ func.func @pass5 (%arg0: tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 > {
63+ %1 = stablehlo.dot_general %arg0 , %arg0 , contracting_dims = [0 ] x [0 ], precision = [DEFAULT , DEFAULT ] : (tensor <2 x2 xf32 >, tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
64+ %2 = stablehlo.transpose %1 , dims = [1 , 0 ] : (tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
65+ return %2 : tensor <2 x2 xf32 >
66+ }
67+
68+ // CHECK: func.func @pass5(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
69+ // CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
70+ // CHECK-NEXT: return %0 : tensor<2x2xf32>
71+ // CHECK-NEXT: }
72+
4973func.func @fail1 (%arg0: tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 > {
5074 %0 = stablehlo.transpose %arg0 , dims = [1 , 0 ] : (tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
5175 %1 = stablehlo.subtract %arg0 , %0 : tensor <2 x2 xf32 >
@@ -59,3 +83,17 @@ func.func @fail1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
5983// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
6084// CHECK-NEXT: return %2 : tensor<2x2xf32>
6185// CHECK-NEXT: }
86+
87+ func.func @fail2 (%arg0: tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 > {
88+ %0 = stablehlo.transpose %arg0 , dims = [1 , 0 ] : (tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
89+ %1 = stablehlo.dot_general %0 , %arg0 , contracting_dims = [1 ] x [1 ], precision = [DEFAULT , DEFAULT ] : (tensor <2 x2 xf32 >, tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
90+ %2 = stablehlo.transpose %1 , dims = [1 , 0 ] : (tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
91+ return %2 : tensor <2 x2 xf32 >
92+ }
93+
94+ // CHECK: func.func @fail2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
95+ // CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
96+ // CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
97+ // CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
98+ // CHECK-NEXT: return %2 : tensor<2x2xf32>
99+ // CHECK-NEXT: }
0 commit comments