Skip to content

Torch-MLIR Reduction + Broadcast Pattern #2876

@Someone117

Description

@Someone117

Torch-MLIR generates a pattern of reduce, element-wise operation, broadcast. In vector FHE, the rotate-and-reduce sum populates all slots in the vector with the sum, not just one element. It could be helpful to remove the broadcast operation, but linalg.reduce discards dimensions, forcing the broadcast. What would be a good way to symbolize a reduction that keeps the dimensions?

Sample from torch.layernorm:

module {
  func.func @main(%arg0: tensor<2x64x768xf32> {secret.secret}) -> tensor<2x64x768xf32> {
    %cst = arith.constant dense<0.000000e+00> : tensor<2x64xf32>
    %cst_0 = arith.constant dense<7.680000e+02> : tensor<2x64xf32>
    %reduced = linalg.reduce ins(%arg0 : tensor<2x64x768xf32>) outs(%cst : tensor<2x64xf32>) dimensions = [2] 
      (%in: f32, %init: f32) {
        %2 = arith.addf %in, %init : f32
        linalg.yield %2 : f32
      }
    %0 = arith.divf %reduced, %cst_0 : tensor<2x64xf32>
    %1 = tensor.empty() : tensor<2x64x768xf32>
    // Additional broadcast
    %broadcasted = linalg.broadcast ins(%0 : tensor<2x64xf32>) outs(%1 : tensor<2x64x768xf32>) dimensions = [2] 
    return %broadcasted : tensor<2x64x768xf32>
  }
}

Could become something like:

  func.func @main(%arg0: tensor<2x64x768xf32> {secret.secret}) -> tensor<2x64x768xf32> {
    %cst = arith.constant dense<0.000000e+00> : tensor<2x64x768xf32>
    %cst_0 = arith.constant dense<7.680000e+02> : tensor<2x64x768xf32>
    // Something similar that keeps dimensions
    %reduced = linalg.reduce ins(%arg0 : tensor<2x64x768xf32>) outs(%cst : tensor<2x64x768xf32>) dimensions = [2] 
      (%in: f32, %init: f32) {
        %2 = arith.addf %in, %init : f32
        linalg.yield %2 : f32
      }
    %0 = arith.divf %reduced, %cst_0 : tensor<2x64x768xf32>
    return %0 : tensor<2x64x768xf32>
  }
}

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions