Skip to content

tensor_ext.remap doesn't handle shape expansion properly #2847

@asraa

Description

@asraa

This is derived from #2844

The reproducer with --convert-to-ciphertext-semantics=ciphertext-size=16384

#layout = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : ct = i0 and (-i1 + slot) mod 1024 = 0 and 0 <= i0 <= 127 and 0 <= i1 <= 767 and 0 <= slot <= 16383 }">
#layout1 = #tensor_ext.layout<"{ [i0, i1, i2] -> [ct, slot] : i0 = 0 and (-768i1 - i2 + slot + 16384*floor((768i1 + i2)/16384)) mod 131072 = 0 and 0 <= i1 <= 127 and 0 <= i2 <= 767 and 0 <= ct <= 5 and -16383 + 768i1 + i2 <= 16384ct <= 768i1 + i2 and 0 <= slot <= 16383 }">
#layout2 = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : (-768i0 - i1 + slot + 16384*floor((768i0 + i1)/16384)) mod 131072 = 0 and 0 <= i0 <= 127 and 0 <= i1 <= 767 and -16383 + 768i0 + i1 <= 16384ct <= 768i0 + i1 and 0 <= slot <= 16383 }">
module {
  func.func @forward(%arg0: !secret.secret<tensor<1x128x768xf32>> {tensor_ext.layout = #layout1}) -> (!secret.secret<tensor<128x768xf32>> {tensor_ext.layout = #layout}) {
    %0 = secret.generic(%arg0: !secret.secret<tensor<1x128x768xf32>> {tensor_ext.layout = #layout1}) {
    ^body(%input0: tensor<1x128x768xf32>):
      %collapsed = tensor.collapse_shape %input0 [[0, 1], [2]] {tensor_ext.layout = #layout2} : tensor<1x128x768xf32> into tensor<128x768xf32>
      %1 = tensor_ext.convert_layout %collapsed {from_layout = #layout2, tensor_ext.layout = #layout, to_layout = #layout} : tensor<128x768xf32>
      secret.yield %1 : tensor<128x768xf32>
    } -> (!secret.secret<tensor<128x768xf32>> {tensor_ext.layout = #layout})
    return %0 : !secret.secret<tensor<128x768xf32>>
  }
}

The layout conversion here is taking the row-major layout, which for a tensor<128x768> requires 6 ciphertexts, and converting that to a per-row ciphertext layout, requiring 128 layouts. When the ConvertConverLayout pattern runs, it lowers this to a remap operation and a shape changing extract slice

#layout = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : (-i1 - 256ct + slot) mod 1024 = 0 and 0 <= i1 <= 16383 and 0 <= ct <= 127 and -767 + 16384i0 + i1 <= 768ct <= 16384i0 + i1 and 0 <= slot <= 16383 }">
#layout1 = #tensor_ext.layout<"{ [i0, i1, i2] -> [ct, slot] : i0 = 0 and (-768i1 - i2 + slot + 16384*floor((768i1 + i2)/16384)) mod 131072 = 0 and 0 <= i1 <= 127 and 0 <= i2 <= 767 and 0 <= ct <= 5 and -16383 + 768i1 + i2 <= 16384ct <= 768i1 + i2 and 0 <= slot <= 16383 }">
#layout2 = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : ct = i0 and (-i1 + slot) mod 1024 = 0 and 0 <= i0 <= 127 and 0 <= i1 <= 767 and 0 <= slot <= 16383 }">
#original_type = #tensor_ext.original_type<originalType = tensor<1x128x768xf32>, layout = #layout1>
#original_type1 = #tensor_ext.original_type<originalType = tensor<128x768xf32>, layout = #layout2>
"builtin.module"() ({
  "func.func"() <{arg_attrs = [{tensor_ext.original_type = #original_type}], function_type = (!secret.secret<tensor<6x16384xf32>>) -> !secret.secret<tensor<128x16384xf32>>, res_attrs = [{tensor_ext.original_type = #original_type1}], sym_name = "forward"}> ({
  ^bb0(%arg0: !secret.secret<tensor<6x16384xf32>>):
    %0 = "secret.generic"(%arg0) ({
    ^bb0(%arg1: tensor<6x16384xf32>):
      %1 = "tensor_ext.remap"(%arg1) <{permutation = #layout}> : (tensor<6x16384xf32>) -> tensor<6x16384xf32>
      %2 = "tensor.extract_slice"(%1) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0>, static_sizes = array<i64: 128, 16384>, static_strides = array<i64: 1, 1>}> : (tensor<6x16384xf32>) -> tensor<128x16384xf32>
      "secret.yield"(%2) : (tensor<128x16384xf32>) -> ()
    }) : (!secret.secret<tensor<6x16384xf32>>) -> !secret.secret<tensor<128x16384xf32>>
    "func.return"(%0) : (!secret.secret<tensor<128x16384xf32>>) -> ()
  }) : () -> ()
}) : () -> ()

BUT I think remap should have the output type be 128x16384... so remap needs to change types at some point, and remap right now expects the same input and output type. OR the input should be inserted into a larger 128 length tensor that is zero-padded, and then remapped.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions