Skip to content

Commit e3c976f

Browse files
committed
Implement constant checks for symmetry
1 parent 5d0f44c commit e3c976f

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,11 +593,45 @@ SymmetricResultAnalysis initSymmetricResultAnalysis() {
593593
}
594594

595595
bool SymmetricResultAnalysis::constantIntCheck(DenseElementsAttr attr) {
596-
return false; // TODO
596+
auto type = dyn_cast<RankedTensorType>(attr.getType());
597+
if (!type || type.getRank() != 2)
598+
return false;
599+
600+
int64_t n = type.getDimSize(0);
601+
int64_t m = type.getDimSize(1);
602+
if (n != m)
603+
return false;
604+
605+
auto values = attr.getValues<APInt>();
606+
for (int64_t i = 0; i < n; ++i) {
607+
for (int64_t j = 0; j < n; ++j) {
608+
if (values[i * n + j] != values[j * n + i])
609+
return false;
610+
}
611+
}
612+
613+
return true;
597614
}
598615

599616
bool SymmetricResultAnalysis::constantFloatCheck(DenseElementsAttr attr) {
600-
return false; // TODO
617+
auto type = dyn_cast<RankedTensorType>(attr.getType());
618+
if (!type || type.getRank() != 2)
619+
return false;
620+
621+
int64_t n = type.getDimSize(0);
622+
int64_t m = type.getDimSize(1);
623+
if (n != m)
624+
return false;
625+
626+
auto values = attr.getValues<APFloat>();
627+
for (int64_t i = 0; i < n; ++i) {
628+
for (int64_t j = 0; j < n; ++j) {
629+
if (values[i * n + j] != values[j * n + i])
630+
return false;
631+
}
632+
}
633+
634+
return true;
601635
}
602636

603637
SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(

0 commit comments

Comments
 (0)