Skip to content

Commit 8c2e29e

Browse files
snonkavik-pal
andauthored
add constant int/float check for symm (#1621)
* constant check * fmt * add splat and broadcast: * fix: formatting * refactor: cleanup into a single function * add test --------- Co-authored-by: Avik Pal <[email protected]>
1 parent aded9a2 commit 8c2e29e

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,12 +592,55 @@ SymmetricResultAnalysis initSymmetricResultAnalysis() {
592592
return SymmetricResultAnalysis();
593593
}
594594

595+
bool checkNotEqual(APInt a, APInt b) { return a != b; }
596+
597+
bool checkNotEqual(APFloat a, APFloat b) {
598+
return a.compare(b) != llvm::APFloat::cmpEqual;
599+
}
600+
601+
template <typename Ty> bool checkConstantSymmetric(DenseElementsAttr attr) {
602+
if (!attr)
603+
return false;
604+
605+
auto type = dyn_cast<RankedTensorType>(attr.getType());
606+
if (!type)
607+
return false;
608+
609+
if (type.getRank() == 0)
610+
return true;
611+
if (type.getRank() != 2)
612+
return false;
613+
614+
auto shape = type.getShape();
615+
int64_t rows = shape[0];
616+
int64_t cols = shape[1];
617+
618+
if (rows != cols)
619+
return false;
620+
if (attr.isSplat())
621+
return true;
622+
623+
auto values = attr.getValues<Ty>();
624+
auto it = values.begin();
625+
626+
for (int64_t i = 0; i < rows; i++) {
627+
for (int64_t j = i + 1; j < cols; j++) {
628+
auto a = *(it + i * cols + j);
629+
auto b = *(it + j * cols + i);
630+
if (checkNotEqual(a, b))
631+
return false;
632+
}
633+
}
634+
635+
return true;
636+
}
637+
595638
bool SymmetricResultAnalysis::constantIntCheck(DenseElementsAttr attr) {
596-
return false; // TODO
639+
return checkConstantSymmetric<APInt>(attr);
597640
}
598641

599642
bool SymmetricResultAnalysis::constantFloatCheck(DenseElementsAttr attr) {
600-
return false; // TODO
643+
return checkConstantSymmetric<APFloat>(attr);
601644
}
602645

603646
SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
@@ -633,6 +676,15 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
633676

634677
// TODO: check for dot_general as well
635678

679+
if (auto broadcastOp = dyn_cast<stablehlo::BroadcastInDimOp>(op)) {
680+
auto operand = broadcastOp.getOperand();
681+
auto operandTy = cast<RankedTensorType>(operand.getType());
682+
auto dims = broadcastOp.getBroadcastDimensions();
683+
if (operandTy.getRank() == 0 && dims.empty()) {
684+
return State::GUARANTEED;
685+
}
686+
}
687+
636688
// commutative operation with A and A^T will always be symmetric
637689
// op(A, A^T) will also always be symmetric
638690
if (stablehlo::hasTraitElementwise(op) &&

test/lit_tests/structured_tensors/transpose_symmetric.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,26 @@ func.func @pass2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
2626
// CHECK-NEXT: return %1 : tensor<2x2xf32>
2727
// CHECK-NEXT: }
2828

29+
func.func @pass3() -> tensor<3x3xf32> {
30+
%cst = stablehlo.constant dense<3.000000e+00> : tensor<f32>
31+
%cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
32+
%0 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<3x3xf32>
33+
%1 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<3x3xf32>
34+
%2 = stablehlo.transpose %0, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
35+
%3 = stablehlo.transpose %1, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
36+
%4 = stablehlo.add %2, %3 : tensor<3x3xf32>
37+
return %4 : tensor<3x3xf32>
38+
}
39+
40+
// CHECK: func.func @pass3() -> tensor<3x3xf32> {
41+
// CHECK-NEXT: %cst = stablehlo.constant dense<3.000000e+00> : tensor<f32>
42+
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
43+
// CHECK-NEXT: %0 = stablehlo.broadcast_in_dim %cst_0, dims = [] {enzymexla.guaranteed_symmetric = true} : (tensor<f32>) -> tensor<3x3xf32>
44+
// CHECK-NEXT: %1 = stablehlo.broadcast_in_dim %cst, dims = [] {enzymexla.guaranteed_symmetric = true} : (tensor<f32>) -> tensor<3x3xf32>
45+
// CHECK-NEXT: %2 = stablehlo.add %0, %1 : tensor<3x3xf32>
46+
// CHECK-NEXT: return %2 : tensor<3x3xf32>
47+
// CHECK-NEXT: }
48+
2949
func.func @fail1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
3050
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
3151
%1 = stablehlo.subtract %arg0, %0 : tensor<2x2xf32>

0 commit comments

Comments
 (0)