Skip to content

ImplementShiftNetwork: Reduce the number of conflicts when there are multiple sources that map to a target slot #2350

@asraa

Description

@asraa

Consider the following IR:

#layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = i1 and (2i2 - i3 + slot) mod 4 = 0 and 0 <= i1 <= 1 and 0 <= i2 <= 1 and 0 <= i3 <= 1 and 0 <= slot <= 1023 }">
#layout1 = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : exists (e0, e1, e2: i0 = 0 and ct = 0 and 16e2 = -i1 + slot + 16e0 and 0 <= i1 <= 1023 and 0 <= slot <= 1023 and 0 <= e1 <= 3 and -3 + i1 - 16e0 <= 4e1 <= i1 - 16e0) }">
#original_type = #tensor_ext.original_type<originalType = tensor<1x2x2x2xf32>, layout = #layout>
module {
  func.func @main(%arg0: !secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #tensor_ext.original_type<originalType = tensor<1x1x4x4xf32>, layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-4i2 - i3 + slot) mod 16 = 0 and 0 <= i2 <= 3 and 0 <= i3 <= 3 and 0 <= slot <= 1023 }">>}, %arg1: tensor<2x1x3x3xf32>) -> (!secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #original_type}) {
    %0 = secret.generic(%arg0: !secret.secret<tensor<1x1024xf32>>) {
    ^body(%input0: tensor<1x1024xf32>):
      %1 = tensor_ext.remap %input0 {permutation = #layout1} : tensor<1x1024xf32>
      secret.yield %1 : tensor<1x1024xf32>
    } -> !secret.secret<tensor<1x1024xf32>>
    return %0 : !secret.secret<tensor<1x1024xf32>>
  }
}

In this case we are remapping the input according to a permutation that was generated from slice extraction from tensor<1x1x4x4> to tensor<4x4>. Since the ciphertext size is 1024, the original 4x4 tensor was repeated (1024/16) times in the source. The slice extraction preserved the layout so that the result also had the same layout (and was also repeated the same way in the ciphertext). When the remap was lowered to a shift network, the Mapping generated a number of conflicts since the original 16 element tensor was repeated and the mapping could choose multiple of the same sources to map to a target. For example, it would generate the following IR if the mapping de-duped the sources that mapped to a target and choose them by the first that appear in the point pair collector:

#layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = i1 and (2i2 - i3 + slot) mod 4 = 0 and 0 <= i1 <= 1 and 0 <= i2 <= 1 and 0 <= i3 <= 1 and 0 <= slot <= 1023 }">
#original_type = #tensor_ext.original_type<originalType = tensor<1x2x2x2xf32>, layout = #layout>
module {
  func.func @main(%arg0: !secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #tensor_ext.original_type<originalType = tensor<1x1x4x4xf32>, layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-4i2 - i3 + slot) mod 16 = 0 and 0 <= i2 <= 3 and 0 <= i3 <= 3 and 0 <= slot <= 1023 }">>}, %arg1: tensor<2x1x3x3xf32>) -> (!secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #original_type}) {
    %c16 = arith.constant 16 : index
    %cst = arith.constant dense_resource<__elided__> : tensor<1024xf32>
    %c32 = arith.constant 32 : index
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
    %c64 = arith.constant 64 : index
    %cst_1 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
    %c128 = arith.constant 128 : index
    %cst_2 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
    %c256 = arith.constant 256 : index
    %cst_3 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
    %c512 = arith.constant 512 : index
    %cst_4 = arith.constant dense_resource<__elided__> : tensor<1024xf32>
    %0 = tensor.empty() : tensor<1x1024xf32>
    %1 = secret.generic(%arg0: !secret.secret<tensor<1x1024xf32>>) {
    ^body(%input0: tensor<1x1024xf32>):
      %extracted_slice = tensor.extract_slice %input0[0, 0] [1, 1024] [1, 1] : tensor<1x1024xf32> to tensor<1024xf32>
      %2 = arith.mulf %extracted_slice, %cst_4 : tensor<1024xf32>
      %3 = tensor_ext.rotate %2, %c16 : tensor<1024xf32>, index
      %4 = arith.addf %2, %3 : tensor<1024xf32>
      %5 = arith.mulf %4, %cst : tensor<1024xf32>
      %6 = tensor_ext.rotate %5, %c32 : tensor<1024xf32>, index
      %7 = arith.addf %5, %6 : tensor<1024xf32>
      %8 = arith.mulf %7, %cst_0 : tensor<1024xf32>
      %9 = tensor_ext.rotate %8, %c64 : tensor<1024xf32>, index
      %10 = arith.addf %8, %9 : tensor<1024xf32>
      %11 = arith.mulf %10, %cst_1 : tensor<1024xf32>
      %12 = tensor_ext.rotate %11, %c128 : tensor<1024xf32>, index
      %13 = arith.addf %11, %12 : tensor<1024xf32>
      %14 = arith.mulf %13, %cst_2 : tensor<1024xf32>
      %15 = tensor_ext.rotate %14, %c256 : tensor<1024xf32>, index
      %16 = arith.addf %14, %15 : tensor<1024xf32>
      %17 = arith.mulf %16, %cst_3 : tensor<1024xf32>
      %18 = tensor_ext.rotate %17, %c512 : tensor<1024xf32>, index
      %19 = arith.addf %17, %18 : tensor<1024xf32>
      %inserted_slice = tensor.insert_slice %19 into %0[0, 0] [1, 1024] [1, 1] : tensor<1024xf32> into tensor<1x1024xf32>
      secret.yield %inserted_slice : tensor<1x1024xf32>
    } -> !secret.secret<tensor<1x1024xf32>>
    return %1 : !secret.secret<tensor<1x1024xf32>>
  }
}

Where the constants are selecting groups of 16 floating point elements.

Instead of arbitrarily choosing the source that maps to a target, we can also select the one that has the closest distance to the target, which will result in choosing the source point in a repeated block closest to the target and de-dupe them in a way that might uniqueify which source maps to each target. Computing distance can be done with the "virtual distance": ciphertext_diff * ciphertextSize + slot_diff. If we do that, then the resulting IR would be

#layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = i1 and (2i2 - i3 + slot) mod 4 = 0 and 0 <= i1 <= 1 and 0 <= i2 <= 1 and 0 <= i3 <= 1 and 0 <= slot <= 1023 }">
#original_type = #tensor_ext.original_type<originalType = tensor<1x2x2x2xf32>, layout = #layout>
module {
  func.func @main(%arg0: !secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #tensor_ext.original_type<originalType = tensor<1x1x4x4xf32>, layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-4i2 - i3 + slot) mod 16 = 0 and 0 <= i2 <= 3 and 0 <= i3 <= 3 and 0 <= slot <= 1023 }">>}, %arg1: tensor<2x1x3x3xf32>) -> (!secret.secret<tensor<1x1024xf32>> {tensor_ext.original_type = #original_type}) {
    %0 = tensor.empty() : tensor<1x1024xf32>
    %1 = secret.generic(%arg0: !secret.secret<tensor<1x1024xf32>>) {
    ^body(%input0: tensor<1x1024xf32>):
      %extracted_slice = tensor.extract_slice %input0[0, 0] [1, 1024] [1, 1] : tensor<1x1024xf32> to tensor<1024xf32>
      %inserted_slice = tensor.insert_slice %extracted_slice into %0[0, 0] [1, 1024] [1, 1] : tensor<1024xf32> into tensor<1x1024xf32>
      secret.yield %inserted_slice : tensor<1x1024xf32>
    } -> !secret.secret<tensor<1x1024xf32>>
    return %1 : !secret.secret<tensor<1x1024xf32>>
  }
}

!!!

Obviously this is just one way to find a way to minimize conflicts in the graph. But it will only work for patterns like this where the source and target are repeated an equal amount of time. If we hold on to all the sources that may map to the target, then there is probably a better way to prune the conflict graph.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions