Torch-MLIR generates a pattern of reduce, element-wise operation, broadcast. In vector FHE, the rotate-and-reduce sum populates all slots in the vector with the sum, not just one element. It could be helpful to remove the broadcast operation, but linalg.reduce discards dimensions, forcing the broadcast. What would be a good way to symbolize a reduction that keeps the dimensions?
Sample from torch.layernorm:
module {
func.func @main(%arg0: tensor<2x64x768xf32> {secret.secret}) -> tensor<2x64x768xf32> {
%cst = arith.constant dense<0.000000e+00> : tensor<2x64xf32>
%cst_0 = arith.constant dense<7.680000e+02> : tensor<2x64xf32>
%reduced = linalg.reduce ins(%arg0 : tensor<2x64x768xf32>) outs(%cst : tensor<2x64xf32>) dimensions = [2]
(%in: f32, %init: f32) {
%2 = arith.addf %in, %init : f32
linalg.yield %2 : f32
}
%0 = arith.divf %reduced, %cst_0 : tensor<2x64xf32>
%1 = tensor.empty() : tensor<2x64x768xf32>
// Additional broadcast
%broadcasted = linalg.broadcast ins(%0 : tensor<2x64xf32>) outs(%1 : tensor<2x64x768xf32>) dimensions = [2]
return %broadcasted : tensor<2x64x768xf32>
}
}
Could become something like:
func.func @main(%arg0: tensor<2x64x768xf32> {secret.secret}) -> tensor<2x64x768xf32> {
%cst = arith.constant dense<0.000000e+00> : tensor<2x64x768xf32>
%cst_0 = arith.constant dense<7.680000e+02> : tensor<2x64x768xf32>
// Something similar that keeps dimensions
%reduced = linalg.reduce ins(%arg0 : tensor<2x64x768xf32>) outs(%cst : tensor<2x64x768xf32>) dimensions = [2]
(%in: f32, %init: f32) {
%2 = arith.addf %in, %init : f32
linalg.yield %2 : f32
}
%0 = arith.divf %reduced, %cst_0 : tensor<2x64x768xf32>
return %0 : tensor<2x64x768xf32>
}
}
Torch-MLIR generates a pattern of reduce, element-wise operation, broadcast. In vector FHE, the rotate-and-reduce sum populates all slots in the vector with the sum, not just one element. It could be helpful to remove the broadcast operation, but linalg.reduce discards dimensions, forcing the broadcast. What would be a good way to symbolize a reduction that keeps the dimensions?
Sample from torch.layernorm:
Could become something like: