From 963580e203ec144660458c9e4515868d1718ae6e Mon Sep 17 00:00:00 2001 From: shiroha <3202778076@qq.com> Date: Wed, 19 Feb 2025 05:34:37 +0800 Subject: [PATCH 01/13] [midend/lib/Conversion/LowerLinalgToGemmini] Add pass support for gemmini to run E2E LeNet and some tests. --- .../conv_2d_nhwc_fhwc_5x5_i8.mlir | 32 + .../GemminiDialect/conv_2d_nhwc_fhwc_f32.mlir | 31 + .../GemminiDialect/conv_2d_nhwc_fhwc_i8.mlir | 30 + examples/GemminiDialect/makefile | 20 + examples/GemminiDialect/print.mlir | 31 + midend/include/Dialect/Gemmini/Gemmini.td | 56 +- .../LowerGemmini/LowerGemminiPass.cpp | 90 ++ .../LowerLinalgToGemmini.cpp | 182 ++- .../Transforms/LegalizeForLLVMExport.cpp | 1099 +++++++++-------- 9 files changed, 1009 insertions(+), 562 deletions(-) create mode 100644 examples/GemminiDialect/conv_2d_nhwc_fhwc_5x5_i8.mlir create mode 100644 examples/GemminiDialect/conv_2d_nhwc_fhwc_f32.mlir create mode 100644 examples/GemminiDialect/conv_2d_nhwc_fhwc_i8.mlir create mode 100644 examples/GemminiDialect/print.mlir diff --git a/examples/GemminiDialect/conv_2d_nhwc_fhwc_5x5_i8.mlir b/examples/GemminiDialect/conv_2d_nhwc_fhwc_5x5_i8.mlir new file mode 100644 index 0000000000..a7d7662b10 --- /dev/null +++ b/examples/GemminiDialect/conv_2d_nhwc_fhwc_5x5_i8.mlir @@ -0,0 +1,32 @@ +// RUN: buddy-opt %s \ +// RUN: --convert-linalg-to-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @input : memref<1x7x7x1xi8> = dense<[[[[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]]]]> + +memref.global "private" @kernel : memref<1x5x5x1xi8> = dense<[[[[1], [1], [1], [1], [1]], + [[1], [1], [1], [1], [1]], + [[1], [1], [1], [1], [1]], + [[1], [1], [1], [1], [1]], + [[1], [1], [1], [1], [1]]]]> + +func.func @main() -> i8 { + %0 = arith.constant 0 : i8 + %input = memref.get_global @input : memref<1x7x7x1xi8> + %kernel = memref.get_global @kernel : memref<1x5x5x1xi8> + %output = memref.alloc() : memref<1x3x3x1xi8> + + // CHECK: gemmini.tile_conv %{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} : + // CHECK-SAME: memref<1x7x7x1xi8> memref<25x1xi8> memref<1xi32> memref<9x1xi8> i64 i64 + linalg.conv_2d_nhwc_fhwc + ins(%input, %kernel : memref<1x7x7x1xi8>, memref<1x5x5x1xi8>) + outs(%output : memref<1x3x3x1xi8>) + gemmini.print %output : memref<1x3x3x1xi8> + return %0 : i8 +} diff --git a/examples/GemminiDialect/conv_2d_nhwc_fhwc_f32.mlir b/examples/GemminiDialect/conv_2d_nhwc_fhwc_f32.mlir new file mode 100644 index 0000000000..998b2d4388 --- /dev/null +++ b/examples/GemminiDialect/conv_2d_nhwc_fhwc_f32.mlir @@ -0,0 +1,31 @@ +// RUN: buddy-opt %s \ +// RUN: --convert-linalg-to-gemmini="acc_t=f32" | \ +// RUN: FileCheck %s + +memref.global "private" @input : memref<1x5x5x1xf32> = dense<[[[[1.],[2.],[3.],[4.],[5.]], + [[6.],[7.],[8.],[9.],[10.]], + [[11.],[12.],[13.],[14.],[15.]], + [[16.],[17.],[18.],[19.],[20.]], + [[21.],[22.],[23.],[24.],[25.]]]]> + +memref.global "private" @kernel : memref<1x3x3x1xf32> = dense<[[[[1.], [1.], [1.]], + [[1.], [1.], [1.]], + [[1.], [1.], [1.]]]]> + + +func.func @main() -> i8 { + %0 = arith.constant 0 : i8 + // batchsize = 2 inputchannel = 2 + %input = memref.get_global @input : memref<1x5x5x1xf32> + // outputchannel = 3 + %kernel = memref.get_global @kernel : memref<1x3x3x1xf32> + // batchsize h w outputchannel + %output = memref.alloc() : memref<1x3x3x1xf32> + // CHECK: gemmini.tile_conv %{{.+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} : + // CHECK: memref<1x5x5x1xf32> memref<9x1xf32> memref<1xf32> memref<9x1xf32> i64 i64 + linalg.conv_2d_nhwc_fhwc + ins(%input, %kernel : memref<1x5x5x1xf32>, memref<1x3x3x1xf32>) + outs(%output : memref<1x3x3x1xf32>) + gemmini.print %output : memref<1x3x3x1xf32> + return %0 : i8 +} diff --git a/examples/GemminiDialect/conv_2d_nhwc_fhwc_i8.mlir b/examples/GemminiDialect/conv_2d_nhwc_fhwc_i8.mlir new file mode 100644 index 0000000000..0bfeafca19 --- /dev/null +++ b/examples/GemminiDialect/conv_2d_nhwc_fhwc_i8.mlir @@ -0,0 +1,30 @@ +// RUN: buddy-opt %s \ +// RUN: --convert-linalg-to-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @input : memref<1x7x7x1xi8> = dense<[[[[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]]]]> + +memref.global "private" @kernel : memref<1x3x3x1xi8> = dense<[[[[1], [1], [1]], + [[1], [1], [1]], + [[1], [1], [1]]]]> + +func.func @main() -> i8 { + %0 = arith.constant 0 : i8 + %input = memref.get_global @input : memref<1x7x7x1xi8> + %kernel = memref.get_global @kernel : memref<1x3x3x1xi8> + %output = memref.alloc() : memref<1x5x5x1xi8> + + // CHECK: gemmini.tile_conv %{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} : + // CHECK-SAME: memref<1x7x7x1xi8> memref<9x1xi8> memref<1xi32> memref<25x1xi8> i64 i64 + linalg.conv_2d_nhwc_fhwc + ins(%input, %kernel : memref<1x7x7x1xi8>, memref<1x3x3x1xi8>) + outs(%output : memref<1x5x5x1xi8>) + gemmini.print %output : memref<1x5x5x1xi8> + return %0 : i8 +} diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index ca30047f40..873c8d4e10 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -2491,3 +2491,23 @@ exo-matmul-4-run: -I${RISCV}/../../generators/gemmini/software/gemmini-rocc-tests \ -O2 -static -o a.out @spike --extension=gemmini pk a.out + +gemmini-print-lower: + @${BUDDY_OPT} ./print.mlir \ + -convert-linalg-to-gemmini \ + -convert-linalg-to-loops \ + -lower-gemmini \ + -o log.mlir + +gemmini-print-run: + @${BUDDY_OPT} ./print.mlir \ + -convert-linalg-to-gemmini \ + -convert-linalg-to-loops \ + -lower-gemmini | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -relocation-model=pic \ + -o log.o + @riscv64-unknown-linux-gnu-gcc -O2 -static log.o -o print + @spike --extension=gemmini pk print diff --git a/examples/GemminiDialect/print.mlir b/examples/GemminiDialect/print.mlir new file mode 100644 index 0000000000..d366a9417f --- /dev/null +++ b/examples/GemminiDialect/print.mlir @@ -0,0 +1,31 @@ +// RUN: buddy-opt %s \ +// RUN: --convert-linalg-to-gemmini \ +// --convert-linalg-to-loops \ +// --lower-gemmini | \ +// RUN: FileCheck %s + +func.func @main() -> i8 { + %c0 = arith.constant 0 : i8 + + %scalar = arith.constant 42 : i8 + // CHECK: gemmini.print_scalar %{{.*}} : i8 + gemmini.print_scalar %scalar : i8 + + %vector = memref.alloc() : memref<4xi8> // 1D向量 + %matrix = memref.alloc() : memref<2x3xi8> // 2D矩阵 + %tensor = memref.alloc() : memref<1x2x3xi8> // 3D张量 + %c1 = arith.constant 1 : i8 + linalg.fill ins(%c1 : i8) outs(%vector : memref<4xi8>) + linalg.fill ins(%c1 : i8) outs(%matrix : memref<2x3xi8>) + // CHECK: gemmini.print %{{.*}} : memref<4xi8> + gemmini.print %vector : memref<4xi8> + // CHECK: gemmini.print %{{.*}} : memref<2x3xi8> + gemmini.print %matrix : memref<2x3xi8> + // CHECK: gemmini.print %{{.*}} : memref<1x2x3xi8> + gemmini.print %tensor : memref<1x2x3xi8> + memref.dealloc %vector : memref<4xi8> + memref.dealloc %matrix : memref<2x3xi8> + memref.dealloc %tensor : memref<1x2x3xi8> + + return %c0 : i8 +} diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index e098ccc578..00d852258e 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -64,7 +64,7 @@ def ConfigStOp : Gemmini_Op<"config_st"> { }]; let arguments = (ins I64:$stride, DefaultValuedAttr:$activation, - DefaultValuedAttr:$scale); + DefaultValuedAttr:$scale); let assemblyFormat = "$stride attr-dict `:` type($stride)"; } @@ -88,19 +88,19 @@ def ConfigExOp : Gemmini_Op<"config_ex"> { ConfigExOp configures the execute pipeline. - dataflow: output-stationary (0) or weight-stationary (1) dataflow - sysAct: activation function relu (1) or no activation function (0) - - sysShift: the number of bits by which the accumulated result of a matmul + - sysShift: the number of bits by which the accumulated result of a matmul is right-shifted when leaving the systolic array. - - sysAccScale: the scalar value by which we scale the accType output of the + - sysAccScale: the scalar value by which we scale the accType output of the accumulator down to inputType values when reading from the accumulator. (In the default config, rs1[63:32] is of type float32) - cStride: TODO - - aStride: the stride (in scratchpad addresses) by which the rows of A are - fed into the systolic array. "A" in this context refers to the - left-hand matrix A in the matmul represented by A * B = C. - If this stride is 1, then we feed consecutive rows in the - scratchpad, starting from the starting address of A, into the - systolic array as the A matrix. If the stride is 2, then we feed + - aStride: the stride (in scratchpad addresses) by which the rows of A are + fed into the systolic array. "A" in this context refers to the + left-hand matrix A in the matmul represented by A * B = C. + If this stride is 1, then we feed consecutive rows in the + scratchpad, starting from the starting address of A, into the + systolic array as the A matrix. If the stride is 2, then we feed every other row into the systolic array instead. - aTranspose: transpose A - bTranspose: transpose B @@ -192,7 +192,13 @@ def MvoutOp : Gemmini_Op<"mvout"> { def PrintOp : Gemmini_Op<"print"> { let summary = "Print memref value."; - let arguments = (ins AnyTypeOf<[I8MemRef, I32MemRef, F32MemRef, F64MemRef]>:$input); + let arguments = (ins AnyTypeOf<[I8MemRef, I32MemRef, F32MemRef, F64MemRef]>:$input); + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +def PrintScalarOp : Gemmini_Op<"print_scalar"> { + let summary = "Print a scalar value."; + let arguments = (ins AnyType:$input); let assemblyFormat = "$input attr-dict `:` type($input)"; } @@ -224,7 +230,7 @@ def PreloadOp : Gemmini_Op<"preload"> { let arguments = (ins I64:$bdAddr, I64:$cAddr, I64:$bdRows, I64:$bdCols, I64:$cRows, I64:$cCols); let assemblyFormat = [{ - $bdAddr $cAddr $bdRows $bdCols $cRows $cCols attr-dict `:` type($bdAddr) + $bdAddr $cAddr $bdRows $bdCols $cRows $cCols attr-dict `:` type($bdAddr) type($cAddr) type($bdRows) type($bdCols) type($cRows) type($cCols) }]; } @@ -308,17 +314,17 @@ def TileConvOp : Gemmini_Op<"tile_conv"> { I64:$outRowDim, I64:$outColDim, I64:$kernelDim, DefaultValuedAttr:$scale, DefaultValuedAttr:$stride, - DefaultValuedAttr:$inputDilation, + DefaultValuedAttr:$inputDilation, DefaultValuedAttr:$kernelDilation, - DefaultValuedAttr:$padding, + DefaultValuedAttr:$padding, DefaultValuedAttr:$wrot180, - DefaultValuedAttr:$transOutput1203, + DefaultValuedAttr:$transOutput1203, DefaultValuedAttr:$transInput3120, - DefaultValuedAttr:$transWeight1203, + DefaultValuedAttr:$transWeight1203, DefaultValuedAttr:$transWeight0132, DefaultValuedAttr:$act, DefaultValuedAttr:$poolSize, - DefaultValuedAttr:$poolStride, + DefaultValuedAttr:$poolStride, DefaultValuedAttr:$poolPadding); let assemblyFormat = [{ $input $weights $bias $output $outRowDim $outColDim $kernelDim attr-dict `:` type($input) @@ -330,13 +336,13 @@ def TileConvOp : Gemmini_Op<"tile_conv"> { // Gemmini intrinsic operation definitions //===----------------------------------------------------------------------===// -class Gemmini_IntrOpBase traits = []> : - LLVM_IntrOpBase traits = []> : + LLVM_IntrOpBase overloadedResults=*/[], - /*list overloadedOperands=*/[], - /*list traits=*/traits, + /*list overloadedResults=*/[], + /*list overloadedOperands=*/[], + /*list traits=*/traits, /*int numResults=*/0>; def Gemmini_Mvin_IntrOp : Gemmini_IntrOpBase<"mvin">, @@ -357,13 +363,13 @@ def Gemmini_Flush_IntrOp : Gemmini_IntrOpBase<"flush">, def Gemmini_ConifgLd_IntrOp : Gemmini_IntrOpBase<"config_ld">, Arguments<(ins LLVM_Type, LLVM_Type)>; -def Gemmini_ConfigSt_IntrOp : Gemmini_IntrOpBase<"config_st">, +def Gemmini_ConfigSt_IntrOp : Gemmini_IntrOpBase<"config_st">, Arguments<(ins LLVM_Type, LLVM_Type)>; -def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">, +def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">, Arguments<(ins LLVM_Type, LLVM_Type)>; -def Gemmini_ConfigNorm_IntrOp : Gemmini_IntrOpBase<"config_norm">, +def Gemmini_ConfigNorm_IntrOp : Gemmini_IntrOpBase<"config_norm">, Arguments<(ins LLVM_Type, LLVM_Type)>; def Gemmini_Preload_IntrOp : Gemmini_IntrOpBase<"preload">, @@ -414,4 +420,4 @@ def Gemmini_LoopConvWsConfig5_IntrOp : Gemmini_IntrOpBase<"loop_conv_ws_config5" def Gemmini_LoopConvWsConfig6_IntrOp : Gemmini_IntrOpBase<"loop_conv_ws_config6">, Arguments<(ins LLVM_Type, LLVM_Type)>; -#endif +#endif diff --git a/midend/lib/Conversion/LowerGemmini/LowerGemminiPass.cpp b/midend/lib/Conversion/LowerGemmini/LowerGemminiPass.cpp index 0dcecff32f..2369848040 100644 --- a/midend/lib/Conversion/LowerGemmini/LowerGemminiPass.cpp +++ b/midend/lib/Conversion/LowerGemmini/LowerGemminiPass.cpp @@ -152,6 +152,95 @@ class PrintOpLowering : public ConversionPattern { } }; +class PrintScalarOpLowering : public ConversionPattern { +public: + explicit PrintScalarOpLowering(MLIRContext *context) + : ConversionPattern(gemmini::PrintScalarOp::getOperationName(), 1, + context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto context = rewriter.getContext(); + auto loc = op->getLoc(); + + ModuleOp parentModule = op->getParentOfType(); + + auto printfRef = getOrInsertPrintf(rewriter, parentModule); + + Type elementType = op->getOperand(0).getType(); + Value formatSpecifierCst; + + if (elementType == rewriter.getF32Type() || + elementType == rewriter.getF64Type()) { + formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "scalar_fmt", StringRef("%f\n\0", 5), parentModule); + } else if (elementType == rewriter.getI8Type() || + elementType == rewriter.getI32Type()) { + formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "scalar_fmt", StringRef("%d\n\0", 5), parentModule); + } + + Value valueToPrint = op->getOperand(0); + if (elementType == rewriter.getF32Type()) { + valueToPrint = rewriter.create(loc, rewriter.getF64Type(), + valueToPrint); + } else if (elementType == rewriter.getI8Type()) { + valueToPrint = rewriter.create(loc, rewriter.getI32Type(), + valueToPrint); + } + + rewriter.create( + loc, getPrintfType(context), printfRef, + ArrayRef({formatSpecifierCst, valueToPrint})); + + rewriter.eraseOp(op); + return success(); + } + +private: + static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { + auto llvmI32Ty = IntegerType::get(context, 32); + auto llvmPtr = LLVM::LLVMPointerType::get(context); + return LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtr, true); + } + + static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module) { + auto *context = module.getContext(); + if (module.lookupSymbol("printf")) + return SymbolRefAttr::get(context, "printf"); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), "printf", + getPrintfType(context)); + return SymbolRefAttr::get(context, "printf"); + } + + static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module) { + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol(name))) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + auto type = LLVM::LLVMArrayType::get( + IntegerType::get(builder.getContext(), 8), value.size()); + global = builder.create(loc, type, true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value), 0); + } + + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create(loc, builder.getI64Type(), + builder.getIndexAttr(0)); + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), + globalPtr, ArrayRef({cst0, cst0})); + } +}; + namespace { class LowerGemminiToLLVMPass : public PassWrapper> { @@ -222,6 +311,7 @@ void LowerGemminiToLLVMPass::runOnOperation() { cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); populateFuncToLLVMConversionPatterns(converter, patterns); patterns.add(&getContext()); + patterns.add(&getContext()); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp index bfee320cc4..7def482508 100644 --- a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp +++ b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp @@ -47,7 +47,7 @@ class MatmulLowering : public OpRewritePattern { Value input0 = inputs[0]; Value input1 = inputs[1]; Value output0 = ouputs[0]; - MemRefType input0Type = dyn_cast(input0.getType()); + MemRefType input0Type = dyn_cast(input0.getType()); MemRefType biasType = MemRefType::get(input0Type.getShape(), rewriter.getI32Type()); TypedAttr fillOpInputAttr = rewriter.getI32IntegerAttr(0); @@ -75,6 +75,167 @@ class MatmulLowering : public OpRewritePattern { std::string accType; }; +class Conv2DNhwcFhwcLowering + : public OpRewritePattern { +public: + explicit Conv2DNhwcFhwcLowering(MLIRContext *context, std::string accType) + : OpRewritePattern(context), accType(accType) {} + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, + PatternRewriter &rewriter) const override { + Value input = convOp.getInputs()[0]; + Value kernel = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + Location loc = convOp.getLoc(); + + MemRefType inputType = dyn_cast(input.getType()); + MemRefType kernelType = dyn_cast(kernel.getType()); + MemRefType outputType = dyn_cast(output.getType()); + + Type kernelElemType = kernelType.getElementType(); + Type outputElemType = outputType.getElementType(); + + ArrayRef inputShape = inputType.getShape(); + + DenseIntElementsAttr dilationsAttr = convOp.getDilationsAttr(); + DenseIntElementsAttr stridesAttr = convOp.getStridesAttr(); + + size_t dilations = 1; + size_t strides = 1; + if (dilationsAttr) + dilations = (*dilationsAttr.begin()).getLimitedValue(); + if (stridesAttr) + strides = (*stridesAttr.begin()).getLimitedValue(); + + if (inputShape[1] != inputShape[2]) // h, w + return failure(); + ArrayRef kernelShape = kernelType.getShape(); + if (kernelShape[1] != kernelShape[2]) // h, w + return failure(); + ArrayRef outputShape = outputType.getShape(); + + // Create kernelMat(hwc, f) and outputMat(nhw, c). + SmallVector kernelMatShape = { + kernelShape[1] * kernelShape[2] * kernelShape[3], kernelShape[0]}; + MemRefType kernelMatType = MemRefType::get(kernelMatShape, kernelElemType); + Value kernelMat = rewriter.create(loc, kernelMatType); + + SmallVector outputMatShape = { + outputShape[0] * outputShape[1] * outputShape[2], outputShape[3]}; + MemRefType outputMatType = MemRefType::get(outputMatShape, outputElemType); + Value outputMat = rewriter.create(loc, outputMatType); + + MemRefType biasType = + MemRefType::get(outputShape[3], rewriter.getI32Type()); + if (accType == "f32") + biasType = MemRefType::get(outputShape[3], rewriter.getF32Type()); + Value bias = rewriter.create(loc, biasType); + + TypedAttr attr = rewriter.getI32IntegerAttr(0); + if (accType == "f32") + attr = rewriter.getF32FloatAttr(0); + Value constant0 = rewriter.create(loc, attr); + SmallVector inputs = {constant0}; + SmallVector outputs = {bias}; + rewriter.create(loc, inputs, outputs); + + // kernelShape + Operation *loopOp = nullptr; + SmallVector loopIvs; + for (size_t i = 0; i != kernelShape.size(); i++) { + Value lowerBound = rewriter.create(loc, 0); + Value upperBound = + rewriter.create(loc, kernelShape[i]); + Value step = rewriter.create(loc, 1); + auto loop = + rewriter.create(loc, lowerBound, upperBound, step); + loopIvs.push_back(loop.getInductionVar()); + if (i == 0) + loopOp = loop.getOperation(); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + Value kernelDim = rewriter.create( + loc, kernelShape[1]); // dim_h = dim_w + Value inChannels = + rewriter.create(loc, kernelShape[3]); + + // Conv kernel mapping (f,h,w,c) -> (h*w*c, f) + Value tmp0 = + rewriter.create(loc, loopIvs[1], kernelDim); // h * kW + tmp0 = rewriter.create(loc, tmp0, inChannels); // * C + Value tmp1 = + rewriter.create(loc, loopIvs[2], inChannels); // w * C + tmp0 = rewriter.create(loc, tmp0, tmp1); // + (w * C) + tmp0 = rewriter.create(loc, tmp0, loopIvs[3]); // + c + + // load kernel + Value element = rewriter.create(loc, kernel, loopIvs); + SmallVector indices = {tmp0, loopIvs[0]}; // [h*w*c, f] + rewriter.create(loc, element, kernelMat, + indices); // Store the loaded data + rewriter.setInsertionPointAfter(loopOp); + + attr = rewriter.getI64IntegerAttr(outputShape[1]); + Value outRowDim = rewriter.create(loc, attr); + attr = rewriter.getI64IntegerAttr(outputShape[2]); + Value outColDim = rewriter.create(loc, attr); + kernelDim = rewriter.create(loc, attr); + kernelDim = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(kernelShape[1])); + + rewriter.create( + loc, input, kernelMat, bias, outputMat, outRowDim, outColDim, kernelDim, + llvm::APFloat(float(1.0)), strides, dilations); + + // After the conv operation is completed, the data in outputMat needs to be + // transferred into output (2-D to 4-D). + loopIvs.clear(); + indices.clear(); + + for (size_t i = 0; i < outputShape.size(); i++) { + Value lowerBound = rewriter.create(loc, 0); + Value upperBound = + rewriter.create(loc, outputShape[i]); + Value step = rewriter.create(loc, 1); + auto loop = + rewriter.create(loc, lowerBound, upperBound, step); + loopIvs.push_back(loop.getInductionVar()); + if (i == 0) + loopOp = loop.getOperation(); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Map output from 2D (N*H*W, C) back to NHWC (n,h,w,c) + Value outH = rewriter.create(loc, outputShape[1]); + Value outW = rewriter.create(loc, outputShape[2]); + + // Calculate the row index in the 2D matrix: n * (H*W) + h * W + w + tmp0 = rewriter.create(loc, loopIvs[0], outH); // n * H + tmp0 = rewriter.create(loc, tmp0, outW); // * W + tmp1 = rewriter.create(loc, loopIvs[1], outW); // h * W + tmp0 = rewriter.create(loc, tmp0, tmp1); // + (h * W) + tmp0 = rewriter.create(loc, tmp0, loopIvs[2]); // + w + + // The index in the 2D matrix is [n*H*W + h*W + w, c] + indices.assign({tmp0, loopIvs[3]}); + + tmp0 = rewriter.create(loc, outputMat, indices); + rewriter.create(loc, tmp0, output, loopIvs); + rewriter.setInsertionPointAfter(loopOp); + + rewriter.create(loc, kernelMat); + rewriter.create(loc, outputMat); + rewriter.create(loc, bias); + + rewriter.eraseOp(convOp); + return success(); + } + +private: + std::string accType; +}; + class Conv2DNchwFchwLowering : public OpRewritePattern { public: @@ -88,9 +249,9 @@ class Conv2DNchwFchwLowering Value input1 = inputs[1]; Value output = convOp.getOutputs()[0]; Location loc = convOp.getLoc(); - MemRefType inputType = dyn_cast(input0.getType()); - MemRefType weightsType = dyn_cast(input1.getType()); - MemRefType outputType = dyn_cast(output.getType()); + MemRefType inputType = dyn_cast(input0.getType()); + MemRefType weightsType = dyn_cast(input1.getType()); + MemRefType outputType = dyn_cast(output.getType()); ArrayRef inputShape = inputType.getShape(); ArrayRef outputShape = outputType.getShape(); ArrayRef weightsShape = weightsType.getShape(); @@ -233,9 +394,9 @@ class Conv2DNhwcHwcfLowering Value kernel = convOp.getInputs()[1]; Value output = convOp.getOutputs()[0]; Location loc = convOp.getLoc(); - MemRefType inputType = dyn_cast(input.getType()); - MemRefType kernelType = dyn_cast(kernel.getType()); - MemRefType outputType = dyn_cast(output.getType()); + MemRefType inputType = dyn_cast(input.getType()); + MemRefType kernelType = dyn_cast(kernel.getType()); + MemRefType outputType = dyn_cast(output.getType()); Type kernelElemType = kernelType.getElementType(); Type outputElemType = outputType.getElementType(); ArrayRef inputShape = inputType.getShape(); @@ -359,11 +520,11 @@ class BatchMatMulOpLowering : public OpRewritePattern { Value input0 = inputs[0]; Value input1 = inputs[1]; Value output = batchMatMulOp.getOutputs()[0]; - MemRefType input0Type = dyn_cast(input0.getType()); + MemRefType input0Type = dyn_cast(input0.getType()); ArrayRef input0Shape = input0Type.getShape(); - MemRefType input1Type = dyn_cast(input1.getType()); + MemRefType input1Type = dyn_cast(input1.getType()); ArrayRef input1Shape = input1Type.getShape(); - MemRefType outputType = dyn_cast(output.getType()); + MemRefType outputType = dyn_cast(output.getType()); ArrayRef outputShape = outputType.getShape(); Type elemType = input0Type.getElementType(); for (unsigned i = 0; i != input0Shape[0]; i++) { @@ -414,6 +575,7 @@ void populateLowerLinalgToGemminiConversionPatterns(RewritePatternSet &patterns, std::string accType) { patterns.add(patterns.getContext(), accType); patterns.add(patterns.getContext(), accType); + patterns.add(patterns.getContext(), accType); patterns.add(patterns.getContext(), accType); patterns.add(patterns.getContext()); } diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 31304a913e..76d2b9bab2 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -38,7 +38,8 @@ using namespace buddy::gemmini; namespace { int64_t getNumberFromValue(Value &value) { - return dyn_cast(value.getDefiningOp()->getAttr("value")).getInt(); + return dyn_cast(value.getDefiningOp()->getAttr("value")) + .getInt(); } acc_scale_t_bits acc_scale_t_to_acc_scale_t_bits(acc_scale_t x) { @@ -249,7 +250,8 @@ struct GemminiMvinLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Value input = mvinOp.getInput(); Location loc = input.getLoc(); - MemRefType memRefType = dyn_cast(mvinOp.getOperandTypes().front()); + MemRefType memRefType = + dyn_cast(mvinOp.getOperandTypes().front()); llvm::ArrayRef memRefShape = memRefType.getShape(); TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); Value extractOp = rewriter.create( @@ -281,7 +283,8 @@ struct GemminiMvin2Lowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Value input = mvin2Op.getInput(); Location loc = input.getLoc(); - MemRefType memRefType = dyn_cast(mvin2Op.getOperandTypes().front()); + MemRefType memRefType = + dyn_cast(mvin2Op.getOperandTypes().front()); llvm::ArrayRef memRefShape = memRefType.getShape(); TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); Value extractOp = rewriter.create( @@ -313,7 +316,8 @@ struct GemminiMvin3Lowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Value input = mvin3Op.getInput(); Location loc = input.getLoc(); - MemRefType memRefType = dyn_cast(mvin3Op.getOperandTypes().front()); + MemRefType memRefType = + dyn_cast(mvin3Op.getOperandTypes().front()); llvm::ArrayRef memRefShape = memRefType.getShape(); TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); Value extractOp = rewriter.create( @@ -353,7 +357,8 @@ struct GemminiMvoutLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, i64Type, extractOp); Value spadAddr = mvoutOp.getAddr(); uint64_t number = getNumberFromValue(spadAddr); - MemRefType memRefType =dyn_cast(mvoutOp.getOperandTypes().front()); + MemRefType memRefType = + dyn_cast(mvoutOp.getOperandTypes().front()); llvm::ArrayRef memRefShape = memRefType.getShape(); uint64_t spadAddrInt = (uint64_t)memRefShape[0] << (addrLen + 16) | (uint64_t)memRefShape[1] << addrLen | number; @@ -947,9 +952,12 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { MemRefType bArrayType = dyn_cast(bArray.getType()); MemRefType cArrayType = dyn_cast(cArray.getType()); MemRefType dArrayType = dyn_cast(dArray.getType()); - StridedLayoutAttr aArrayLayout = dyn_cast(aArrayType.getLayout()); - StridedLayoutAttr bArrayLayout = dyn_cast(bArrayType.getLayout()); - StridedLayoutAttr cArrayLayout = dyn_cast(cArrayType.getLayout()); + StridedLayoutAttr aArrayLayout = + dyn_cast(aArrayType.getLayout()); + StridedLayoutAttr bArrayLayout = + dyn_cast(bArrayType.getLayout()); + StridedLayoutAttr cArrayLayout = + dyn_cast(cArrayType.getLayout()); SmallVector resultType = {rewriter.getIndexType()}; TypeRange typeRange(resultType); Location loc = tileMatMulOp.getLoc(); @@ -1145,30 +1153,26 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, rs1Value, rs2Value); } - void spTiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, - int outChannels, int outRowDim, int outColDim, - int poolOutRowDim, int poolOutColDim, int stride, - int padding, int kernelDim, int kernelDilation, int inStride, - int weightStride, int outStride, int poolSize, - int poolStride, int poolPadding, int batches, int porows, - int pocols, int pochs, int krows, int kcols, int kchs, - int lpad, int rpad, int upad, int dpad, int plpad, int prpad, - int pupad, int pdpad, Value &input, Value &weights, - Value &output, Value &bias, int act, acc_scale_t scale, - bool wrot180, bool transOutput1203, bool transInput3120, - bool transWeight1203, bool transWeight0132, bool noBias, - bool noPool, bool downsample, bool inputDilated, bool dw, - TileConvOp &tileConvOp, - ConversionPatternRewriter &rewriter) const { + void gemminiRiscConvWs( + int batchSize, int inRowDim, int inColDim, int inChannels, + int outChannels, int outRowDim, int outColDim, int poolOutRowDim, + int poolOutColDim, int stride, int padding, int kernelDim, + int kernelDilation, int inStride, int weightStride, int outStride, + int poolSize, int poolStride, int poolPadding, int batches, int porows, + int pocols, int pochs, int krows, int kcols, int kchs, int lpad, int rpad, + int upad, int dpad, int plpad, int prpad, int pupad, int pdpad, + Value &input, Value &weights, Value &output, Value &bias, int act, + acc_scale_t scale, bool wrot180, bool transOutput1203, + bool transInput3120, bool transWeight1203, bool transWeight0132, + bool noBias, bool noPool, bool downsample, bool inputDilated, + int maxPixelsPerRow, bool dw, TileConvOp &tileConvOp, + ConversionPatternRewriter &rewriter) const { Location loc = tileConvOp.getLoc(); - if (dw) { - kchs = 1; - pochs = 1; - } - const int orows = porows * poolStride + poolSize - 1 - pupad - pdpad; const int ocols = pocols * poolStride + poolSize - 1 - plpad - prpad; + + const int ichs = kchs; const int ochs = pochs; // Calculate image dimensions @@ -1180,10 +1184,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { int irowsUnpadded = irows - upad - dpad; int icolsUnpadded = icols - lpad - rpad; - const int ichs = kchs; - #define UNDILATED(x) ((inputDilated) ? (((x) + 1) / 2) : (x)) - if (inputDilated) { irowsUnpadded = (irowsUnpadded + 1) / 2; icolsUnpadded = (icolsUnpadded + 1) / 2; @@ -1192,18 +1193,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { icols = icolsUnpadded + UNDILATED(lpad) + UNDILATED(rpad); } -#ifdef HAS_FIRST_LAYER_OPTIMIZATIONS - const bool transposed = - transOutput1203 || transInput3120 || transWeight1203 || transWeight0132; - int maxPixelsPerRow = transposed || wrot180 || downsample || inputDilated || - kernelDilation > 1 || ichs > dim - ? 1 - : dim / ichs; - if (maxPixelsPerRow > kcols) - maxPixelsPerRow = kcols; -#else - const int maxPixelsPerRow = 1; -#endif // Calculate spad address offsets const int outChannelsPerBank = ochs / dim + (ochs % dim != 0); const int inChannelsPerBank = kchs / dim + (kchs % dim != 0); @@ -1226,25 +1215,13 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { if (output != 0) { cSpAddrRow = (cSpAddrRow + accRows / 2) % accRows; } - if (inRowDim == inColDim && outRowDim == outColDim && - poolOutRowDim == poolOutColDim) { - gemminiLoopConvWs( - batchSize, inRowDim, inChannels, outChannels, outRowDim, - poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, - poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, - kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, - ocols, weights, output, bias, input, noBias, noPool, downsample, - wrot180, inputDilated, act, transOutput1203, transWeight1203, - transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, - rewriter); - return; - } - if (!noPool) { + + if ((inRowDim == inColDim) && (outRowDim == outColDim) && + (poolOutRowDim == poolOutColDim) && !noPool) { llvm::outs() << "Pooling with rectangular convolutions is currently not " "supported.\n"; return; } - // Only rectangular convolutions will use the following C code // mvin bias const size_t maxBlockLen = MAX_BYTES / (dim * 1); const size_t maxBlockLenAcc = MAX_BYTES / (dim * 4); @@ -1357,6 +1334,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } } } + // mvin weights if (weights != NULL) { int max_chs_per_mvin = @@ -1424,6 +1402,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } } } + // Compute { const int b_it = transInput3120 ? dim : 1; @@ -1444,14 +1423,12 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { bool newWeights = true; for (int b = 0; b < batches; b += b_it) { for (int orow = 0; orow < orows; orow++) { - // Skip some kernel rows due to input-dilation if (inputDilated && ((krow * kernelDilation + orow * stride - upad) % 2 != 0)) { continue; } for (int ocol = 0; ocol < ocols;) { - // Skip some cols dimensions due to input-dilation if (inputDilated && ((kcol + ocol * stride - lpad) % 2 != 0)) { ocol++; @@ -1575,463 +1552,500 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } } } else { - printf("Pooling with rectangular convolutions is currently not " - "supported.\n"); + // TODO: need to enable pooling + printf("Pooling in RISC mode is unsupported.\n "); exit(1); } } } - void tiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, - int outChannels, int outRowDim, int outColDim, int stride, - int inputDilation, int kernelDilation, int padding, - int kernelDim, int inStride, int weightStride, int outStride, - bool wrot180, bool transOutput1203, bool transInput3120, - bool transWeight1203, bool transWeight0132, int batches, - int porows, int pocols, int pochs, int krows, int kcols, - int kchs, const Value &input, const Value &weights, - const Value &bias, Value &output, int act, acc_scale_t scale, - int poolSize, int poolStride, int poolPadding, - TileConvOp &tileConvOp, - ConversionPatternRewriter &rewriter) const { - bool noBias = false; - bool noPool = poolStride == 0; - if (noPool) { - poolSize = 1; - poolStride = 1; - poolPadding = 0; + void spTiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, + int outChannels, int outRowDim, int outColDim, + int poolOutRowDim, int poolOutColDim, int stride, + int padding, int kernelDim, int kernelDilation, int inStride, + int weightStride, int outStride, int poolSize, + int poolStride, int poolPadding, int batches, int porows, + int pocols, int pochs, int krows, int kcols, int kchs, + int lpad, int rpad, int upad, int dpad, int plpad, int prpad, + int pupad, int pdpad, Value &input, Value &weights, + Value &output, Value &bias, int act, acc_scale_t scale, + bool wrot180, bool transOutput1203, bool transInput3120, + bool transWeight1203, bool transWeight0132, bool noBias, + bool noPool, bool downsample, bool inputDilated, bool dw, + TileConvOp &tileConvOp, + ConversionPatternRewriter &rewriter) const { + + if (dw) { + kchs = 1; + pochs = 1; } - const bool downsample = stride == 2 && kernelDim == 1 && - inRowDim % 2 == 0 && inColDim % 2 == 0 && - padding == 0 && noPool && inputDilation == 1 && - !transInput3120; - const int inputDilated = inputDilation == 2; - int64_t stDramStride = transOutput1203 - ? batchSize * outChannels * sizeOfElemT - : outChannels * sizeOfElemT; - Location loc = tileConvOp.getLoc(); - Value strideValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(stDramStride)); - rewriter.create(loc, strideValue, act, llvm::APFloat(scale)); - rewriter.create( - loc, /*dataflow = */ WEIGHT_STATIONARY, /*act = */ 0, /*shift = */ 0, - /*scale = */ llvm::APFloat((float)0), /*cStride = */ inputDilation, - /*aStride = */ stride >> downsample, - /*aTranspose = */ transInput3120, /*bTranspose*/ transWeight0132, - /*setOnlyStrides = */ false); - const int poolOutRowDim = - (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; - const int poolOutColDim = - (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; - const int dilatedInRowDim = inRowDim + (inputDilation - 1) * (inRowDim - 1); - const int dilatedInColDim = inColDim + (inputDilation - 1) * (inColDim - 1); - - int porowEnd = poolOutRowDim; - - for (int b = 0; b < batchSize; b += batches) { - for (int porow = 0; porow < porowEnd; porow += porows) { - const int orow = porow * poolStride - poolPadding; - for (int pocol = 0; pocol < poolOutColDim; pocol += pocols) { - const int ocol = pocol * poolStride - poolPadding; - for (int poch = 0; poch < outChannels; poch += pochs) { - for (int krow = 0; krow < kernelDim; krow += krows) { - const int orow_floored = orow < 0 ? 0 : orow; - - int irow = - orow_floored * stride + krow * kernelDilation - padding; - for (int kcol = 0; kcol < kernelDim; kcol += kcols) { - const int ocol_floored = ocol < 0 ? 0 : ocol; - int icol = - ocol_floored * stride + kcol * kernelDilation - padding; - - for (int kch = 0; kch < inChannels; kch += kchs) { - TypedAttr offsetAttr = rewriter.getI64IntegerAttr( - ((b * poolOutRowDim * poolOutColDim + - porow * poolOutColDim + pocol) * - outChannels + - poch) * - sizeOfElemT); - Value offsetValue = - rewriter.create(loc, offsetAttr); - Value out = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), output, - offsetValue); - if (transOutput1203) { - offsetAttr = rewriter.getI64IntegerAttr( - ((porow * poolOutColDim * batchSize + - pocol * batchSize + b) * + + const int ichs = kchs; + +#ifdef HAS_FIRST_LAYER_OPTIMIZATIONS + const bool transposed = + transOutput1203 || transInput3120 || transWeight1203 || transWeight0132; + int maxPixelsPerRow = transposed || wrot180 || downsample || inputDilated || + kernelDilation > 1 || ichs > dim + ? 1 + : dim / ichs; + if (maxPixelsPerRow > kcols) + maxPixelsPerRow = kcols; +#else + const int maxPixelsPerRow = 1; +#endif + + // TODO: add an option to select between gemminiRiscConvWs and + // gemminiLoopConvWs if (inRowDim == inColDim && outRowDim == outColDim && + // poolOutRowDim == poolOutColDim) { + // gemminiLoopConvWs( + // batchSize, inRowDim, inChannels, outChannels, outRowDim, + // poolOutRowDim, stride, padding, kernelDim, kernelDilation, + // poolSize, poolStride, poolPadding, batches, porows, pocols, pochs, + // krows, kcols, kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, + // pdpad, orows, ocols, weights, output, bias, input, noBias, noPool, + // downsample, wrot180, inputDilated, act, transOutput1203, + // transWeight1203, transWeight0132, transInput3120, maxPixelsPerRow, + // dw, tileConvOp, rewriter); + // return; + // } + + gemminiRiscConvWs( + batchSize, inRowDim, inColDim, inChannels, outChannels, outRowDim, + outColDim, poolOutRowDim, poolOutColDim, stride, padding, kernelDim, + kernelDilation, inStride, weightStride, outStride, poolSize, poolStride, + poolPadding, batches, porows, pocols, pochs, krows, kcols, kchs, lpad, + rpad, upad, dpad, plpad, prpad, pupad, pdpad, input, weights, output, + bias, act, scale, wrot180, transOutput1203, transInput3120, + transWeight1203, transWeight0132, noBias, noPool, downsample, + inputDilated, maxPixelsPerRow, dw, tileConvOp, rewriter); + + void tiledConv( + int batchSize, int inRowDim, int inColDim, int inChannels, + int outChannels, int outRowDim, int outColDim, int stride, + int inputDilation, int kernelDilation, int padding, int kernelDim, + int inStride, int weightStride, int outStride, bool wrot180, + bool transOutput1203, bool transInput3120, bool transWeight1203, + bool transWeight0132, int batches, int porows, int pocols, int pochs, + int krows, int kcols, int kchs, const Value &input, + const Value &weights, const Value &bias, Value &output, int act, + acc_scale_t scale, int poolSize, int poolStride, int poolPadding, + TileConvOp &tileConvOp, ConversionPatternRewriter &rewriter) const { + bool noBias = false; + bool noPool = poolStride == 0; + if (noPool) { + poolSize = 1; + poolStride = 1; + poolPadding = 0; + } + const bool downsample = stride == 2 && kernelDim == 1 && + inRowDim % 2 == 0 && inColDim % 2 == 0 && + padding == 0 && noPool && inputDilation == 1 && + !transInput3120; + const int inputDilated = inputDilation == 2; + int64_t stDramStride = transOutput1203 + ? batchSize * outChannels * sizeOfElemT + : outChannels * sizeOfElemT; + Location loc = tileConvOp.getLoc(); + Value strideValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(stDramStride)); + rewriter.create(loc, strideValue, act, llvm::APFloat(scale)); + rewriter.create( + loc, /*dataflow = */ WEIGHT_STATIONARY, /*act = */ 0, /*shift = */ 0, + /*scale = */ llvm::APFloat((float)0), /*cStride = */ inputDilation, + /*aStride = */ stride >> downsample, + /*aTranspose = */ transInput3120, /*bTranspose*/ transWeight0132, + /*setOnlyStrides = */ false); + const int poolOutRowDim = + (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutColDim = + (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int dilatedInRowDim = + inRowDim + (inputDilation - 1) * (inRowDim - 1); + const int dilatedInColDim = + inColDim + (inputDilation - 1) * (inColDim - 1); + + int porowEnd = poolOutRowDim; + + for (int b = 0; b < batchSize; b += batches) { + for (int porow = 0; porow < porowEnd; porow += porows) { + const int orow = porow * poolStride - poolPadding; + for (int pocol = 0; pocol < poolOutColDim; pocol += pocols) { + const int ocol = pocol * poolStride - poolPadding; + for (int poch = 0; poch < outChannels; poch += pochs) { + for (int krow = 0; krow < kernelDim; krow += krows) { + const int orow_floored = orow < 0 ? 0 : orow; + + int irow = + orow_floored * stride + krow * kernelDilation - padding; + for (int kcol = 0; kcol < kernelDim; kcol += kcols) { + const int ocol_floored = ocol < 0 ? 0 : ocol; + int icol = + ocol_floored * stride + kcol * kernelDilation - padding; + + for (int kch = 0; kch < inChannels; kch += kchs) { + TypedAttr offsetAttr = rewriter.getI64IntegerAttr( + ((b * poolOutRowDim * poolOutColDim + + porow * poolOutColDim + pocol) * outChannels + poch) * sizeOfElemT); - offsetValue = + Value offsetValue = rewriter.create(loc, offsetAttr); - out = rewriter.create(tileConvOp.getLoc(), - rewriter.getI64Type(), - output, offsetValue); - } + Value out = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), output, + offsetValue); + if (transOutput1203) { + offsetAttr = rewriter.getI64IntegerAttr( + ((porow * poolOutColDim * batchSize + + pocol * batchSize + b) * + outChannels + + poch) * + sizeOfElemT); + offsetValue = + rewriter.create(loc, offsetAttr); + out = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), output, + offsetValue); + } - if (krow + krows < kernelDim || kcol + kcols < kernelDim || - kch + kchs < inChannels) { - out = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64IntegerAttr(0)); - } - Value pochValue = rewriter.create( - tileConvOp.getLoc(), - rewriter.getI64IntegerAttr(poch * sizeOfAccT)); - Value bias_ = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), bias, - pochValue); - if (krow > 0 || kcol > 0 || kch > 0) { - bias_ = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64IntegerAttr(0)); - } + if (krow + krows < kernelDim || kcol + kcols < kernelDim || + kch + kchs < inChannels) { + out = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64IntegerAttr(0)); + } + Value pochValue = rewriter.create( + tileConvOp.getLoc(), + rewriter.getI64IntegerAttr(poch * sizeOfAccT)); + Value bias_ = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), bias, + pochValue); + if (krow > 0 || kcol > 0 || kch > 0) { + bias_ = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64IntegerAttr(0)); + } - const int batches_ = - batchSize - b > batches ? batches : batchSize - b; - const int porows_ = poolOutRowDim - porow > porows - ? porows - : poolOutRowDim - porow; - const int pocols_ = poolOutColDim - pocol > pocols - ? pocols - : poolOutColDim - pocol; - const int pochs_ = - outChannels - poch > pochs ? pochs : outChannels - poch; - const int krows_ = - kernelDim - krow > krows ? krows : kernelDim - krow; - const int kcols_ = - kernelDim - kcol > kcols ? kcols : kernelDim - kcol; - const int kchs_ = - inChannels - kch > kchs ? kchs : inChannels - kch; - - const int ocols_ = pocols_ * poolStride + poolSize - 1; - const int orows_ = porows_ * poolStride + poolSize - 1; - - const int plpad = ocol < 0 ? -ocol : 0; - const int prpad = - ocol + ocols_ > outColDim ? ocol + ocols_ - outColDim : 0; - const int pupad = orow < 0 ? -orow : 0; - const int pdpad = - orow + orows_ > outRowDim ? orow + orows_ - outRowDim : 0; - - const int dilatedKrows_ = - krows_ + (kernelDilation - 1) * (krows_ - 1); - const int dilatedKcols_ = - kcols_ + (kernelDilation - 1) * (kcols_ - 1); - - const int icols_ = - (ocols_ - plpad - prpad) * stride + dilatedKcols_ - 1; - const int irows_ = - (orows_ - pupad - pdpad) * stride + dilatedKrows_ - 1; - - int lpad = icol < 0 ? -icol : 0; - int rpad = icol + icols_ > dilatedInColDim - ? icol + icols_ - dilatedInColDim - : 0; - int upad = irow < 0 ? -irow : 0; - int dpad = irow + irows_ > dilatedInRowDim - ? irow + irows_ - dilatedInRowDim - : 0; - - if (inputDilated) { - lpad += lpad == 0 && icol % 2 != 0; - rpad += rpad == 0 && (icol + icols_) % 2 != 1; - upad += upad == 0 && irow % 2 != 0; - dpad += dpad == 0 && (irow + irows_) % 2 != 1; - } + const int batches_ = + batchSize - b > batches ? batches : batchSize - b; + const int porows_ = poolOutRowDim - porow > porows + ? porows + : poolOutRowDim - porow; + const int pocols_ = poolOutColDim - pocol > pocols + ? pocols + : poolOutColDim - pocol; + const int pochs_ = + outChannels - poch > pochs ? pochs : outChannels - poch; + const int krows_ = + kernelDim - krow > krows ? krows : kernelDim - krow; + const int kcols_ = + kernelDim - kcol > kcols ? kcols : kernelDim - kcol; + const int kchs_ = + inChannels - kch > kchs ? kchs : inChannels - kch; + + const int ocols_ = pocols_ * poolStride + poolSize - 1; + const int orows_ = porows_ * poolStride + poolSize - 1; + + const int plpad = ocol < 0 ? -ocol : 0; + const int prpad = ocol + ocols_ > outColDim + ? ocol + ocols_ - outColDim + : 0; + const int pupad = orow < 0 ? -orow : 0; + const int pdpad = orow + orows_ > outRowDim + ? orow + orows_ - outRowDim + : 0; + + const int dilatedKrows_ = + krows_ + (kernelDilation - 1) * (krows_ - 1); + const int dilatedKcols_ = + kcols_ + (kernelDilation - 1) * (kcols_ - 1); + + const int icols_ = + (ocols_ - plpad - prpad) * stride + dilatedKcols_ - 1; + const int irows_ = + (orows_ - pupad - pdpad) * stride + dilatedKrows_ - 1; + + int lpad = icol < 0 ? -icol : 0; + int rpad = icol + icols_ > dilatedInColDim + ? icol + icols_ - dilatedInColDim + : 0; + int upad = irow < 0 ? -irow : 0; + int dpad = irow + irows_ > dilatedInRowDim + ? irow + irows_ - dilatedInRowDim + : 0; - int krow_ = krow; - int kcol_ = kcol; - if (wrot180) { - krow_ = kernelDim - krow - krows_; - kcol_ = kernelDim - kcol - kcols_; - } - offsetAttr = rewriter.getI64IntegerAttr( - ((krow_ * kernelDim * inChannels + kcol_ * inChannels + - kch) * - outChannels + - poch) * - sizeOfElemT); - offsetValue = rewriter.create( - tileConvOp.getLoc(), offsetAttr); - Value weightsSlice = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), weights, - offsetValue); - if (transWeight1203) { + if (inputDilated) { + lpad += lpad == 0 && icol % 2 != 0; + rpad += rpad == 0 && (icol + icols_) % 2 != 1; + upad += upad == 0 && irow % 2 != 0; + dpad += dpad == 0 && (irow + irows_) % 2 != 1; + } + + int krow_ = krow; + int kcol_ = kcol; + if (wrot180) { + krow_ = kernelDim - krow - krows_; + kcol_ = kernelDim - kcol - kcols_; + } offsetAttr = rewriter.getI64IntegerAttr( - ((kch * kernelDim * kernelDim + krow_ * kernelDim + - kcol_) * + ((krow_ * kernelDim * inChannels + kcol_ * inChannels + + kch) * outChannels + poch) * sizeOfElemT); offsetValue = rewriter.create( tileConvOp.getLoc(), offsetAttr); - weightsSlice = rewriter.create( + Value weightsSlice = rewriter.create( tileConvOp.getLoc(), rewriter.getI64Type(), weights, offsetValue); - } else if (transWeight0132) { + if (transWeight1203) { + offsetAttr = rewriter.getI64IntegerAttr( + ((kch * kernelDim * kernelDim + krow_ * kernelDim + + kcol_) * + outChannels + + poch) * + sizeOfElemT); + offsetValue = rewriter.create( + tileConvOp.getLoc(), offsetAttr); + weightsSlice = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), weights, + offsetValue); + } else if (transWeight0132) { + offsetAttr = rewriter.getI64IntegerAttr( + ((krow_ * kernelDim * outChannels + + kcol_ * outChannels + poch) * + inChannels + + kch) * + sizeOfElemT); + offsetValue = rewriter.create( + tileConvOp.getLoc(), offsetAttr); + weightsSlice = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), weights, + offsetValue); + } offsetAttr = rewriter.getI64IntegerAttr( - ((krow_ * kernelDim * outChannels + - kcol_ * outChannels + poch) * + ((b * inRowDim * inColDim + + ((irow + upad) >> inputDilated) * inColDim + + ((icol + lpad) >> inputDilated)) * inChannels + kch) * sizeOfElemT); offsetValue = rewriter.create( tileConvOp.getLoc(), offsetAttr); - weightsSlice = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), weights, + Value in = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), input, offsetValue); - } - offsetAttr = rewriter.getI64IntegerAttr( - ((b * inRowDim * inColDim + - ((irow + upad) >> inputDilated) * inColDim + - ((icol + lpad) >> inputDilated)) * - inChannels + - kch) * - sizeOfElemT); - offsetValue = rewriter.create( - tileConvOp.getLoc(), offsetAttr); - Value in = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), input, - offsetValue); - if (transInput3120) { - offsetAttr = rewriter.getI64IntegerAttr( - ((kch * inRowDim * inColDim + - ((irow + upad) >> inputDilated) * inColDim + - ((icol + lpad) >> inputDilated)) * - batchSize + - b) * - sizeOfElemT); - in = rewriter.create(tileConvOp.getLoc(), - rewriter.getI64Type(), - input, offsetValue); - } + if (transInput3120) { + offsetAttr = rewriter.getI64IntegerAttr( + ((kch * inRowDim * inColDim + + ((irow + upad) >> inputDilated) * inColDim + + ((icol + lpad) >> inputDilated)) * + batchSize + + b) * + sizeOfElemT); + in = rewriter.create(tileConvOp.getLoc(), + rewriter.getI64Type(), + input, offsetValue); + } - spTiledConv( - batchSize, inRowDim, inColDim, inChannels, outChannels, - outRowDim, outColDim, poolOutRowDim, poolOutColDim, - stride, padding, kernelDim, kernelDilation, inStride, - weightStride, outStride, poolSize, poolStride, - poolPadding, batches_, porows_, pocols_, pochs_, krows_, - kcols_, kchs_, lpad, rpad, upad, dpad, plpad, prpad, - pupad, pdpad, in, weightsSlice, out, bias_, act, scale, - wrot180, transOutput1203, transInput3120, transWeight1203, - transWeight0132, noBias, noPool, downsample, inputDilated, - false, tileConvOp, rewriter); + spTiledConv( + batchSize, inRowDim, inColDim, inChannels, outChannels, + outRowDim, outColDim, poolOutRowDim, poolOutColDim, + stride, padding, kernelDim, kernelDilation, inStride, + weightStride, outStride, poolSize, poolStride, + poolPadding, batches_, porows_, pocols_, pochs_, krows_, + kcols_, kchs_, lpad, rpad, upad, dpad, plpad, prpad, + pupad, pdpad, in, weightsSlice, out, bias_, act, scale, + wrot180, transOutput1203, transInput3120, + transWeight1203, transWeight0132, noBias, noPool, + downsample, inputDilated, false, tileConvOp, rewriter); + } } } } } } } + IntegerAttr flushAttr = rewriter.getI64IntegerAttr(0); + Value flushValue = rewriter.create( + loc, rewriter.getI64Type(), flushAttr); + rewriter.replaceOpWithNewOp(tileConvOp, flushValue, + flushValue); } - IntegerAttr flushAttr = rewriter.getI64IntegerAttr(0); - Value flushValue = rewriter.create( - loc, rewriter.getI64Type(), flushAttr); - rewriter.replaceOpWithNewOp(tileConvOp, flushValue, - flushValue); - } - int tiledConvTotalSpadRows(bool acc, int stride, int inputDilation, - int kernelDilation, bool downsample, - bool transWeight0132, bool transInput3120, - int batches, int porows, int pocols, int ochs, - int krows, int kcols, int kchs, int poolSize, - int poolStride) const { + int tiledConvTotalSpadRows( + bool acc, int stride, int inputDilation, int kernelDilation, + bool downsample, bool transWeight0132, bool transInput3120, int batches, + int porows, int pocols, int ochs, int krows, int kcols, int kchs, + int poolSize, int poolStride) const { - const int orows = porows * poolStride + poolSize - 1; - const int ocols = pocols * poolStride + poolSize - 1; + const int orows = porows * poolStride + poolSize - 1; + const int ocols = pocols * poolStride + poolSize - 1; - const int krowsDilated = krows + (kernelDilation - 1) * (krows - 1); - const int kcolsDilated = kcols + (kernelDilation - 1) * (kcols - 1); + const int krowsDilated = krows + (kernelDilation - 1) * (krows - 1); + const int kcolsDilated = kcols + (kernelDilation - 1) * (kcols - 1); - int irows = orows * stride + krowsDilated - 1; - int icols = ocols * stride + kcolsDilated - 1; - const int ichs = kchs; + int irows = orows * stride + krowsDilated - 1; + int icols = ocols * stride + kcolsDilated - 1; + const int ichs = kchs; - irows = irows / inputDilation + (irows % inputDilation != 0); - icols = icols / inputDilation + (icols % inputDilation != 0); + irows = irows / inputDilation + (irows % inputDilation != 0); + icols = icols / inputDilation + (icols % inputDilation != 0); - const int inChannelsPerBank = ichs / dim + (ichs % dim != 0); - const int outChannelsPerBank = ochs / dim + (ochs % dim != 0); - const int batchesPerBank = batches / dim + (batches % dim != 0); - - const int aRows = transInput3120 - ? (batchesPerBank * ichs * (irows >> downsample) * - (icols >> downsample)) - : (inChannelsPerBank * batches * - (irows >> downsample) * (icols >> downsample)); + const int inChannelsPerBank = ichs / dim + (ichs % dim != 0); + const int outChannelsPerBank = ochs / dim + (ochs % dim != 0); + const int batchesPerBank = batches / dim + (batches % dim != 0); - const int bRows = transWeight0132 - ? inChannelsPerBank * kcols * krows * ochs - : outChannelsPerBank * kcols * krows * kchs; + const int aRows = transInput3120 + ? (batchesPerBank * ichs * (irows >> downsample) * + (icols >> downsample)) + : (inChannelsPerBank * batches * + (irows >> downsample) * (icols >> downsample)); - const int cRows = outChannelsPerBank * batches * orows * ocols; + const int bRows = transWeight0132 + ? inChannelsPerBank * kcols * krows * ochs + : outChannelsPerBank * kcols * krows * kchs; - return acc ? cRows : aRows + bRows; - } + const int cRows = outChannelsPerBank * batches * orows * ocols; -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - explicit GemminiTileConvLowering(LLVMTypeConverter &typeConverter, - int64_t dim, int64_t addrLen, - int64_t accRows, int64_t bankRows, - size_t sizeOfElemT, size_t sizeOfAccT) - : ConvertOpToLLVMPattern(typeConverter), dim(dim), addrLen(addrLen), - accRows(accRows), bankRows(bankRows), sizeOfElemT(sizeOfElemT), - sizeOfAccT(sizeOfAccT) {} - LogicalResult - matchAndRewrite(TileConvOp tileConvOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value input = tileConvOp.getInput(); - Value output = tileConvOp.getOutput(); - Value weights = tileConvOp.getWeights(); - Value bias = tileConvOp.getBias(); - MemRefType inputType = dyn_cast(input.getType()); - MemRefType biasType = dyn_cast(bias.getType()); - ArrayRef inputShape = inputType.getShape(); - ArrayRef biasShape = biasType.getShape(); - - Value outRowDimValue = tileConvOp.getOutRowDim(); - int outRowDim = getNumberFromValue(outRowDimValue); - Value outColDimValue = tileConvOp.getOutColDim(); - int outColDim = getNumberFromValue(outColDimValue); - Value kernelDimValue = tileConvOp.getKernelDim(); - int kernelDim = getNumberFromValue(kernelDimValue); - int batchSize = inputShape[0]; - int inRowDim = inputShape[1]; - int inColDim = inputShape[2]; - int inChannels = inputShape[3]; - int outChannels = biasShape[0]; - int stride = tileConvOp.getStride(); - int inputDilation = tileConvOp.getInputDilation(); - int kernelDilation = tileConvOp.getKernelDilation(); - int padding = tileConvOp.getPadding(); - int act = tileConvOp.getAct(); - float scale = tileConvOp.getScale().convertToFloat(); - int poolSize = tileConvOp.getPoolSize(); - int poolStride = tileConvOp.getPoolStride(); - int poolPadding = tileConvOp.getPoolPadding(); - bool wrot180 = tileConvOp.getWrot180(); - bool transOutput1203 = tileConvOp.getTransOutput1203(); - bool transInput3120 = tileConvOp.getTransInput3120(); - bool transWeight1203 = tileConvOp.getTransWeight1203(); - bool transWeight0132 = tileConvOp.getTransWeight0132(); - Location loc = tileConvOp.getLoc(); - IntegerType i64Type = rewriter.getI64Type(); - Value inputExtractOp = - rewriter.create(loc, input); - Value inputIndexCastOp = - rewriter.create(loc, i64Type, inputExtractOp); - Value outputExtractOp = - rewriter.create(loc, output); - Value outputIndexCastOp = - rewriter.create(loc, i64Type, outputExtractOp); - Value biasExtractOp = - rewriter.create(loc, bias); - Value biasIndexCastOp = - rewriter.create(loc, i64Type, biasExtractOp); - Value weightsExtractOp = - rewriter.create(loc, weights); - Value weightsIndexCastOp = - rewriter.create(loc, i64Type, weightsExtractOp); - const bool noPool = poolSize == 0; - if (noPool) { - poolSize = 1; - poolStride = 1; - poolPadding = 0; + return acc ? cRows : aRows + bRows; } - const int poolOutRowDim = - (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; - const int poolOutColDim = - (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; - const bool downsample = stride == 2 && kernelDim == 1 && padding == 0 && - noPool && inRowDim % 2 == 0 && inColDim % 2 == 0; - int args[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, - kernelDim, kernelDim, inChannels}; - const int maxArgs[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, - kernelDim, kernelDim, inChannels}; - const int orowsIdx = 1; - const int ocolsIdx = 2; - const int outChannelsIdx = 3; - const int inChannelsIdx = 6; - const int maxSpadRows = (BANK_NUM * bankRows / 2); - const int maxAccRows = (accRows / 2); - int spadRows = tiledConvTotalSpadRows( - false, stride, inputDilation, kernelDilation, downsample, - transWeight0132, transInput3120, args[0], args[1], args[2], args[3], - args[4], args[5], args[6], poolSize, poolStride); - int accRows = tiledConvTotalSpadRows( - true, stride, inputDilation, kernelDilation, downsample, - transWeight0132, transInput3120, args[0], args[1], args[2], args[3], - args[4], args[5], args[6], poolSize, poolStride); - while (spadRows > maxSpadRows || accRows > maxAccRows) { - int maxVal = -1; - int maxIdx = -1; - for (size_t i = 0; i < sizeof(args) / sizeof(args[0]); i++) { - if (!(i == ocolsIdx && args[i] <= dim && args[orowsIdx] > 1) && - args[i] > maxVal) { - maxVal = args[i]; - maxIdx = i; - } - } - if (maxIdx == outChannelsIdx || maxIdx == inChannelsIdx) { - if (args[maxIdx] % dim != 0) { - args[maxIdx] = (args[maxIdx] / dim) * dim; - } else { - args[maxIdx] -= dim; - } - args[maxIdx] = args[maxIdx] == 0 ? 1 : args[maxIdx]; - } else { - args[maxIdx]--; + public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + explicit GemminiTileConvLowering(LLVMTypeConverter & typeConverter, + int64_t dim, int64_t addrLen, + int64_t accRows, int64_t bankRows, + size_t sizeOfElemT, size_t sizeOfAccT) + : ConvertOpToLLVMPattern(typeConverter), dim(dim), addrLen(addrLen), + accRows(accRows), bankRows(bankRows), sizeOfElemT(sizeOfElemT), + sizeOfAccT(sizeOfAccT) {} + LogicalResult matchAndRewrite(TileConvOp tileConvOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) + const override { + Value input = tileConvOp.getInput(); + Value output = tileConvOp.getOutput(); + Value weights = tileConvOp.getWeights(); + Value bias = tileConvOp.getBias(); + MemRefType inputType = dyn_cast(input.getType()); + MemRefType biasType = dyn_cast(bias.getType()); + ArrayRef inputShape = inputType.getShape(); + ArrayRef biasShape = biasType.getShape(); + + Value outRowDimValue = tileConvOp.getOutRowDim(); + int outRowDim = getNumberFromValue(outRowDimValue); + Value outColDimValue = tileConvOp.getOutColDim(); + int outColDim = getNumberFromValue(outColDimValue); + Value kernelDimValue = tileConvOp.getKernelDim(); + int kernelDim = getNumberFromValue(kernelDimValue); + int batchSize = inputShape[0]; + int inRowDim = inputShape[1]; + int inColDim = inputShape[2]; + int inChannels = inputShape[3]; + int outChannels = biasShape[0]; + int stride = tileConvOp.getStride(); + int inputDilation = tileConvOp.getInputDilation(); + int kernelDilation = tileConvOp.getKernelDilation(); + int padding = tileConvOp.getPadding(); + int act = tileConvOp.getAct(); + float scale = tileConvOp.getScale().convertToFloat(); + int poolSize = tileConvOp.getPoolSize(); + int poolStride = tileConvOp.getPoolStride(); + int poolPadding = tileConvOp.getPoolPadding(); + bool wrot180 = tileConvOp.getWrot180(); + bool transOutput1203 = tileConvOp.getTransOutput1203(); + bool transInput3120 = tileConvOp.getTransInput3120(); + bool transWeight1203 = tileConvOp.getTransWeight1203(); + bool transWeight0132 = tileConvOp.getTransWeight0132(); + Location loc = tileConvOp.getLoc(); + IntegerType i64Type = rewriter.getI64Type(); + Value inputExtractOp = + rewriter.create(loc, input); + Value inputIndexCastOp = + rewriter.create(loc, i64Type, inputExtractOp); + Value outputExtractOp = + rewriter.create(loc, output); + Value outputIndexCastOp = + rewriter.create(loc, i64Type, outputExtractOp); + Value biasExtractOp = + rewriter.create(loc, bias); + Value biasIndexCastOp = + rewriter.create(loc, i64Type, biasExtractOp); + Value weightsExtractOp = + rewriter.create(loc, weights); + Value weightsIndexCastOp = + rewriter.create(loc, i64Type, weightsExtractOp); + const bool noPool = poolSize == 0; + if (noPool) { + poolSize = 1; + poolStride = 1; + poolPadding = 0; } - spadRows = tiledConvTotalSpadRows( + const int poolOutRowDim = + (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutColDim = + (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; + const bool downsample = stride == 2 && kernelDim == 1 && padding == 0 && + noPool && inRowDim % 2 == 0 && inColDim % 2 == 0; + int args[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, + kernelDim, kernelDim, inChannels}; + const int maxArgs[] = {batchSize, poolOutRowDim, poolOutColDim, + outChannels, kernelDim, kernelDim, + inChannels}; + const int orowsIdx = 1; + const int ocolsIdx = 2; + const int outChannelsIdx = 3; + const int inChannelsIdx = 6; + const int maxSpadRows = (BANK_NUM * bankRows / 2); + const int maxAccRows = (accRows / 2); + int spadRows = tiledConvTotalSpadRows( false, stride, inputDilation, kernelDilation, downsample, transWeight0132, transInput3120, args[0], args[1], args[2], args[3], args[4], args[5], args[6], poolSize, poolStride); - accRows = tiledConvTotalSpadRows( + int accRows = tiledConvTotalSpadRows( true, stride, inputDilation, kernelDilation, downsample, transWeight0132, transInput3120, args[0], args[1], args[2], args[3], args[4], args[5], args[6], poolSize, poolStride); - } - bool notIncreased = false; - while (!notIncreased) { - notIncreased = true; - - int argsCandidate[] = {args[0], args[1], args[2], args[3], - args[4], args[5], args[6]}; - argsCandidate[ocolsIdx]++; - - if (argsCandidate[ocolsIdx] > maxArgs[ocolsIdx]) - continue; - - spadRows = tiledConvTotalSpadRows( - false, stride, inputDilation, kernelDilation, downsample, - transWeight0132, transInput3120, argsCandidate[0], argsCandidate[1], - argsCandidate[2], argsCandidate[3], argsCandidate[4], - argsCandidate[5], argsCandidate[6], poolSize, poolStride); - accRows = tiledConvTotalSpadRows( - true, stride, inputDilation, kernelDilation, downsample, - transWeight0132, transInput3120, argsCandidate[0], argsCandidate[1], - argsCandidate[2], argsCandidate[3], argsCandidate[4], - argsCandidate[5], argsCandidate[6], poolSize, poolStride); + while (spadRows > maxSpadRows || accRows > maxAccRows) { + int maxVal = -1; + int maxIdx = -1; + for (size_t i = 0; i < sizeof(args) / sizeof(args[0]); i++) { + if (!(i == ocolsIdx && args[i] <= dim && args[orowsIdx] > 1) && + args[i] > maxVal) { + maxVal = args[i]; + maxIdx = i; + } + } - if (spadRows <= maxSpadRows && accRows <= maxAccRows) { - args[ocolsIdx] = argsCandidate[ocolsIdx]; - notIncreased = false; + if (maxIdx == outChannelsIdx || maxIdx == inChannelsIdx) { + if (args[maxIdx] % dim != 0) { + args[maxIdx] = (args[maxIdx] / dim) * dim; + } else { + args[maxIdx] -= dim; + } + args[maxIdx] = args[maxIdx] == 0 ? 1 : args[maxIdx]; + } else { + args[maxIdx]--; + } + spadRows = tiledConvTotalSpadRows( + false, stride, inputDilation, kernelDilation, downsample, + transWeight0132, transInput3120, args[0], args[1], args[2], args[3], + args[4], args[5], args[6], poolSize, poolStride); + accRows = tiledConvTotalSpadRows( + true, stride, inputDilation, kernelDilation, downsample, + transWeight0132, transInput3120, args[0], args[1], args[2], args[3], + args[4], args[5], args[6], poolSize, poolStride); } - } + bool notIncreased = false; + while (!notIncreased) { + notIncreased = true; - bool nothingIncreased = false; - while (!nothingIncreased) { - nothingIncreased = true; - for (size_t i = 0; i < sizeof(args) / sizeof(args[0]); i++) { int argsCandidate[] = {args[0], args[1], args[2], args[3], args[4], args[5], args[6]}; - argsCandidate[i]++; + argsCandidate[ocolsIdx]++; - if (argsCandidate[i] > maxArgs[i]) + if (argsCandidate[ocolsIdx] > maxArgs[ocolsIdx]) continue; + spadRows = tiledConvTotalSpadRows( false, stride, inputDilation, kernelDilation, downsample, transWeight0132, transInput3120, argsCandidate[0], argsCandidate[1], @@ -2044,82 +2058,113 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { argsCandidate[5], argsCandidate[6], poolSize, poolStride); if (spadRows <= maxSpadRows && accRows <= maxAccRows) { - args[i] = argsCandidate[i]; - nothingIncreased = false; + args[ocolsIdx] = argsCandidate[ocolsIdx]; + notIncreased = false; } } + + bool nothingIncreased = false; + while (!nothingIncreased) { + nothingIncreased = true; + for (size_t i = 0; i < sizeof(args) / sizeof(args[0]); i++) { + int argsCandidate[] = {args[0], args[1], args[2], args[3], + args[4], args[5], args[6]}; + argsCandidate[i]++; + + if (argsCandidate[i] > maxArgs[i]) + continue; + spadRows = tiledConvTotalSpadRows( + false, stride, inputDilation, kernelDilation, downsample, + transWeight0132, transInput3120, argsCandidate[0], + argsCandidate[1], argsCandidate[2], argsCandidate[3], + argsCandidate[4], argsCandidate[5], argsCandidate[6], poolSize, + poolStride); + accRows = tiledConvTotalSpadRows( + true, stride, inputDilation, kernelDilation, downsample, + transWeight0132, transInput3120, argsCandidate[0], + argsCandidate[1], argsCandidate[2], argsCandidate[3], + argsCandidate[4], argsCandidate[5], argsCandidate[6], poolSize, + poolStride); + + if (spadRows <= maxSpadRows && accRows <= maxAccRows) { + args[i] = argsCandidate[i]; + nothingIncreased = false; + } + } + } + const int batches = args[0]; + const int orows = args[1]; + const int ocols = args[2]; + const int ochs = args[3]; + const int krows = args[4]; + const int kcols = args[5]; + const int kchs = args[6]; + + const int inStride = inChannels; + const int outStride = outChannels; + const int weightStride = outChannels; + tiledConv(batchSize, inRowDim, inColDim, inChannels, outChannels, + outRowDim, outColDim, stride, inputDilation, kernelDilation, + padding, kernelDim, inStride, weightStride, outStride, wrot180, + transOutput1203, transInput3120, transWeight1203, + transWeight0132, batches, orows, ocols, ochs, krows, kcols, + kchs, inputIndexCastOp, weightsIndexCastOp, biasIndexCastOp, + outputIndexCastOp, act, scale, poolSize, + noPool ? 0 : poolStride, poolPadding, tileConvOp, rewriter); + return success(); } - const int batches = args[0]; - const int orows = args[1]; - const int ocols = args[2]; - const int ochs = args[3]; - const int krows = args[4]; - const int kcols = args[5]; - const int kchs = args[6]; - - const int inStride = inChannels; - const int outStride = outChannels; - const int weightStride = outChannels; - tiledConv(batchSize, inRowDim, inColDim, inChannels, outChannels, outRowDim, - outColDim, stride, inputDilation, kernelDilation, padding, - kernelDim, inStride, weightStride, outStride, wrot180, - transOutput1203, transInput3120, transWeight1203, transWeight0132, - batches, orows, ocols, ochs, krows, kcols, kchs, inputIndexCastOp, - weightsIndexCastOp, biasIndexCastOp, outputIndexCastOp, act, - scale, poolSize, noPool ? 0 : poolStride, poolPadding, tileConvOp, - rewriter); - return success(); - } -private: - int64_t dim; - int64_t addrLen; - int64_t accRows; - int64_t bankRows; - size_t sizeOfElemT; - size_t sizeOfAccT; -}; + private: + int64_t dim; + int64_t addrLen; + int64_t accRows; + int64_t bankRows; + size_t sizeOfElemT; + size_t sizeOfAccT; + }; -void mlir::populateGemminiLegalizeForLLVMExportPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns, int64_t dim, - int64_t addrLen, int64_t accRows, int64_t bankRows, size_t sizeOfElemT, - size_t sizeOfAccT) { - patterns - .add, ForwardOperands, - ForwardOperands>(converter, &converter.getContext()); - patterns.add(converter); - patterns.add(converter); - patterns.add(converter, dim); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter); - patterns.add(converter); - patterns.add(converter, dim, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, dim, addrLen, accRows, + void mlir::populateGemminiLegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, int64_t dim, + int64_t addrLen, int64_t accRows, int64_t bankRows, size_t sizeOfElemT, + size_t sizeOfAccT) { + patterns.add, + ForwardOperands, + ForwardOperands>(converter, + &converter.getContext()); + patterns.add(converter); + patterns.add(converter); + patterns.add(converter, dim); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter); + patterns.add(converter); + patterns.add(converter, dim, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, dim, addrLen, accRows, + bankRows, sizeOfElemT, sizeOfAccT); + patterns.add(converter, dim, addrLen, accRows, bankRows, sizeOfElemT, sizeOfAccT); - patterns.add(converter, dim, addrLen, accRows, - bankRows, sizeOfElemT, sizeOfAccT); -} + } -void mlir::configureGemminiLegalizeForExportTarget( - LLVMConversionTarget &target) { - target.addLegalOp< - Flush_IntrOp, ConfigSt_IntrOp, ConifgLd_IntrOp, ConfigEX_IntrOp, - Mvin_IntrOp, Mvin2_IntrOp, Mvin3_IntrOp, Mvout_IntrOp, Preload_IntrOp, - ComputePreloaded_IntrOp, ComputeAccumulated_IntrOp, - LoopWsConfigBounds_IntrOp, LoopWsConfigAddrsAB_IntrOp, - LoopWsConfigAddrsDC_IntrOp, LoopWsConfigStridesAB_IntrOp, - LoopWsConfigStridesDC_IntrOp, LoopWs_IntrOp, LoopConvWsConfig1_IntrOp, - LoopConvWsConfig2_IntrOp, LoopConvWsConfig3_IntrOp, - LoopConvWsConfig4_IntrOp, LoopConvWsConfig5_IntrOp, - LoopConvWsConfig6_IntrOp, LoopConvWs_IntrOp, ConfigNorm_IntrOp>(); - target.addIllegalOp(); -} + void + mlir::configureGemminiLegalizeForExportTarget(LLVMConversionTarget &target) { + target.addLegalOp< + Flush_IntrOp, ConfigSt_IntrOp, ConifgLd_IntrOp, ConfigEX_IntrOp, + Mvin_IntrOp, Mvin2_IntrOp, Mvin3_IntrOp, Mvout_IntrOp, Preload_IntrOp, + ComputePreloaded_IntrOp, ComputeAccumulated_IntrOp, + LoopWsConfigBounds_IntrOp, LoopWsConfigAddrsAB_IntrOp, + LoopWsConfigAddrsDC_IntrOp, LoopWsConfigStridesAB_IntrOp, + LoopWsConfigStridesDC_IntrOp, LoopWs_IntrOp, LoopConvWsConfig1_IntrOp, + LoopConvWsConfig2_IntrOp, LoopConvWsConfig3_IntrOp, + LoopConvWsConfig4_IntrOp, LoopConvWsConfig5_IntrOp, + LoopConvWsConfig6_IntrOp, LoopConvWs_IntrOp, ConfigNorm_IntrOp>(); + target.addIllegalOp(); + } From 23a6260d429cdc29db8394750c2003e338f06e7c Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Sat, 22 Feb 2025 10:39:55 +0000 Subject: [PATCH 02/13] [examples] Add index cast example with vector type. --- examples/MLIRVector/makefile | 24 ++++++++++++++++++++++ examples/MLIRVector/vector-index-cast.mlir | 22 ++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 examples/MLIRVector/vector-index-cast.mlir diff --git a/examples/MLIRVector/makefile b/examples/MLIRVector/makefile index a3acbc74ab..327dfbfaad 100644 --- a/examples/MLIRVector/makefile +++ b/examples/MLIRVector/makefile @@ -745,3 +745,27 @@ vector-iteration-run: --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +vector-index-cast-lower: + @${MLIR_OPT} ./vector-index-cast.mlir \ + --lower-affine \ + -convert-vector-to-scf -convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts -o ./log.mlir + +vector-index-cast-translate: + @${MLIR_OPT} ./vector-index-cast.mlir \ + --lower-affine \ + -convert-vector-to-scf -convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +vector-index-cast-run: + @${MLIR_OPT} ./vector-index-cast.mlir \ + --lower-affine \ + -convert-vector-to-scf -convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/examples/MLIRVector/vector-index-cast.mlir b/examples/MLIRVector/vector-index-cast.mlir new file mode 100644 index 0000000000..16bb3d221f --- /dev/null +++ b/examples/MLIRVector/vector-index-cast.mlir @@ -0,0 +1,22 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-vector-to-scf -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=i32 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func @main() -> i32 { + + %v0 = arith.constant dense<[0, 0, 1, 1]> : vector<4xindex> + // CHECK: ( 0, 0, 1, 1 ) + vector.print %v0 : vector<4xindex> + + %v1 = arith.index_cast %v0 : vector<4xindex> to vector<4xi1> + // CHECK: ( 0, 0, 1, 1 ) + vector.print %v1 : vector<4xi1> + + %ret = arith.constant 0 : i32 + return %ret : i32 +} From a52ec6a8a552292d890e699b7e72ce0448d4c5ba Mon Sep 17 00:00:00 2001 From: Wu Xintong <13683168028@163.com> Date: Tue, 25 Feb 2025 19:48:03 +0800 Subject: [PATCH 03/13] [examples] Fix the issue with subgraph names in Stable Diffusion. (#466) --- .../import-stable-diffusion.py | 72 +++++++++---------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/examples/BuddyStableDiffusion/import-stable-diffusion.py b/examples/BuddyStableDiffusion/import-stable-diffusion.py index ee76d33811..09f6a96bab 100644 --- a/examples/BuddyStableDiffusion/import-stable-diffusion.py +++ b/examples/BuddyStableDiffusion/import-stable-diffusion.py @@ -29,11 +29,15 @@ from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.graph import GraphDriver from buddy.compiler.graph.transform import simply_fuse +from buddy.compiler.graph.type import DeviceType from buddy.compiler.ops import tosa +from buddy.compiler.graph.operation import * from diffusers import StableDiffusionPipeline # Parse command-line arguments for output directory -parser = argparse.ArgumentParser(description="Stable Diffusion model AOT importer") +parser = argparse.ArgumentParser( + description="Stable Diffusion model AOT importer" +) parser.add_argument( "--output-dir", type=str, @@ -112,54 +116,46 @@ params_unet = dynamo_compiler_unet.imported_params[graph_unet] params_vae = dynamo_compiler_vae.imported_params[graph_vae] -pattern_list = [simply_fuse] - -graphs_text_encoder[0].fuse_ops(pattern_list) -graphs_unet[0].fuse_ops(pattern_list) -graphs_vae[0].fuse_ops(pattern_list) +group_text_encoder = [] +for op in graph_text_encoder.body: + if isinstance(op, PlaceholderOp) or isinstance(op, OutputOp): + continue + group_text_encoder.append(op) +graph_text_encoder.op_groups["subgraph0_text_encoder"] = group_text_encoder +graph_text_encoder.group_map_device["subgraph0_text_encoder"] = DeviceType.CPU + +group_unet = [] +for op in graph_unet.body: + if isinstance(op, PlaceholderOp) or isinstance(op, OutputOp): + continue + group_unet.append(op) +graph_unet.op_groups["subgraph0_unet"] = group_unet +graph_unet.group_map_device["subgraph0_unet"] = DeviceType.CPU + +group_vae = [] +for op in graph_vae.body: + if isinstance(op, PlaceholderOp) or isinstance(op, OutputOp): + continue + group_vae.append(op) +graph_vae.op_groups["subgraph0_vae"] = group_vae +graph_vae.group_map_device["subgraph0_vae"] = DeviceType.CPU driver_text_encoder = GraphDriver(graphs_text_encoder[0]) driver_unet = GraphDriver(graphs_unet[0]) driver_vae = GraphDriver(graphs_vae[0]) -driver_text_encoder._subgraphs[ - "subgraph0_text_encoder" -] = driver_text_encoder._subgraphs.pop("subgraph0") -driver_text_encoder._subgraphs_inputs[ - "subgraph0_text_encoder" -] = driver_text_encoder._subgraphs_inputs.pop("subgraph0") -driver_text_encoder._subgraphs_outputs[ - "subgraph0_text_encoder" -] = driver_text_encoder._subgraphs_outputs.pop("subgraph0") -driver_unet._subgraphs["subgraph0_unet"] = driver_unet._subgraphs.pop( - "subgraph0" -) -driver_unet._subgraphs_inputs[ - "subgraph0_unet" -] = driver_unet._subgraphs_inputs.pop("subgraph0") -driver_unet._subgraphs_outputs[ - "subgraph0_unet" -] = driver_unet._subgraphs_outputs.pop("subgraph0") -driver_vae._subgraphs["subgraph0_vae"] = driver_vae._subgraphs.pop("subgraph0") -driver_vae._subgraphs_inputs[ - "subgraph0_vae" -] = driver_vae._subgraphs_inputs.pop("subgraph0") -driver_vae._subgraphs_outputs[ - "subgraph0_vae" -] = driver_vae._subgraphs_outputs.pop("subgraph0") - -driver_text_encoder.subgraphs[0]._func_name = "subgraph0_text_encoder" -driver_unet.subgraphs[0]._func_name = "subgraph0_unet" -driver_vae.subgraphs[0]._func_name = "subgraph0_vae" - driver_text_encoder.subgraphs[0].lower_to_top_level_ir() driver_unet.subgraphs[0].lower_to_top_level_ir() driver_vae.subgraphs[0].lower_to_top_level_ir() # Save output files to specified directory -with open(os.path.join(output_dir, "subgraph0_text_encoder.mlir"), "w") as module_file: +with open( + os.path.join(output_dir, "subgraph0_text_encoder.mlir"), "w" +) as module_file: print(driver_text_encoder.subgraphs[0]._imported_module, file=module_file) -with open(os.path.join(output_dir, "forward_text_encoder.mlir"), "w") as module_file: +with open( + os.path.join(output_dir, "forward_text_encoder.mlir"), "w" +) as module_file: print(driver_text_encoder.construct_main_graph(True), file=module_file) with open(os.path.join(output_dir, "subgraph0_unet.mlir"), "w") as module_file: From 152cb3d523c36a7a99a7c08760c25d8d01c4ffe6 Mon Sep 17 00:00:00 2001 From: Junyi Mei Date: Tue, 25 Feb 2025 19:49:40 +0800 Subject: [PATCH 04/13] [Test] Add test cases for python binding (#464) * [Test] Add testcases for python binding Signed-off-by: Junyi Mei * [Test] Add testcase for conv-vectorization Signed-off-by: Junyi Mei * [Test] More detailed check for conv-vectorization Signed-off-by: Junyi Mei * [Test] Add testcase for matmul-vectorization Signed-off-by: Junyi Mei * [Test] Verify module ops in python binding Signed-off-by: Junyi Mei --------- Signed-off-by: Junyi Mei --- tests/Python/dialects/test_dap_fir.py | 86 +++++++ tests/Python/dialects/test_dap_iir.py | 101 ++++++++ tests/Python/dialects/test_dip_corr2d.py | 184 ++++++++++++++ tests/Python/dialects/test_dip_resize.py | 309 +++++++++++++++++++++++ tests/Python/dialects/test_dip_rotate.py | 80 ++++++ tests/Python/dialects/test_gemmini.py | 72 ++++++ tests/Python/dialects/test_rvv.py | 81 ++++++ tests/Python/passes/conv2d.py | 91 +++++++ tests/Python/passes/matmul.py | 77 ++++++ tests/Python/test_python.py | 49 ---- 10 files changed, 1081 insertions(+), 49 deletions(-) create mode 100644 tests/Python/dialects/test_dap_fir.py create mode 100644 tests/Python/dialects/test_dap_iir.py create mode 100644 tests/Python/dialects/test_dip_corr2d.py create mode 100644 tests/Python/dialects/test_dip_resize.py create mode 100644 tests/Python/dialects/test_dip_rotate.py create mode 100644 tests/Python/dialects/test_gemmini.py create mode 100644 tests/Python/dialects/test_rvv.py create mode 100644 tests/Python/passes/conv2d.py create mode 100644 tests/Python/passes/matmul.py delete mode 100644 tests/Python/test_python.py diff --git a/tests/Python/dialects/test_dap_fir.py b/tests/Python/dialects/test_dap_fir.py new file mode 100644 index 0000000000..df34c87409 --- /dev/null +++ b/tests/Python/dialects/test_dap_fir.py @@ -0,0 +1,86 @@ +# RUN: %PYTHON %s | FileCheck %s + +from buddy_mlir.dialects import dap, func +from buddy_mlir import ir + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +def dapFir(dtype, context: ir.Context) -> ir.Module: + with context, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + memref = ir.MemRefType.get( + [ir.ShapedType.get_dynamic_size()], dtype + ) + + @func.FuncOp.from_py_func(memref, memref, memref) + def buddy_fir(in_, filter, out): + dap.fir(in_, filter, out) + return + + return module + + +# CHECK-LABEL: TEST: testDapFirF32 +@run +def testDapFirF32(): + with ir.Context() as context: + module = dapFir(ir.F32Type.get(), context) + module.operation.verify() + # CHECK: dap.fir {{.*}} : memref, memref, memref + print(module) + + +# CHECK-LABEL TEST: testDapFirF64 +@run +def testDapFirF64(): + with ir.Context() as context: + module = dapFir(ir.F64Type.get(), context) + module.operation.verify() + # CHECK: dap.fir {{.*}} : memref, memref, memref + print(module) + + +# CHECK-LABEL: TEST: testDapFirI8 +@run +def testDapFirI8(): + with ir.Context() as context: + module = dapFir(ir.IntegerType.get_signless(8), context) + module.operation.verify() + # CHECK: dap.fir {{.*}} : memref, memref, memref + print(module) + + +# CHECK-LABEL: TEST: testDapFirI16 +@run +def testDapFirI16(): + with ir.Context() as context: + module = dapFir(ir.IntegerType.get_signless(16), context) + module.operation.verify() + # CHECK: dap.fir {{.*}} : memref, memref, memref + print(module) + + +# CHECK-LABEL: TEST: testDapFirI32 +@run +def testDapFirI32(): + with ir.Context() as context: + module = dapFir(ir.IntegerType.get_signless(32), context) + module.operation.verify() + # CHECK: dap.fir {{.*}} : memref, memref, memref + print(module) + + +# CHECK-LABEL: TEST: testDapFirI64 +@run +def testDapFirI64(): + with ir.Context() as context: + module = dapFir(ir.IntegerType.get_signless(64), context) + module.operation.verify() + # CHECK: dap.fir {{.*}} : memref, memref, memref + print(module) diff --git a/tests/Python/dialects/test_dap_iir.py b/tests/Python/dialects/test_dap_iir.py new file mode 100644 index 0000000000..5e9f00a47e --- /dev/null +++ b/tests/Python/dialects/test_dap_iir.py @@ -0,0 +1,101 @@ +# RUN: %PYTHON %s | FileCheck %s + +from buddy_mlir.dialects import dap, func +from buddy_mlir import ir +from buddy_mlir.passmanager import PassManager + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +def dapIir(dtype, context: ir.Context) -> ir.Module: + with context, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + memref1d = ir.MemRefType.get( + [ir.ShapedType.get_dynamic_size()], dtype + ) + + memref2d = ir.MemRefType.get( + [ + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ], + dtype, + ) + + @func.FuncOp.from_py_func(memref1d, memref2d, memref1d) + def buddy_iir(in_, filter, out): + dap.iir(in_, filter, out) + return + + return module + + +# CHECK-LABEL: TEST: testDapIirF32 +@run +def testDapIirF32(): + with ir.Context() as context: + module = dapIir(ir.F32Type.get(), context) + module.operation.verify() + # CHECK: dap.iir {{.*}} : memref, memref, memref + print(module) + + pm = PassManager("builtin.module") + pm.add("lower-dap") + pm.run(module.operation) + + print(module) + + +# CHECK-LABEL TEST: testDapIirF64 +@run +def testDapIirF64(): + with ir.Context() as context: + module = dapIir(ir.F64Type.get(), context) + module.operation.verify() + # CHECK: dap.iir {{.*}} : memref, memref, memref + print(module) + + +# CHECK-LABEL: TEST: testDapIirI8 +@run +def testDapIirI8(): + with ir.Context() as context: + module = dapIir(ir.IntegerType.get_signless(8), context) + module.operation.verify() + # CHECK: dap.iir {{.*}} : memref, memref, memref + print(module) + + +# CHECK-LABEL: TEST: testDapIirI16 +@run +def testDapIirI16(): + with ir.Context() as context: + module = dapIir(ir.IntegerType.get_signless(16), context) + module.operation.verify() + # CHECK: dap.iir {{.*}} : memref, memref, memref + print(module) + + +# CHECK-LABEL: TEST: testDapIirI32 +@run +def testDapIirI32(): + with ir.Context() as context: + module = dapIir(ir.IntegerType.get_signless(32), context) + module.operation.verify() + # CHECK: dap.iir {{.*}} : memref, memref, memref + print(module) + + +# CHECK-LABEL: TEST: testDapIirI64 +@run +def testDapIirI64(): + with ir.Context() as context: + module = dapIir(ir.IntegerType.get_signless(64), context) + module.operation.verify() + # CHECK: dap.iir {{.*}} : memref, memref, memref + print(module) diff --git a/tests/Python/dialects/test_dip_corr2d.py b/tests/Python/dialects/test_dip_corr2d.py new file mode 100644 index 0000000000..c74b4f1322 --- /dev/null +++ b/tests/Python/dialects/test_dip_corr2d.py @@ -0,0 +1,184 @@ +# RUN: %PYTHON %s | FileCheck %s + +from buddy_mlir.dialects import dip, func +from buddy_mlir import ir + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +def dipCorr2D( + dtype, boundary_option: ir.Attribute, context: ir.Context +) -> ir.Module: + with context, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + index = ir.IndexType.get() + memref = ir.MemRefType.get( + [ + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ], + dtype, + ) + + @func.FuncOp.from_py_func( + memref, memref, memref, index, index, dtype + ) + def buddy_corr2d( + input, identity, output, kernelAnchorX, kernelAnchorY, c + ): + dip.corr_2d( + input, + identity, + output, + kernelAnchorX, + kernelAnchorY, + c, + boundary_option, + ) + + return module + + +# CHECK-LABEL: TEST: testDipCorr2DConstantPaddingF32 +@run +def testDipCorr2DConstantPaddingF32(): + with ir.Context() as context: + module = dipCorr2D( + ir.F32Type.get(), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d {{.*}} : memref, memref, memref, index, index, f32 + print(module) + + +# CHECK-LABEL: TEST: testDipCorr2DConstantPaddingF64 +@run +def testDipCorr2DConstantPaddingF64(): + with ir.Context() as context: + module = dipCorr2D( + ir.F64Type.get(), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d {{.*}} : memref, memref, memref, index, index, f64 + print(module) + + +# CHECK-LABEL: TEST: testDipCorr2DConstantPaddingI8 +@run +def testDipCorr2DConstantPaddingI8(): + with ir.Context() as context: + module = dipCorr2D( + ir.IntegerType.get_signless(8), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d {{.*}} : memref, memref, memref, index, index, i8 + print(module) + + +# CHECK-LABEL: TEST: testDipCorr2DConstantPaddingI32 +@run +def testDipCorr2DConstantPaddingI32(): + with ir.Context() as context: + module = dipCorr2D( + ir.IntegerType.get_signless(32), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d {{.*}} : memref, memref, memref, index, index, i32 + print(module) + + +# CHECK-LABEL: TEST: testDipCorr2DConstantPaddingI64 +@run +def testDipCorr2DConstantPaddingI64(): + with ir.Context() as context: + module = dipCorr2D( + ir.IntegerType.get_signless(64), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d , memref, memref, index, index, i64 + print(module) + + +# CHECK-LABEL: TEST: testDipCorr2DReplicatePaddingF32 +@run +def testDipCorr2DReplicatePaddingF32(): + with ir.Context() as context: + module = dipCorr2D( + ir.F32Type.get(), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d {{.*}} : memref, memref, memref, index, index, f32 + print(module) + + +# CHECK-LABEL: TEST: testDipCorr2DReplicatePaddingF64 +@run +def testDipCorr2DReplicatePaddingF64(): + with ir.Context() as context: + module = dipCorr2D( + ir.F64Type.get(), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d {{.*}} : memref, memref, memref, index, index, f64 + print(module) + + +# CHECK-LABEL: TEST: testDipCorr2DReplicatePaddingI8 +@run +def testDipCorr2DReplicatePaddingI8(): + with ir.Context() as context: + module = dipCorr2D( + ir.IntegerType.get_signless(8), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d {{.*}} : memref, memref, memref, index, index, i8 + print(module) + + +# CHECK-LABEL: TEST: testDipCorr2DReplicatePaddingI32 +@run +def testDipCorr2DReplicatePaddingI32(): + with ir.Context() as context: + module = dipCorr2D( + ir.IntegerType.get_signless(32), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d {{.*}} : memref, memref, memref, index, index, i32 + print(module) + + +# CHECK-LABEL: TEST: testDipCorr2DReplicatePaddingI64 +@run +def testDipCorr2DReplicatePaddingI64(): + with ir.Context() as context: + module = dipCorr2D( + ir.IntegerType.get_signless(64), + ir.Attribute.parse("#dip>"), + context, + ) + module.operation.verify() + # CHECK: dip.corr_2d {{.*}} : memref, memref, memref, index, index, i64 + print(module) diff --git a/tests/Python/dialects/test_dip_resize.py b/tests/Python/dialects/test_dip_resize.py new file mode 100644 index 0000000000..fbf8dd33f6 --- /dev/null +++ b/tests/Python/dialects/test_dip_resize.py @@ -0,0 +1,309 @@ +# RUN: %PYTHON %s | FileCheck %s + +from buddy_mlir.dialects import dip, func +from buddy_mlir import ir + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +def dipResize2D( + dtype, interpolation_attr: ir.Attribute, context: ir.Context +) -> ir.Module: + with context, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + f32 = ir.F32Type.get() + memref = ir.MemRefType.get( + [ + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ], + dtype, + ) + + @func.FuncOp.from_py_func(memref, f32, f32, memref) + def buddy_resize2d( + input, + horizontal_scaling_factor, + vertical_scaling_factor, + output, + ): + dip.resize_2d( + input, + horizontal_scaling_factor, + vertical_scaling_factor, + output, + interpolation_attr, + ) + + return module + + +def dipResize4DNchw( + dtype, + interpolation_attr: ir.Attribute, + context: ir.Context, +) -> ir.Module: + with context, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + f32 = ir.F32Type.get() + memref = ir.MemRefType.get( + [ + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ], + dtype, + ) + + @func.FuncOp.from_py_func(memref, f32, f32, memref) + def buddy_resize4d( + input, + horizontal_scaling_factor, + vertical_scaling_factor, + output, + ): + dip.resize_4d_nchw( + input, + horizontal_scaling_factor, + vertical_scaling_factor, + output, + interpolation_attr, + ) + + return module + + +def dipResize4DNhwc( + dtype, + interpolation_attr: ir.Attribute, + context: ir.Context, +) -> ir.Module: + with context, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + f32 = ir.F32Type.get() + memref = ir.MemRefType.get( + [ + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ], + dtype, + ) + + @func.FuncOp.from_py_func(memref, f32, f32, memref) + def buddy_resize4d( + input, + horizontal_scaling_factor, + vertical_scaling_factor, + output, + ): + dip.resize_4d_nhwc( + input, + horizontal_scaling_factor, + vertical_scaling_factor, + output, + interpolation_attr, + ) + + return module + + +# CHECK-LABEL: TEST: testDipResize2DNearestInpterpolationF32 +@run +def testDipResize2DNearestInpterpolationF32(): + with ir.Context() as context: + module = dipResize2D( + ir.F32Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_2d NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize2DNearestInpterpolationF64 +@run +def testDipResize2DNearestInpterpolationF64(): + with ir.Context() as context: + module = dipResize2D( + ir.F64Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_2d NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize2DBilinearF32 +@run +def testDipResize2DBilinearF32(): + with ir.Context() as context: + module = dipResize2D( + ir.F32Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_2d BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize2DBilinearF64 +@run +def testDipResize2DBilinearF64(): + with ir.Context() as context: + module = dipResize2D( + ir.F64Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_2d BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize4DNchwNearestInterpolationF32 +@run +def testDipResize4DNchwNearestInterpolationF32(): + with ir.Context() as context: + module = dipResize4DNchw( + ir.F32Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize4DNchwNearestInterpolationF64 +@run +def testDipResize4DNchwNearestInterpolationF64(): + with ir.Context() as context: + module = dipResize4DNchw( + ir.F64Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize4DNchwBilinearF32 +@run +def testDipResize4DNchwBilinearF32(): + with ir.Context() as context: + module = dipResize4DNchw( + ir.F32Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_4d_nchw BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize4DNchwBilinearF64 +@run +def testDipResize4DNchwBilinearF64(): + with ir.Context() as context: + module = dipResize4DNchw( + ir.F64Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_4d_nchw BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize4DNhwcNearestInterpolationF32 +@run +def testDipResize4DNhwcNearestInterpolationF32(): + with ir.Context() as context: + module = dipResize4DNhwc( + ir.F32Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize4DNhwcNearestInterpolationF64 +@run +def testDipResize4DNhwcNearestInterpolationF64(): + with ir.Context() as context: + module = dipResize4DNhwc( + ir.F64Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize4DNhwcBilinearF32 +@run +def testDipResize4DNhwcBilinearF32(): + with ir.Context() as context: + module = dipResize4DNhwc( + ir.F32Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_4d_nhwc BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipResize4DNhwcBilinearF64 +@run +def testDipResize4DNhwcBilinearF64(): + with ir.Context() as context: + module = dipResize4DNhwc( + ir.F64Type.get(), + ir.Attribute.parse( + "#dip" + ), + context, + ) + module.operation.verify() + # CHECK: dip.resize_4d_nhwc BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + print(module) diff --git a/tests/Python/dialects/test_dip_rotate.py b/tests/Python/dialects/test_dip_rotate.py new file mode 100644 index 0000000000..7a7e484bde --- /dev/null +++ b/tests/Python/dialects/test_dip_rotate.py @@ -0,0 +1,80 @@ +# RUN: %PYTHON %s | FileCheck %s + +from buddy_mlir.dialects import dip, func +from buddy_mlir import ir + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +def dipRotate2D(dtype, context: ir.Context) -> ir.Module: + with context, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + f32 = ir.F32Type.get() + memref = ir.MemRefType.get( + [ + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ], + dtype, + ) + + @func.FuncOp.from_py_func(memref, f32, memref) + def buddy_resize2d(input, angle, output): + dip.rotate_2d(input, angle, output) + + return module + + +# CHECK-LABEL: TEST: testDipRotate2DF32 +@run +def testDipRotate2DF32(): + with ir.Context() as context: + module = dipRotate2D(ir.F32Type.get(), context) + module.operation.verify() + # CHECK: dip.rotate_2d {{.*}} : memref, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipRotate2DF64 +@run +def testDipRotate2DF64(): + with ir.Context() as context: + module = dipRotate2D(ir.F64Type.get(), context) + module.operation.verify() + # CHECK: dip.rotate_2d {{.*}} : memref, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipRotate2DI8 +@run +def testDipRotate2DI8(): + with ir.Context() as context: + module = dipRotate2D(ir.IntegerType.get_signless(8), context) + module.operation.verify() + # CHECK: dip.rotate_2d {{.*}} : memref, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipRotate2DI32 +@run +def testDipRotate2DI32(): + with ir.Context() as context: + module = dipRotate2D(ir.IntegerType.get_signless(32), context) + module.operation.verify() + # CHECK: dip.rotate_2d {{.*}} : memref, f32, memref + print(module) + + +# CHECK-LABEL: TEST: testDipRotate2DI64 +@run +def testDipRotate2DI64(): + with ir.Context() as context: + module = dipRotate2D(ir.IntegerType.get_signless(64), context) + module.operation.verify() + # CHECK: dip.rotate_2d {{.*}} : memref, f32, memref + print(module) diff --git a/tests/Python/dialects/test_gemmini.py b/tests/Python/dialects/test_gemmini.py new file mode 100644 index 0000000000..9d90801f59 --- /dev/null +++ b/tests/Python/dialects/test_gemmini.py @@ -0,0 +1,72 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from buddy_mlir import ir +from buddy_mlir.dialects import arith, linalg, memref, func +from buddy_mlir.passmanager import PassManager + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +# CHECK-LABEL: TEST: testLinalgMatmulConversion +@run +def testLinalgMatmulConversion(): + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + i8 = ir.IntegerType.get_signless(8) + + memref8x8 = ir.MemRefType.get([8, 8], i8) + + with ir.InsertionPoint(module.body): + c1 = arith.ConstantOp(i8, 1) + c2 = arith.ConstantOp(i8, 2) + + mem0 = memref.alloc(memref8x8, [], []) + mem1 = memref.alloc(memref8x8, [], []) + mem2 = memref.alloc(memref8x8, [], []) + + linalg.fill(c2, outs=[mem0]) + linalg.fill(c1, outs=[mem1]) + linalg.matmul(mem0, mem1, outs=[mem2]) + + module.operation.verify() + + pm = PassManager("builtin.module") + pm.add("convert-linalg-to-gemmini") + pm.run(module.operation) + + # CHECK: gemmini.tile_matmul [[MEM0:%.*]] [[MEM1:%.*]] [[MEM2:%.*]] : + # CHECK-SAME: memref<8x8xi8> memref<8x8xi8> memref<8x8xi8> memref<8x8xi32> + print(module) + + +# CHECK-LABEL TEST: testLinalgConv2DConversion +@run +def testLinalgConv2DConversion(): + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + f32 = ir.F32Type.get() + + memref2x2x5x5 = ir.MemRefType.get([2, 2, 5, 5], f32) + memref2x2x3x3 = ir.MemRefType.get([2, 2, 3, 3], f32) + + with ir.InsertionPoint(module.body): + + @func.FuncOp.from_py_func(memref2x2x5x5, memref2x2x3x3) + def linalg_conv2d(input, weight): + mem2 = memref.alloc(memref2x2x3x3, [], []) + linalg.conv_2d_nchw_fchw(input, weight, outs=[mem2]) + + return + + module.operation.verify() + + pm = PassManager("builtin.module") + pm.add("convert-linalg-to-gemmini") + pm.run(module.operation) + # CHECK: gemmini.tile_conv {{.+}} : + # CHECK-SAME: memref<2x5x5x2xf32> memref<18x2xf32> memref<2xi32> memref<18x2xf32> i64 i64 + print(module) diff --git a/tests/Python/dialects/test_rvv.py b/tests/Python/dialects/test_rvv.py new file mode 100644 index 0000000000..e83d024b3e --- /dev/null +++ b/tests/Python/dialects/test_rvv.py @@ -0,0 +1,81 @@ +# RUN: %PYTHON %s | FileCheck %s + +from buddy_mlir.dialects import arith, func, rvv +from buddy_mlir import ir +from buddy_mlir.passmanager import PassManager + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +# CHECK-LABEL: TEST: testRVVLegalizeForLLVM +@run +def testRVVLegalizeForLLVM(): + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + index_type = ir.IndexType.get() + + @func.FuncOp.from_py_func(index_type, results=[index_type]) + def rvv_setvl(avl): + # SEW = 32 + sew = arith.ConstantOp(index_type, 2) + # LMUL = 2 + lmul = arith.ConstantOp(index_type, 2) + vl = rvv.setvl(index_type, avl, sew, lmul) + func.return_([vl]) + return + + module.operation.verify() + + # CHECK: func @rvv_setvl(%[[ARG0:.*]]: index) -> index { + # CHECK-NEXT: %[[SEW:.*]] = arith.constant 2 : index + # CHECK-NEXT: %[[LMUL:.*]] = arith.constant 2 : index + # CHECK-NEXT: rvv.setvl %[[ARG0]], %[[SEW]], %[[LMUL]] : index + # CHECK-NEXT: return %[[RESULT:.*]] : index + # CHECK-NEXT: } + print(module) + + pm = PassManager("builtin.module") + pm.add("lower-rvv") + pm.run(module.operation) + # CHECK: rvv.intr.vsetvli{{.*}} : (i64, i64, i64) -> i64 + print(module) + + +# CHECK-LABEL: TEST: testRVVRsqrtLegalizeForLLVM +@run +def testRVVRsqrtLegalizeForLLVM(): + with ir.Context(): + module = ir.Module.parse( + """ + func.func @rvv_rsqrt(%arg0: memref) { + %c0 = arith.constant 0 : index + + %sew = arith.constant 2 : index + %lmul = arith.constant 1 : index + %avl8 = arith.constant 8 : index + %vl8 = rvv.setvl %avl8, %sew, %lmul : index + + %load_vec = rvv.load %arg0[%c0], %vl8 : memref, vector<[4]xf32>, index + %rsqrt_vec = math.rsqrt %load_vec : vector<[4]xf32> + rvv.store %rsqrt_vec, %arg0[%c0], %vl8 : vector<[4]xf32>, memref, index + + return + } + """ + ) + + module.operation.verify() + + pm = PassManager("builtin.module") + pm.add("lower-rvv") + pm.run(module.operation) + + # CHECK: rvv.intr.vsetvli{{.*}} : (i64, i64, i64) -> i64 + # CHECK: rvv.intr.vle{{.*}} : (vector<[4]xf32>, !llvm.ptr>, i64) -> vector<[4]xf32> + # CHECK: rvv.intr.vse{{.*}} : (vector<[4]xf32>, !llvm.ptr>, i64) -> () + print(module) diff --git a/tests/Python/passes/conv2d.py b/tests/Python/passes/conv2d.py new file mode 100644 index 0000000000..732603df3d --- /dev/null +++ b/tests/Python/passes/conv2d.py @@ -0,0 +1,91 @@ +# RUN: %PYTHON %s | FileCheck %s + +from buddy_mlir.dialects import func, linalg +from buddy_mlir import ir +from buddy_mlir.passmanager import PassManager + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +# CHECK-LABEL: TEST: testConv2DVectorize +@run +def testConv2DVectorize(): + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + memref = ir.MemRefType.get( + [ + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ], + ir.F32Type.get(), + ) + + @func.FuncOp.from_py_func(memref, memref, memref) + def conv2d(input, kernel, output): + linalg.conv_2d(input, kernel, outs=[output]) + return + + module.operation.verify() + + pm = PassManager("builtin.module") + pm.add("conv-vectorization{strip-mining=32}") + pm.add("cse") + pm.run(module.operation) + + # Check the affine maps and function signature + # CHECK: #map = affine_map<(d0) -> (d0)> + # CHECK: #map1 = affine_map<(d0) -> (d0 ceildiv 32)> + # CHECK: func.func @conv2d(%[[IN:.+]]: memref, %[[KERNEL:.+]]: memref, %[[OUT:.+]]: memref) { + # CHECK: %[[ZERO:.+]] = arith.constant 0 : index + # CHECK: %[[ONE:.+]] = arith.constant 1 : index + # CHECK: %[[C32:.+]] = arith.constant 32 : index + # CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 + + # Check the pass-through vector for masked loads and stores + # CHECK: %[[PASS_THRU:.+]] = vector.splat %[[CST]] : vector<32xf32> + + # Check the dims + # CHECK: %[[KERNEL_DIM0:.+]] = memref.dim %[[KERNEL]], %[[ZERO]] : memref + # CHECK: %[[KERNEL_DIM1:.+]] = memref.dim %[[KERNEL]], %[[ONE]] : memref + # CHECK: %[[OUT_DIM0:.+]] = memref.dim %[[OUT]], %[[ZERO]] : memref + # CHECK: %[[OUT_DIM1:.+]] = memref.dim %[[OUT]], %[[ONE]] : memref + + # Check the vectorized loop nest + # CHECK: affine.for %[[I:.+]] = #map(%[[ZERO]]) to #map(%[[OUT_DIM0]]) { + # CHECK-NEXT: affine.for %[[J:.+]] = #map(%[[ZERO]]) to #map(%[[KERNEL_DIM0]]) { + # CHECK-NEXT: affine.for %[[K:.+]] = #map(%[[ZERO]]) to #map(%[[KERNEL_DIM1]]) { + # CHECK-NEXT: affine.for %[[L:.+]] = #map(%[[ZERO]]) to #map1(%[[OUT_DIM1]]) { + # CHECK: %[[KERNEL_VAL:.+]] = memref.load %[[KERNEL]][%[[J]], %[[K]]] : memref + # CHECK: %[[COND:.+]] = arith.cmpf one, %[[KERNEL_VAL]], %{{.+}} : f32 + # CHECK: scf.if %[[COND]] { + # CHECK-NEXT: %[[BROADCAST:.+]] = vector.broadcast %[[KERNEL_VAL]] : f32 to vector<32xf32> + # CHECK-NEXT: %[[CURRENT:.+]] = arith.muli %[[L]], %[[C32]] : index + # CHECK-NEXT: %[[REMAINDER:.+]] = arith.subi %[[OUT_DIM1]], %[[CURRENT]] : index + # CHECK-NEXT: %[[COND:.+]] = arith.cmpi sge, %[[REMAINDER]], %[[C32]] : index + # CHECK-NEXT: scf.if %[[COND]] { + # CHECK-NEXT: %[[IN_VEC:.+]] = affine.vector_load %[[IN]][%[[I]] + %[[J]], %[[K]] + %[[L]] * 32] : memref, vector<32xf32> + # CHECK-NEXT: %[[OUT_VEC:.+]] = affine.vector_load %[[OUT]][%[[I]], %[[L]] * 32] : memref, vector<32xf32> + # CHECK-NEXT: %[[FMA:.+]] = vector.fma %[[IN_VEC]], %[[BROADCAST]], %[[OUT_VEC]] : vector<32xf32> + # CHECK-NEXT: affine.vector_store %[[FMA]], %[[OUT]][%[[I]], %[[L]] * 32] : memref, vector<32xf32> + # CHECK-NEXT: } else { + # CHECK-NEXT: %[[MASK:.+]] = vector.create_mask %[[REMAINDER]] : vector<32xi1> + # CHECK-NEXT: %[[INPUT_ROW:.+]] = arith.addi %[[I]], %[[J]] : index + # CHECK-NEXT: %[[INPUT_COL:.+]] = arith.addi %[[K]], %[[CURRENT]] : index + # CHECK-NEXT: %[[IN_VEC:.+]] = vector.maskedload %[[IN]][%[[INPUT_ROW]], %[[INPUT_COL]]], %[[MASK]], %[[PASS_THRU]] : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + # CHECK-NEXT: %[[OUT_VEC:.+]] = vector.maskedload %[[OUT]][%[[I]], %[[CURRENT]]], %[[MASK]], %[[PASS_THRU]] : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + # CHECK-NEXT: %[[FMA:.+]] = vector.fma %[[IN_VEC]], %[[BROADCAST]], %[[OUT_VEC]] : vector<32xf32> + # CHECK-NEXT: vector.maskedstore %[[OUT]][%[[I]], %[[CURRENT]]], %[[MASK]], %[[FMA]] : memref, vector<32xi1>, vector<32xf32> + # CHECK-NEXT: } + # CHECK-NEXT: } + # CHECK-NEXT: } + # CHECK-NEXT: } + # CHECK-NEXT: } + # CHECK-NEXT: } + # CHECK-NEXT: return + # CHECK-NEXT: } + print(module) diff --git a/tests/Python/passes/matmul.py b/tests/Python/passes/matmul.py new file mode 100644 index 0000000000..4a41785985 --- /dev/null +++ b/tests/Python/passes/matmul.py @@ -0,0 +1,77 @@ +# RUN: %PYTHON %s | FileCheck %s + +from buddy_mlir.dialects import func, linalg +from buddy_mlir import ir +from buddy_mlir.passmanager import PassManager + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +# CHECK-LABEL: TEST: testMatmulVectorize +@run +def testMatmulVectorize(): + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + memref = ir.MemRefType.get( + [ + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + ], + ir.F32Type.get(), + ) + + @func.FuncOp.from_py_func(memref, memref, memref) + def matmul(a, b, c): + linalg.matmul(a, b, outs=[c]) + return + + module.operation.verify() + + pm = PassManager("builtin.module") + pm.add("matmul-vectorization{vector-size=32}") + pm.add("cse") + pm.run(module.operation) + + # CHECK: #map = affine_map<(d0) -> (d0)> + # CHECK: #map1 = affine_map<(d0) -> (d0 ceildiv 32)> + # CHECK: func.func @matmul(%[[A:.+]]: memref, %[[B:.+]]: memref, %[[C:.+]]: memref) { + # CHECK: %[[ZERO:.+]] = arith.constant 0 : index + # CHECK: %[[ONE:.+]] = arith.constant 1 : index + # CHECK: %[[C32:.+]] = arith.constant 32 : index + # CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 + # CHECK: %[[PASS_THRU:.+]] = vector.splat %[[CST]] : vector<32xf32> + # CHECK: %[[DIM:.+]] = memref.dim %[[A]], %[[ZERO]] : memref + # CHECK: %[[DIM_0:.+]] = memref.dim %[[B]], %[[ZERO]] : memref + # CHECK: %[[DIM_1:.+]] = memref.dim %[[B]], %[[ONE]] : memref + # CHECK-NEXT: affine.for %[[I:.+]] = #map(%[[ZERO]]) to #map(%[[DIM_0]]) { + # CHECK-NEXT: affine.for %[[J:.+]] = #map(%[[ZERO]]) to #map(%[[DIM]]) { + # CHECK-NEXT: affine.for %[[K:.+]] = #map(%[[ZERO]]) to #map1(%[[DIM_1]]) { + # CHECK-NEXT: %[[A_VAL:.+]] = memref.load %[[A]][%[[J]], %[[I]]] : memref + # CHECK-NEXT: %[[BROADCAST:.+]] = vector.broadcast %[[A_VAL]] : f32 to vector<32xf32> + # CHECK-NEXT: %[[CURRENT:.+]] = arith.muli %[[K]], %[[C32]] : index + # CHECK-NEXT: %[[REMAINDER:.+]] = arith.subi %[[DIM_1]], %[[CURRENT]] : index + # CHECK-NEXT: %[[COND:.+]] = arith.cmpi sge, %[[REMAINDER]], %[[C32]] : index + # CHECK-NEXT: scf.if %[[COND]] { + # CHECK-NEXT: %[[B_VEC:.+]] = affine.vector_load %[[B]][%[[I]], %[[K]] * 32] : memref, vector<32xf32> + # CHECK-NEXT: %[[C_VEC:.+]] = affine.vector_load %[[C]][%[[J]], %[[K]] * 32] : memref, vector<32xf32> + # CHECK-NEXT: %[[FMA:.+]] = vector.fma %[[BROADCAST]], %[[B_VEC]], %[[C_VEC]] : vector<32xf32> + # CHECK-NEXT: affine.vector_store %[[FMA]], %[[C]][%[[J]], %[[K]] * 32] : memref, vector<32xf32> + # CHECK-NEXT: } else { + # CHECK-NEXT: %[[MASK:.+]] = vector.create_mask %[[REMAINDER]] : vector<32xi1> + # CHECK-NEXT: %[[B_VEC:.+]] = vector.maskedload %[[B]][%[[I]], %[[CURRENT]]], %[[MASK]], %[[PASS_THRU]] : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + # CHECK-NEXT: %[[C_VEC:.+]] = vector.maskedload %[[C]][%[[J]], %[[CURRENT]]], %[[MASK]], %[[PASS_THRU]] : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + # CHECK-NEXT: %[[FMA:.+]] = vector.fma %[[BROADCAST]], %[[B_VEC]], %[[C_VEC]] : vector<32xf32> + # CHECK-NEXT: vector.maskedstore %[[C]][%[[J]], %[[CURRENT]]], %[[MASK]], %[[FMA]] : memref, vector<32xi1>, vector<32xf32> + # CHECK-NEXT: } + # CHECK-NEXT: } + # CHECK-NEXT: } + # CHECK-NEXT: } + # CHECK-NEXT: return + # CHECK-NEXT: } + + print(module) diff --git a/tests/Python/test_python.py b/tests/Python/test_python.py deleted file mode 100644 index 00a7eebbb4..0000000000 --- a/tests/Python/test_python.py +++ /dev/null @@ -1,49 +0,0 @@ -# RUN: %PYTHON %s 2>&1 | FileCheck %s - -from buddy_mlir.ir import Context, Module -from buddy_mlir.passmanager import PassManager - - -with Context(): - mod = Module.parse( - """ - %0 = arith.constant 0 : i8 - %1 = arith.constant 1 : i8 - %2 = arith.constant 2 : i8 - %mem0 = memref.alloc() : memref<8x8xi8> - %mem1 = memref.alloc() : memref<8x8xi8> - %mem2 = memref.alloc() : memref<8x8xi8> - linalg.fill - ins(%2 : i8) - outs(%mem0 : memref<8x8xi8>) - linalg.fill - ins(%1 : i8) - outs(%mem1 : memref<8x8xi8>) - linalg.matmul - ins(%mem0, %mem1 : memref<8x8xi8>, memref<8x8xi8>) - outs(%mem2 : memref<8x8xi8>) - gemmini.print %mem2 : memref<8x8xi8> - """ - ) - - pm = PassManager("builtin.module") - pm.add("convert-linalg-to-gemmini") - pm.run(mod.operation) - - # CHECK: module { - # CHECK: %c0_i8 = arith.constant 0 : i8 - # CHECK: %c1_i8 = arith.constant 1 : i8 - # CHECK: %c2_i8 = arith.constant 2 : i8 - # CHECK: %alloc = memref.alloc() : memref<8x8xi8> - # CHECK: %alloc_0 = memref.alloc() : memref<8x8xi8> - # CHECK: %alloc_1 = memref.alloc() : memref<8x8xi8> - # CHECK: linalg.fill ins(%c2_i8 : i8) outs(%alloc : memref<8x8xi8>) - # CHECK: linalg.fill ins(%c1_i8 : i8) outs(%alloc_0 : memref<8x8xi8>) - # CHECK: %alloc_2 = memref.alloc() : memref<8x8xi32> - # CHECK: %c0_i32 = arith.constant 0 : i32 - # CHECK: linalg.fill ins(%c0_i32 : i32) outs(%alloc_2 : memref<8x8xi32>) - # CHECK: gemmini.tile_matmul %alloc %alloc_0 %alloc_1 %alloc_2 : memref<8x8xi8> memref<8x8xi8> memref<8x8xi8> memref<8x8xi32> - # CHECK: memref.dealloc %alloc_2 : memref<8x8xi32> - # CHECK: gemmini.print %alloc_1 : memref<8x8xi8> - # CHECK: } - print(str(mod)) From c8360ca7f794f559dd05e85769a6d66b49f30ef9 Mon Sep 17 00:00:00 2001 From: Junyi Mei Date: Tue, 25 Feb 2025 19:52:38 +0800 Subject: [PATCH 05/13] [midend] Optimization on matmul-transpose-b vectorization (#465) Current vectorization pass of matmul-transpose-b reduce the vector in each iteration and accumulate it to the result element. This commit modify it into elementwise addition and do the reduction after the inner loop with reassoc enabled. Signed-off-by: Junyi Mei --- .../MatMulTransposeBVec.cpp | 316 +++++++++--------- .../matmul-transpose-b-vectorization.mlir | 101 ++++++ 2 files changed, 266 insertions(+), 151 deletions(-) create mode 100644 tests/Conversion/matmul-transpose-b-vectorization.mlir diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp index 345de4c1de..2f81414a33 100644 --- a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp +++ b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp @@ -19,12 +19,16 @@ //===----------------------------------------------------------------------===// #include +#include #include #include +#include +#include #include #include #include #include +#include #include #include "Utils/Utils.h" @@ -37,132 +41,138 @@ using namespace vector; //===----------------------------------------------------------------------===// namespace { -class MatMulTransposeBVecPattern : public ConversionPattern{ +class MatMulTransposeBVecPattern : public ConversionPattern { public: - explicit MatMulTransposeBVecPattern(MLIRContext *context,int64_t vecSizeparam) - : ConversionPattern(linalg::MatmulTransposeBOp::getOperationName(),1,context){ - vecSize = vecSizeparam; - } - - LogicalResult - matchAndRewrite(Operation *op,ArrayRef /*operands*/, - ConversionPatternRewriter &rewriter) const override{ - auto loc = op->getLoc(); - auto ctx = op->getContext(); - // Get input A, B, C. - Value A = op->getOperand(0); - Value B = op->getOperand(1); - Value C = op->getOperand(2); - - // Get shape of input and output. - ShapedType ATy = A.getType().cast(); - Type eleTy = ATy.getElementType(); - - // the element type for mask vector. - IntegerType i1 = IntegerType::get(ctx, 1); - - VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); - VectorType vectorMaskTy = VectorType::get({vecSize}, i1); - - const Value c0 = - rewriter.create(loc, rewriter.getIndexAttr(0)); - const Value c1 = - rewriter.create(loc, rewriter.getIndexAttr(1)); - const Value step = rewriter.create(loc, vecSize); - - const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); - Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); - - const Value aRow = rewriter.create(loc, A, c0); - const Value bRow = rewriter.create(loc, B, c0); - const Value bCol = rewriter.create(loc, B, c1); - - AffineExpr d0; - bindDims(ctx, d0); - AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); - SmallVector lowerBounds(2, c0); - SmallVector uperBounds{aRow, bRow}; - SmallVector steps(2, 1); - // clang-format off - affine::buildAffineLoopNest( - rewriter, loc, lowerBounds, uperBounds, steps, - [&](OpBuilder &builder, Location loc, ValueRange ivs) { - // Create loop based on vector size. - builder.create( - loc, ValueRange{c0}, builder.getDimIdentityMap(), - ValueRange{bCol}, vecTailMap, 1, std::nullopt, - [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, - ValueRange itrArgs) { - AffineExpr a,b,c; - bindDims(ctx, a,b,c); - AffineMap AVectorMap = AffineMap::get( - /*dimCount=*/3, /*symbolCount=*/0, {a, c * vecSize}, ctx); - // Check tail. - AffineExpr m, n, k; - bindDims(ctx, m, n, k); - AffineMap BVectorMap = AffineMap::get( - /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); - - // Calculate the tail. - Value bColCur = builder.create(loc, iv, step); - Value tailLen = builder.create(loc, bCol, bColCur); - Value tailFlag = rewriter.create( - loc, arith::CmpIPredicate::sge, tailLen, step); - // If the current column does not reach the tail. - builder.create(loc, tailFlag, - [&](OpBuilder &builder, Location loc) { - Value aVec = builder.create( - loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); - Value bVec = builder.create( - loc, vectorTy, B, BVectorMap, ValueRange{ivs[1], ivs[1], iv}); - Value resvec = builder.create(loc,aVec,bVec); - Value res1 = builder.create( - loc,mlir::vector::CombiningKind::ADD,resvec); - Value res2 = builder.create( - loc, C, ValueRange{ivs[0], ivs[1]}); - Value sum = builder.create(loc, res1, res2); - builder.create(loc, sum, - C, ValueRange{ivs[0], ivs[1]}); - builder.create(loc); - }, - // The else branch - [&](OpBuilder &builder, Location loc) { - // TODO: remove this value and operation? - // Value aVec = builder.create( - // loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); - builder.create( - loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); - // Create mask according to the tail. - Value maskVec = builder.create( - loc, vectorMaskTy, tailLen); - Value ColIdxTail = builder.create(loc, iv, step); - - Value aVecTail = builder.create( - loc, vectorTy, A, ValueRange{ivs[0], ColIdxTail}, - maskVec, passthruVec); - - Value bVecTail = builder.create( - loc, vectorTy, B, ValueRange{ivs[1], ColIdxTail}, - maskVec, passthruVec); - - Value resvec = builder.create(loc,aVecTail,bVecTail); - Value res1 = builder.create( - loc,mlir::vector::CombiningKind::ADD,resvec); - Value res2 = builder.create( - loc, C, ValueRange{ivs[0], ivs[1]}); - Value sum = builder.create(loc, res1, res2); - builder.create(loc, sum, C, ValueRange{ivs[0], ivs[1]}); - builder.create(loc); - }); - builder.create(loc); - }); + explicit MatMulTransposeBVecPattern(MLIRContext *context, + int64_t vecSizeparam) + : ConversionPattern(linalg::MatmulTransposeBOp::getOperationName(), 1, + context) { + vecSize = vecSizeparam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + // Get input A, B, C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Get shape of input and output. + ShapedType ATy = A.getType().cast(); + Type eleTy = ATy.getElementType(); + + // the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + + VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); + VectorType vectorMaskTy = VectorType::get({vecSize}, i1); + + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value step = rewriter.create(loc, vecSize); + + const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); + Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); + + const Value aRow = rewriter.create(loc, A, c0); + const Value bRow = rewriter.create(loc, B, c0); + const Value bCol = rewriter.create(loc, B, c1); + + AffineExpr d0; + bindDims(ctx, d0); + AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); + SmallVector lowerBounds(2, c0); + SmallVector uperBounds{aRow, bRow}; + SmallVector steps(2, 1); + + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create loop based on vector size. + auto innerLoop = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{bCol}, vecTailMap, 1, ValueRange{passthruVec}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + Value acc = itrArgs[0]; + + AffineExpr a, b, c; + bindDims(ctx, a, b, c); + AffineMap AVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {a, c * vecSize}, ctx); + // Check tail. + AffineExpr m, n, k; + bindDims(ctx, m, n, k); + AffineMap BVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); + + // Calculate the tail. + Value bColCur = builder.create(loc, iv, step); + Value tailLen = + builder.create(loc, bCol, bColCur); + Value tailFlag = rewriter.create( + loc, arith::CmpIPredicate::sge, tailLen, step); + // If the current column does not reach the tail. + auto ifOp = builder.create( + loc, tailFlag, + [&](OpBuilder &builder, Location loc) { + Value aVec = builder.create( + loc, vectorTy, A, AVectorMap, + ValueRange{ivs[0], ivs[1], iv}); + Value bVec = builder.create( + loc, vectorTy, B, BVectorMap, + ValueRange{ivs[1], ivs[1], iv}); + Value resvec = + builder.create(loc, aVec, bVec); + Value newAcc = + builder.create(loc, acc, resvec); + builder.create(loc, newAcc); + }, + // The else branch + [&](OpBuilder &builder, Location loc) { + // Create mask according to the tail. + Value maskVec = builder.create( + loc, vectorMaskTy, tailLen); + Value ColIdxTail = + builder.create(loc, iv, step); + + Value aVecTail = builder.create( + loc, vectorTy, A, ValueRange{ivs[0], ColIdxTail}, + maskVec, passthruVec); + + Value bVecTail = builder.create( + loc, vectorTy, B, ValueRange{ivs[1], ColIdxTail}, + maskVec, passthruVec); + + Value resvec = builder.create( + loc, aVecTail, bVecTail); + Value newAcc = + builder.create(loc, acc, resvec); + builder.create(loc, newAcc); + }); + builder.create(loc, ifOp.getResult(0)); + }); + + Value load = builder.create( + loc, C, ValueRange{ivs[0], ivs[1]}); + Value reduction = builder.create( + loc, CombiningKind::ADD, innerLoop->getResult(0), load, + arith::FastMathFlags::reassoc); + builder.create(loc, reduction, C, + ValueRange{ivs[0], ivs[1]}); }); - // clang-format on - rewriter.eraseOp(op); - return success(); - } + + rewriter.eraseOp(op); + return success(); + } + private: - int64_t vecSize; + int64_t vecSize; }; } // end anonymous namespace @@ -170,41 +180,45 @@ class MatMulTransposeBVecPattern : public ConversionPattern{ // MatMulVectorizationPass //===----------------------------------------------------------------------===// -namespace{ - class MatMulTransposeBVecPass - :public PassWrapper>{ +namespace { +class MatMulTransposeBVecPass + : public PassWrapper> { public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulTransposeBVecPass) - StringRef getArgument() const final{ return "matmul-transpose-b-vectorization"; } - StringRef getDescription() const final { return "vectorize linalg MatmulTransposeBOp"; } - MatMulTransposeBVecPass() = default; - MatMulTransposeBVecPass(const MatMulTransposeBVecPass &) {} - void runOnOperation() override; - void getDependentDialects(DialectRegistry ®istry) const override{ - registry.insert(); - } - Option vecSize{*this,"vec-size", - llvm::cl::desc("The size of vectorization"), - llvm::cl::init(32)}; - + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulTransposeBVecPass) + StringRef getArgument() const final { + return "matmul-transpose-b-vectorization"; + } + StringRef getDescription() const final { + return "vectorize linalg MatmulTransposeBOp"; + } + MatMulTransposeBVecPass() = default; + MatMulTransposeBVecPass(const MatMulTransposeBVecPass &) {} + void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option vecSize{*this, "vec-size", + llvm::cl::desc("The size of vectorization"), + llvm::cl::init(32)}; }; -} +} // namespace -void MatMulTransposeBVecPass::runOnOperation(){ - MLIRContext *context = &getContext(); - ModuleOp module = getOperation(); +void MatMulTransposeBVecPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); - ConversionTarget target(*context); - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); - RewritePatternSet patterns(context); - patterns.add(context,vecSize); + RewritePatternSet patterns(context); + patterns.add(context, vecSize); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/tests/Conversion/matmul-transpose-b-vectorization.mlir b/tests/Conversion/matmul-transpose-b-vectorization.mlir new file mode 100644 index 0000000000..391c7ce84d --- /dev/null +++ b/tests/Conversion/matmul-transpose-b-vectorization.mlir @@ -0,0 +1,101 @@ +// RUN: buddy-opt %s \ +// RUN: -matmul-transpose-b-vectorization="vec-size=64" \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module{ + func.func private @printMemrefF32(memref<*xf32>) + func.func private @printMemrefF64(memref<*xf64>) + + func.func @matmul_f32(%a : memref, %b : memref, %c : memref) { + linalg.matmul_transpose_b + ins(%a, %b: memref, memref) + outs(%c:memref) + return + } + + func.func @matmul_f64(%a : memref, %b : memref, %c : memref) { + linalg.matmul_transpose_b + ins(%a, %b: memref, memref) + outs(%c:memref) + return + } + + func.func @main(){ + // Set up dims. + %cM = arith.constant 4 : index + %cN = arith.constant 4 : index + %cK = arith.constant 4 : index + + //-------------------------------------------------------------------------- + // Test f32 as element type. + //-------------------------------------------------------------------------- + + // Set Init Value. + %cf1_32 = arith.constant 1.0 : f32 + + %A_f32 = memref.alloc(%cM, %cK) : memref + %B_f32 = memref.alloc(%cK, %cN) : memref + %C_f32 = memref.alloc(%cM, %cN) : memref + + linalg.fill ins(%cf1_32 : f32) outs(%A_f32 : memref) + linalg.fill ins(%cf1_32 : f32) outs(%B_f32 : memref) + linalg.fill ins(%cf1_32 : f32) outs(%C_f32 : memref) + + call @matmul_f32(%A_f32, %B_f32, %C_f32) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C_f32 = memref.cast %C_f32 : memref to memref<*xf32> + call @printMemrefF32(%print_C_f32) : (memref<*xf32>) -> () + + memref.dealloc %C_f32 : memref + memref.dealloc %B_f32 : memref + memref.dealloc %A_f32 : memref + + //-------------------------------------------------------------------------- + // Test f64 as element type. + //-------------------------------------------------------------------------- + + // Set Init Value. + %cf1_64 = arith.constant 1.0 : f64 + + %A_f64 = memref.alloc(%cM, %cK) : memref + %B_f64 = memref.alloc(%cK, %cN) : memref + %C_f64 = memref.alloc(%cM, %cN) : memref + + linalg.fill ins(%cf1_64 : f64) outs(%A_f64 : memref) + linalg.fill ins(%cf1_64 : f64) outs(%B_f64 : memref) + linalg.fill ins(%cf1_64 : f64) outs(%C_f64 : memref) + + call @matmul_f64(%A_f64, %B_f64, %C_f64) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C_f64 = memref.cast %C_f64 : memref to memref<*xf64> + call @printMemrefF64(%print_C_f64) : (memref<*xf64>) -> () + + memref.dealloc %C_f64 : memref + memref.dealloc %B_f64 : memref + memref.dealloc %A_f64 : memref + + return + } +} From ed1ee24ea80b87865ce4adf426c5cc24d0d7d8ad Mon Sep 17 00:00:00 2001 From: lishiyang <3158015337@qq.com> Date: Tue, 25 Feb 2025 20:06:05 +0800 Subject: [PATCH 06/13] [examples] Add RVV code generation makefile targets. (#453) * [examples] Add MLIR vector dialect to RVV code generation makefile target Signed-off-by: asdf1113 <3158015337@qq.com> * [docs] Add RVV instruction support docs Mapping MLIR Vector Operations to RVV Instructions --------- Signed-off-by: asdf1113 <3158015337@qq.com> --- docs/RVVInstructionSupport.md | 329 +++++++++++++++++++++++++++++++++ examples/MLIRVector/makefile | 336 ++++++++++++++++++++++++++++++++++ 2 files changed, 665 insertions(+) create mode 100644 docs/RVVInstructionSupport.md diff --git a/docs/RVVInstructionSupport.md b/docs/RVVInstructionSupport.md new file mode 100644 index 0000000000..2fd89f5c73 --- /dev/null +++ b/docs/RVVInstructionSupport.md @@ -0,0 +1,329 @@ +[TOC] + +## Vector Loads and Stores Instructions + +| MLIR Operation | Generated RVV Instruction | Remarks | +| -------------------- | ------------------------- | ----------------------------------------------- | +| `vector.load` | `vle` | Supports unit stride vector load | +| `vector.maskedload` | Masked `vle` | Supports unit stride vector load with mask | +| `vector.gather` | `vluxei` | Indexed load, but `vloxei` is unsupported | +| `vector.store` | `vse` | Supports unit stride vector store | +| `vector.maskedstore` | Masked `vse` | Supports unit stride vector store with mask | +| `vector.scatter` | `vsoxei` | Indexed store, but `vsuxei` is unsupported | + +- `vector.scatter` -> `llvm.masked.scatter` -> `vsoxei` + +> Scatter with overlapping addresses is guaranteed to be ordered from least-significant to most-significant element. + +- `vector.gather` -> `llvm.masked.gather` -> `vluxei` + + + +### Unsupported Instructions + +- **Masked Register Load:** + - `vlm` + - `vsm` +- **Vector Strided:** + - `vlse` + - `vsse` +- **Unit-stride Fault-Only-First Loads:** + - `vleff` +- **Vector Load/Store Segment Instructions:** + - `vlsege` + - `vlssege` + - `vluxsegei` + - `vloxsegei` + - `vssege` + - `vsssege` + - `vsuxsegei` + - `vsoxsegei` +- **Indexed Load/Store:** + - `vsuxei` + - `vloxei` +- **Vector Load/Store Whole Register:** + - `vlre` + - `vsr` + + + +## Vector Reduction Instructions + +- `vector.reduction` with `vector.mask` supports masked reduction instructions: + - `vector.reduction` -> `llvm.vector.reduce` + + - `vector.reduction` + `vector.mask` -> `llvm.vp.reduce` + +- By setting the attribute `fastmath`, `vector.reduction` can generate `vfredusum` and `vfredosum` instructions. + + + +### Unsupported Instructions + +- **Vector Widening Integer Reduction:** + - `vwredsumu` + - `vwredsum` +- **Vector Widening Floating-Point Reduction:** + - `vfwredusum` + - `vfwredosum` + + + +## Vector Arithmetic Instructions + +- `vector.mask` implements the `MaskingOpInterface` to predicate another operation: + + > The `vector.mask` is a `MaskingOpInterface` operation that predicates the execution of another operation. It takes an `i1` vector mask and an optional passthru vector as arguments. + > + > A implicitly `vector.yield`-terminated region encloses the operation to be masked. Values used within the region are captured from above. Only one *maskable* operation can be masked with a `vector.mask` operation at a time. An operation is *maskable* if it implements the `MaskableOpInterface`. The terminator yields all results of the maskable operation to the result of this operation. + +- However, the following code doesn't lower to a valid implementation: + + ``` + %0 = vector.mask %mask, %passthru { arith.divsi %a, %b : vector<8xi32> } : vector<8xi1> -> vector<8xi32> + ``` + + + +- All vector arithmetic instructions support only `.vv` format; widening or narrowing operations are unsupported. For example, the `addi` operation: + + > The `addi` operation takes two operands and returns one result, each of these is required to be the same type. + + + +- `vector.fma` implements fused multiply-add, with parameters restricted to floating-point values. The conversion flow is::`vector.fma` → `llvm.fmuladd` → `vfmadd.vv` + + + +### Vector Integer Arithmetic Instructions + +- **Vector Single-Width Integer Add and Subtract:** + - `vadd` + - `vsub` + - `vrsub` +- **Vector Integer Extension:** + - `vzext` + - `vsext` +- **Vector Integer Add-with-Carry / Subtract-with-Borrow:** + - `vadc` + - `vmadc` + - `vsbc` + - `vmsbc` +- **Vector Bitwise Logical Instructions:** + - `vand` + - `vor` + - `vxor` +- **Vector Single-Width Shift Instructions:** + - `vsll` + - `vsrl` + - `vsra` +- **Vector Integer Compare Instructions:** + - `vmseq` + - `vmsne` + - `vmsltu` + - `vmslt` + - `vmsleu` + - `vmsle` + - `vmsgtu` + - `vmsgt` +- **Vector Integer Min/Max Instructions:** + - `vminu` + - `vmin` + - `vmaxu` + - `vmax` +- **Vector Single-Width Integer Multiply Instructions:** + - `vmulhu` + - `vmul` + - `vmulhsu` + - `vmulh` +- **Vector Integer Divide Instructions:** + - `vdivu` + - `vdiv` + - `vremu` + - `vrem` +- **Vector Single-Width Integer Multiply-Add Instructions:** + - `vmacc` + - `vnmsac` + - `vmadd` +- **Vector Narrowing Instructions:** + - `vnmsub` + - `vnsrl` + - `vnsra` +- **Vector Widening Instructions:** + - `vwaddu` + - `vwadd` + - `vwsubu` + - `vwsub` + - `vwmulu` + - `vwmulsu` + - `vwmul` + - `vwmaccu` + - `vwmacc` + + + +### Vector Fixed-Point Arithmetic Instructions + +- **Vector Single-Width Saturating Add and Subtract:** + - `vsadd` + - `vssubu` + - `vssub` + - `vsaddu` +- **Vector Single-Width Averaging Add and Subtract:** + - `vaaddu` + - `vaadd` + - `vasubu` + - `vasub` +- **Vector Single-Width Fractional Multiply with Rounding and Saturation:** + - `vsmul` + +- **Vector Single-Width Scaling Shift Instructions:** + - `vssrl` + - `vssra` + +- **Vector Narrowing Fixed-Point Clip Instructions:** + - `vnclipu` + - `vnclip` + + + +### Vector Floating-Point Instructions + +- **Vector Single-Width Floating-Point Add/Subtract:** + - `vfadd` + - `vfsub` + - `vfrsub` +- **Vector Single-Width Floating-Point Multiply/Divide:** + - `vfdiv` + - `vfrdiv` + - `vfmul` +- **Vector Single-Width Floating-Point Fused Multiply-Add:** + - `vfmadd` + - `vfnmadd` + - `vfmsub` + - `vfnmsub` + - `vfmacc` + - `vfnmacc` + - `vfmsac` + - `vfnmsac` +- **Vector Floating-Point Square-Root:** + - `vfsqrt` +- **Vector Floating-Point Reciprocal Square-Root Estimate:** + - `vfrsqrt7` +- **Vector Floating-Point Reciprocal Estimate:** + - `vfrec7` +- **Vector Floating-Point Min/Max:** + - `vfmin` + - `vfmax` +- **Vector Floating-Point Sign-Injection:** + - `vfsgnj` + - `vfsgnjn` + - `vfsgnjx` +- **Vector Floating-Point Compare Instructions:** + - `vmfeq` + - `vmfle` + - `vmflt` + - `vmfne` + - `vmfgt` + - `vmfge` +- **Vector Floating-Point Classify:** + - `vfclass` +- **Vector Widening Instructions:** + - `vfwadd` + - `vfwsub` + - `vfwmul` + - `vfwmacc` + - `vfwnmacc` +- **Single-Width Floating-Point/Integer Type-Convert:** + - `vfcvt` +- **Widening Floating-Point/Integer Type-Convert:** + - `vfwcvt` +- **Narrowing Floating-Point/Integer Type-Convert:** + - `vfncvt` + + + +## Vector Mask Instructions + +- `vector.step` vs RVV `vid` instructions : + + | | `vector.step` | `vid.v` | + | ---------------- | -------------------------- | -------------------------------- | + | **Function** | Generate a linear sequence | Generate a linear index sequence | + | **Range** | `[0, N-1]` | `[0, vl-1]` | + | **Mask Support** | Not supported | Supported | + + > In `LowerVectorStep.cpp`, `arith.constant` replaced the `vector.step` operation, so it's untested whether `vector.step` can generate `vid.v` instructions. + + + +### Unsupported Instructions + +- **Vector Element Index Instruction:** + - `vid` +- **Vector Mask-Register Logical Instructions:** + - `vmandnot` + - `vmand` + - `vmor` + - `vmxor` + - `vmornot` + - `vmnand` + - `vmnor` + - `vmxnor` +- **Vector Count Population in Mask:** + - `vcpop.m` +- **Find-First-Set Mask Bit:** + - `vfirst` +- **Set-Before-First Mask Bit:** + - `vmsbf` +- **Set-Including-First Mask Bit:** + - `vmsif` +- **Set-Only-First Mask Bit:** + - `vmsof` +- **Vector Iota Instruction:** + - `viota` + + + +## Vector Permutation Instructions + +- `vector.broadcast` supports broadcasting from scalar or lower-dimensional vectors to higher-dimensional vectors. `vector.splat` extends a scalar value to all elements of a result vector. Both share overlapping functionality, as they map to `vmv.v.x` and `vfmv.v.f` instructions. + + + +### Unsupported Instructions + +- **Scalar Move Instructions:** + - `vmv.x.s` + - `vfmv.f.s` + - `vmv.s.x` + - `vfmv.s.f` +- **Vector Slide Instructions:** + - `vslideup` + - `vslide1up` + - `vfslide1up` + - `vslidedown` + - `vslide1down` + - `vfslide1down` +- **Vector Register Gather Instructions:** + - `vrgather` + - `vrgatherei16` +- **Vector Compress Instruction:** + - `vcompress` +- **Whole Vector Register Move:** + - `vmvr` +- **Vector Move Instructions:** + - `vmerge` + - `vmv.v.v` + - `vfmerge` + - `vfmv.v.v` + + + +## Configuration-Setting Instructions + +The vector dialect doesn't have corresponding operations, so they cannot be directly generated. + +- `vsetvli` +- `vsetivli` +- `vsetvl` diff --git a/examples/MLIRVector/makefile b/examples/MLIRVector/makefile index 327dfbfaad..4ee5fc99df 100644 --- a/examples/MLIRVector/makefile +++ b/examples/MLIRVector/makefile @@ -92,6 +92,18 @@ vector-broadcast-asm-rv: -mattr=+m,+d,+v -riscv-v-vector-bits-min=128 \ --filetype=asm -o log.s +vector-broadcast-asm-rvv: + @${MLIR_OPT} ./vector-broadcast.mlir \ + --convert-vector-to-scf --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-broadcast-run vector-broadcast-run: @${MLIR_OPT} ./vector-broadcast.mlir \ @@ -133,6 +145,18 @@ vector-fma-asm-rv: -mattr=+m,+d,+v -riscv-v-vector-bits-min=128 \ --filetype=asm -o log.s +vector-fma-asm-rvv: + @${MLIR_OPT} ./vector-fma.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-fma-run vector-fma-run: @${MLIR_OPT} ./vector-fma.mlir \ @@ -174,6 +198,18 @@ vector-long-asm-rv: -mattr=+m,+d,+v -riscv-v-vector-bits-min=128 \ --filetype=asm -o log.s +vector-long-asm-rvv: + @${MLIR_OPT} ./vector-long.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-long-run vector-long-run: @${MLIR_OPT} ./vector-long.mlir \ @@ -196,6 +232,18 @@ vector-transpose-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-transpose-asm-rvv: + @${MLIR_OPT} ./vector-transpose.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-transpose-run vector-transpose-run: @${MLIR_OPT} ./vector-transpose.mlir \ @@ -218,6 +266,18 @@ vector-shape-cast-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-shape-cast-asm-rvv: + @${MLIR_OPT} ./vector-shape-cast.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-shape-cast-run vector-shape-cast-run: @${MLIR_OPT} ./vector-shape-cast.mlir \ @@ -241,6 +301,18 @@ vector-type-cast-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-type-cast-asm-rvv: + @${MLIR_OPT} ./vector-type-cast.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-type-cast-run vector-type-cast-run: @${MLIR_OPT} ./vector-type-cast.mlir \ @@ -264,6 +336,18 @@ vector-bitcast-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-bitcast-asm-rvv: + @${MLIR_OPT} ./vector-bitcast.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-bitcast-run vector-bitcast-run: @${MLIR_OPT} ./vector-bitcast.mlir \ @@ -286,6 +370,18 @@ vector-shuffle-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-shuffle-asm-rvv: + @${MLIR_OPT} ./vector-shuffle.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-shuffle-run vector-shuffle-run: @${MLIR_OPT} ./vector-shuffle.mlir \ @@ -309,6 +405,18 @@ vector-splat-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-splat-asm-rvv: + @${MLIR_OPT} ./vector-splat.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-splat-run vector-splat-run: @${MLIR_OPT} ./vector-splat.mlir \ @@ -332,6 +440,18 @@ vector-insert-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-insert-asm-rvv: + @${MLIR_OPT} ./vector-insert.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-insert-run vector-insert-run: @${MLIR_OPT} ./vector-insert.mlir \ @@ -355,6 +475,18 @@ vector-reduction-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-reduction-asm-rvv: + @${MLIR_OPT} ./vector-reduction.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-reduction-run vector-reduction-run: @${MLIR_OPT} ./vector-reduction.mlir \ @@ -378,6 +510,18 @@ vector-outerproduct-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-outerproduct-asm-rvv: + @${MLIR_OPT} ./vector-outerproduct.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-outerproduct-run vector-outerproduct-run: @${MLIR_OPT} ./vector-outerproduct.mlir \ @@ -401,6 +545,18 @@ vector-create-mask-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-create-mask-asm-rvv: + @${MLIR_OPT} ./vector-create-mask.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-create-mask-run vector-create-mask-run: @${MLIR_OPT} ./vector-create-mask.mlir \ @@ -423,6 +579,18 @@ vector-extract-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-extract-asm-rvv: + @${MLIR_OPT} ./vector-extract.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-extract-run vector-extract-run: @${MLIR_OPT} ./vector-extract.mlir \ @@ -445,6 +613,18 @@ vector-maskedload-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-maskedload-asm-rvv: + @${MLIR_OPT} ./vector-maskedload.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-maskedload-run vector-maskedload-run: @${MLIR_OPT} ./vector-maskedload.mlir \ @@ -468,6 +648,18 @@ vector-maskedstore-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-maskedstore-asm-rvv: + @${MLIR_OPT} ./vector-maskedstore.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-maskedstore-run vector-maskedstore-run: @${MLIR_OPT} ./vector-maskedstore.mlir \ @@ -491,6 +683,18 @@ vector-extract-strided-slice-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-extract-strided-slice-asm-rvv: + @${MLIR_OPT} ./vector-maskedstore.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-extract-strided-slice-run vector-extract-strided-slice-run: @${MLIR_OPT} ./vector-extract-strided-slice.mlir \ @@ -513,6 +717,18 @@ vector-constant-mask-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-constant-mask-asm-rvv: + @${MLIR_OPT} ./vector-constant-mask.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-constant-mask-run vector-constant-mask-run: @${MLIR_OPT} ./vector-constant-mask.mlir \ @@ -535,6 +751,18 @@ vector-expandload-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-expandload-asm-rvv: + @${MLIR_OPT} ./vector-expandload.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-expandload-run vector-expandload-run: @${MLIR_OPT} ./vector-expandload.mlir \ @@ -557,6 +785,18 @@ vector-compressstore-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-compressstore-asm-rvv: + @${MLIR_OPT} ./vector-compressstore.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-compressstore-run vector-compressstore-run: @${MLIR_OPT} ./vector-compressstore.mlir \ @@ -579,6 +819,18 @@ vector-insert-strided-slice-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-insert-strided-slice-asm-rvv: + @${MLIR_OPT} ./vector-insert-strided-slice.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-insert-strided-slice-run vector-insert-strided-slice-run: @${MLIR_OPT} ./vector-insert-strided-slice.mlir \ @@ -601,6 +853,18 @@ vector-scatter-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-scatter-asm-rvv: + @${MLIR_OPT} ./vector-scatter.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-scatter-run vector-scatter-run: @${MLIR_OPT} ./vector-scatter.mlir \ @@ -624,6 +888,18 @@ vector-gather-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-gather-asm-rvv: + @${MLIR_OPT} ./vector-gather.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-gather-run vector-gather-run: @${MLIR_OPT} ./vector-gather.mlir \ @@ -647,6 +923,18 @@ vector-transfer-read-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-transfer-read-asm-rvv: + @${MLIR_OPT} ./vector-transfer-read.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-transfer-read-run vector-transfer-read-run: @${MLIR_OPT} ./vector-transfer-read.mlir \ @@ -669,6 +957,18 @@ vector-transfer-write-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-transfer-write-asm-rvv: + @${MLIR_OPT} ./vector-transfer-write.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-load-run vector-transfer-write-run: @${MLIR_OPT} ./vector-transfer-write.mlir \ @@ -691,6 +991,18 @@ vector-contract-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-contract-asm-rvv: + @${MLIR_OPT} ./vector-contract.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O0 -S \ + -o log.s + run-targets += vector-load-run vector-contract-run: @${MLIR_OPT} ./vector-contract.mlir \ @@ -713,6 +1025,18 @@ vector-store-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-store-asm-rvv: + @${MLIR_OPT} ./vector-store.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-store-run vector-store-run: @${MLIR_OPT} ./vector-store.mlir \ @@ -737,6 +1061,18 @@ vector-iteration-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-iteration-asm-rvv: + @${MLIR_OPT} ./vector-iteration.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + vector-iteration-run: @${MLIR_OPT} ./vector-iteration.mlir \ --lower-affine \ From 277c4b8a16d8e12ee9cd81d03574d69599be2151 Mon Sep 17 00:00:00 2001 From: xlinsist Date: Tue, 25 Feb 2025 15:31:39 +0000 Subject: [PATCH 07/13] [frontend] Fix tokenizeDeepSeekR1 in TextContainer --- frontend/Interfaces/buddy/LLM/TextContainer.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/Interfaces/buddy/LLM/TextContainer.h b/frontend/Interfaces/buddy/LLM/TextContainer.h index c2870f4cd3..943bad043f 100644 --- a/frontend/Interfaces/buddy/LLM/TextContainer.h +++ b/frontend/Interfaces/buddy/LLM/TextContainer.h @@ -433,11 +433,11 @@ void Text::tokenizeDeepSeekR1(const std::string &vocab, size_t length) { const int userToken = 151644; const int assistantToken = 151645; const int thinkToken = 151648; - + + tokenCnt = 0; this->aligned[tokenCnt++] = bos; this->aligned[tokenCnt++] = bos; this->aligned[tokenCnt++] = userToken; - tokenCnt = 2; // Load Vocab loadVocab(vocab); From 1ec1dde371b9ef21b01ac0d469a6ebbb63b04052 Mon Sep 17 00:00:00 2001 From: HarryZ <122276274+hharryz@users.noreply.github.com> Date: Wed, 26 Feb 2025 12:45:47 +0800 Subject: [PATCH 08/13] [midend] Add Rotate4D to DIP dialect. (#368) * [feat] accelerate rotate2d, add rotate4d affine part. * [feat] finish rotate_4d(nhwc and nchw format) * [modify] add test, fix code format. * clean up code format using pre-commit * [fix] force inline in DIP header file to avoid link failure * [fix] format code with pre-commit --- examples/DIPDialect/CMakeLists.txt | 3 + examples/DIPDialect/rotation4D.cpp | 104 +++++ frontend/Interfaces/buddy/DIP/DIP.h | 109 +++++ frontend/Interfaces/buddy/DIP/ImgContainer.h | 69 +-- frontend/Interfaces/lib/DIP.mlir | 12 + midend/include/Dialect/DIP/DIPOps.td | 46 ++ midend/include/Utils/AffineTransformUtils.h | 22 +- midend/include/Utils/DIPUtils.h | 10 +- .../lib/Conversion/LowerDIP/LowerDIPPass.cpp | 97 +++-- midend/lib/Utils/AffineTransformUtils.cpp | 411 ++++++++++++++---- midend/lib/Utils/DIPUtils.cpp | 70 ++- tests/Dialect/DIP/rotate4D_lowering.mlir | 14 + tests/Dialect/DIP/rotate4D_roundtrip.mlir | 62 +++ 13 files changed, 878 insertions(+), 151 deletions(-) create mode 100644 examples/DIPDialect/rotation4D.cpp create mode 100644 tests/Dialect/DIP/rotate4D_lowering.mlir create mode 100644 tests/Dialect/DIP/rotate4D_roundtrip.mlir diff --git a/examples/DIPDialect/CMakeLists.txt b/examples/DIPDialect/CMakeLists.txt index 2f897ad633..7b2f075c9c 100644 --- a/examples/DIPDialect/CMakeLists.txt +++ b/examples/DIPDialect/CMakeLists.txt @@ -23,6 +23,9 @@ endif() add_executable(rotation2D rotation2D.cpp) target_link_libraries(rotation2D ${DIP_LIBS}) +add_executable(rotation4D rotation4D.cpp) +target_link_libraries(rotation4D ${DIP_LIBS}) + add_executable(resize2D resize2D.cpp) target_link_libraries(resize2D ${DIP_LIBS}) diff --git a/examples/DIPDialect/rotation4D.cpp b/examples/DIPDialect/rotation4D.cpp new file mode 100644 index 0000000000..074b7cf4c1 --- /dev/null +++ b/examples/DIPDialect/rotation4D.cpp @@ -0,0 +1,104 @@ +//====- rotation4D.cpp - Example of buddy-opt tool ===========================// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements a 4D rotation example with dip.rotate_4d operation. +// The dip.rotate_4d operation will be compiled into an object file with the +// buddy-opt tool. +// This file will be linked with the object file to generate the executable +// file. +// +//===----------------------------------------------------------------------===// + +#include "buddy/DIP/imgcodecs/loadsave.h" +#include +#include +#include +#include +#include + +using namespace std; + +bool testImplementation(int argc, char *argv[]) { + const int inputBatch = 1; + + // Read as color image in [HWC] format + Img input = dip::imread(argv[1], dip::IMGRD_COLOR); + const int inputHeight = input.getSizes()[0]; + const int inputWidth = input.getSizes()[1]; + const int inputChannels = input.getSizes()[2]; + const int inputStride = inputHeight * inputWidth * inputChannels; + + // Image format is [NHWC] + intptr_t inputSizes_NHWC[4] = {inputBatch, inputHeight, inputWidth, + inputChannels}; + Img inputImages_NHWC(inputSizes_NHWC); + + auto imagePtr = inputImages_NHWC.getData(); + memcpy(imagePtr, input.getData(), inputStride * sizeof(float)); + for (int i = 1; i < inputBatch; i++) { + Img input = dip::imread(argv[1], dip::IMGRD_COLOR); + memcpy(imagePtr + i * inputStride, input.getData(), + inputStride * sizeof(float)); + } + + MemRef output = dip::Rotate4D( + &inputImages_NHWC, 30, dip::ANGLE_TYPE::DEGREE, dip::IMAGE_FORMAT::NHWC); + + const int outoutHeight = output.getSizes()[1]; + const int outputWidth = output.getSizes()[2]; + intptr_t outputSizes_NHWC[4] = {inputBatch, outoutHeight, outputWidth, + inputChannels}; + const int outputStride = outoutHeight * outputWidth * inputChannels; + Img outputImages_NHWC(output.getData(), outputSizes_NHWC); + + for (int i = 0; i < inputBatch; i++) { + intptr_t imageSizes[3] = {outoutHeight, outputWidth, inputChannels}; + Img outputImage(outputImages_NHWC.getData() + i * outputStride, + imageSizes); + dip::imwrite(argv[2], outputImage); + } + + // Image Format is [NCHW] + // Rearrange memory layout + intptr_t inputSizes_NCHW[4] = {inputBatch, inputChannels, inputHeight, + inputWidth}; + Img inputImages_NCHW(inputSizes_NCHW); + dip::detail::Transpose(&inputImages_NCHW, &inputImages_NHWC, + {0, 3, 1, 2}); + output = dip::Rotate4D(&inputImages_NCHW, 30, dip::ANGLE_TYPE::DEGREE, + dip::IMAGE_FORMAT::NCHW); + + intptr_t outputSizes_NCHW[4] = {inputBatch, inputChannels, outoutHeight, + outputWidth}; + Img outputImages_NCHW(output.getData(), outputSizes_NCHW); + dip::detail::Transpose(&outputImages_NHWC, &outputImages_NCHW, + {0, 2, 3, 1}); + + for (int i = 0; i < inputBatch; i++) { + intptr_t imageSizes[3] = {outoutHeight, outputWidth, inputChannels}; + Img outputImage(outputImages_NHWC.getData() + i * outputStride, + imageSizes); + dip::imwrite(argv[2], outputImage); + } + + return 1; +} + +int main(int argc, char *argv[]) { + testImplementation(argc, argv); + + return 0; +} diff --git a/frontend/Interfaces/buddy/DIP/DIP.h b/frontend/Interfaces/buddy/DIP/DIP.h index 8598b61fc5..84ab596341 100644 --- a/frontend/Interfaces/buddy/DIP/DIP.h +++ b/frontend/Interfaces/buddy/DIP/DIP.h @@ -41,6 +41,9 @@ enum class INTERPOLATION_TYPE { BILINEAR_INTERPOLATION }; +// Available formats for 4D images. +enum class IMAGE_FORMAT { NHWC, NCHW }; + namespace detail { // Functions present inside dip::detail are not meant to be called by users // directly. @@ -66,6 +69,13 @@ void _mlir_ciface_corrfft_2d(MemRef *inputReal, void _mlir_ciface_rotate_2d(Img *input, float angleValue, MemRef *output); +// Declare the Rotate4D C interface. +void _mlir_ciface_rotate_4d_nhwc(Img *input, float angleValue, + MemRef *output); + +void _mlir_ciface_rotate_4d_nchw(Img *input, float angleValue, + MemRef *output); + // Declare the Resize2D C interface. void _mlir_ciface_resize_2d_nearest_neighbour_interpolation( Img *input, float horizontalScalingFactor, @@ -200,6 +210,59 @@ inline void padKernel(MemRef *kernel, unsigned int centerX, } } +template +void Transpose(MemRef *output, MemRef *input, + const std::vector &axes) { + std::vector inputDims(D); + for (int i = 0; i < D; ++i) { + inputDims[i] = input->getSizes()[i]; + } + + std::vector outputDims(D); + for (int i = 0; i < D; ++i) { + outputDims[i] = inputDims[axes[i]]; + } + + const T *inputData = input->getData(); + T *outputData = output->getData(); + + std::vector inputStrides(D); + inputStrides[D - 1] = 1; + for (int i = D - 2; i >= 0; --i) { + inputStrides[i] = inputStrides[i + 1] * inputDims[i + 1]; + } + + std::vector outputStrides(D); + outputStrides[D - 1] = 1; + for (int i = D - 2; i >= 0; --i) { + outputStrides[i] = outputStrides[i + 1] * outputDims[i + 1]; + } + + std::vector indices(D, 0); + std::vector outputIndices(D, 0); + + while (true) { + intptr_t inputIndex = 0; + intptr_t outputIndex = 0; + for (int i = 0; i < D; ++i) { + inputIndex += indices[i] * inputStrides[i]; + outputIndex += indices[axes[i]] * outputStrides[i]; + } + outputData[outputIndex] = inputData[inputIndex]; + + int i = D - 1; + while (i >= 0) { + indices[i]++; + if (indices[i] < inputDims[i]) + break; + indices[i] = 0; + i--; + } + if (i < 0) + break; + } +} + // Helper function for applying 2D resize operation on images. inline MemRef Resize2D_Impl(Img *input, INTERPOLATION_TYPE type, @@ -386,6 +449,52 @@ inline MemRef Rotate2D(Img *input, float angle, return output; } +inline MemRef Rotate4D(Img *input, float angle, + ANGLE_TYPE angleType, IMAGE_FORMAT format) { + float angleRad; + + if (angleType == ANGLE_TYPE::DEGREE) + angleRad = M_PI * angle / 180; + else + angleRad = angle; + + float sinAngle = std::sin(angleRad); + float cosAngle = std::cos(angleRad); + + int outputRows, outputCols; + intptr_t sizesOutput[4]; + if (format == IMAGE_FORMAT::NHWC) { + outputRows = std::round(std::abs(input->getSizes()[1] * cosAngle) + + std::abs(input->getSizes()[2] * sinAngle)); + outputCols = std::round(std::abs(input->getSizes()[2] * cosAngle) + + std::abs(input->getSizes()[1] * sinAngle)); + sizesOutput[0] = input->getSizes()[0]; + sizesOutput[1] = outputRows; + sizesOutput[2] = outputCols; + sizesOutput[3] = input->getSizes()[3]; + } else { + // format == IMAGE_FORMAT::NCHW + outputRows = std::round(std::abs(input->getSizes()[2] * cosAngle) + + std::abs(input->getSizes()[3] * sinAngle)); + outputCols = std::round(std::abs(input->getSizes()[3] * cosAngle) + + std::abs(input->getSizes()[2] * sinAngle)); + sizesOutput[0] = input->getSizes()[0]; + sizesOutput[1] = input->getSizes()[1]; + sizesOutput[2] = outputRows; + sizesOutput[3] = outputCols; + } + + MemRef output(sizesOutput); + if (format == IMAGE_FORMAT::NHWC) { + detail::_mlir_ciface_rotate_4d_nhwc(input, angleRad, &output); + } else { + // format == FORMAT_4D_IMAGE::NCHW + detail::_mlir_ciface_rotate_4d_nchw(input, angleRad, &output); + } + + return output; +} + // User interface for 2D Resize. inline MemRef Resize2D(Img *input, INTERPOLATION_TYPE type, std::vector size) { diff --git a/frontend/Interfaces/buddy/DIP/ImgContainer.h b/frontend/Interfaces/buddy/DIP/ImgContainer.h index 82a6ca5ad0..45df0e2b66 100644 --- a/frontend/Interfaces/buddy/DIP/ImgContainer.h +++ b/frontend/Interfaces/buddy/DIP/ImgContainer.h @@ -22,10 +22,10 @@ #define FRONTEND_INTERFACES_BUDDY_DIP_IMGCONTAINER #include "buddy/Core/Container.h" +#include #include #include #include -#include #ifdef BUDDY_ENABLE_PNG #include #endif @@ -53,7 +53,7 @@ struct PaletteBlock { }; // file bmp image palette -void FillPalette(PaletteBlock *palette, int bpp, bool negative = false) { +inline void FillPalette(PaletteBlock *palette, int bpp, bool negative = false) { int i, length = 1 << bpp; int xor_mask = negative ? 255 : 0; @@ -128,7 +128,7 @@ template class Image : public MemRef { }; template -Image::Image(T *data, intptr_t sizes[N]): MemRef(data, sizes) {} +Image::Image(T *data, intptr_t sizes[N]) : MemRef(data, sizes) {} // Image Container Constructor // Constructs an image container object from the image file path. @@ -650,42 +650,42 @@ bool Image::decodePNG(const std::vector &fileData) { } #endif -template -int findFormat(const std::string &_ext){ +template int findFormat(const std::string &_ext) { if (_ext.size() <= 1) return 0; const char *ext = strrchr(_ext.c_str(), '.'); if (!ext) return 0; - + if (strcmp(ext, ".bmp") == 0) { return 1; } else { - std::cerr << "Unsupported to generate" << ext << "format image" << std::endl; + std::cerr << "Unsupported to generate" << ext << "format image" + << std::endl; return 0; } } template -static void imageWrite(const std::string &filename, Image &image){ +static void imageWrite(const std::string &filename, Image &image) { int imformat = findFormat(filename); switch (imformat) { - case 1: - BMPEncode(filename, image); - break; - default: - return; + case 1: + BMPEncode(filename, image); + break; + default: + return; } return; } template -void BMPEncode(const std::string &filename, Image &image){ +void BMPEncode(const std::string &filename, Image &image) { std::ofstream bmp(filename, std::ios::binary); if (!bmp) { - std::cerr << "Failed to create file" << std::endl; - return; + std::cerr << "Failed to create file" << std::endl; + return; } int width = image.getSizes()[3]; int height = image.getSizes()[2]; @@ -705,28 +705,28 @@ void BMPEncode(const std::string &filename, Image &image){ // Write file header bmp.write("BM", 2); int fileSizeInt = validToInt(fileSize); - bmp.write(reinterpret_cast(&fileSizeInt), 4); - bmp.write(reinterpret_cast(&zero), 4); - bmp.write(reinterpret_cast(&headerSize), 4); + bmp.write(reinterpret_cast(&fileSizeInt), 4); + bmp.write(reinterpret_cast(&zero), 4); + bmp.write(reinterpret_cast(&headerSize), 4); // Write bitmap header - bmp.write(reinterpret_cast(&bitmapHeaderSize), 4); - bmp.write(reinterpret_cast(&width), 4); - bmp.write(reinterpret_cast(&height), 4); - bmp.write(reinterpret_cast(&one), 2); + bmp.write(reinterpret_cast(&bitmapHeaderSize), 4); + bmp.write(reinterpret_cast(&width), 4); + bmp.write(reinterpret_cast(&height), 4); + bmp.write(reinterpret_cast(&one), 2); int bitDepth = channels << 3; - bmp.write(reinterpret_cast(&(bitDepth)), 2); - bmp.write(reinterpret_cast(&zero), 4); - bmp.write(reinterpret_cast(&zero), 4); - bmp.write(reinterpret_cast(&zero), 4); - bmp.write(reinterpret_cast(&zero), 4); - bmp.write(reinterpret_cast(&zero), 4); - bmp.write(reinterpret_cast(&zero), 4); + bmp.write(reinterpret_cast(&(bitDepth)), 2); + bmp.write(reinterpret_cast(&zero), 4); + bmp.write(reinterpret_cast(&zero), 4); + bmp.write(reinterpret_cast(&zero), 4); + bmp.write(reinterpret_cast(&zero), 4); + bmp.write(reinterpret_cast(&zero), 4); + bmp.write(reinterpret_cast(&zero), 4); // Write palette if (channels == 1) { FillPalette(palette, 8); - bmp.write(reinterpret_cast(&palette), sizeof(palette)); + bmp.write(reinterpret_cast(&palette), sizeof(palette)); } // Write image data @@ -735,14 +735,15 @@ void BMPEncode(const std::string &filename, Image &image){ for (int y = height - 1; y >= 0; y--) { for (int i = 0; i < width; i++) { for (int t = channels - 1; t >= 0; t--) { - unsigned char pixel= static_cast(data[y * width + i + t * step]); - bmp.write(reinterpret_cast(&pixel), 1); + unsigned char pixel = + static_cast(data[y * width + i + t * step]); + bmp.write(reinterpret_cast(&pixel), 1); } } if (fileStep > width * channels) bmp.write(zeropad, fileStep - width * channels); } - + bmp.close(); } diff --git a/frontend/Interfaces/lib/DIP.mlir b/frontend/Interfaces/lib/DIP.mlir index 3153d1ebe8..29b58cee18 100644 --- a/frontend/Interfaces/lib/DIP.mlir +++ b/frontend/Interfaces/lib/DIP.mlir @@ -42,6 +42,18 @@ func.func @rotate_2d(%inputImage : memref, %angle : f32, %outputImage : return } +func.func @rotate_4d_nhwc(%inputImage : memref, %angle : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +{ + dip.rotate_4d NHWC %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @rotate_4d_nchw(%inputImage : memref, %angle : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +{ + dip.rotate_4d NCHW %inputImage, %angle, %outputImage : memref, f32, memref + return +} + func.func @resize_2d_nearest_neighbour_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} { dip.resize_2d NEAREST_NEIGHBOUR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref diff --git a/midend/include/Dialect/DIP/DIPOps.td b/midend/include/Dialect/DIP/DIPOps.td index b5f928b888..090830780d 100644 --- a/midend/include/Dialect/DIP/DIPOps.td +++ b/midend/include/Dialect/DIP/DIPOps.td @@ -55,11 +55,24 @@ def DIP_InterpolationType : I32EnumAttr<"InterpolationType", let cppNamespace = "::buddy::dip"; } +def DIP_ImageFormat : I32EnumAttr<"ImageFormat", + "Specifies the format of image.", + [ + I32EnumAttrCase<"NHWC", 0, "NHWC">, + I32EnumAttrCase<"NCHW", 1, "NCHW">, + I32EnumAttrCase<"HW", 2, "HW">, + ]>{ + let genSpecializedAttr = 0; + let cppNamespace = "::buddy::dip"; +} + def DIP_BoundaryOptionAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } def DIP_InterpolationAttr : EnumAttr; +def DIP_ImageFormatAttr : EnumAttr; + def DIP_Corr2DOp : DIP_Op<"corr_2d"> { let summary = [{This operation is used for performing 2D correlation on an image. The 2D correlation API provided by the linalg dialect is more suited for @@ -169,6 +182,39 @@ def DIP_Rotate2DOp : DIP_Op<"rotate_2d"> { }]; } +def DIP_Rotate4DOp : DIP_Op<"rotate_4d"> { + let summary = [{This operation intends to provide utility for rotating 4D images via the DIP dialect. + Image rotation has many applications such as data augmentation, alignment adjustment, etc. and + can thus be used in native MLIR pipelines involving above mentioned uses. + + Standard affine rotation matrix [cosθ sinθ] is used whenever |tan(θ/2)| > 8.1 for the user defined θ. + [-sinθ cosθ] + In all other cases, the 3 shear method is used to simplify calculations and minimize generated + artifacts. The equivalent 3 shear matrix combination is as follows : + [1 -tan(θ/2)] * [ 1 0] * [1 -tan(θ/2)] + [0 1 ] [sinθ 1] [0 1 ] + + For example: + + ```mlir + dip.rotate_4d IMAGE_FORMAT %inputImage, %angle, %outputImage : memref, f32, memref + ``` + + where ```IMAGE_FORMAT``` can be ```NHWC``` or ```NCHW```. + }]; + + let arguments = (ins Arg:$memrefI, + F32 : $angle, + Arg:$memrefO, + DIP_ImageFormatAttr:$image_format); + + let assemblyFormat = [{ + $image_format $memrefI `,` $angle `,` $memrefO attr-dict `:` type($memrefI) `,` type($angle) `,` type($memrefO) + }]; +} + def DIP_Resize2DOp : DIP_Op<"resize_2d"> { let summary = [{ diff --git a/midend/include/Utils/AffineTransformUtils.h b/midend/include/Utils/AffineTransformUtils.h index 08f0602d39..85561ca125 100644 --- a/midend/include/Utils/AffineTransformUtils.h +++ b/midend/include/Utils/AffineTransformUtils.h @@ -26,16 +26,22 @@ using namespace mlir; namespace buddy { // Given x*m0+m2(and x*m3+m5) and m1(and m4), compute new x and y, then remap // origin pixels to new pixels -void affineTransformCore(OpBuilder &builder, Location loc, Value input, - Value output, Value yStart, Value yEnd, Value xStart, - Value xEnd, Value m1, Value m4, Value xAddr1, - Value xAddr2, int64_t stride, const int &RSV_BITS, - int interp_type); +void affineTransformCore(OpBuilder &builder, Location loc, MLIRContext *ctx, + Value input, Value output, Value yStart, Value yEnd, + Value xStart, Value xEnd, Value m1, Value m4, + Value xAddr1, Value xAddr2, int64_t stride, + const int &RSV_BITS, int interp_type, + dip::ImageFormat format); // remap using nearest neighbor interpolation -void remapNearest(OpBuilder &builder, Location loc, Value input, Value output, - Value mapInt, Value yStart, Value xStart, Value rows, - Value cols); +void remapNearest2D(OpBuilder &builder, Location loc, MLIRContext *ctx, + Value input, Value output, Value mapInt, Value yStart, + Value xStart, Value rows, Value cols); + +void remapNearest3D(OpBuilder &builder, Location loc, MLIRContext *ctx, + Value input, Value output, Value mapInt, Value yStart, + Value xStart, Value rows, Value cols, + dip::ImageFormat format, Value niv); // remap using bilinear interpolation void remapBilinear(OpBuilder &builder, Location loc, Value input, Value output, diff --git a/midend/include/Utils/DIPUtils.h b/midend/include/Utils/DIPUtils.h index a8b77e8f23..c5bd3104d1 100644 --- a/midend/include/Utils/DIPUtils.h +++ b/midend/include/Utils/DIPUtils.h @@ -115,6 +115,14 @@ void fillPixelsNearestNeighbour4D( // Calculate tan(angle / 2) where angle is a function parameter. Value customTanVal(OpBuilder &builder, Location loc, Value angleVal); +// Calculate the real affine matrix for rotation by +// getting the rotation matrix and modfiying it to +// preserve the full original image . +SmallVector calculateRotationMatrix(OpBuilder &builder, Location loc, + Value inputCol, Value inputRow, + Value outputCol, Value outputRow, + Value angleVal); + // Get affine matrix used in rotation. SmallVector getRotationMatrix(OpBuilder &builder, Location loc, Value centerX, Value centerY, @@ -124,7 +132,7 @@ SmallVector getRotationMatrix(OpBuilder &builder, Location loc, void affineTransformController(OpBuilder &builder, Location loc, MLIRContext *ctx, Value input, Value output, SmallVector affineMatrix, - int64_t stride); + int64_t stride, dip::ImageFormat format); // Controls shear transform application. void shearTransformController( diff --git a/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp b/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp index ebbf3ffa48..ef7fc9b9b0 100644 --- a/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp +++ b/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp @@ -191,8 +191,6 @@ class DIPRotate2DOpLowering : public OpRewritePattern { Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); - Value c1F32 = indexToF32(rewriter, loc, c1); - // Get input image dimensions. Value inputRow = rewriter.create(loc, input, c0); Value inputCol = rewriter.create(loc, input, c1); @@ -201,38 +199,84 @@ class DIPRotate2DOpLowering : public OpRewritePattern { Value outputRow = rewriter.create(loc, output, c0); Value outputCol = rewriter.create(loc, output, c1); - // let alpha = scale * cos(angle), beta = scale * sin(angle) - // the affine matrix would be as follow: - // [[alpha, beta, (1 - alpha) * centerx - beta * centery], - // [-beta, alpha, beta * centerx + (1 - alpha) * centery]] - Value centerX = rewriter.create(loc, inputCol, c1); - Value centerY = rewriter.create(loc, inputRow, c1); - Value centerXF32 = indexToF32(rewriter, loc, centerX); - Value centerYF32 = indexToF32(rewriter, loc, centerY); - - auto affineMatrix = dip::getRotationMatrix(rewriter, loc, centerXF32, - centerYF32, angleVal, c1F32); - - Value deltaXI = rewriter.create(loc, outputCol, inputCol); - Value deltaYI = rewriter.create(loc, outputRow, inputRow); - Value deltaXIDiv2 = rewriter.create(loc, deltaXI, c1); - Value deltaYIDiv2 = rewriter.create(loc, deltaYI, c1); - Value deltaXFDiv2 = indexToF32(rewriter, loc, deltaXIDiv2); - Value deltaYFDiv2 = indexToF32(rewriter, loc, deltaYIDiv2); - - affineMatrix[2] = - rewriter.create(loc, affineMatrix[2], deltaXFDiv2); - affineMatrix[5] = - rewriter.create(loc, affineMatrix[5], deltaYFDiv2); + auto rotationMatrix = dip::calculateRotationMatrix( + rewriter, loc, inputCol, inputRow, outputCol, outputRow, angleVal); + + dip::affineTransformController(rewriter, loc, ctx, input, output, + rotationMatrix, stride, + dip::ImageFormat::HW); + + // Remove the origin rotation operation. + rewriter.eraseOp(op); + return success(); + } + + int64_t stride; +}; + +class DIPRotate4DOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit DIPRotate4DOpLowering(MLIRContext *context, int64_t strideParam) + : OpRewritePattern(context) { + stride = strideParam; + } + + LogicalResult matchAndRewrite(dip::Rotate4DOp op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + + // Register operand values. + Value input = op->getOperand(0); + Value angleValue = op->getOperand(1); + Value output = op->getOperand(2); + auto imageFormatAttr = op.getImageFormat(); + + auto inElemTy = input.getType().cast().getElementType(); + dip::DIP_ERROR error = + dip::checkDIPCommonTypes(op, {input, output}); + + if (error == dip::DIP_ERROR::INCONSISTENT_TYPES) { + return op->emitOpError() + << "input, and output must have the same element type"; + } else if (error == dip::DIP_ERROR::UNSUPPORTED_TYPE) { + return op->emitOpError() << "supports only f32, f64 and integer types. " + << inElemTy << "is passed"; + } + + // Create constant indices. + Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); + + // Get image dimensions. + Value inputRow, inputCol, outputRow, outputCol; + if (imageFormatAttr == dip::ImageFormat::NHWC) { + inputRow = rewriter.create(loc, input, c1); + inputCol = rewriter.create(loc, input, c2); + outputRow = rewriter.create(loc, output, c1); + outputCol = rewriter.create(loc, output, c2); + } else if (imageFormatAttr == dip::ImageFormat::NCHW) { + inputRow = rewriter.create(loc, input, c2); + inputCol = rewriter.create(loc, input, c3); + outputRow = rewriter.create(loc, output, c2); + outputCol = rewriter.create(loc, output, c3); + } + + auto rotationMatrix = dip::calculateRotationMatrix( + rewriter, loc, inputCol, inputRow, outputCol, outputRow, angleValue); dip::affineTransformController(rewriter, loc, ctx, input, output, - affineMatrix, stride); + rotationMatrix, stride, imageFormatAttr); // Remove the origin rotation operation. rewriter.eraseOp(op); return success(); } +private: int64_t stride; }; @@ -1589,6 +1633,7 @@ void populateLowerDIPConversionPatterns(RewritePatternSet &patterns, patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); + patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); diff --git a/midend/lib/Utils/AffineTransformUtils.cpp b/midend/lib/Utils/AffineTransformUtils.cpp index 9b3bfc2839..54f328c5a2 100644 --- a/midend/lib/Utils/AffineTransformUtils.cpp +++ b/midend/lib/Utils/AffineTransformUtils.cpp @@ -30,6 +30,8 @@ #include #include +#include "DIP/DIPDialect.h" +#include "DIP/DIPOps.h" #include "Utils/AffineTransformUtils.h" #include "Utils/Utils.h" @@ -38,21 +40,21 @@ using namespace mlir; namespace buddy { // compute core(tiled) void affineTransformCoreTiled(OpBuilder &builder, Location loc, - Value resIntPart, Value resFracPart, Value yStart, - Value yEnd, Value xStart, Value xEnd, Value m1, - Value m4, Value xAddr1, Value xAddr2, - Value rsvValVec, Value strideVal, Value c0, - Value c1, Value c_rsv, int64_t stride) { + Value resIntPart, Value yStart, Value yEnd, + Value xStart, Value xEnd, Value m1, Value m4, + Value xAddr1, Value xAddr2, Value rsvValVec, + Value strideVal, Value c0, Value c1, Value c_rsv, + int64_t stride) { VectorType vectorTyI32 = VectorType::get({stride}, IntegerType::get(builder.getContext(), 32)); VectorType vectorTyI16 = VectorType::get({stride}, IntegerType::get(builder.getContext(), 16)); - VectorType vectorTyI8 = - VectorType::get({stride}, IntegerType::get(builder.getContext(), 8)); + builder.create( loc, yStart, yEnd, c1, std::nullopt, [&](OpBuilder &yBuilder, Location yLoc, Value yiv, ValueRange) { Value yOffset = yBuilder.create(yLoc, yiv, yStart); + Value yF32 = indexToF32(yBuilder, yLoc, yiv); Value yF32_0 = yBuilder.create(yLoc, yF32, m1); Value yF32_1 = yBuilder.create(yLoc, yF32, m4); @@ -62,6 +64,7 @@ void affineTransformCoreTiled(OpBuilder &builder, Location loc, yF32_0_rsv); Value y1 = yBuilder.create(yLoc, yBuilder.getI32Type(), yF32_1_rsv); + Value y0Vec = yBuilder.create(yLoc, vectorTyI32, y0); Value y1Vec = yBuilder.create(yLoc, vectorTyI32, y1); @@ -73,34 +76,30 @@ void affineTransformCoreTiled(OpBuilder &builder, Location loc, xAddr1, xiv); Value x1Vec = xBuilder.create(xLoc, vectorTyI32, xAddr2, xiv); + Value srcXVec = xBuilder.create(xLoc, x0Vec, y0Vec); Value srcYVec = xBuilder.create(xLoc, x1Vec, y1Vec); - Value srcXVecFrac = - xBuilder.create(loc, vectorTyI8, srcXVec); - Value srcYVecFrac = - xBuilder.create(loc, vectorTyI8, srcYVec); - xBuilder.create( - loc, srcXVecFrac, resFracPart, - ValueRange{c0, yOffset, xOffset}); - xBuilder.create( - loc, srcYVecFrac, resFracPart, - ValueRange{c1, yOffset, xOffset}); + Value srcXVecShifted = - xBuilder.create(loc, srcXVec, rsvValVec); + xBuilder.create(xLoc, srcXVec, rsvValVec); Value srcYVecShifted = - xBuilder.create(loc, srcYVec, rsvValVec); + xBuilder.create(xLoc, srcYVec, rsvValVec); Value srcXVecInt = xBuilder.create( - loc, vectorTyI16, srcXVecShifted); + xLoc, vectorTyI16, srcXVecShifted); Value srcYVecInt = xBuilder.create( - loc, vectorTyI16, srcYVecShifted); - xBuilder.create( - loc, srcXVecInt, resIntPart, - ValueRange{c0, yOffset, xOffset}); + xLoc, vectorTyI16, srcYVecShifted); + + SmallVector maskVec; + for (int i = 0; i < stride; i++) { + maskVec.push_back(i); + maskVec.push_back(i + stride); + } + Value res2Store = xBuilder.create( + loc, srcXVecInt, srcYVecInt, maskVec); xBuilder.create( - loc, srcYVecInt, resIntPart, - ValueRange{c1, yOffset, xOffset}); + loc, res2Store, resIntPart, ValueRange{yOffset, xOffset, c0}); xBuilder.create(xLoc); }); @@ -109,13 +108,12 @@ void affineTransformCoreTiled(OpBuilder &builder, Location loc, }); } -// Given x*m0+m2(and x*m3+m5) and m1(and m4), compute new x and y, then remap -// origin pixels to new pixels -void affineTransformCore(OpBuilder &builder, Location loc, Value input, - Value output, Value yStart, Value yEnd, Value xStart, - Value xEnd, Value m1, Value m4, Value xAddr1, - Value xAddr2, int64_t stride, const int &RSV_BITS, - int interp_type) { +void affineTransformCore(OpBuilder &builder, Location loc, MLIRContext *ctx, + Value input, Value output, Value yStart, Value yEnd, + Value xStart, Value xEnd, Value m1, Value m4, + Value xAddr1, Value xAddr2, int64_t stride, + const int &RSV_BITS, int interp_type, + dip::ImageFormat format) { Value c0 = builder.create(loc, 0); Value c1 = builder.create(loc, 1); Value c_rsv = builder.create( @@ -131,52 +129,92 @@ void affineTransformCore(OpBuilder &builder, Location loc, Value input, // TODO: auto config BLOCK_SZ by input type. float->32, uchar->64 #define BLOCK_SZ 32 MemRefType resIntPartType = - MemRefType::get({2, BLOCK_SZ / 2, BLOCK_SZ * 2}, + MemRefType::get({BLOCK_SZ / 2, BLOCK_SZ * 2, 2}, IntegerType::get(builder.getContext(), 16)); - MemRefType resFracPartType = - MemRefType::get({2, BLOCK_SZ / 2, BLOCK_SZ * 2}, - IntegerType::get(builder.getContext(), 8)); - Value resIntPart = builder.create(loc, resIntPartType); - Value resFracPart = builder.create(loc, resFracPartType); + Value resIntPart = builder.create(loc, resIntPartType); Value rowStride = builder.create(loc, BLOCK_SZ / 2); Value colStride = builder.create(loc, BLOCK_SZ * 2); #undef BLOCK_SZ - builder.create( - loc, yStart, yEnd, rowStride, std::nullopt, - [&](OpBuilder &yBuilder, Location yLoc, Value yiv, ValueRange) { - Value realYEnd = yBuilder.create( - yLoc, yEnd, yBuilder.create(yLoc, yiv, rowStride)); - Value rows = yBuilder.create(yLoc, realYEnd, yiv); - yBuilder.create( - yLoc, xStart, xEnd, colStride, std::nullopt, - [&](OpBuilder &xBuilder, Location xLoc, Value xiv, ValueRange) { - Value realXEnd = xBuilder.create( - xLoc, xEnd, - xBuilder.create(xLoc, xiv, colStride)); - Value cols = xBuilder.create(xLoc, realXEnd, xiv); - affineTransformCoreTiled(xBuilder, xLoc, resIntPart, resFracPart, - yiv, realYEnd, xiv, realXEnd, m1, m4, - xAddr1, xAddr2, rsvValVec, strideVal, c0, - c1, c_rsv, stride); - - // remap - remapNearest(xBuilder, xLoc, input, output, resIntPart, yiv, xiv, - rows, cols); + if (format == dip::ImageFormat::HW) { + builder.create( + loc, yStart, yEnd, rowStride, std::nullopt, + [&](OpBuilder &yBuilder, Location yLoc, Value yiv, ValueRange) { + Value realYEnd = yBuilder.create( + yLoc, yEnd, yBuilder.create(yLoc, yiv, rowStride)); + Value rows = yBuilder.create(yLoc, realYEnd, yiv); - xBuilder.create(xLoc); - }); - yBuilder.create(yLoc); - }); + yBuilder.create( + yLoc, xStart, xEnd, colStride, std::nullopt, + [&](OpBuilder &xBuilder, Location xLoc, Value xiv, ValueRange) { + Value realXEnd = xBuilder.create( + xLoc, xEnd, + xBuilder.create(xLoc, xiv, colStride)); + Value cols = + xBuilder.create(xLoc, realXEnd, xiv); + + affineTransformCoreTiled(xBuilder, xLoc, resIntPart, yiv, + realYEnd, xiv, realXEnd, m1, m4, + xAddr1, xAddr2, rsvValVec, strideVal, + c0, c1, c_rsv, stride); + // remap + remapNearest2D(xBuilder, xLoc, ctx, input, output, resIntPart, + yiv, xiv, rows, cols); + + xBuilder.create(xLoc); + }); + yBuilder.create(yLoc); + }); + + } else if (format == dip::ImageFormat::NCHW || + format == dip::ImageFormat::NHWC) { + Value inputBatch = builder.create(loc, input, c0); + builder.create( + loc, c0, inputBatch, c1, std::nullopt, + [&](OpBuilder &nBuilder, Location nLoc, Value niv, ValueRange) { + nBuilder.create( + loc, yStart, yEnd, rowStride, std::nullopt, + [&](OpBuilder &yBuilder, Location yLoc, Value yiv, ValueRange) { + Value realYEnd = yBuilder.create( + yLoc, yEnd, + yBuilder.create(yLoc, yiv, rowStride)); + Value rows = + yBuilder.create(yLoc, realYEnd, yiv); + + yBuilder.create( + yLoc, xStart, xEnd, colStride, std::nullopt, + [&](OpBuilder &xBuilder, Location xLoc, Value xiv, + ValueRange) { + Value realXEnd = xBuilder.create( + xLoc, xEnd, + xBuilder.create(xLoc, xiv, colStride)); + Value cols = + xBuilder.create(xLoc, realXEnd, xiv); + + affineTransformCoreTiled( + xBuilder, xLoc, resIntPart, yiv, realYEnd, xiv, + realXEnd, m1, m4, xAddr1, xAddr2, rsvValVec, + strideVal, c0, c1, c_rsv, stride); + // remap + remapNearest3D(xBuilder, xLoc, ctx, input, output, + resIntPart, yiv, xiv, rows, cols, format, + niv); + + xBuilder.create(xLoc); + }); + yBuilder.create(yLoc); + }); + nBuilder.create(nLoc); + }); + } builder.create(loc, resIntPart); - builder.create(loc, resFracPart); } -void remapNearest(OpBuilder &builder, Location loc, Value input, Value output, - Value mapInt, Value yStart, Value xStart, Value rows, - Value cols) { +void remapNearest2D(OpBuilder &builder, Location loc, MLIRContext *ctx, + Value input, Value output, Value mapInt, Value yStart, + Value xStart, Value rows, Value cols) { Value c0 = builder.create(loc, 0); Value c1 = builder.create(loc, 1); Value inputRow = builder.create(loc, input, c0); @@ -185,14 +223,16 @@ void remapNearest(OpBuilder &builder, Location loc, Value input, Value output, loc, c0, rows, c1, std::nullopt, [&](OpBuilder &yBuilder, Location yLoc, Value yiv, ValueRange) { Value dstY = yBuilder.create(yLoc, yiv, yStart); + yBuilder.create( yLoc, c0, cols, c1, std::nullopt, [&](OpBuilder &xBuilder, Location xLoc, Value xiv, ValueRange) { Value dstX = xBuilder.create(xLoc, xiv, xStart); Value srcXI16 = xBuilder.create( - xLoc, mapInt, ValueRange{c0, yiv, xiv}); + xLoc, mapInt, ValueRange{yiv, xiv, c0}); Value srcYI16 = xBuilder.create( - xLoc, mapInt, ValueRange{c1, yiv, xiv}); + xLoc, mapInt, ValueRange{yiv, xiv, c1}); + Value srcX = xBuilder.create( xLoc, IndexType::get(xBuilder.getContext()), srcXI16); Value srcY = xBuilder.create( @@ -203,12 +243,21 @@ void remapNearest(OpBuilder &builder, Location loc, Value input, Value output, xBuilder.create(xLoc, xInBound, yInBound); xBuilder.create( xLoc, pixelInBound, - [&](OpBuilder &ifBuilder, Location ifLoc) { - Value pixel = ifBuilder.create( - ifLoc, input, ValueRange{srcY, srcX}); - ifBuilder.create(ifLoc, pixel, output, - ValueRange{dstY, dstX}); - ifBuilder.create(ifLoc); + [&](OpBuilder &thenBuilder, Location thenLoc) { + Value pixel = thenBuilder.create( + thenLoc, input, ValueRange{srcY, srcX}); + thenBuilder.create(thenLoc, pixel, output, + ValueRange{dstY, dstX}); + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + auto inElemTy = + input.getType().cast().getElementType(); + Value pixel = insertZeroConstantOp(ctx, elseBuilder, + elseLoc, inElemTy); + elseBuilder.create(elseLoc, pixel, output, + ValueRange{dstY, dstX}); + elseBuilder.create(elseLoc); }); xBuilder.create(xLoc); @@ -217,4 +266,214 @@ void remapNearest(OpBuilder &builder, Location loc, Value input, Value output, yBuilder.create(yLoc); }); } + +void remapNearest3D(OpBuilder &builder, Location loc, MLIRContext *ctx, + Value input, Value output, Value mapInt, Value yStart, + Value xStart, Value rows, Value cols, + dip::ImageFormat format, Value niv) { + Value c0 = builder.create(loc, 0); + Value c1 = builder.create(loc, 1); + Value c2 = builder.create(loc, 2); + Value c3 = builder.create(loc, 3); + + Value inputRow, inputCol, inputChannel; + if (format == dip::ImageFormat::NHWC) { + inputRow = builder.create(loc, input, c1); + inputCol = builder.create(loc, input, c2); + inputChannel = builder.create(loc, input, c3); + } else if (format == dip::ImageFormat::NCHW) { + inputRow = builder.create(loc, input, c2); + inputCol = builder.create(loc, input, c3); + inputChannel = builder.create(loc, input, c1); + } + Value is3Channel = builder.create( + loc, arith::CmpIPredicate::eq, inputChannel, c3); + + builder.create( + loc, c0, rows, c1, std::nullopt, + [&](OpBuilder &yBuilder, Location yLoc, Value yiv, ValueRange) { + Value dstY = yBuilder.create(yLoc, yiv, yStart); + + yBuilder.create( + yLoc, is3Channel, + [&](OpBuilder &thenBuilder, Location thenLoc) { + // 3 channels, make common case fast + thenBuilder.create( + thenLoc, c0, cols, c1, std::nullopt, + [&](OpBuilder &xBuilder, Location xLoc, Value xiv, + ValueRange) { + Value dstX = + xBuilder.create(xLoc, xiv, xStart); + Value srcXI16 = xBuilder.create( + xLoc, mapInt, ValueRange{yiv, xiv, c0}); + Value srcYI16 = xBuilder.create( + xLoc, mapInt, ValueRange{yiv, xiv, c1}); + + Value srcX = xBuilder.create( + xLoc, IndexType::get(xBuilder.getContext()), srcXI16); + Value srcY = xBuilder.create( + xLoc, IndexType::get(xBuilder.getContext()), srcYI16); + Value xInBound = + inBound(xBuilder, xLoc, srcX, c0, inputCol); + Value yInBound = + inBound(xBuilder, xLoc, srcY, c0, inputRow); + Value pixelInBound = xBuilder.create( + xLoc, xInBound, yInBound); + xBuilder.create( + xLoc, pixelInBound, + [&](OpBuilder &thenBuilder, Location thenLoc) { + if (format == dip::ImageFormat::NCHW) { + Value srcC0 = thenBuilder.create( + thenLoc, input, + ValueRange{niv, c0, srcY, srcX}); + Value srcC1 = thenBuilder.create( + thenLoc, input, + ValueRange{niv, c1, srcY, srcX}); + Value srcC2 = thenBuilder.create( + thenLoc, input, + ValueRange{niv, c2, srcY, srcX}); + thenBuilder.create( + thenLoc, srcC0, output, + ValueRange{niv, c0, dstY, dstX}); + thenBuilder.create( + thenLoc, srcC1, output, + ValueRange{niv, c1, dstY, dstX}); + thenBuilder.create( + thenLoc, srcC2, output, + ValueRange{niv, c2, dstY, dstX}); + } else if (format == dip::ImageFormat::NHWC) { + Value srcC0 = thenBuilder.create( + thenLoc, input, + ValueRange{niv, srcY, srcX, c0}); + Value srcC1 = thenBuilder.create( + thenLoc, input, + ValueRange{niv, srcY, srcX, c1}); + Value srcC2 = thenBuilder.create( + thenLoc, input, + ValueRange{niv, srcY, srcX, c2}); + thenBuilder.create( + thenLoc, srcC0, output, + ValueRange{niv, dstY, dstX, c0}); + thenBuilder.create( + thenLoc, srcC1, output, + ValueRange{niv, dstY, dstX, c1}); + thenBuilder.create( + thenLoc, srcC2, output, + ValueRange{niv, dstY, dstX, c2}); + } + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + auto inElemTy = input.getType() + .cast() + .getElementType(); + Value pixel = insertZeroConstantOp(ctx, elseBuilder, + elseLoc, inElemTy); + if (format == dip::ImageFormat::NCHW) { + elseBuilder.create( + elseLoc, pixel, output, + ValueRange{niv, c0, dstY, dstX}); + elseBuilder.create( + elseLoc, pixel, output, + ValueRange{niv, c1, dstY, dstX}); + elseBuilder.create( + elseLoc, pixel, output, + ValueRange{niv, c2, dstY, dstX}); + } else if (format == dip::ImageFormat::NHWC) { + elseBuilder.create( + elseLoc, pixel, output, + ValueRange{niv, dstY, dstX, c0}); + elseBuilder.create( + elseLoc, pixel, output, + ValueRange{niv, dstY, dstX, c1}); + elseBuilder.create( + elseLoc, pixel, output, + ValueRange{niv, dstY, dstX, c2}); + } + elseBuilder.create(elseLoc); + }); + thenBuilder.create(thenLoc); + }); + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + elseBuilder.create( + elseLoc, c0, cols, c1, std::nullopt, + [&](OpBuilder &xBuilder, Location xLoc, Value xiv, + ValueRange) { + Value dstX = + xBuilder.create(xLoc, xiv, xStart); + Value srcXI16 = xBuilder.create( + xLoc, mapInt, ValueRange{yiv, xiv, c0}); + Value srcYI16 = xBuilder.create( + xLoc, mapInt, ValueRange{yiv, xiv, c1}); + + Value srcX = xBuilder.create( + xLoc, IndexType::get(xBuilder.getContext()), srcXI16); + Value srcY = xBuilder.create( + xLoc, IndexType::get(xBuilder.getContext()), srcYI16); + Value xInBound = + inBound(xBuilder, xLoc, srcX, c0, inputCol); + Value yInBound = + inBound(xBuilder, xLoc, srcY, c0, inputRow); + Value pixelInBound = xBuilder.create( + xLoc, xInBound, yInBound); + xBuilder.create( + xLoc, pixelInBound, + [&](OpBuilder &thenBuilder, Location thenLoc) { + thenBuilder.create( + thenLoc, c0, inputChannel, c1, std::nullopt, + [&](OpBuilder &cBuilder, Location cLoc, Value civ, + ValueRange) { + if (format == dip::ImageFormat::NCHW) { + Value srcC = cBuilder.create( + cLoc, input, + ValueRange{niv, civ, srcY, srcX}); + cBuilder.create( + cLoc, srcC, output, + ValueRange{niv, civ, dstY, dstX}); + } else if (format == dip::ImageFormat::NHWC) { + Value srcC = cBuilder.create( + cLoc, input, + ValueRange{niv, srcY, srcX, civ}); + cBuilder.create( + cLoc, srcC, output, + ValueRange{niv, dstY, dstX, civ}); + } + cBuilder.create(cLoc); + }); + thenBuilder.create(elseLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + auto inElemTy = input.getType() + .cast() + .getElementType(); + Value pixel = insertZeroConstantOp(ctx, elseBuilder, + elseLoc, inElemTy); + elseBuilder.create( + elseLoc, c0, inputChannel, c1, std::nullopt, + [&](OpBuilder &cBuilder, Location cLoc, Value civ, + ValueRange) { + if (format == dip::ImageFormat::NCHW) { + cBuilder.create( + cLoc, pixel, output, + ValueRange{niv, civ, dstY, dstX}); + } else if (format == dip::ImageFormat::NHWC) { + cBuilder.create( + cLoc, pixel, output, + ValueRange{niv, dstY, dstX, civ}); + } + cBuilder.create(cLoc); + }); + elseBuilder.create(elseLoc); + }); + xBuilder.create(xLoc); + }); + elseBuilder.create(elseLoc); + }); + + yBuilder.create(yLoc); + }); +} + } // namespace buddy diff --git a/midend/lib/Utils/DIPUtils.cpp b/midend/lib/Utils/DIPUtils.cpp index da41b65cd6..d68451bb51 100644 --- a/midend/lib/Utils/DIPUtils.cpp +++ b/midend/lib/Utils/DIPUtils.cpp @@ -50,6 +50,11 @@ checkDIPCommonTypes(dip::Corr2DOp, template DIP_ERROR checkDIPCommonTypes(dip::Rotate2DOp, const std::vector &args); + +template DIP_ERROR +checkDIPCommonTypes(dip::Rotate4DOp, + const std::vector &args); + template DIP_ERROR checkDIPCommonTypes(dip::Resize2DOp, const std::vector &args); @@ -116,6 +121,7 @@ DIP_ERROR checkDIPCommonTypes(DIPOP op, const std::vector &args) { return DIP_ERROR::UNSUPPORTED_TYPE; } } else if (op->getName().stripDialect() == "rotate_2d" || + op->getName().stripDialect() == "rotate_4d" || op->getName().stripDialect() == "resize_2d" || op->getName().stripDialect() == "resize_4d_nhwc" || op->getName().stripDialect() == "resize_4d_nchw") { @@ -459,6 +465,47 @@ Value customTanVal(OpBuilder &builder, Location loc, Value angleVal) { return builder.create(loc, sinVal, cosVal); } +// Calculate the real affine matrix for rotation by +// getting the rotation matrix and modfiying it to +// preserve the full original image . +SmallVector calculateRotationMatrix(OpBuilder &builder, Location loc, + Value inputCol, Value inputRow, + Value outputCol, Value outputRow, + Value angleVal) { + Value c1 = builder.create(loc, 1); + + // let alpha = scale * cos(angle), beta = scale * sin(angle) + // the affine matrix would be as follow: + // [[alpha, beta, (1 - alpha) * centerx - beta * centery], + // [-beta, alpha, beta * centerx + (1 - alpha) * centery]] + Value centerX = builder.create(loc, inputCol, c1); + Value centerY = builder.create(loc, inputRow, c1); + Value centerXF32 = indexToF32(builder, loc, centerX); + Value centerYF32 = indexToF32(builder, loc, centerY); + + // scaling ratio = 1. + Value scale = indexToF32(builder, loc, c1); + + auto affineMatrix = dip::getRotationMatrix(builder, loc, centerXF32, + centerYF32, angleVal, scale); + + // modify the affine matrix to preserve the full original + // image after rotation + Value deltaXI = builder.create(loc, outputCol, inputCol); + Value deltaYI = builder.create(loc, outputRow, inputRow); + Value deltaXIDiv2 = builder.create(loc, deltaXI, c1); + Value deltaYIDiv2 = builder.create(loc, deltaYI, c1); + Value deltaXFDiv2 = indexToF32(builder, loc, deltaXIDiv2); + Value deltaYFDiv2 = indexToF32(builder, loc, deltaYIDiv2); + + affineMatrix[2] = + builder.create(loc, affineMatrix[2], deltaXFDiv2); + affineMatrix[5] = + builder.create(loc, affineMatrix[5], deltaYFDiv2); + + return affineMatrix; +} + // Get affine matrix used in rotation. SmallVector getRotationMatrix(OpBuilder &builder, Location loc, Value centerX, Value centerY, @@ -544,7 +591,7 @@ inline void inverseAffineMatrix(OpBuilder &builder, Location loc, void affineTransformController(OpBuilder &builder, Location loc, MLIRContext *ctx, Value input, Value output, SmallVector affineMatrix, - int64_t stride) { + int64_t stride, dip::ImageFormat format) { VectorType vectorTyF32 = VectorType::get({stride}, FloatType::getF32(ctx)); VectorType vectorTyI32 = VectorType::get({stride}, IntegerType::get(ctx, 32)); @@ -562,8 +609,19 @@ void affineTransformController(OpBuilder &builder, Location loc, Value m5Vec = builder.create(loc, vectorTyF32, affineMatrix[5]); - Value outputRow = builder.create(loc, output, c0Index); - Value outputCol = builder.create(loc, output, c1Index); + // get the output image dimensions + int dimIndex = -1; + if (format == dip::ImageFormat::HW) { + dimIndex = 0; + } else if (format == dip::ImageFormat::NHWC) { + dimIndex = 1; + } else if (format == dip::ImageFormat::NCHW) { + dimIndex = 2; + } + Value rowIndex = builder.create(loc, dimIndex); + Value colIndex = builder.create(loc, dimIndex + 1); + Value outputRow = builder.create(loc, output, rowIndex); + Value outputCol = builder.create(loc, output, colIndex); Value strideVal = builder.create(loc, stride); Value outputColStrideRatio = @@ -630,9 +688,9 @@ void affineTransformController(OpBuilder &builder, Location loc, builderFor.create(locFor); }); - affineTransformCore(builder, loc, input, output, c0Index, outputRow, c0Index, - outputCol, affineMatrix[1], affineMatrix[4], xMm0, xMm3, - stride, RSV_BITS, 0); + affineTransformCore(builder, loc, ctx, input, output, c0Index, outputRow, + c0Index, outputCol, affineMatrix[1], affineMatrix[4], + xMm0, xMm3, stride, RSV_BITS, 0, format); builder.create(loc, xMm0); builder.create(loc, xMm3); diff --git a/tests/Dialect/DIP/rotate4D_lowering.mlir b/tests/Dialect/DIP/rotate4D_lowering.mlir new file mode 100644 index 0000000000..ad949ba90b --- /dev/null +++ b/tests/Dialect/DIP/rotate4D_lowering.mlir @@ -0,0 +1,14 @@ +// RUN: buddy-opt -verify-diagnostics %s | buddy-opt | FileCheck %s + + +func.func @buddy_rotate4d_nhwc_f32(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NHWC {{.*}} : memref, f32, memref + dip.rotate_4d NHWC %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nchw(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NCHW {{.*}} : memref, f32, memref + dip.rotate_4d NCHW %inputImage, %angle, %outputImage : memref, f32, memref + return +} diff --git a/tests/Dialect/DIP/rotate4D_roundtrip.mlir b/tests/Dialect/DIP/rotate4D_roundtrip.mlir new file mode 100644 index 0000000000..e35ae88243 --- /dev/null +++ b/tests/Dialect/DIP/rotate4D_roundtrip.mlir @@ -0,0 +1,62 @@ +// RUN: buddy-opt -verify-diagnostics %s | buddy-opt | FileCheck %s + + +func.func @buddy_rotate4d_nhwc_f32(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NHWC {{.*}} : memref, f32, memref + dip.rotate_4d NHWC %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nhwc_f64(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NHWC {{.*}} : memref, f32, memref + dip.rotate_4d NHWC %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nhwc_i8(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NHWC {{.*}} : memref, f32, memref + dip.rotate_4d NHWC %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nhwc_i32(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NHWC {{.*}} : memref, f32, memref + dip.rotate_4d NHWC %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nhwc_i64(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NHWC {{.*}} : memref, f32, memref + dip.rotate_4d NHWC %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nchw_f32(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NCHW {{.*}} : memref, f32, memref + dip.rotate_4d NCHW %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nchw_f64(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NCHW {{.*}} : memref, f32, memref + dip.rotate_4d NCHW %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nchw_i8(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NCHW {{.*}} : memref, f32, memref + dip.rotate_4d NCHW %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nchw_i32(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NCHW {{.*}} : memref, f32, memref + dip.rotate_4d NCHW %inputImage, %angle, %outputImage : memref, f32, memref + return +} + +func.func @buddy_rotate4d_nchw_i64(%inputImage : memref, %angle : f32, %outputImage : memref) -> () { + // CHECK: dip.rotate_4d NCHW {{.*}} : memref, f32, memref + dip.rotate_4d NCHW %inputImage, %angle, %outputImage : memref, f32, memref + return +} From a6542f26a0e26c3554e95393ced9ef2ad8d6ba7b Mon Sep 17 00:00:00 2001 From: Hongbin Zhang Date: Wed, 5 Mar 2025 17:39:26 +0800 Subject: [PATCH 09/13] [CI] Fix LLVM version bumping issue in CI pipeline. (#475) --- .github/workflows/TestBuild.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/TestBuild.yml b/.github/workflows/TestBuild.yml index 28395a06e0..a314db74bd 100644 --- a/.github/workflows/TestBuild.yml +++ b/.github/workflows/TestBuild.yml @@ -44,7 +44,9 @@ jobs: # 5. If the cache is not found, pull the LLVM submodule. - name: Checkout LLVM submodule - run: git submodule update --init --recursive llvm + run: | + rm -rf llvm + git submodule update --init --recursive llvm if: steps.cache-llvm-source.outputs.cache-hit != 'true' # 6. Cache the LLVM build directory. @@ -76,7 +78,7 @@ jobs: if: steps.cache-llvm-build-dir.outputs.cache-hit != 'true' # 8. Check the build process of buddy-mlir repository. - - name: Check budddy-mlir build + - name: Check buddy-mlir build run: | source ~/miniconda3/bin/activate buddy pip install -r requirements.txt From c36dafe8cfe937b9555bd0cbed991c5635adf96c Mon Sep 17 00:00:00 2001 From: Junyi Mei Date: Thu, 6 Mar 2025 00:01:26 +0800 Subject: [PATCH 10/13] [midend] Support scalable vector for matmul-transpos-b vectorization (#471) * [midend] Support scalable vector for matmul-transpos-b vectorization * [midend] Fix matmul-transpose-b scalable vector problem * [Test] Disable scalable in matmul-transpose-b vectorization testcase --- .../MatMulTransposeBVec.cpp | 93 +++++++++++++------ .../matmul-transpose-b-vectorization.mlir | 2 +- 2 files changed, 68 insertions(+), 27 deletions(-) diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp index 2f81414a33..91ef00d4d2 100644 --- a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp +++ b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp @@ -43,11 +43,12 @@ using namespace vector; namespace { class MatMulTransposeBVecPattern : public ConversionPattern { public: - explicit MatMulTransposeBVecPattern(MLIRContext *context, - int64_t vecSizeparam) + explicit MatMulTransposeBVecPattern(MLIRContext *context, int64_t vfParam, + bool scalableParam) : ConversionPattern(linalg::MatmulTransposeBOp::getOperationName(), 1, context) { - vecSize = vecSizeparam; + vf = vfParam; + scalable = scalableParam; } LogicalResult @@ -67,14 +68,18 @@ class MatMulTransposeBVecPattern : public ConversionPattern { // the element type for mask vector. IntegerType i1 = IntegerType::get(ctx, 1); - VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); - VectorType vectorMaskTy = VectorType::get({vecSize}, i1); + VectorType vectorTy = mlir::VectorType::get({vf}, eleTy, {scalable}); + VectorType vectorMaskTy = VectorType::get({vf}, i1, {scalable}); const Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); const Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); - const Value step = rewriter.create(loc, vecSize); + Value step = rewriter.create(loc, vf); + if (scalable) { + Value vscale = rewriter.create(loc); + step = rewriter.create(loc, step, vscale); + } const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); @@ -83,9 +88,20 @@ class MatMulTransposeBVecPattern : public ConversionPattern { const Value bRow = rewriter.create(loc, B, c0); const Value bCol = rewriter.create(loc, B, c1); + AffineMap vecTailMap; AffineExpr d0; bindDims(ctx, d0); - AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); + if (scalable) { + auto s0 = getAffineSymbolExpr(0, ctx); + vecTailMap = AffineMap::get(1, 1, {d0.ceilDiv(s0)}, ctx); + } else { + vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vf)}, ctx); + } + SmallVector innerUpperBoundOperands{bCol}; + if (scalable) { + innerUpperBoundOperands.push_back(step); + } + SmallVector lowerBounds(2, c0); SmallVector uperBounds{aRow, bRow}; SmallVector steps(2, 1); @@ -96,21 +112,34 @@ class MatMulTransposeBVecPattern : public ConversionPattern { // Create loop based on vector size. auto innerLoop = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), - ValueRange{bCol}, vecTailMap, 1, ValueRange{passthruVec}, + innerUpperBoundOperands, vecTailMap, 1, ValueRange{passthruVec}, [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange itrArgs) { Value acc = itrArgs[0]; - AffineExpr a, b, c; - bindDims(ctx, a, b, c); - AffineMap AVectorMap = AffineMap::get( - /*dimCount=*/3, /*symbolCount=*/0, {a, c * vecSize}, ctx); + AffineExpr a, b; + bindDims(ctx, a, b); + AffineMap AVectorMap; + if (scalable) { + auto s0 = getAffineSymbolExpr(0, ctx); + AVectorMap = AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/1, {a, b * s0}, ctx); + } else { + AVectorMap = AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/0, {a, b * vf}, ctx); + } // Check tail. - AffineExpr m, n, k; - bindDims(ctx, m, n, k); - AffineMap BVectorMap = AffineMap::get( - /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); - + AffineExpr m, k; + bindDims(ctx, m, k); + AffineMap BVectorMap; + if (scalable) { + auto s0 = getAffineSymbolExpr(0, ctx); + BVectorMap = AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/1, {m, k * s0}, ctx); + } else { + BVectorMap = AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/0, {m, k * vf}, ctx); + } // Calculate the tail. Value bColCur = builder.create(loc, iv, step); Value tailLen = @@ -121,12 +150,16 @@ class MatMulTransposeBVecPattern : public ConversionPattern { auto ifOp = builder.create( loc, tailFlag, [&](OpBuilder &builder, Location loc) { + SmallVector aVecMapOperands{ivs[0], iv}; + SmallVector bVecMapOperands{ivs[1], iv}; + if (scalable) { + aVecMapOperands.push_back(step); + bVecMapOperands.push_back(step); + } Value aVec = builder.create( - loc, vectorTy, A, AVectorMap, - ValueRange{ivs[0], ivs[1], iv}); + loc, vectorTy, A, AVectorMap, aVecMapOperands); Value bVec = builder.create( - loc, vectorTy, B, BVectorMap, - ValueRange{ivs[1], ivs[1], iv}); + loc, vectorTy, B, BVectorMap, bVecMapOperands); Value resvec = builder.create(loc, aVec, bVec); Value newAcc = @@ -172,7 +205,11 @@ class MatMulTransposeBVecPattern : public ConversionPattern { } private: - int64_t vecSize; + /// Vectorization factor. This is the vector length when not scalable, and + /// the minimum vector length when scalable. + int64_t vf; + /// If use scalable vector. + bool scalable; }; } // end anonymous namespace @@ -198,9 +235,13 @@ class MatMulTransposeBVecPass registry.insert(); } - Option vecSize{*this, "vec-size", - llvm::cl::desc("The size of vectorization"), - llvm::cl::init(32)}; + Option vf{*this, "vf", + llvm::cl::desc("Specify vectorization factor."), + llvm::cl::init(32)}; + Option scalable{ + *this, "scalable", + llvm::cl::desc("Specify whether the vectorization factor is scalable."), + llvm::cl::init(false)}; }; } // namespace @@ -216,7 +257,7 @@ void MatMulTransposeBVecPass::runOnOperation() { target.addLegalOp(); RewritePatternSet patterns(context); - patterns.add(context, vecSize); + patterns.add(context, vf, scalable); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/tests/Conversion/matmul-transpose-b-vectorization.mlir b/tests/Conversion/matmul-transpose-b-vectorization.mlir index 391c7ce84d..2201666176 100644 --- a/tests/Conversion/matmul-transpose-b-vectorization.mlir +++ b/tests/Conversion/matmul-transpose-b-vectorization.mlir @@ -1,5 +1,5 @@ // RUN: buddy-opt %s \ -// RUN: -matmul-transpose-b-vectorization="vec-size=64" \ +// RUN: -matmul-transpose-b-vectorization="vf=8 scalable=false" \ // RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ // RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ // RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ From dc45fbad5fbfebbc2db1dcbf7e88f7c62e9db2d6 Mon Sep 17 00:00:00 2001 From: Xinye_Zhu <62270271+R-Tars@users.noreply.github.com> Date: Thu, 6 Mar 2025 00:27:08 +0800 Subject: [PATCH 11/13] Update vocab.txt for DeepSeek-R1 (#468) --- .../buddy-deepseek-r1-main.cpp | 5 +- .../BuddyDeepSeekR1/import-deepseek-r1.py | 3 +- examples/BuddyDeepSeekR1/vocab.txt | 113740 +++++++-------- 3 files changed, 56873 insertions(+), 56875 deletions(-) diff --git a/examples/BuddyDeepSeekR1/buddy-deepseek-r1-main.cpp b/examples/BuddyDeepSeekR1/buddy-deepseek-r1-main.cpp index 975abcc374..7269cc584a 100644 --- a/examples/BuddyDeepSeekR1/buddy-deepseek-r1-main.cpp +++ b/examples/BuddyDeepSeekR1/buddy-deepseek-r1-main.cpp @@ -50,7 +50,7 @@ void getUserInput(std::string &inputStr) { void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } /// Print information for each iteration. -void printIterInfo2(size_t iterIdx, std::string str, double time) { +void printIterInfo(size_t iterIdx, std::string str, double time) { std::cout << "\033[32;1m[Iteration " << iterIdx << "] \033[0m"; std::cout << "Token: " << str << " | " << "Time: " << time << "s" << std::endl; @@ -169,14 +169,13 @@ int main() { int maxIndex = findMaxIndex(startPtr, endPtr); std::string tok = inputContainer.getStr(maxIndex); // Print the generated token and inference time. - printIterInfo2(i, tok, inferenceTime.count() / 1000); + printIterInfo(i, tok, inferenceTime.count() / 1000); // Stop if a <|end▁of▁sentence|> token is generated. if (maxIndex == 151643) { break; } // Append the generated token into the input and output container. - // inputContainer.appendTokenIdx(maxIndex); inputContainer.appendTokenIdx(maxIndex); attention_mask.getData()[MaxTokenLength - generateLen + i] = 1; outputContainer.appendTokenIdx(maxIndex); diff --git a/examples/BuddyDeepSeekR1/import-deepseek-r1.py b/examples/BuddyDeepSeekR1/import-deepseek-r1.py index a14d88ea90..b0b8673b79 100644 --- a/examples/BuddyDeepSeekR1/import-deepseek-r1.py +++ b/examples/BuddyDeepSeekR1/import-deepseek-r1.py @@ -51,8 +51,7 @@ if model_path is None: model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" -# Initialize the tokenizer and model from the specified model path. -tokenizer = AutoTokenizer.from_pretrained(model_path) +# Initialize the model from the specified model path. model = AutoModelForCausalLM.from_pretrained( model_path, torchscript=True ).eval() diff --git a/examples/BuddyDeepSeekR1/vocab.txt b/examples/BuddyDeepSeekR1/vocab.txt index 103cb497fe..33b9f91c23 100644 --- a/examples/BuddyDeepSeekR1/vocab.txt +++ b/examples/BuddyDeepSeekR1/vocab.txt @@ -317,7 +317,7 @@ am om );Ċ im -čĊ +Ċ Ġ( il // @@ -339,7 +339,7 @@ ol ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ th )Ċ -Ġ{Ċ +Ġ{ Ġg ig iv @@ -454,14 +454,14 @@ de op up get -Ġ}Ċ +Ġ} ile Ġan ata ore ri Ġpro -;čĊ +;Ċ ĉĉĉĉ ter ain @@ -525,7 +525,7 @@ ave ang Ġare Ġint -âĢĻ +’ _t ert ial @@ -553,7 +553,7 @@ ize ure Ġby ire -Ġ}ĊĊ +Ġ} .p Ġsh ice @@ -713,7 +713,7 @@ ON per ich Ġbut -ĠĊ +Ġ Ġ_ _m add @@ -733,8 +733,8 @@ ivate Ġim Ġconst .t -Ġ*/Ċ -);čĊ +Ġ*/ +);Ċ Ġvoid Ġset ĠSystem @@ -746,7 +746,7 @@ li ally set ep -âĢĻs +’s bo def ',Ċ @@ -852,7 +852,7 @@ ator View List ĉreturn -âĢĿ +” Ġpre Ġx clude @@ -869,7 +869,7 @@ LE OR Ġprivate tem -čĊčĊ +ĊĊ user Ġ) com @@ -961,7 +961,7 @@ utton ose Ġ!= ater -é +é reate oll pos @@ -970,7 +970,7 @@ ng AL using ames -Ġ{čĊ +Ġ{ ates ely Ġwork @@ -1034,7 +1034,7 @@ date Ġwere Ġfile Ġwould -ĠâĢľ +Ġ“ ven iss Ġour @@ -1064,7 +1064,7 @@ Res Ġcomm ise min -ĠĠĠĠĊ +ĠĠĠĠ #include ethod .P @@ -1151,7 +1151,7 @@ let DE red Ġfe -Ġ},Ċ +Ġ}, Ġ, (t Ġfirst @@ -1211,12 +1211,12 @@ ttp ': ics Ġunder -Ġ*Ċ +Ġ* .L ); ices Ġreg -)čĊ +)Ċ ĉpublic SS Ġthen @@ -1363,7 +1363,7 @@ ities uff play .add -ĠâĢĵ +Ġ– Ġwant Ġcomp ments @@ -1403,7 +1403,7 @@ _F AM ility eter -âĢĻt +’t ĊĊĊ ayout -------------------------------- @@ -1431,13 +1431,13 @@ Im Ġtry Ġnow rough ->čĊ +>Ċ ackage Ġhim ._ ify Ġbreak -Ġ);Ċ +Ġ); ren #define itt @@ -1454,7 +1454,7 @@ ocument cription Error -b -о +о ][ trans Ġpoint @@ -1474,9 +1474,9 @@ ql Ġshow User ased -Ġ{ĊĊ +Ġ{ Ġfind -а +а ED span enu @@ -1502,7 +1502,7 @@ from Ġbl label else -е +е Ġ(! ized (), @@ -1552,7 +1552,7 @@ lear html Index uthor -Ġ/**Ċ +Ġ/** Ġline Event _D @@ -1570,7 +1570,7 @@ ature ected ES ister -ĉĊ +ĉ Ġbefore ale other @@ -1623,7 +1623,7 @@ vert Ġdec lease oun -Ġ});Ċ +Ġ}); fr formation etail @@ -1645,7 +1645,7 @@ aint ĠAr mon til -();čĊ +();Ċ ): Set atter @@ -1664,7 +1664,7 @@ ender ational ware .log -{čĊ +{Ċ Ġusing _B Ġ:= @@ -1684,7 +1684,7 @@ amespace Ġà lob Ġparam -Ġ}čĊ +Ġ} Ġecho function ******************************** @@ -1725,7 +1725,7 @@ urs Ġprot ./ pre -Ġ)Ċ +Ġ) ma Ġsur Ġfound @@ -1769,9 +1769,9 @@ px vice _data ĠNULL -}čĊ +}Ċ idd -ãĢĤ +。 Ġmed org ider @@ -1787,12 +1787,12 @@ ants -c Ġopen Ġest -ĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠ Ġnext IM -ÑĤ +т OT -ó +ó Ġfollow content ĠĠĠĠĠĠĠĠĠĠĠĠ @@ -1800,7 +1800,7 @@ content HE ĠRes Ġhref -и +и Ġcar ypes image @@ -1869,7 +1869,7 @@ AB Ġmark rid ified -,čĊ +,Ċ yn press Ġgroup @@ -1924,7 +1924,7 @@ not her icon On -;čĊčĊ +;ĊĊ ivity mand .Windows @@ -1938,7 +1938,7 @@ raph leg assword ?ĊĊ -â̦ +… ook uck Ġmessage @@ -1951,13 +1951,13 @@ Get enter ground ene -á +á .length Node (i Class for -ĠâĢĶ +Ġ— ten oin Ġke @@ -1985,7 +1985,7 @@ head rg Ġproduct This -.âĢĿ +.” ĠBut loy Ġdouble @@ -2071,7 +2071,7 @@ OD Ġfield iven oto -âĢľ +“ col (x ght @@ -2106,7 +2106,7 @@ CE \" irt Ġwrit -н +н ĉm ftware ond @@ -2182,7 +2182,7 @@ aterial iled Ġput Qu -ÑĢ +р ung map ĉĉĉĉĉĉĉĉ @@ -2245,7 +2245,7 @@ format Ġgreat inter cale -Ñģ +с ron iving Ent @@ -2276,7 +2276,7 @@ _v UI ocation md -Ġ[Ċ +Ġ[ Ġ] sw Ġincre @@ -2291,7 +2291,7 @@ ways ĠĠĠĠĠĠĠĠĠĠĠĠĠĠ Ġpay plit -âĢĶ +— Ġcoun obj .php @@ -2301,9 +2301,9 @@ ething aster los lation -ĠĠĊ +ĠĠ Le -ä +ä ({ ready ĠNo @@ -2389,15 +2389,15 @@ anel sc .to Ġproject -ü +ü Ġelement Ġsuccess -ĉĉĊ +ĉĉ .sh ram ched ())Ċ -Ġ(Ċ +Ġ( Ġdate Ġtot _ST @@ -2417,14 +2417,14 @@ Act era cope .$ -,âĢĿ +,” Ġpop Ġfew Ġlen uid eters ules -ÃŃ +í source https Ġdem @@ -2527,7 +2527,7 @@ Instance Ġcustom location model -ĠčĊ +Ġ Ġsource Ġeas .out @@ -2551,7 +2551,7 @@ HT Ġoutput Ġemail .push -Ġ}čĊčĊ +Ġ} ination atrix Table @@ -2568,7 +2568,7 @@ field Ġrequired _R Ġgovern -}čĊčĊ +}ĊĊ lex ., ĠSet @@ -2595,7 +2595,7 @@ mod annel Ġnp ugg -Ġ/>Ċ +Ġ/> Ġcalled body Ġcho @@ -2603,7 +2603,7 @@ body _set ird Ġ>= -Ġ};Ċ +Ġ}; Ġoptions ĠGener Ġheight @@ -2749,7 +2749,7 @@ arge ../../ EL Ġvalues -Ġ})Ċ +Ġ}) pen No icro @@ -2758,7 +2758,7 @@ icro acy rec ()-> -ĉĠĠĠ +ĠĠĠ ")) Content _W @@ -2789,7 +2789,7 @@ ks .text atures Ġtotal -Ġ*/ĊĊ +Ġ*/ ope Ġstat UM @@ -2813,7 +2813,7 @@ irl Ġoverride Ġcompany Ġdone -");čĊ +");Ċ Ġgre .Re Ġbelie @@ -2832,10 +2832,10 @@ func uments -h Number -:čĊ +:Ċ ĠLog erver -Ġ),Ċ +Ġ), ament Ġobj inc @@ -2877,7 +2877,7 @@ options ĠMar (a Ġwithin -.âĢĿĊĊ +.”ĊĊ ODE _DE admin @@ -2954,7 +2954,7 @@ oud Ġmen AX ĠCopyright -ö +ö avig req Client @@ -3009,7 +3009,7 @@ output ĠOF _time Ġoffer -Ġ});ĊĊ +Ġ}); HER egin "" @@ -3036,13 +3036,13 @@ product ajor And Ġdisplay -л +л Ġtimes Ġfour Ġfar Ġpresent ĠNS -Ġ\Ċ +Ġ\ uest Ġbas echo @@ -3077,7 +3077,7 @@ omic Ġmean ips Ġaut -);čĊčĊ +);ĊĊ Ġuntil Ġmarket Ġarea @@ -3127,9 +3127,9 @@ UP OS iod ĠMon -âĢĻre +’re Ġlik -ç +ç ively .v imer @@ -3163,7 +3163,7 @@ env Ġsoftware Ġimp Ġwin -ón +ón Ġthing Trans ĠTHE @@ -3188,7 +3188,7 @@ unch isk Ġrights (M -Ġ"""Ċ +Ġ""" aper .model Ġpo @@ -3200,7 +3200,7 @@ artment ĠEd Ġseason Ġdest -ã +ã (h Ġpossible Ġsever @@ -3209,7 +3209,7 @@ artment Ġsent Ġenc Ġcommand -Ġ],Ċ +Ġ], _x Ġrecent olution @@ -3222,8 +3222,8 @@ vector INT bsite ĉp -.čĊ - +.Ċ + sl attern ĠClass @@ -3367,12 +3367,12 @@ ids aken Ġmeet Ġmom -ĠâĢĺ +Ġ‘ Ġ?> Ġden obile change -ĠĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠ ici na ĠForm @@ -3405,13 +3405,13 @@ annot itive Ġquestion ĠQu -ãĢĤĊĊ +。ĊĊ gle Ġword Ġprovide ĠReturn Ġresearch -ão +ão ustr Ġpublish chema @@ -3454,7 +3454,7 @@ ades _se ause Ġemploy -Ġ*/čĊ +Ġ*/ Ġfre Ġ'@ Ġcomplet @@ -3473,7 +3473,7 @@ eneric while ias BUG -Ġ);ĊĊ +Ġ); Ġrole Reg ĠColor @@ -3532,7 +3532,7 @@ ios ignment ULT Pr -";čĊ +";Ċ Ġunderstand uary Ġhappen @@ -3566,7 +3566,7 @@ New Ġdetails Ġprob ĠAND -()čĊ +()Ċ ilar Ġ${ rypt @@ -3608,7 +3608,7 @@ amed Ġfood Source (string -Ġ+Ċ +Ġ+ ites dr Ġmembers @@ -3632,7 +3632,7 @@ ros ugin fa Ġconnection -Ġ};ĊĊ +Ġ}; Ġbecome Mode Ġev @@ -3653,11 +3653,11 @@ float Child typ Ġcertain -ión +ión OUT Ġimpro iles -Ġ-->Ċ +Ġ--> ĠPart values oss @@ -3675,7 +3675,7 @@ login heet Default @" -ĉĠ +Ġ click (value ĠAb @@ -3775,15 +3775,15 @@ ctx Ġunsigned .Point ĠOne -ı +ı iple aid -Ñĥ +у Vector byte Ġwait -ĠÃł -Ã¥ +Ġà +å Ġtogether Ġthrows FO @@ -3835,7 +3835,7 @@ password ny Ġesc .write -ï¼Į +, What .H Ġhistory @@ -3868,7 +3868,7 @@ icult Ġtook Ġgames Ġ}} -Ġ?>Ċ +Ġ?> Ġproducts Is Ġbad @@ -3956,7 +3956,7 @@ styles ico Ġess .Control -Ġé +Ġé ball Ġlearn inding @@ -3980,7 +3980,7 @@ Base ________ Ġcomment INE -âĢĻve +’ve But ĠEl ĠUs @@ -4026,14 +4026,14 @@ amples ensive font stream -using +using .springframework server Ġbill ACK ilename Ġframe -Ġ=Ċ +Ġ= Edit adius Ġdraw @@ -4100,7 +4100,7 @@ _SIZE Ġwor Ġprintf rag -Âł + DD ĠVal Ġactiv @@ -4140,7 +4140,7 @@ _ch Ġdownload (T aved -âĢĵ +– Ġstudents Ġfig light @@ -4165,7 +4165,7 @@ clus _date Ġ/** Ġauth -Ġ[]Ċ +Ġ[] Ġperiod nown Ġvot @@ -4233,7 +4233,7 @@ ha Ġreq OST angular -Ñı +я Ġfive Ġdistributed Ġfriend @@ -4247,7 +4247,7 @@ dis zy Ġheader ĠCheck -âĢĻm +’m just holder ="čĊ +">Ċ .annot Ġcollection '. @@ -4453,10 +4453,10 @@ urrenc _test Ġentire Down -Ġ}ĊĊĊ +Ġ} (result ĠRead -è +è Mod Ġtrying "),Ċ @@ -4472,7 +4472,7 @@ _on Ġphys }/ Ġnamespace -ĉčĊ +ĉ acc Player ARE @@ -4550,12 +4550,12 @@ Widget Ġarticle rodu andid -Ñĭ +ы ĠCr ka (): lood -ĉĉĉĊ +ĉĉĉ Ġalmost Ġsell ervlet @@ -4591,8 +4591,8 @@ ores (res Ġreserved SP -Ġâ̦ -ÅĤ +Ġ… +ł Ġsignific Off ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ @@ -4698,7 +4698,7 @@ aves ĠMain _k eries -âĢĻll +’ll transform imestamp Pre @@ -4708,7 +4708,7 @@ stant Location _NAME Ġloss -ĠĊĊ +Ġ net Ġengine Block @@ -4721,7 +4721,7 @@ Block airs wner Ġlower -",čĊ +",Ċ ĠDem ufact Ġps @@ -4791,7 +4791,7 @@ istrib _class ello That -к +к pecially ĠPresident Ġcampaign @@ -4822,7 +4822,7 @@ fg ];ĊĊ Ġcallback ĠHttp -ÑĮ +ь long MS ATH @@ -4858,7 +4858,7 @@ Each (key SELECT Pos -));čĊ +));Ċ Ġrecomm Ġtraining ĠEnt @@ -5029,7 +5029,7 @@ sec aset Ġparameters Ġcrush -"čĊ +"Ċ ILITY igration Ġcout @@ -5061,7 +5061,7 @@ ara refix asc Reader -Ġп +Ġп astic (() Cl @@ -5120,10 +5120,10 @@ _length .Component ev .Ex -ï¼ļ +: "; ĠHigh -Ġ)ĊĊ +Ġ) ĠPoint oph Ġlines @@ -5131,7 +5131,7 @@ oph ")ĊĊ ox application -Ġ]Ċ +Ġ] ĊĊĊĊĊĊ Ġsoon ctions @@ -5154,7 +5154,7 @@ Function endar .index Ġfill -ÄĻ +ę Ġchoose how ĠAmerica @@ -5184,7 +5184,7 @@ Provider den ĠReturns Ġnote -ür +ür pm ideos Ġspecified @@ -5257,7 +5257,7 @@ ourse Ġaddition Ġvarious Ġreceive -ен +ен ĠHT Obj DF @@ -5371,7 +5371,7 @@ IES JSON IE iant -ãĢģ +、 _j ĠSept _map @@ -5395,7 +5395,7 @@ Changed Ġok Ġfeed IX -és +és ĠNews remove erry @@ -5407,7 +5407,7 @@ Current .content .Group ustral -ĠÑģ +Ġс }) Ġpopular Ġstre @@ -5432,11 +5432,11 @@ rs aries (D (get -â̦ĊĊ +…ĊĊ Ġrelated Ġvers Ġsil -Ġ"";Ċ +Ġ""; Ġcmd Ġtechnology .width @@ -5466,13 +5466,13 @@ common imation :@" chie -Ġ...ĊĊ +Ġ... river ĠMarch category fin Ġcourt -в +в Server Ġcontainer -st @@ -5496,7 +5496,7 @@ Oper outes Ġchannel Ġchanged -ê +ê Ġfinally _number Please @@ -5548,7 +5548,7 @@ II ĠWork ĠURL ĠUpdate -',čĊ +',Ċ Ġimmedi close ados @@ -5611,7 +5611,7 @@ RA Ġcontains Ġstack mar -Ġ{}Ċ +Ġ{} Ġundefined Ass ĠChina @@ -5651,7 +5651,7 @@ afe ĠMod Ġtab ano -ñ +ñ ipping -e Ġinsert @@ -5701,7 +5701,7 @@ ambda Ġmovie Ġsec Ġactivity -ا +ا Ġsql _all incip @@ -5719,7 +5719,7 @@ elcome ĠSy ĠCent ALSE -ación +ación EXT Ġlicense ĠLong @@ -5758,7 +5758,7 @@ _PRO QU åı antity -ÂŃ +­ words Ġreadonly Ġflex @@ -5781,8 +5781,8 @@ role Ġgoes MP white -):čĊ -))čĊ +):Ċ +))Ċ Ġreference Ġmis ĠProject @@ -5803,7 +5803,7 @@ Part Ġworth hib game -Ġв +Ġв acion ĠWhite (type @@ -5859,7 +5859,7 @@ gs Ġlogin atives ']);Ċ -Äħ +ą Ġill IA children @@ -5870,7 +5870,7 @@ DO Ġ"# ToString Ġnecessary -ĠĠĠĊ +ĠĠĠ cell Entry Ġ'# @@ -5901,11 +5901,11 @@ _input (in Strip ìĿ -ção +ção Ġevidence )); ĠBro -Ġ[];Ċ +Ġ[]; Ġou buf Script @@ -5926,7 +5926,7 @@ ober Ġlogger Ġrecently Ġreturned -ččĊ +Ċ )))Ċ itions Ġseek @@ -5949,7 +5949,7 @@ anning ĠLink ĠResponse Ġstri -ż +ż ĠDB æĹ android @@ -5958,13 +5958,13 @@ otion (@ .test ĊĊĊĊĊĊĊĊ -];čĊ +];Ċ Ġdirectly Ġ"% ris elta AIL -){čĊ +){Ċ mine ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ (k @@ -6018,7 +6018,7 @@ pragma ragment ĠNode ening -Ñĩ +ч Ġroute ĠSchool hi @@ -6058,7 +6058,7 @@ Access _arg ĠJanuary ĠDay -")čĊ +")Ċ uple document gorith @@ -6105,14 +6105,14 @@ Vert cout Ġenv _label -Ġ>Ċ +Ġ> run Ġscene (array device _title agon -]čĊ +]Ċ aby Ġbecame boolean @@ -6149,7 +6149,7 @@ NG XX Ġmiddle chor -ø +ø erval .Column reading @@ -6199,10 +6199,10 @@ rael rics np Ġcore -());čĊ +());Ċ Main Ġexpert -ĉĉčĊ +ĉĉ _en Ġ/> utter @@ -6226,7 +6226,7 @@ haps .pre cm Values -Ġ"Ċ +Ġ" column ivil Login @@ -6258,7 +6258,7 @@ ita {} ĉC Ġstuff -Ġ:Ċ +Ġ: ĠAR Task hidden @@ -6272,7 +6272,7 @@ Enter ĠTwitter ĠCounty scribe -Ġ=>Ċ +Ġ=> Ġhy fit Ġmilitary @@ -6311,7 +6311,7 @@ enge Ġcustomers Ġcast udget -ï¼ģ +! icens Ġdetermin Selected @@ -6358,9 +6358,9 @@ mo ito Ġanalysis Ġdeliver -ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ idx -Ãł +à ongo ĠEnglish čĊ +Ġ--> Ġrelief lap quer @@ -16036,7 +16036,7 @@ yo Ġmemor ĠED Ġjur -æį® +据 _TABLE Ġuuid Expr @@ -16061,7 +16061,7 @@ _conf (df Ġlocked Ġrising -ãĥ»ãĥ» +・・ ĠMs Ġscenes _EXT @@ -16072,8 +16072,8 @@ people ĠFun Ġbless ĠUpdated -ün -ĠĠĠĠĠĠĠĠĠĠĠĠčĊ +ün +ĠĠĠĠĠĠĠĠĠĠĠĠ pection Release .logger @@ -16093,7 +16093,7 @@ failed ĠINCLUDING Ġwriters {}Ċ -ÃŃt +ít _copy }: ĠBat @@ -16102,20 +16102,20 @@ eding placement ĠHost Sound -им +им Ġsought mid Ġsalary ogg -âĦ¢ +™ bul Ġwir validator _STAT .store ĠBattle -ın -Ġ-->ĊĊ +ın +Ġ--> Trump dot ĠCONT @@ -16155,11 +16155,11 @@ intage big ologist ennis -Ùĩ +ه Ġchicken -ĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠ çĽ -ãģ§ +で Ġpeak Ġdrinking Ġencode @@ -16184,7 +16184,7 @@ parameter Ġthousand Ġcoordinate -generated -íķĺ +하 generated Ġadmitted Ġpussy @@ -16204,7 +16204,7 @@ raction Blueprint Ġattrs Ġsmoke -ÐĴ +В .Equals FB ĠResources @@ -16216,11 +16216,11 @@ etype \", Ġsensitive Ġtall -?âĢĿĊĊ +?”ĊĊ Proxy iy _section -âĢĶâĢĶâĢĶâĢĶ +———— brid Ġcircuit atan @@ -16238,7 +16238,7 @@ cookie _div ĠUILabel vely -});čĊ +});Ċ _ENT #+#+ articles @@ -16337,7 +16337,7 @@ emen ASSWORD $s ĠCirc -ой +ой etric /P Ġepoch @@ -16358,7 +16358,7 @@ XY >The ĠAk Ġgrass -/*čĊ +/*Ċ (dis Ġguns Ġtb @@ -16446,7 +16446,7 @@ ISH Ġfeaturing Returns ĠKr -Ġ.Ċ +Ġ. Ġnam _cb Testing @@ -16474,7 +16474,7 @@ gal ĠCatholic ARGET cpu -çłģ +码 .scroll ISING ifestyle @@ -16509,14 +16509,14 @@ UPPORT deg rgb Ġstrongly -Ġ};čĊ +Ġ}; Ġ): Ġlect ursive ROL ĠWeight Ġentertainment -Ġ));Ċ +Ġ)); Ġgonna Ġbb .do @@ -16528,7 +16528,7 @@ earning Limit issions [v -ä¸į +不 irty Del Ġunderlying @@ -16558,7 +16558,7 @@ wan mount Ġmonitoring Ġtout -ëĬĶ +는 },{ ................................ =int @@ -16584,7 +16584,7 @@ enta Tools Ġbab Ġcareful -ãģĦ +い Ġcrucial Ġcalculated ĠSA @@ -16623,10 +16623,10 @@ lobals Ġproven .onCreate Ġalarm -Ġ§ +Ġ§ Ġcommonly icos -æĸ° +新 ĠStation }). ĠFilm @@ -16639,13 +16639,13 @@ Stats _property Ġages ('-- -Ġför +Ġför ĠProfessor Ġhydro Push Ġorganized Accept -ém +ém _cell Ġnb pb @@ -16683,7 +16683,7 @@ SB _contents iseconds verty -át +át Guid nom Ġconclusion @@ -16691,7 +16691,7 @@ nom Ġlovely Ġemit bec -ĉĉĉĉĠ +Ġ Ġintellect Ġbrew ecycle @@ -16712,7 +16712,7 @@ Cur .ar mock ĠAdministration -ãģ¾ +ま Ġelectron flate Ġlombok @@ -16742,11 +16742,11 @@ alling roduction ĠTransport ĠNOTE -æĸĩ +文 Ġfewer _TIM ì§ -ки +ки Age FIN ĠìĿ @@ -16800,9 +16800,9 @@ season CU Ġintroduction Ġmatplotlib -Åij +ő Ġnewspaper -âĢĶand +—and (Ċ ?>" -Ġ///Ċ +Ġ/// Ġeiner Ġweekly ĉlogger @@ -17089,7 +17089,7 @@ Bad HR ĠJordan iza -ĠÂł +Ġ ĠSher .header (other @@ -17105,7 +17105,7 @@ regon ĠMPI Ġanchor aca -ør +ør Ġade anchor quee @@ -17164,7 +17164,7 @@ projects Ġcheese EMPL aro -ĠاÙĦ +Ġال Ġconsists refresh ureau @@ -17173,9 +17173,9 @@ ureau Ġflavor DataSource Execute -ение +ение Ġshit -åĪĨ +分 ĊĊ Ġterrible bean @@ -17521,7 +17521,7 @@ fade ĠArts .Application Ġbehalf -æĪ· +户 Ġmere (`${ Ġawareness @@ -17558,21 +17558,21 @@ thetic .Top .Page ={` -Ġ;čĊ +Ġ; depth mann WD ĠSom .Right -Ġ)}Ċ +Ġ)} Ġtrait -ÃĹ +× iac Ġrv Sample .Xml opped -ĠÑĦ +Ġф lists Ġtear iversary @@ -17605,7 +17605,7 @@ BT Invoke Ġlucky rat -Ġ?Ċ +Ġ? Ġhandled (fd contents @@ -17640,7 +17640,7 @@ _comment Ġcolleagues maps âĺ -ĊĉĊ +ĉ (al _req Ġfut @@ -17684,7 +17684,7 @@ _br Ġlosses ĠAdded charg -Ġпо +Ġпо _system ĠSometimes ĠSpain @@ -17699,7 +17699,7 @@ quires usage Ġjun imiter -ï¼ģĊĊ +!ĊĊ Ġfifth toggle Ġdecline @@ -17712,7 +17712,7 @@ inge Ġpodcast Ġnaturally Pages -为 +为 ĠDespite Ġlighting Ġcrate @@ -17742,7 +17742,7 @@ _inst Criterion ĠTIM .Height -ĠâĢĻ +Ġ’ ();ĊĊĊ Products _SP @@ -17751,7 +17751,7 @@ _SP este Ġdatos dit -ав +ав IGNAL Ġlesson ">' @@ -17785,7 +17785,7 @@ UA riendly tech .gameObject -иÑĤÑĮ +ить Ġmoon ftime Ġnoch @@ -17823,12 +17823,12 @@ pus Ġpocket Ġram igrations -.čĊčĊ +.ĊĊ Ġ[( Ġadopted Ġreportedly ĠDream -Ġ}));Ċ +Ġ})); losing Ġteeth ĠBooks @@ -17837,19 +17837,19 @@ enny LEMENT Ġgel ĠPlant -!âĢĿ +!” .host ĠReply rength Ġrecognition -Ġ}}>Ċ +Ġ}}> LA Ġmirror Ġassistant (device Ġspiritual builder -§ +§ Ġoutr Ġtt ĠPER @@ -17875,11 +17875,11 @@ elter _GR ĠčĊ +Ġ/> metic Ġtransformation -åı· +号 Ġrgb istributions Ġimplicit /in destination -аÑĤÑĮ +ать Zero Ġunset .where .go Ġformation Ġdeclaration -()čĊčĊ +()ĊĊ ĠExpl -ĉĉĉĠĠ +ĠĠ /pro .JSON Ġdesk @@ -18084,7 +18084,7 @@ nb ..... ĠNull mx -Ġç +Ġç Ġpause ----------- _MO @@ -18106,7 +18106,7 @@ csv >/ ĠGOP lad -ĠÑĢ +Ġр ĠindexPath matrix =f @@ -18156,7 +18156,7 @@ FUNCTION ellite Ġdent ĠMicro -åıĸ +取 '][$ ĠIE imension @@ -18204,7 +18204,7 @@ Tests .admin ultipart (lambda -namespace +namespace ĠSport Ġ!( acles @@ -18217,9 +18217,9 @@ acles supported Ġpink Ġinvited -ños +ños _enabled -Ġ-Ċ +Ġ- FW eners ĠMY @@ -18234,7 +18234,7 @@ untu ivals Donald limited -ĉĉĉĉĉĉĊ +ĉĉĉĉĉĉ Ġanalyst (entry Ġrepresentative @@ -18290,7 +18290,7 @@ notes ĠTE ĉerror <' -Ġ»ĊĊ +Ġ» Ġfiltered ĠMach Ġhung @@ -18322,8 +18322,8 @@ Jan apa ĠNSLog _lines -ña -ĉĉĠĠĠĠĠĠĠ +ña +ĠĠĠĠĠĠĠ .Sc Rep etroit @@ -18332,7 +18332,7 @@ MIT compat owned _indices -],čĊ +],Ċ Ġdiscovery ĠDiego obi @@ -18360,8 +18360,8 @@ may ercise ĠLu Ġrg -ĠÑģÑĤ -ĉĉĊĉĉĊ +Ġст +ĉĉĉĉ (un TERNAL Ġlessons @@ -18394,8 +18394,8 @@ tb Ġsalope ByteArray Original -Ġ[{Ċ -åĽŀ +Ġ[{ +回 ĠClin oenix ĠSamsung @@ -18409,7 +18409,7 @@ fail Ġpromot Ġincl _only -를 +를 ĠAttorney -date Ġlandscape @@ -18419,7 +18419,7 @@ SY ĠArr pag ParallelGroup -':čĊ +':Ċ Ġlogs aunch unci @@ -18443,7 +18443,7 @@ Rotation Ġpleased itage .Wh -ĉĉĠĠĠĠ +ĠĠĠĠ MR ĠMORE ĠNatural @@ -18471,7 +18471,7 @@ trigger WORK declare Ġdecrease -ÅĽci +ści loom .None ĠMI @@ -18490,8 +18490,8 @@ opper -target _FUNCTION Ġoct -ениÑı -åľ¨ +ения +在 Ġwestern Ġcomputers ĠRET @@ -18501,11 +18501,11 @@ getValue _DATE .Next ĠFif -él +él icked æİ -MM -Ġ{ĊĊĊ +Ġ{ Ġcontacts Ġdigits Produ @@ -18528,21 +18528,21 @@ works Ġcontribution ĠTony Ġsquad -ай -Ġîn +ай +Ġîn there outed ĉq ĻĤ good LI -页 +页 ĠLiving izabeth Ġkt ĠDallas ]],Ċ -Ġ/>ĊĊ +Ġ/> Ġraising /router _game @@ -18554,7 +18554,7 @@ zens notification Ġ'../../../ Ġblame -ãĢĤĊĊĊĊ +。ĊĊĊĊ anco Identity follow @@ -18572,7 +18572,7 @@ xs _gshared ĠCT Force -ĊĠĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠ Ġorange Ġlp Ġanswered @@ -18589,16 +18589,16 @@ AIT _factor -one ĠHAVE -"čĊčĊ +"ĊĊ Prof -Ġär +Ġär strings Ġdirty ĠFace ĠBegin ĠBus Ġwis -åŃĹ +字 Ġspeaker Ġcarrier ĠOm @@ -18609,7 +18609,7 @@ Allow ĠComplete ĠEasy Ġbills -ĠĠĊĊ +ĠĠ Vertical Ġpron ĠDefine @@ -18623,7 +18623,7 @@ umes artist ĠCType Foundation -à¹Ī +่ ĠSetup Ġrecipes ĠUIColor @@ -18656,7 +18656,7 @@ also Ġopacity ĠUnfortunately ĠIllinois -Ġне +Ġне ĠTemple ĠTrail ĠKelly @@ -18669,9 +18669,9 @@ igits Ġib ĠMOD attery -аз +аз Ġvend -енÑĤ +ент ĠHttpClient safe _ASS @@ -18696,7 +18696,7 @@ ned _% Ġfavourite ĠBru -Ġá +Ġá secondary Ġmast Ġsoph @@ -18742,17 +18742,17 @@ iu Ġzo (ID _required -Ġsé +Ġsé ĠQueue AO Ġgem pton ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ ijk -({čĊ +({Ċ Ġcollision ĠUkraine -Ġ-*-Ċ +Ġ-*- NSInteger _BLOCK ĠTexture @@ -18770,7 +18770,7 @@ abling }/>Ċ Ġinnovation _" -Ġ);čĊčĊ +Ġ); Ġspots Ġchoosing .cs @@ -18789,7 +18789,7 @@ dney ''' !", Ġparticle -Ãĥ +à [MAX IVER ERENCE @@ -18828,7 +18828,7 @@ edException Ġwiring Ġcourts WEB -æľī +有 \. illance Ġbrows @@ -18861,7 +18861,7 @@ igne Ġindicating keeper Ġcada -ég +ég consin ĠGB Ġlb @@ -18912,7 +18912,7 @@ Leg =null Keyboard ')). -Ġ"";čĊ +Ġ""; Ġattitude .navigate -error @@ -18932,7 +18932,7 @@ Fixture ĠHashSet Nombre _month -ư +ư -start xygen ĉft @@ -18941,11 +18941,11 @@ iagnostics Ġconcepts Ġconstr .State -ин +ин Nov -α +α ĠPanel -个 +个 compare >()Ċ Ġapplying @@ -18985,7 +18985,7 @@ _reset family WW Ġsavings -ĠâĢĿ +Ġ” _enable sidebar Running @@ -19001,7 +19001,7 @@ _dataset ĠDas Ġhan Getty -ál +ál Ġny Ġpoverty Ġresulted @@ -19009,7 +19009,7 @@ Getty ĠVisit Ġobtaining /'.$ -ĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠ shall _LEFT UIImage @@ -19041,7 +19041,7 @@ _stop Ġreporters Ġversus aja -Ġα +Ġα Ġgovernor ListItem Ġsealed @@ -19068,14 +19068,14 @@ Nil Ġsecretary ĠjPanel vez -³³³³ + direction ĠEP Ġhunt JsonProperty ĠPORT ]", -ап +ап ĠForeign panic Ġtrials @@ -19106,7 +19106,7 @@ Put roadcast Icons )")Ċ -æĪIJåĬŁ +成功 gui Ġassumed Ġrx @@ -19148,7 +19148,7 @@ __) HD Modified Ġpredicted -ÅĦ +ń anie Sorry (doc @@ -19180,7 +19180,7 @@ otime aug ĠHong _norm -ãģ¨ +と Ġsecre (Build ĠContract @@ -19203,7 +19203,7 @@ Utility Ready Ġgall Ġallegedly -ĉĉĉĉĠĠĠ +ĠĠĠ ĠMetal ĠPersonal ĠborderRadius @@ -19243,7 +19243,7 @@ Summary ĠInstall ĠFab itmap -Ġ))Ċ +Ġ)) Ġintersection ighbor ĠBry @@ -19269,7 +19269,7 @@ akan flash Ġdelet boards -ĠĠĉ +ĠĠ ROP ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ Ġacqu @@ -19304,12 +19304,12 @@ market Ġaug period ĠConstant -Ġ?>">Ċ +Ġ?>"> Ġlobby pal Ġsink iah -С +С urname Ġconver Ġinvestigate @@ -19353,10 +19353,10 @@ latest Ġluxury Ġrefr ĠKitchen -ÑĦ +ф (at Final -ück +ück _zero ĠABC ĠManchester @@ -19374,7 +19374,7 @@ anga Ġig uploads Ġpacked -Ġ}];Ċ +Ġ}]; (sender ĠWire isons @@ -19389,7 +19389,7 @@ ackson _member Turn ĠSoviet -ìĹIJ +에 auge Ġincoming Ġjak @@ -19401,7 +19401,7 @@ Stage OwnProperty .setItem Ġdc -ä½ľ +作 Ġbrut Ġattempting .len @@ -19544,7 +19544,7 @@ aware Ġdependencies ĠVideos -row -Ġ**/Ċ +Ġ**/ Ġnou Ġhover æŀ @@ -19572,7 +19572,7 @@ incess Ġactors ĠQual _clean -ãĢijãĢIJ +】【 MSG Green ĠOfficer @@ -19585,7 +19585,7 @@ olygon Ġdrama Ġexceptions osed -Ġ+čĊ +Ġ+ Ġlegacy CV Ġcontributed @@ -19607,8 +19607,8 @@ bury GLOBALS urrencies Ġtons -âĢĻ, -Ġê +’, +Ġê (col ĠSymbol Ġstayed @@ -19620,7 +19620,7 @@ nr Ġgains Ġshortly .Menu -ý +ý KNOWN Ġoperators -V @@ -19643,7 +19643,7 @@ Three valuate ĠMini bu -оз +оз
    ";čĊ +>";Ċ ĠSav .Bold Ġenables @@ -20106,8 +20106,8 @@ ocomplete via upo Ġabortion -ière -ï¼ij +ière +1 _BUTTON _domain Ġbra @@ -20182,7 +20182,7 @@ atz Ġvisitor ĠFly Seq -à¸Ļ +น ĠVisual Ġlibraries atoes @@ -20193,12 +20193,12 @@ VRTX ĠDM Split Ġletting -ÐĿ +Н _errors epoch PARAM cu -ÑģÑĤв +ств olutions Editing fonts @@ -20208,7 +20208,7 @@ fonts ĠJudge Ġbrothers FILES -ço +ço wb _PI '^ @@ -20219,7 +20219,7 @@ Tim igg ĠMoore Ġcryptoc -åĩº +出 _posts otate ?' @@ -20227,7 +20227,7 @@ otate Ġkl ="$ Ġdecoration -ạ +ạ ĠDIRECT GUI )=>{Ċ @@ -20262,7 +20262,7 @@ contract /components Ġnr ĠIndones -ĠоÑĤ +Ġот ĠVolume .files (resp @@ -20285,7 +20285,7 @@ pie Ġexpense ication ĠLarge -Ġ± +Ġ± ĠBowl (models /N @@ -20293,7 +20293,7 @@ Pa .reload Ġwondering Execution -ĉĠĠĠĠĠĠ +ĠĠĠĠĠĠ ĠGraphics ĠContin _job @@ -20336,7 +20336,7 @@ _HEADER Hex Ġfarmers Ġmaintaining -ĠĠĠčĊ +ĠĠĠ syn [T rus @@ -20351,8 +20351,8 @@ slider Ġsuppliers scriber pes -Ðŀ -":čĊ +О +":Ċ \Controller ))ĊĊĊ Ġlua @@ -20373,7 +20373,7 @@ zing endant [Ċ oler Ġlibert -Ġ`Ċ +Ġ` Ġwenn lated Ġimmune @@ -22082,7 +22082,7 @@ lated logs Ġ../ ĠADC -Ġ}}">Ċ +Ġ}}"> >');Ċ =b ĠWind @@ -22105,7 +22105,7 @@ _many ĠFetch ĠMarvel Ġresist -ого +ого bidden ĠRunnable :false @@ -22125,12 +22125,12 @@ _fail ĠThrowable .router ĠRevolution -ÑĢа +ра _NON Ł¥ Ġelder Ġabroad -Ġе +Ġе ĠAdult blr glyphicon @@ -22161,12 +22161,12 @@ _bool ulative Ġcone ĠMult -Ġmö +Ġmö ĠForward ]):Ċ Ġconvinced acted -ãģĵ +こ ĠConfigure Ġceiling Der @@ -22188,7 +22188,7 @@ fac ButtonItem Ġblocking strar -ò +ò ĠExport Ġthrew otta @@ -22205,7 +22205,7 @@ Police boxes Ġdiamond ,l -Ġĉĉĉ +Ġ Ġcurious tv Ġerotische @@ -22224,7 +22224,7 @@ ikan [N (Qt (Base -æģ¯ +息 beat ĠEmpty ĉo @@ -22237,11 +22237,11 @@ Cent ĠTIME Management -sp -ême +ême Ġnotion unifu PK -è¡Į +行 ĠCURLOPT \"\ UV @@ -22285,7 +22285,7 @@ Include ROLL ĠdataType David -ร +ร lop -month Ġscar @@ -22310,7 +22310,7 @@ regular ĠKa MAN Ġastr -Ġ'')Ċ +Ġ'') Ġfed Ġparsing ĠYears @@ -22341,7 +22341,7 @@ Orientation grades ropol basic -Ġ");čĊ +Ġ"); Ġawards (range -all @@ -22365,7 +22365,7 @@ piece ĠPok celona mutex -;čĊčĊčĊ +;ĊĊĊ Ġstrikes Loaded )arg @@ -22380,9 +22380,9 @@ apple acon Ġprinter ĠGC -å®ļ +定 Ġrendered -,âĢĻ +,’ heit social .ge @@ -22416,8 +22416,8 @@ IDEO ĠMoscow ,this ĠVictoria -æĶ¹ -ĠÐŁ +改 +ĠП .stack ĠBarn paredStatement @@ -22439,7 +22439,7 @@ _txt Ġ(< averse Ġdevast -ãĢĢ + .Dec ĠGard /ui @@ -22482,7 +22482,7 @@ ERRU _filters Preferred scene -еÑģ +ес ĠAffairs Ġ"#{ ĠonSubmit @@ -22494,7 +22494,7 @@ hit Jo .getC Initialized -ÑĤи +ти cuts (Type ĠAgreement @@ -22605,7 +22605,7 @@ seed olk ĠAsset reach -'),čĊ +'),Ċ navigation LF /util @@ -22619,7 +22619,7 @@ TagName permission ifiable xFFFFFFFF -ни +ни .Buffer _irq dark @@ -22695,14 +22695,14 @@ rena Ġrd Ġdeserve Ġwheels -å¸Ĥ +市 Ġcritics Namespace ĠFra -ĠĊĊĊĊ +Ġ Ġalla Ġrequiring -æľŁ +期 utation Ġdelayed Ġadministrative @@ -22710,7 +22710,7 @@ utation .hidden Tex Ġboundaries -Ġ]);ĊĊ +Ġ]); ĠFollowing ~/ Fi @@ -22744,7 +22744,7 @@ rez Used wear Ġlegitimate -Ġ"ĊĊ +Ġ" Ġhv Std ĠHold @@ -22763,7 +22763,7 @@ EqualTo å± linear observ -Ġpiù +Ġpiù Ġcomplement WithValue (password @@ -22785,7 +22785,7 @@ PEED ĉE _tool Ġladies -оÑģ +ос ))))Ċ ;;;; .dot @@ -22831,7 +22831,7 @@ amiento Ġfiring NaN ĉtemplate -ад +ад .En Ġdefence ĠTel @@ -22855,7 +22855,7 @@ LIST Ġtrash Ġregistr Ġseller ->';čĊ +>';Ċ ĠstartTime çĻ sy @@ -22909,8 +22909,8 @@ adesh atty Ġcertified sj -Ġêtre -ÅĤo +Ġêtre +ło Ġpublishing ĠMalays .getUser @@ -22927,11 +22927,11 @@ ande _front ĠMcG TestMethod -à¸Ń +อ Ġoccasionally ĠWales Ġexercises -ĠÐĴ +ĠВ -plus Ġvalidator Ġprayer @@ -22981,7 +22981,7 @@ astro owners .mode Ġdiagnosis -Ġ_Ċ +Ġ_ ĠKnight ĉA Ġobserve @@ -23029,7 +23029,7 @@ BOX Ġfecha Ġvide ĠLeader -以 +以 $(". Ġdiameter Ġmild @@ -23062,7 +23062,7 @@ _SELECT ĠArabia _clock Ġvoy -Ġиз +Ġиз Ġstir isible -effect @@ -23079,10 +23079,10 @@ _timestamp Ġweakness Throw ĠAngel -ä¿® +修 Ġuncert -ï¼īĊ -ĠìĿ´ +)Ċ +Ġ이 Which Ġ[-]: Something @@ -23145,7 +23145,7 @@ Lat ĠComput Ġterrorism Ġsweep -Ġ[]čĊ +Ġ[] Ġpassenger Ġeastern Ġtweets @@ -23222,7 +23222,7 @@ Views Ġmedication ĠWy ĠAnna -ع +ع ĠVertex .types Organ @@ -23257,11 +23257,11 @@ kernel States ĠTeen _components -ìĪĺ +수 Received Ġlyrics rites -ĉĉĉĉĉĠ +Ġ -American [num /python @@ -23269,7 +23269,7 @@ rites Ġapple ĠJonathan Ġmomentum -ั +ั Ĥ¹ Ġmich andra @@ -23303,7 +23303,7 @@ quet /get /master WIN -åħĥ +元 West argc Ġproducers @@ -23339,14 +23339,14 @@ INESS _begin -heading Course -ĠčĊčĊ +Ġ ombie graded ĠGPS -Ġże +Ġże Fit caption -ön +ön /image lia (mod @@ -23362,7 +23362,7 @@ teacher ĠSent support jectory -ĠÙħ +Ġم Registration ĠGray ,false @@ -23382,7 +23382,7 @@ Ip ĠMario ĠQuestions PACE -æĸ¹ +方 eor }}" Ġoven @@ -23417,10 +23417,10 @@ _TIMEOUT .createSequentialGroup _person Ġbeam -ĉĠĠĠĠĠĠĠĠ +ĠĠĠĠĠĠĠĠ ĠNotFound .'Ċ -ÃŃs +ís .TextView PDF Ġkar @@ -23453,11 +23453,11 @@ agonal ungeon Adv carousel -ÃŁe +ße _DESC Ġhammer -áºŃ -ĠĠĠĠĠĠĠĠĊĊ +ậ +ĠĠĠĠĠĠĠĠ -core -service Ġcorners @@ -23472,11 +23472,11 @@ osc asures _internal Ġprints -Ġ])Ċ +Ġ]) ĠCleveland repo Disc -Ġ">Ċ +Ġ"> ���� Ġnearest _tb @@ -23529,14 +23529,14 @@ Pod Ġparticipating clusions (ByVal -ì +ì ĠHOW _setopt Ġaccompanying aton Ġ/\ ĠAuthentication -ién +ién ĠBarack /*. Ġeager @@ -23569,9 +23569,9 @@ _axis Ġexamination '.Ċ mons -++){čĊ +++){Ċ ĠForms -íķľ +한 CppMethod _trace Ġengineer @@ -23598,7 +23598,7 @@ JSONObject Desktop _SYMBOL (resource -ĠĊ +Ġ Ġnewest uli Ġdesert @@ -23614,7 +23614,7 @@ etherlands Ġcement .auto _AN -âĢĻ. +’. selection ĠBond Den @@ -23635,9 +23635,9 @@ icular ĉkey >\< ENSION -Ġ[čĊ +Ġ[ Ġprecisely -Ġété +Ġété ĠPast ĠCambridge -full @@ -23654,7 +23654,7 @@ RM ossible Ġactress Ġdolor -å½ķ +录 Need .toggle ĠRace @@ -23694,14 +23694,14 @@ mitt lando Ġpig inals -ência +ência Surface ĠUUID Ġbeneficial Ġsequences ĉmemset Ġmagical -« +« Ġworn ASC popup @@ -23752,7 +23752,7 @@ Period Ġphotography ĠFramework Ġspecialist -Ġ?ĊĊ +Ġ? _selected .Player Ġallocation @@ -23761,7 +23761,7 @@ _selected vable -offset .AppCompatActivity -ам +ам .AddWithValue Ġicons Ġshutdown @@ -23834,7 +23834,7 @@ VALUE ĠNat _Ad @@ -24916,7 +24916,7 @@ sigma Ġdeparture Ġcelebration ĠSay -ï¼Ĵ +2 ĠHills .hasOwnProperty Ġtypings @@ -24966,7 +24966,7 @@ slice Ġinvesting irable Ġxmlns -ï¼Ľ +; arta Ġtheories _city @@ -24983,7 +24983,7 @@ ismatch Ġecosystem Ġtempt Ġ\\ -Ġ//{Ċ +Ġ//{ ĠChristopher ĠKentucky ĠHttpServletResponse @@ -25050,7 +25050,7 @@ extend ']). FFECT ĠPinterest -úmero +úmero ricane ĉsession Ġcrystal @@ -25065,14 +25065,14 @@ ften Ġlease scr Ġrefuse -ãĢĭ +》 ftp information Ġevaluated Ġinjection Ġjack Ġworkshop -注 +注 PTH ĠTs offer @@ -25160,7 +25160,7 @@ _properties .Unit _CLK Ġgt -Ġ();ĊĊ +Ġ(); Ġhandy ĠThompson Ġunnecessary @@ -25175,7 +25175,7 @@ ieu Ġthy Ġlt _mail -ä¿®æĶ¹ +修改 ailand ĠPhilip Ġbitter @@ -25198,11 +25198,11 @@ uka ĠConsumer Ġaggreg Circle -à¸ģ +ก _blocks Ġlegally Ġ"| -ãĥĥ +ッ .board .Ab Functions @@ -25230,7 +25230,7 @@ caled .Controllers ĠWolf Ġcrushers -á»ĩ +ệ .Auth .addAttribute his @@ -25247,7 +25247,7 @@ MISSION jamin ĠSB Ġdetermination -Ġ'');Ċ +Ġ''); ĠBeng Ġvos Ġinhab @@ -25287,7 +25287,7 @@ htags Mc Shell rin -{čĊčĊ +{ĊĊ .pow ĉclient Ġconspiracy @@ -25313,13 +25313,13 @@ equ human .messages ĉtyp -Ġ(čĊ +Ġ( ĠSSL LEN ĠRomney (grid ĉmin -Ġ>ĊĊ +Ġ> Ġfruits Ġvoter Inline @@ -25421,19 +25421,19 @@ _bus unto Ġfoss ĠLinks -äng +äng /forms prises Ġachievement CALL -елÑĮ +ель ĠVerify _SOURCE aptcha IDD _reference Gold -ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ Receiver Ġaj _direction @@ -25445,7 +25445,7 @@ _direction techn ĠJerusalem longitude -');čĊčĊ +');ĊĊ Ġwinners Tasks ĠDMA @@ -25458,7 +25458,7 @@ parents ---- -());čĊčĊ +());ĊĊ Extract izens Ġsolver @@ -28972,11 +28972,11 @@ REQ ĠConservative @Column Ġshifted -Ġ:čĊ +Ġ: Ġfich Ġdla Ġshoe -"),čĊ +"),Ċ ularity _RESP Weather @@ -28992,7 +28992,7 @@ owie .prev .IsValid .Fat -ĠsÄĥ +Ġsă keywords without Ġsovere @@ -29041,7 +29041,7 @@ SUB Ġdonde sales llvm -Ġ}],Ċ +Ġ}], OTTOM ĠPurpose Lab @@ -29054,7 +29054,7 @@ asil ĠModified ationally ĠMeeting -误 +误 #region Ġrouting .focus @@ -29075,14 +29075,14 @@ outedEventArgs Ġrename CFG ("// -æİ¥ +接 /pages -Ġprés +Ġprés ĠSpell .Allow ĠINTERRU Ġ(# -âĢĻĊĊ +’ĊĊ _Generic .imshow _tim @@ -29108,7 +29108,7 @@ ScrollView peek _RATE Ġdorm -/čĊ +/Ċ IVITY .Controller (part @@ -29184,7 +29184,7 @@ esta ĠMorris ."""ĊĊ Wrong -ĠÅĽ +Ġś Ray .ec ĠBind @@ -29195,7 +29195,7 @@ isValid _LIMIT Ġdynamics Ġdistinction -ãģĨ +う < Ġunto @@ -29378,7 +29378,7 @@ atables Ġrefriger Ġcoordin avorites -ÑĪи +ши Ġcompassion ĠPOSSIBILITY -secondary @@ -29410,10 +29410,10 @@ dw _SRC Ġnormalized ĠJag -ãĤĴ +を zeit rpc -ÃŃc +íc .inline Ġtravers _numeric @@ -29435,7 +29435,7 @@ trait ĠTodd Ġskeleton Ġoptimize -第 +第 ĠUpon ĠStObject Ġaplic @@ -29454,12 +29454,12 @@ INAL _mesh skill ĠViol -² +² ĠEOF ĠKi ymmetric Ġmaxlength -Å£ +ţ friends ĠEvans Ġlemon @@ -29488,8 +29488,8 @@ preg (serial ifica uming -åľ° -ãģĤ +地 +あ -op UCH ĠHend @@ -29522,12 +29522,12 @@ _provider Ġdiplom Ġmessaging _mut -å¦Ĥ +如 Ġkw ONS arians RPC -)]čĊ +)]Ċ -ray ĠSor mall @@ -29554,7 +29554,7 @@ opia -Ch ĠdataSource /"Ċ -екÑĤ +ект ĠRequestMethod ĠReplace -do @@ -29576,7 +29576,7 @@ development Ġdeserves Ġcurriculum _CONTEXT -ÅĤy +ły HITE ĉID /uploads @@ -29595,7 +29595,7 @@ Phil tok Ġjar Los -âĢĶâĢĶâĢĶâĢĶâĢĶâĢĶâĢĶâĢĶ +———————— .queue -speed Mal @@ -29605,10 +29605,10 @@ umblr ĠDance (filePath Ġattributed -à¥į +् ĠBund coins -Ġsão +Ġsão Ġpir personal Ġprelim @@ -29633,7 +29633,7 @@ hir Ġchapters (angle ĠVlad -设 +设 '.ĊĊ ResponseBody ĠAbd @@ -29663,7 +29663,7 @@ shake chip Ġuv Ġalliance -пиÑģ +пис ĠGOODS zione ĠVI @@ -29711,7 +29711,7 @@ ampling ("\\ Ġsag _proxy -ãģķ +さ pdo .getElementsByTagName Ġdemonstration @@ -29730,7 +29730,7 @@ ahun Ġdesigning ĠGDP Ġlifted -缮 +目 ĠJoint ĠInclude ĠGiants @@ -29746,7 +29746,7 @@ _CPU Ġboom yecto Ġmanufacture -ĠâĢĭ +Ġ​ Ġbbox Ġearthqu ollectors @@ -29760,17 +29760,17 @@ alking ARGE _pixel Ġsuspects -ι +ι usp ĠBMW ieces (person -å¼Ģ +开 é» ĠPodcast Ġbou (Item -û +û (Input HttpGet Ġburg @@ -29787,7 +29787,7 @@ ugu Ġpled ,"% hape -Ġзап +Ġзап ĠMaine .real Ġdalam @@ -29800,7 +29800,7 @@ disp Ġfg tees ĠRecomm -äl +äl Ġchemistry Blocks OID @@ -29823,8 +29823,8 @@ iefs _FULL ernetes ĠPred -ØŃ -äºĭ +ح +事 ubernetes ĠLaura Ġlabeled @@ -29919,7 +29919,7 @@ $row Ġpsychology vh Ġseverity -âĢIJ +‐ Ġstrips AH vertising @@ -29934,7 +29934,7 @@ sson ĠSusan .tile eded -ĠĠĠĠĉĉĉ +ĠĠĠĠ uelle ĠMitchell based @@ -29989,7 +29989,7 @@ Later ĠUE Ġclue comed -åIJįç§° +名称 -main Ġpts Ġcounted @@ -29999,11 +29999,11 @@ icts Ġping ANCEL Ġpec -Ñħод +ход antom ĠBlueprint ĠEventEmitter -Ġlä +Ġlä æ² Ġstraw (comp @@ -30014,9 +30014,9 @@ esModule -base Ġretreat _simple -ĉĉĉĉĉĉĠ +Ġ fee -')čĊčĊ +')ĊĊ ControlItem Ġsubscribers please @@ -30093,7 +30093,7 @@ _FLAGS Sun FROM ĠDir -ãĥ»ãĥ»ãĥ» +・・・ _coord ĠOptim Monitor @@ -30101,7 +30101,7 @@ Monitor XXX Ġtodas feld -ÑĢи +ри imir Ġpolitically Ġmolecular @@ -30141,7 +30141,7 @@ AILABLE .IContainer poll ĠCorps -ε +ε aru ĠKay .range @@ -30164,7 +30164,7 @@ activated .mkdir =user Ġrede -fü +fü _SYSTEM pv Ġcongr @@ -30172,18 +30172,18 @@ pv Ġpractition University Ġtabindex -Ðĺ +И Sets Ġcounties guest fan Ġworden .di -наÑĩ -¿ +нач +¿ igDecimal Ġshore -Ġgö +Ġgö Ġrepairs Ġhelpers Ġcentered @@ -30201,7 +30201,7 @@ cales -season Ġfunctioning _LOCATION -üss +üss bery Para ominator @@ -30209,14 +30209,14 @@ ominator Ġethical hashtags emplo -Ġnúmero +Ġnúmero (activity .Stop .strftime ILD Ġtoe ĉNode -")čĊčĊ +")ĊĊ ĠPuerto Ġexecuting ĠGUID @@ -30231,7 +30231,7 @@ ainted _DOM Ġwil Ġslope -ĠmÃ¥ +Ġmå ĠIraqi Ġorganize ĉjQuery @@ -30245,7 +30245,7 @@ ponsor Ġreluct named ĠOliver -Ġ//}Ċ +Ġ//} -looking Ġfog ĠHO @@ -30276,7 +30276,7 @@ armed _ES Ġfossil ĠAnc -âĢľThis +“This lodash Python Ġhistogram @@ -30321,15 +30321,15 @@ Editable redential ĠPerry kie -Ġ----------Ċ +Ġ---------- .stroke (Intent Ġunity umlah Further Ġprze -Ġsø -ãĤĬ +Ġsø +り ĠPROCUREMENT ĠHousing Ġattorneys @@ -30341,7 +30341,7 @@ draul Instant .JTextField Ġtrades -ла +ла Ġ{! Ġlately IMG @@ -30361,7 +30361,7 @@ _comments ENV ĠConnecticut -FIRST -ĉĉĉĠĠĠĠĠ +ĠĠĠĠĠ achi .Msg rection @@ -30370,7 +30370,7 @@ rection Ġef ĠAdding Ġbreach -Ġï¼ļ +Ġ: rama Ġconducting Ġ(; @@ -30388,7 +30388,7 @@ proj Ġseventh EMPLARY (mock -'],čĊ +'],Ċ _SPEED >false Ġspa @@ -30415,7 +30415,7 @@ ystick fulness apos Da -ĉĉĉĉĉĠĠĠ +ĠĠĠ Ġenrich unordered hole @@ -30432,13 +30432,13 @@ uba arel Ġacted -details -à¸ĩ +ง ĠTheory ĠPun ĠAnonymous ..."Ċ -ères -åı¯ +ères +可 ĠVision _sem asha @@ -30510,7 +30510,7 @@ mag anggal ',[ ropolitan -ĠÃľ +ĠÜ ĠUC .desc -LAST @@ -30532,17 +30532,17 @@ commend Soft Ġpartir wealth -è¦ģ +要 (dataset ĠClimate -show Ġreliability _chunk -代 +代 _stock ĠEXEMPLARY -ï¸ı -ĠvÃŃ +️ +Ġví Ġsmiled Ġdrill .Function @@ -30632,7 +30632,7 @@ horizontal (option Ġweiter ĉsql -Ġ=>{Ċ +Ġ=>{ Ġgarlic Ġrepr Ġreplies @@ -30643,7 +30643,7 @@ horizontal .reject Ġhints Ġpolling -ĉĠĊ +Ġ _rating Ġcath avier @@ -30655,7 +30655,7 @@ avier training ESTAMP ognition -Äģ +ā SENT ventions Ġconsultant @@ -30667,7 +30667,7 @@ Dear _BAD itations Ġmetaph -'é +'é andise -font .chart @@ -30705,9 +30705,9 @@ eri Ġ**** Ġoverlook .Non -Ġrés +Ġrés Ġegy -å°ı +小 Ġattacker ĉĉĉĉĉĉĉĉĉĉĉĉĉĉĉ .sync @@ -30734,12 +30734,12 @@ REAT Phase (ii ĠSUM ->ččĊ +>Ċ Ġsud ĉbackground Ġscholars -muted -ará +ará Ġ===== Ġ____ Creat @@ -30765,7 +30765,7 @@ _REQ ĉpanic psi oka -éĢī +选 >[ Ġunderstands ĠJunior @@ -30777,13 +30777,13 @@ serv ĠCREATE .au Ġsells -ĠĠĊĠĠĊ +ĠĠĠĠ Europe zw preh ĠNSA Ġxy -ิ +ิ ĠBeyond Instead NonQuery @@ -30842,7 +30842,7 @@ repr ĠDry .ro ĠObserv -æłĩ +标 Former ĠBalance ĉjson @@ -30856,7 +30856,7 @@ ISS Fun Ġschemes Ġinterven -æĺİ +明 Ġadverse quotelev Ġsacrific @@ -30866,14 +30866,14 @@ AGIC Ġoccurring ĠCommunication umar -ç¼ĸ +编 ĠTreatment .person ĠLC Ġech ((" ĠDisease -äd +äd ĠAZ .Account Ġcontinuously @@ -30916,7 +30916,7 @@ iolet Ġnerve çĦ ")] -æ±Ĥ +求 ĠSugar _SIM jpeg @@ -30928,14 +30928,14 @@ bove Ġworkforce ĠExecution errer -ĉĠĠĠĠĉ +ĠĠĠĠ Ġprescribed .TextAlign OPEN ĠPB imity ĠExternal -°C +°C ĠApplicationController Ġbarr implicit @@ -30960,8 +30960,8 @@ chemical :UIControlState toInt ] addEventListener IALOG -åIJ¦ +否 .Compare Album ĠKu @@ -32384,7 +32384,7 @@ argest Ġconfigurations Ġaccidentally _photo -Ġ'';čĊ +Ġ''; Ġverse Bob Ġfarming @@ -32401,7 +32401,7 @@ creator ĠEaster .-- UIButton -ãĤī +ら ometers Ġshine Ġhogy @@ -32439,11 +32439,11 @@ oriented Ġlightning fid ĠPle -ãģ¾ãģĻ +ます tro .True Observable -×Ļ +י umbing Ġprospective -filter @@ -32452,7 +32452,7 @@ umbing .Bind Ġpalm clearfix -ös +ös ĠGonz Ġweaken Drive @@ -32461,7 +32461,7 @@ lld obox anean Got -ä¿Ŀ +保 Regex æĥ Ġsalad @@ -32471,7 +32471,7 @@ inheritDoc ĠRV quier Ġclazz -Ä±ÅŁ +ış osterone Ġairline .listdir @@ -32509,7 +32509,7 @@ rah .tags _tests stones -âĢĿ) +”) [g rtype Ġvu @@ -32543,17 +32543,17 @@ _Test Soup ~~~~~~~~~~~~~~~~ (files -ĉĉĉĉĉčĊ +ĉĉĉĉĉ .spark Ġvalued -Ġ%Ċ +Ġ% .controls ĠXCTAssertEqual Ġfame ĠRic DOT ĠAlberta -使 +使 osal .WebControls Ġ------------ @@ -32569,8 +32569,8 @@ _operation uffs *m Ġavant -次 -âĢľYou +次 +“You .permission ...) ĠLic @@ -32586,7 +32586,7 @@ clk (layer pit Ġguided -ĠâĸĪ +Ġ█ Ġsuperb Ġsupplements _cent @@ -32596,7 +32596,7 @@ INARY falls ")); Wall -).čĊ +).Ċ ĠDanny irmingham IALIZ @@ -32608,7 +32608,7 @@ macro amac .box ----Ċ -ãĥ« +ル ĠSuit urst bru @@ -32621,7 +32621,7 @@ uder ?\ fu [B -Ġ:)ĊĊ +Ġ:) (inter brains Ġattitudes @@ -32640,13 +32640,13 @@ ICLE ĠWhatever Ġoutlined sprite -ев +ев _AB _DEPTH Ġcrushed aaa (ev -æľº +机 Anti ICO isEqualTo @@ -32662,8 +32662,8 @@ lista .UndefOr Ġautomation Nor -对 -åıĤæķ° +对 +参数 Ġreflex ĠLaure .showMessageDialog @@ -32675,7 +32675,7 @@ ARED agle Energy Ġquantities -âĢĻé +’é ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ Ġcitizenship mouth @@ -32691,7 +32691,7 @@ Ten ĠYemen (call Ġslavery -Ñģп +сп ĠLam _BITS omega @@ -32715,7 +32715,7 @@ Hel *d ĠLegisl _Pr -ĉĉĉĠĠĠĠĠĠĠ +ĠĠĠĠĠĠĠ Ġsympath Ġchess Ġmam @@ -32741,7 +32741,7 @@ actly Ġcontractors ']} draulic -ódigo +ódigo ĠTT ĠWide ĠARG @@ -32755,7 +32755,7 @@ School manent Overlay ('" -éĩı +量 ĠTimestamp Ġmailing ĠCake @@ -32788,22 +32788,22 @@ Interceptor ĠDH iales Ġvillages -Ø´ +ش ĠENV Sys .XR Ġpoem -ÃĤ + cade plots Ġ{( .git /svg ncmp -ĠÄį +Ġč aines -åĩ½æķ° -Ġ()ĊĊ +函数 +Ġ() opsis ĠRelationship _aut @@ -32812,7 +32812,7 @@ _aut *sizeof official _payload -ĉĉĉĉĉĠĠ +ĠĠ .manager ĠAround ĉsend @@ -32829,7 +32829,7 @@ encrypt Ġasm ICH ĠCGRectMake -ìĦ± +성 ulong Ġitr ĠGST @@ -32865,23 +32865,23 @@ ellt Ġbrace Ġtrails published -å¯Ĩçłģ +密码 }')Ċ Ġacids Ġ!!! _direct >());Ċ -ajÄħ +ają _OCC Ġplanets -æŁ¥ +查 ĠDublin Ġserie .printf deep `) Ġ\$ -Ġμ +Ġμ _VIDEO endors ĠCrypto @@ -32896,7 +32896,7 @@ _training ĠKath ĠIndexPath Ġachievements -Ġserá +Ġserá interopRequire Ġdisse .If @@ -32943,10 +32943,10 @@ formatter Ha vangst Ġemerge -ãĢĤâĢĿ +。” ĠCabinet -square -éĥ¨ +部 Ġrage ĠAJ ĠVT @@ -32980,7 +32980,7 @@ _numbers /X Ġfonts trip -иÑĩ +ич Ġtubes clamation Ġë§ @@ -33046,7 +33046,7 @@ azar Ġpulls ngx Ġinspiring -ÑĥÑİ +ую -direction Ġexplosive ĠcreatedAt @@ -33069,7 +33069,7 @@ olate Ġcatching -password ouched -æĢ§ +性 eous Ġxrange Quality @@ -33088,7 +33088,7 @@ esi ']/ .savefig (trans -ج +ج nee Correct ...")Ċ @@ -33106,7 +33106,7 @@ _sock Ġdestinations emption ĠFAIL -åĴĮ +和 Ġrp fact ĉlen @@ -33124,7 +33124,7 @@ $i Ġthou ogene Ġscholarship -æĽ´ +更 Ġswo aginator eni @@ -33169,7 +33169,7 @@ _POL Ġteenage .binding postal -Ġiçin +Ġiçin ĠDataType éĸ yclerview @@ -33177,8 +33177,8 @@ yclerview _identifier ".$ Ġrelies @@ -33702,7 +33702,7 @@ ottage Ġker Ġappliances rowave -ìĿĢ +은 ematics ĠOrg oping @@ -33760,7 +33760,7 @@ asser -small Ġrealiz (Entity -ús +ús ĠActually ĠElite Ġhelm @@ -33841,7 +33841,7 @@ _DOUBLE ĠSoph Ġelectoral _disable -ĠÑģо +Ġсо ĠLightning Ġmentions ocy @@ -33859,7 +33859,7 @@ Camp .Author Ġdirective -hook -íĦ° +터 }ĊĊĊĊĊ @pytest _rand @@ -33870,7 +33870,7 @@ lasses ĠClasses .have %), -é¢ĺ +题 Ġdisturbing substring ĠKoh @@ -33881,7 +33881,7 @@ purchase ierarchy Ġfps .checkBox -íķ´ +해 _material ducation Ġfw @@ -33914,7 +33914,7 @@ approved Ġdenial /share LinkedList -)čĊčĊčĊ +)ĊĊĊ uddy Ġfines Ġry @@ -33928,11 +33928,11 @@ cuda ĠCock ,:) (folder -Ġméd +Ġméd drag Ġtalents -ĠĠĠĊĊ -еÑģÑĤв +ĠĠĠ +еств mob .yml Ġaster @@ -33954,7 +33954,7 @@ iversal ĠNevertheless -led Ġ(%) -ç¡® +确 Ġtimezone Ġstranger (render @@ -33978,12 +33978,12 @@ adays recv Working Jump -ĠÃ¥r +Ġår ĠAutomatic _Base -æł¼ +格 aurants -¯ +¯ æ¸ (CType IFI @@ -33993,7 +33993,7 @@ IFI Ġfir Ġrestoration ereco -Т +Т _'+ Ġebook Ġdebris @@ -34009,7 +34009,7 @@ lander ĠRPC _EXIT (queue -иÑģÑĤ +ист Dll Ġskull _pub @@ -34035,7 +34035,7 @@ cerpt ĠVote Ġconcurrent ĠMessageBoxIcon -ĠÃĸ +ĠÖ ĠDubai ĠRetail :number @@ -34045,7 +34045,7 @@ _origin _WORK Frames Ġnotably -.âĢľ +.“ Ġtropical Ġniche amina @@ -34081,7 +34081,7 @@ borough heimer (move (Text -});čĊčĊ +});ĊĊ welcome ĠComponents Ġgovernance @@ -34090,7 +34090,7 @@ closed Ġlaundry ĠTerminal izards -.âĢĶ +.— .remote .radius ĠQuebec @@ -34133,13 +34133,13 @@ dar /xhtml vinc _mock -ĊĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ ĠPill .LayoutStyle ĠCommander ]< signature -Ġ{}čĊ +Ġ{} Ġhatred Ġëĭ olesterol @@ -34147,7 +34147,7 @@ olesterol ancellor crop TIM -ĉĉĊĊ +ĉĉ ysqli uitive ĉunset @@ -34182,10 +34182,10 @@ Styled ĠMarco Gallery dale -.âĢĿĊĊĊĊ -érie +.”ĊĊĊĊ +érie /service -äºĨ +了 Ġambient _SETTINGS .Adapter @@ -34195,14 +34195,14 @@ Notice Ġcleans ĠFem chair -Ñĥн +ун /my _bad ĠEconomics ISA _CNT (Menu -äºİ +于 ĠRidge Ġlengthy Dot @@ -34228,7 +34228,7 @@ ongsTo Ġvaries Ġpositioned 'il -éĩij +金 Ġhike (done playlist @@ -34269,7 +34269,7 @@ Letter EXPECT ĉRE .longitude -ünd +ünd Ġstatue .addWidget ĠCaribbean @@ -34282,10 +34282,10 @@ ursion Ġmandate Ġpromotional Ġvk -iaÅĤ +iał Ġpyl ĠCreation -озд +озд Ġsimpler .what ĠRecent @@ -34307,7 +34307,7 @@ _fee Ġabsorb ĠVincent Ġsounded -ÃŃst +íst Ġpharmaceutical htag ĠKindle @@ -34315,17 +34315,17 @@ italize ĠEmperor oustic Ġspecialists -åħ¬ +公 BorderStyle /\ RELATED (',', (expr Ġht -åįĪ +午 _Create Ġspecially -Ġ[];čĊ +Ġ[]; Ġheel Ġsept _arch @@ -34339,16 +34339,16 @@ _arch .hand ĠMAIN ĠDenmark -Ġ],čĊ +Ġ], Ġcryst Ġnack Coords _inner Ġmidst Ġawake -ĠÐŀ +ĠО -break -ÃŃvel +ível _PASS ĠParams Ġdetr @@ -34362,13 +34362,13 @@ CHED _SESSION Ġnouvel oauth -ĠданнÑĭ +Ġданны rink .HeaderText aturated Ġerst Ġåħ -à¥ĩ +े _visible eyer Ġliable @@ -34393,12 +34393,12 @@ _ins Ġrumors Ġrr ĠQuarter -ê³ł +고 Ġfeeds -óg +óg Ġenvelope Ġlear -Ġkø +Ġkø developer Similar :")Ċ @@ -34481,7 +34481,7 @@ igrate (score Keyword "os -ĠĠĠĠĉĊ +ĠĠĠĠ analysis Ġreplay .pass @@ -34490,14 +34490,14 @@ tls Ġsanct .light _mobile -ÑģÑĤÑĮ +сть ĉtotal uity Ġpaused NAS Ġencore loe -Ġ-*-ĊĊ +Ġ-*- .high ampler ĠSecure @@ -34507,13 +34507,13 @@ illary ĠStein ĠDawn Ġmaximize -ย +ย Ġ/^ Ġcontinually Ġshadows -ĉĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ +ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ ĠIActionResult -Ġinformación +Ġinformación CHECK .SelectedItem bundle @@ -34555,7 +34555,7 @@ _NAMESPACE soever _CENTER Interest -ôt +ôt temperature Viewport getResource @@ -34565,7 +34565,7 @@ getResource Ġcylinder Ġtroubles nod -Ñĭв +ыв games _gl Plane @@ -34580,7 +34580,7 @@ many ="",Ċ _Pin uese @@ -35475,7 +35475,7 @@ _tail Ġblown ,// Ġbackgrounds -âĢĻune +’une -sdk ĠsetInterval Ġincentives @@ -35492,7 +35492,7 @@ Positive Ġspinner Ġinvented ĠGates -оÑĤоÑĢ +отор Ġcomparisons è· .primary @@ -35507,7 +35507,7 @@ halten ( >B iteration -ãĥª +リ Ġshirts ounty ->$ @@ -36743,7 +36743,7 @@ houses ritional Ġproximity Ġdiesem -áºŃp +ập Ġdrought .audio ĠLeo @@ -36779,7 +36779,7 @@ ammen Ġ"); amentos etched -Ġ/>}Ċ +Ġ/>} .Users Ġinterrupted Contacts @@ -36804,7 +36804,7 @@ ioxide Ġ''; unset addWidget -лÑİ +лю elles alker Arc @@ -36816,26 +36816,26 @@ _where Ġ\/ ĠTib _AX -]čĊčĊ +]ĊĊ ĠBir Ġbend ĠMAKE ĠMET Ġfutures Ġweighted -"""čĊ +"""Ċ Ġauthorize (program },{" Ġcoefficients -ês +ês PerPage ĠBathroom ĠPublishing GPL Ġsubmissions ĠNUMBER -jÄħ +ją Ġadditionally empre ĠShel @@ -36843,7 +36843,7 @@ otyp Solution Ġthunder _ec -ĠĊĠĠĠĠĊ +ĠĠĠĠĠ ĠFellow Ġkay ĠnewState @@ -36883,7 +36883,7 @@ antd *);Ċ ,u (gen -ç»ĵ +结 reator ĠCord oupper @@ -36912,13 +36912,13 @@ _regs .". Ġfeminist Codec -Ġ**Ċ +Ġ** (labels _MARK FAILED Ġadministered WN -ĠĠĠĠĠĠĠĠĉĉ +ĠĠĠĠĠĠĠĠ Ġnoun wig Ġgotta @@ -36945,7 +36945,7 @@ nv (make Ġbenefici -black -iÃŁ +iß Ġundoubtedly Ġmex ĠAncient @@ -36954,7 +36954,7 @@ iÃŁ Pick Ġreplica $obj -ähr +ähr Ġarrows fty ĠLibya @@ -36977,26 +36977,26 @@ Third _present ĠPierre Ġëª -Ġ[...]ĊĊ +Ġ[...] Prob ĠTraffic icao doctor -Ġ),ĊĊ +Ġ), Tabs alu -ï¼ļâĢľ +:“ Ġinherent _No ritis ĠProof .basename -ä¼ļ +会 Ġchim ĠProtected crit Ġprone -Ġкон +Ġкон ĠHeroes Ġanxious Ġanos @@ -37019,7 +37019,7 @@ acceptable frac Ġboasts Five -± +± ĠTemperature >): Ġcharter @@ -37027,11 +37027,11 @@ REATED Ġsubjected Ġopc healthy -使ç͍ +使用 ĠScientific Ġfrau riages -à¸Ķ +ด .inventory ationale Mad @@ -37042,22 +37042,22 @@ minutes Ġsuspicion sqlite ĉread -ãģ¦ +て Ġworries .putString ĠShanghai (uid rer -ĠvÃŃde +Ġvíde "): Ġmethodology -ĠкоÑĤоÑĢ +Ġкотор ccc avad Ġinduction ĉThread ,string -ại +ại nehmen uition Ġ*__ @@ -37083,9 +37083,9 @@ iw ĠDegree ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ Ġ$('< -ários +ários toUpperCase -ìłľ +제 ĠEUR Ġoversight Ġtablesp @@ -37112,7 +37112,7 @@ istically Ġtester Ġadministrators Ġtagged -Ðĵ +Г Ġshortcut ĠResolution Ġsupervision @@ -37125,12 +37125,12 @@ isten (diff ANTS Ġrider -ĠsÄħ +Ġsą .Series _orders ORIZONTAL Ġretention -ãĢĤčĊčĊ +Ġ★ +">ĊĊ Ġdiagonal ĠCancellationToken _Internal @@ -37263,7 +37263,7 @@ dv -Compatible Originally ,function -ãĢĤčĊ +。Ċ ĠRepresentative asily ircuit @@ -37292,7 +37292,7 @@ _lists Ġhottest .jdbc .Customer -Ġâī¤ +Ġ≤ Ġwaar _scene +'/ @@ -37302,7 +37302,7 @@ _scene Ġ`/ Cases ĠYoutube -ım +ım Ġbalcon ,G MetaData @@ -37336,7 +37336,7 @@ _running -he (named ĠSach -оÑĩ +оч campaign .Abstract (wrapper @@ -37366,7 +37366,7 @@ _errno ('')Ċ ",@" Ġwit -rá +rá ologie ĠStyles ĠBrowserModule @@ -37409,7 +37409,7 @@ _rgb Ġprizes Ġeditable ĉform -ını +ını .decor Demo lices @@ -37419,7 +37419,7 @@ ratulations _chars ĠJahr partial -ÑĥÑĤ +ут ĠReceive ĠLands APTER @@ -37440,7 +37440,7 @@ Ess OND Ĭ¶ (packet -âĢĶbut +—but Invocation ĠNuclear ?;Ċ @@ -37472,7 +37472,7 @@ uncan ĠTheater éĴ ategorie -段 +段 Ġappetite square ĠAlexand @@ -37503,7 +37503,7 @@ ucus ĠClaim ĠRams ĠmodelBuilder -Ġné +Ġné userID =json .ResponseWriter @@ -37522,7 +37522,7 @@ IFA Ġprivat pell emoji -ĠÙĪ +Ġو Genre Ġconcentrated jang @@ -37539,7 +37539,7 @@ _CTL herent rex interactive -ãģ§ãģĻ +です ĠKas Ġdesperately (ar @@ -37555,7 +37555,7 @@ receiver Matcher dependent Ġexcellence -аж +аж LOS Aspect Ġadalah @@ -37570,7 +37570,7 @@ Cols boarding .Children ANGLE -ï +ï ĠYoga Ġhated Adam @@ -37580,9 +37580,9 @@ IMAL _DISPLAY Ġevolve Ġfridge -Ġrég +Ġrég Ġemotionally -âĢľIf +“If awei eresa '," @@ -37619,8 +37619,8 @@ _encoding .configureTestingModule Polygon _DBG -"],čĊ -аб +"],Ċ +аб Ġsimilarity Ġprzez ĠFirm @@ -37639,9 +37639,9 @@ IDGET ĠCommonModule Ġ"'" (Auth -ãĢĤï¼Į +。, ĠStatefulWidget -计 +计 /open inally .Round @@ -37658,7 +37658,7 @@ spell ĠHEL airro bled -ĠбÑĭ +Ġбы Ġsensible ĠLua |(Ċ @@ -37676,7 +37676,7 @@ gue -degree _sound Clone -á»Ļ +ộ aksi >${ _confirmation @@ -37690,7 +37690,7 @@ _metrics _proto Ġpear baseUrl -ĉĉĉĉĉĉĉĉĊ +ĉĉĉĉĉĉĉĉ Ġcoordination :N .animate @@ -37706,7 +37706,7 @@ ificacion ĠLiu >equals ĠAce -ÑĢам +рам ĠSuperman ĠGarcia Ġarrests @@ -37714,13 +37714,13 @@ agar Ġ{}) Ġmacros roupe -être +être Ġtwisted struments _(" _vertices ĠTransition -ик +ик [max mind ĠaccessToken @@ -37870,7 +37870,7 @@ antis Ġempir Ġpathway Ġoak -мен +мен -induced Ġimpair ĠCalgary @@ -37913,7 +37913,7 @@ _Collections .DropDown è° Ġhh -ĠlÃł +Ġlà .pb Ġmemorial ĠATTR @@ -37941,11 +37941,11 @@ atorial Ġparen .lineTo Ġkidney -Ġça +Ġça Ġcui -ï¼Į请 +,请 XC -Ġmoż +Ġmoż Ġnominated lung ImGui @@ -37982,17 +37982,17 @@ ubuntu Incre burse ĠAmateur -æºIJ +源 Blob Ġcholesterol DES minimum Ġrefusing unned -Ðľ +М ĠRD .Servlet -Ġ*/;Ċ +Ġ*/; udden ĠviewBox Ġmetabolism @@ -38001,7 +38001,7 @@ udden agnetic VERRIDE _AUDIO -ÑĢÑĭ +ры Ġarchives .linear ={< @@ -38025,7 +38025,7 @@ elong Ġano Dan Ġpsi -алÑĮ +аль .getChild ĠREF -ab @@ -38033,7 +38033,7 @@ Dan Ċ Ġplag pine @@ -38075,15 +38075,15 @@ yi METHOD ĠEg mapper -æĻĤ +時 .asarray -Ïģ -ição +ρ +ição Reuse _rev ĠPRODUCT _Code -ĠĠĠĠĠčĊ +ĠĠĠĠĠ ĠSERVICE _cover .,Ċ @@ -38107,7 +38107,7 @@ Little Resolution .health ĉfclose -交 +交 Ġstakeholders Ġarchae Digital @@ -38116,7 +38116,7 @@ _pen ĠItemStack ĠCanon ĠKend -Ġø +Ġø _ajax ingredients Delivery @@ -38131,7 +38131,7 @@ _fmt nah Washington zung -ĠÑĨ +Ġц ycz ieves .DEBUG @@ -38142,13 +38142,13 @@ flows ĠdidReceiveMemoryWarning Ġaccountability COUNT -леменÑĤ +лемент blo /id ĠSlow izzard .removeEventListener -Ġìŀħ +Ġ입 /I isma ĠHudson @@ -38174,13 +38174,13 @@ _nat Ġtaxpayer ĠFoster Ġsexuality -ç³» +系 ë° -\čĊ +\Ċ .seek -аниÑı +ания /article -è¿ĩ +过 ĠUhr Ġgrandmother ĠBle @@ -38193,7 +38193,7 @@ oki (Array Ġautonomous Ġobr -¯¯ +¯¯ Ġbasename Ġunveiled sol @@ -38210,7 +38210,7 @@ _RESOURCE (spec (cv Ġnada -ç͵ +电 Ġcrowded Below ĠZach @@ -38229,7 +38229,7 @@ Thunk Ġdataframe _reason gomery -ìĬµëĭĪëĭ¤ +습니다 Ġneglect ĠLines Ġmemb @@ -38263,7 +38263,7 @@ _reward setVisible ĠJsonResponse ICY -询 +询 VarChar aat -green @@ -38302,7 +38302,7 @@ coupon edList ĠStores _malloc -符 +符 ĠAwesome Ġlamb REST @@ -38318,12 +38318,12 @@ PLEMENT .INFO Ġexotic ĠCASE -ĉĠĠĊ +ĠĠ ĠGand theses Ġnovo ĠDell -â̦â̦â̦â̦ +………… _soft Ġagreeing cents @@ -38340,7 +38340,7 @@ DEL Ġ,' ĠLOAD Ġplanted -æľª +未 FormControl _matches Ġperiodic @@ -38377,7 +38377,7 @@ Already ishi Ġtym ĠArmen -ĠÑĢаз +Ġраз -format _Read (columns @@ -38406,7 +38406,7 @@ _SPACE Ġotros Compiler ĉEnd -Ġ]),Ċ +Ġ]), Gravity Ġtensions Ġsmoothly @@ -38419,7 +38419,7 @@ oothing zenie ëŀ ĠChocolate -Ġİ +Ġİ "No ĠALS ĠProgramming @@ -38456,7 +38456,7 @@ BUTTON Ġperspectives Mixin _minus -ĉĉĉĉĠĠĠĠ +ĠĠĠĠ "))) normalized .lastName @@ -38468,7 +38468,7 @@ paginate elig Ġposters nings -ĠÏĦ +Ġτ Ġapost ĠIhre DllImport @@ -38481,7 +38481,7 @@ neapolis Ġlend ĠSHOW _codes -Ġaté +Ġaté ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ -case chte @@ -38495,7 +38495,7 @@ intColor Connell Manifest teams -Ġ};ĊĊĊ +Ġ}; Ġplural Ġovertime ĠEuropa @@ -38505,7 +38505,7 @@ teams itime inston .shadow -ç¨ĭ +程 ĠUSS ServerError IVERS @@ -38513,7 +38513,7 @@ IVERS Ġhumble autoload arez -â̲ +′ ĠAstr icolon .ViewModels @@ -38603,9 +38603,9 @@ Statistics ovenant ============== .Absolute -ĠfÃ¥ +Ġfå Handling -Ġ-------Ċ +Ġ------- (directory ").Ċ anol @@ -38621,7 +38621,7 @@ afety Ġwives ooo Ġprostitu -Ġoù +Ġoù ifty Ġlitigation ĠEz @@ -38695,7 +38695,7 @@ herits ĠDesk _machine .netty -ında +ında =< ĠQR ĠSidebar @@ -38713,7 +38713,7 @@ pectrum .linspace Ġ"... Listen -Æ¡ +ơ .Channel -defined Repeat @@ -38729,7 +38729,7 @@ _application Ġmultiplayer Ġregistering until -Ã¥n +ån (:: ussions Ġpotato @@ -38748,7 +38748,7 @@ ussions ĠHier _RET _bucket -ãĥ¡ +メ avs Ġroz flower @@ -38758,12 +38758,12 @@ WriteBarrier ĠDoll Ġproving .concatenate -âķIJ +═ Ġgchar cdnjs bles ĠListing -ло +ло .xrLabel ĠSak justice @@ -38778,10 +38778,10 @@ ANA recated ĠRuntimeMethod Ġconqu -ãĤ¢ +ア Ġtissues ailer -été +été -Star Ġflames .setIcon @@ -38798,10 +38798,10 @@ agens Ġspill ĠJur Ġdispatcher -ного +ного emonic (dirname -ĠÐĶ +ĠД Ġpasse Ġganz ricing @@ -38810,7 +38810,7 @@ EU essen .attribute jj -ĉĉĠĊ +Ġ [^ Ġstrtolower lexer @@ -38824,11 +38824,11 @@ Market ĠUses ivas .Business -ãģĹãģ¦ +して DIV Ġwasted Ġavoir -êm +êm _ACCOUNT .et ĉSDL @@ -38842,14 +38842,14 @@ PEND ĠAES }), Ġdeduction -ĠpolÃŃt +Ġpolít ĠcomponentWill ĠTelerik _SELF Ġmuse Craft Ġdens -ि +ि (tp Ġtasty Ġbalances @@ -38863,7 +38863,7 @@ emente Ġsoda Republic asmine -éric +éric (Status ĠJsonConvert ĠDisk @@ -38886,7 +38886,7 @@ ximo ĠUSART _super _DECREF -ной +ной _ROW Ġpromotes ĠTA @@ -38940,7 +38940,7 @@ _press .Rel angled /templates --->čĊ +-->Ċ lime Ġsufficiently _nt @@ -38951,7 +38951,7 @@ Expand Ġmulher acob George -常 +常 Ġassim aso Ġcomprised @@ -38979,7 +38979,7 @@ abei jb ĠPit Utf -Ġ/Ċ +Ġ/ reload ĠsetObject /global @@ -38987,7 +38987,7 @@ reload Ġsocks Couldn Ġerotisk -æĿ¡ +条 ĠPressure ĠMaz npos @@ -39022,7 +39022,7 @@ flip assigned Ġabc ĠCOLUMN -ĠðŁĻĤĊĊ +Ġ🙂 )... Ġensemble Ġnewline @@ -39060,7 +39060,7 @@ sess VERTISE ĠFoods Ġtournaments -Ãĵ +Ó ĠMarsh Ġwonders Longitude @@ -39093,7 +39093,7 @@ ickers Ġgarant Ġbf Ġwipe -Ġä¸ĭ +Ġ下 _TRA adox çķ @@ -39109,7 +39109,7 @@ venile ĠGlenn .pattern .DataBind -Ñĥм +ум LayoutInflater chet ĠTestament @@ -39129,7 +39129,7 @@ ivation Ġoutright azu loyment -иÑı +ия aldo ĠPublisher Education @@ -39149,7 +39149,7 @@ tection -large Mel Ġthreaten -нÑı +ня Ġfetish otine _dic @@ -39163,14 +39163,14 @@ jos ĉimg ĉWHERE _lt -å½ĵ +当 .cost ĠTue .labels ĠLV wcsstore ĠJesse -ห +ห Trade Ġpredecessor ëĤ @@ -39188,7 +39188,7 @@ nement maximum .Unlock _SYNC -ágina +ágina Ġdowns ĠWii ])/ @@ -39204,7 +39204,7 @@ unication _mouse urrection (no -Ġ>čĊ +Ġ> Ġaggression Ġbreeding .symbol @@ -39255,7 +39255,7 @@ Physical =v Ġdriv ĠErrors -ĠcÄĥ +Ġcă Death ĠWINDOW Ġpoet @@ -39288,7 +39288,7 @@ _CATEGORY _wp ĠEvaluation starting -Ġ)),Ċ +Ġ)), episode ĠVariant Ġdaemon @@ -39332,7 +39332,7 @@ bins .tele ĠVeterans _ALLOC -олÑĮзоваÑĤ +ользоват innamon ;width ohl @@ -39350,7 +39350,7 @@ imagen .mybatis Seek WER -管çIJĨ +管理 Ġinteress _Event ederland @@ -39372,7 +39372,7 @@ _ZERO Ġ[]; /scripts Ġ-------------------------------------------------------------------------------- -æĥħ +情 Ġweed NBC Ġraped @@ -39390,7 +39390,7 @@ acin Ġtribal .apple ĠBlo -ân +ân ibi rov ĠLives @@ -39421,10 +39421,10 @@ Development movies Ġidentities Ġpromptly -اÙĨ +ان Ġante Ġ"',' -åı£ +口 impse Ġyap TypeName @@ -39432,7 +39432,7 @@ TypeName Ġassociates HEME -empty -Ġت +Ġت olvers Ġpistol Scoped @@ -39463,13 +39463,13 @@ _library Offer located Ġ(_, -âĢľHe +“He ĠOwners )).Ċ Ġbri .Admin ktion -лÑİÑĩ +люч Ġerotici Cancelled Ġagr @@ -39488,7 +39488,7 @@ UserData ĉsort Ġcongrat Ġdioxide -да +да .area ĠJoshua ĠKoch @@ -39509,16 +39509,16 @@ ERTICAL ieron Ġlinger /doc -ź +ź ĠAlready asio -Ġ--Ċ +Ġ-- Ġabbrev ĠAtom him ĠINSERT sun -âĻª +♪ CONNECT erator ĠManning @@ -39526,7 +39526,7 @@ erator gas =>' Ġqueryset -;}čĊ +;}Ċ ĠPopulation utedString resident @@ -39548,7 +39548,7 @@ SOR venience ĠJong Ġwhistle -ĠзнаÑĩ +Ġзнач Ġlending Ġdestructive ĠonDelete @@ -39571,7 +39571,7 @@ exchange .fast Samples London -'])čĊ +'])Ċ ĠIonic Ġpesso ĠKnights @@ -39600,14 +39600,14 @@ textBox Ġsperm Ġcough Hor -âĢĻS +’S .ComponentResourceManager Ġregulator Ġpartnerships /projects trys ĠLaser -⣩ +⟩ ĠFunk Ġunconscious Ġcrust @@ -39636,7 +39636,7 @@ YNAM ==( .UUID _KERNEL -Ġvidé +Ġvidé Ġpq ĠQtGui ĠVarious @@ -39646,12 +39646,12 @@ _patch ĠFail Ġsurviving ("${ -ĠĠĠĠĠĠĠčĊ +ĠĠĠĠĠĠĠ ĠimageUrl .wordpress sources ĉglVertex -âĢĻa +’a Ġescol RARY ĠSnake @@ -39683,7 +39683,7 @@ permit ĠImmigration Ġpathname ffective -âĻĢâĻĢ +♀♀ Ġexams -event ĠTill @@ -39695,7 +39695,7 @@ _traits ĠorderBy Ġsunt ĠNicholas -ز +ز Ġsunny iners Ġaccessibility @@ -39740,14 +39740,14 @@ _fin Ġsheriff -invalid ĠFULL -Ġпод +Ġпод elas "strings ĠRepresentatives surface resolved htdocs -)):čĊ +)):Ċ Ġpressures Ġnorms Ġpla @@ -39760,19 +39760,19 @@ orida Ġdesar compact _LANG -åIJĪ +合 opoly _rad ĠSTDMETHOD Lazy -ĠĠĠĉ +ĠĠĠ ..., (web ĠPont Ġetwas Ġupward _hat -Ġ],ĊĊ +Ġ], ĠbaseUrl Ġworrying -addon @@ -39830,7 +39830,7 @@ persist `s supplier (Form -¡ +¡ _so ĮĢ ĠLegion @@ -39853,9 +39853,9 @@ uminium opa .schedule smtp -à¸ķ +ต urry -ük +ük goog _signature .into @@ -39863,7 +39863,7 @@ _signature Ġhomeowners ĠNSURL ĠPAC -ĠĠĠĠĠĠĠĠĠĠĠĠĊĊ +ĠĠĠĠĠĠĠĠĠĠĠĠ >')Ċ enh Ġincap @@ -39885,8 +39885,8 @@ Discount legacy ĠCapture Ġarising -Ġ");ĊĊ -ÑĪиб +Ġ"); +шиб ĠInfinity Advertisements ĠComing @@ -39905,7 +39905,7 @@ _INC _without .keySet Ġreceivers -æĸ¹æ³ķ +方法 (mem ĠHorizontal Ġcocktail @@ -39934,9 +39934,9 @@ ERVED hea plist _PLUGIN -Ñģл +сл .lookup -á»ģ +ề Ġenlarg Ġpiss Ham @@ -39951,10 +39951,10 @@ Handlers Ġstacks .getFullYear =[];Ċ -车 +车 ,V (split -Ñĥнк +унк Ġbakeca Ġ~/. pez @@ -39971,7 +39971,7 @@ _STEP ĠCorrect rina Ġconcaten -å®ŀ +实 ():ĊĊ Ġunanim lli @@ -40025,8 +40025,8 @@ _COMPONENT endi IMUM ĠGF -ç»Ħ -âĢĶthat +组 +—that bk Mozilla Ġdefenders @@ -40074,7 +40074,7 @@ obia Expose Ġ'} .COLOR -ĠÑĩиÑģ +Ġчис Ajax Ġthru Movies @@ -40088,7 +40088,7 @@ ModelProperty (required ĠPrel eled -æĵįä½ľ +操作 .TRA MAS Ġrealised @@ -40102,7 +40102,7 @@ vidia Ln Ġlust Asc -ĉĉĉĉĉĉĉĠ +Ġ isle -care _INV @@ -40134,7 +40134,7 @@ _frequency ĠuseRef ĠGrove ĠXia -ĠÃŃ +Ġí essenger -cost .fc @@ -40189,7 +40189,7 @@ qi cken Ġsocialist ĠInvoice -ĠпÑĢо +Ġпро %", ennen Ġvivo @@ -40223,7 +40223,7 @@ arius Twig Ġswept -tool -ÄIJ +Đ chapter -grade Ġcuriosity @@ -40252,7 +40252,7 @@ Travel _SEL -pop Ġemission -âĢĻ.ĊĊ +’.ĊĊ .switch otions .photo @@ -40279,7 +40279,7 @@ Css ĠFIELD .relu Ġlis -ìļ° +우 .RELATED Ġlok ĠFlip @@ -40299,10 +40299,10 @@ erro Ġstretched .HasValue ;;;;;;;; -çīĪ +版 Ġfinals .getChildren -Ġ--}}Ċ +Ġ--}} ĠCowboys ĠEdinburgh ĠPlaza @@ -40315,7 +40315,7 @@ _noise .Objects Expressions Ġanthrop -'))čĊ +'))Ċ )." criptive Ġsalmon @@ -40325,7 +40325,7 @@ rho Ġexplores ĠAlgorithm CharArray -à¸Ħ +ค _PACKET JE "]];Ċ @@ -40335,7 +40335,7 @@ Backing reich ĠZion /gr -ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ Motion ĠTribune Ġcritically @@ -40498,7 +40498,7 @@ udeau \">'+ -åĿĢ +址 acency (URL _half @@ -40651,12 +40651,12 @@ _RELEASE (ST AMIL rike -Ġ(){Ċ +Ġ(){ (sprintf ĠAccounts ĠVIEW ĠAj -ãĤ° +グ Ġwhisk Ġidi Ġrode @@ -40669,7 +40669,7 @@ Jobs ĉoffset ĠAhmed ĠTaliban -Ġèİ·åıĸ +Ġ获取 Ġinjected .Authentication _linear @@ -40705,7 +40705,7 @@ ehen ĠShi _BYTES REA -ản +ản _CONNECTION Gateway ĠTravis @@ -40718,14 +40718,14 @@ ermo kor Female _attach -ĠìĤ¬ìļ© +Ġ사용 Ġpoz ==============Ċ (symbol ĠSector __)ĊĊ _padding -ï¼ļ" +:" Ġfabs Ġranged setName @@ -40771,7 +40771,7 @@ appear ĠProceed Ġê¸ anked -из +из ansk ĠHang ĠCler @@ -40806,19 +40806,19 @@ dbh Ġisso ĠACE startDate -ĠbÄĻd +Ġbęd ĠAUTHOR ĠGlobe Ġinsects _Al ushing -è®° +记 /Home ĠLocalDate needed hesive Ġillusion -äºĮ +二 Ġtrat xo /detail @@ -40838,7 +40838,7 @@ Devices ĠOg ĠSEL udiant -Ġ++;Ċ +Ġ++; Ġexplanations occo Ġdiets @@ -40851,7 +40851,7 @@ GD Ġcarbohydr Ġfried ĠEmployment -ìŀ¥ +장 ĠLeonard _${ quares @@ -40865,7 +40865,7 @@ quares otional ĠLite ĠKosten -Ġó +Ġó _attachment orphic Ġdamit @@ -40888,7 +40888,7 @@ illet Ġcentro Ġdims _initialize -ık +ık ĠCenters REN Ġevolutionary @@ -40914,7 +40914,7 @@ ucumber ĠRide Ġzoo Ġchecker -åIJĮ +同 =C Ġgrit ");// @@ -40945,24 +40945,24 @@ Christian ĉGame Ġinstrumental Animations -дал +дал ĠMoses -ĉĉčĊĉĉčĊ +ĉĉĉĉ zs kte -ä¸ļ +业 _DIST bitmap dB Ġpersistence -ÑĢоÑģ +рос $l Bron Ġ{| _chart ĠConsum Ġhemp -Ġ"))Ċ +Ġ")) Ġattackers Ġknowledgeable Ġcet @@ -40976,10 +40976,10 @@ aptops Ġinstructed ĠRus benhavn -Ġин +Ġин Sports Ġonset -æĿĥ +权 .RED _si ĠPST @@ -41012,7 +41012,7 @@ Guide Ġpsycho Ġglam Elim -ädchen +ädchen _plain ĠSau -four @@ -41035,8 +41035,8 @@ parsed Ġtrademarks ĠCoordinate ĠViv -Ġ//}ĊĊ -Ġaprès +Ġ//} +Ġaprès .getPosition (KeyCode ĠSilva @@ -41073,7 +41073,7 @@ Interfaces Autom Ġlw ĠNW -Ġ&&čĊ +Ġ&& Ġproblema ĠManufacturing limits @@ -41090,10 +41090,10 @@ crollView __$ Ġsidewalk (that -ื +ื [q grammar -Ġtë +Ġtë quito Ġspiral extended @@ -41144,7 +41144,7 @@ national .Timer ĉsrc elsen -åħ¶ +其 Ġcommunicating ĠQuiz Ġteng @@ -41173,7 +41173,7 @@ accessToken Ġplum adir .setMessage -Ġï¼Į +Ġ, Ġswallow ĠLamp Ġqw @@ -41189,19 +41189,19 @@ turned .gridColumn Ġpuppy Ġpam -Ġ){čĊ +Ġ){ Ġinviting Ġfrench vim Ġwrapping -Ġ#-}Ċ +Ġ#-} ([- Early Ġshiny .faces Ġrebell abcdef -ält +ält Ġestimation phys losures @@ -41231,7 +41231,7 @@ otti .flag ĉrs _generic -Ġ```Ċ +Ġ``` ACHINE Ġmein (Application @@ -41252,14 +41252,14 @@ umericUpDown Ġterrifying .MODE ĠGW -ár +ár Ġfic Ġcommitments -tech ĠLiquid opez zheimer -aña +aña -media (animated _goal @@ -41287,7 +41287,7 @@ _TRI Ġrecovering ĠGLOBAL .Par -Ġ/>;Ċ +Ġ/>; Ġmarble ulators ĠCycle @@ -41297,7 +41297,7 @@ _metric _CLOCK _Button Harry -è¿Ľ +进 Ġstrains ĠAppBar ĠChan @@ -41310,8 +41310,8 @@ lemen ĠDuncan ĠMint -video -া -ówn +া +ówn ĠEMPTY Ġstacked ĠHA @@ -41319,7 +41319,7 @@ _cut Ġwherein ĠWays (counter -è¯ķ +试 FormGroup Ġblew courses @@ -41335,14 +41335,14 @@ getRepository _devices layui Ġhalfway -Ġfranç +Ġfranç Ġtuning OA _Node arde Ġfierce licted -#čĊ +#Ċ Ġbreakthrough ĠErik Ġbride @@ -41360,7 +41360,7 @@ _accuracy _epi queda /org -éªĮ +验 Ġcompte ))[ Outside @@ -41399,12 +41399,12 @@ resume Ġseus .(* Ġamino -Ġ[]);ĊĊ +Ġ[]); Ġprovoc nox .GetEnumerator =======Ċ -æĸĻ +料 _scroll Ġfilmed ĠSoci @@ -41414,7 +41414,7 @@ Vote "But _RC Animal -ÂĢ +€ ibile Ġawaken orest @@ -41422,12 +41422,12 @@ inja ĠIvan (Command Ġ***** -η +η Ġkvinder /helpers _cases tg -ìĦ¸ +세 Registered ĉpass _digits @@ -41448,10 +41448,10 @@ _equ Ġcommittees istema +". -ÃŃan +ían mant Ġsoutheast -ï¼ĮĊ +,Ċ dialogs PROJECT charger @@ -41498,12 +41498,12 @@ forcing Ġriot prox THON -ización +ización ĠNI rost Ġdispro _instances -ï¼ĮâĢľ +,“ ographer endas ĠIsaac @@ -41523,7 +41523,7 @@ dao _complex Ġgravel _DIP -ément +ément ĠAri _bitmap .quit @@ -41532,7 +41532,7 @@ _bitmap Ġrespiratory Ġrebound DefaultValue -ãĥŃ +ロ Ġcommits .tests _fr @@ -41576,7 +41576,7 @@ Organization Ġconstitutes Ġquand (chunk -"/>čĊ +"/>Ċ ĠLakes mainwindow Carthy @@ -41584,14 +41584,14 @@ spin (csv :red -commerce -ู +ู Ġdiscovering Ġeco _fac inceton ĠGreens jwt -ص +ص ĠBroncos ĠGoods (GTK @@ -41601,7 +41601,7 @@ jwt went ĠNatal Ġenthusiastic -á»į +ọ FN /database Catalog @@ -41618,7 +41618,7 @@ exam @Controller uliar .getParent -Ġ";ĊĊ +Ġ"; :size issors Ġfis @@ -41634,7 +41634,7 @@ umbling -address Ġheroin YTE -ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ Friend Ġave ĠPNG @@ -41663,13 +41663,13 @@ ellites Ġfrankly ĠErf CEL -ĠpaÃŃs +Ġpaís Ġhedge Ġlatent ĠIRQ ĠHerald ĠPrec -ë³´ +보 .TEXT Salary Ġautumn @@ -41681,7 +41681,7 @@ Mor Ġjournals _IT ĠTrou -ä¼ł +传 HasColumnName Composite Ġspice @@ -41691,7 +41691,7 @@ _CODES iona Ġnuestra oct -ĠĠĠĠĊĠĠĠĠĊĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠ (parameter Ġstudios ĠprojectId @@ -41723,7 +41723,7 @@ Exc Ġsph Ġcheating andro -ÃŃo +ío Ġprince oire ĠDestination @@ -41760,7 +41760,7 @@ Milliseconds ĠAmanda _np jury -Ġkön +Ġkön Ġtherapist Ġhomosexual ĠDrake @@ -41789,7 +41789,7 @@ WillAppear aji Ġreproductive ĠCAS -ãģ£ +っ FUNC ĠRuth )+( @@ -41837,10 +41837,10 @@ hos /sys colm (pool -Ġestán +Ġestán ĠPending -emás -Ġktóry +emás +Ġktóry ));ĊĊĊ transactions Ġwield @@ -41851,7 +41851,7 @@ _ss Ġprisoner .ReadAll Ġbesch ---;čĊ +--;Ċ Ġcrisp _SCAN Ġae @@ -41871,7 +41871,7 @@ between ĠLt _inline ethyl -¼ +¼ _packages Ġbarrels _he @@ -41930,7 +41930,7 @@ onen Ġbah Ġmolecule Rad -è¿° +述 ANCH -background -agent @@ -41938,7 +41938,7 @@ ANCH :boolean Ġtide erializer -_;čĊ +_;Ċ Fee **) ergy @@ -41980,7 +41980,7 @@ nal particle Ġsignaling Ġaccessory -ĉĉĉĉĉĉĠĠ +ĠĠ Ġviele ĠNoah -ag @@ -42003,18 +42003,18 @@ ampton Ġplaintiff );Ċ invest .*ĊĊ -Ġtélé +Ġtélé Ġsuperf Ġcascade DTD @@ -42037,12 +42037,12 @@ stackoverflow ĠRaiders Ġ#' olicies -ìľ¼ë¡ľ +으로 emap Ġkj Ġquota ĠGardens -ë²Ī +번 ĠAngels Ġoft Ġlowercase @@ -42065,7 +42065,7 @@ arked GTK Ġmaid :disable -éĽĨ +集 ĠPf Ġalbeit openh @@ -42101,7 +42101,7 @@ inp Ġrang .Dynamic _Render -аÑĤа +ата Waiting ĠWik Ġoverwhelmed @@ -42120,7 +42120,7 @@ uggestion Ġeros _tri Ġissuing -Ġhá +Ġhá Ġisolate Overflow ,E @@ -42133,12 +42133,12 @@ _zip ")}Ċ Ġamat ĠCisco -ĠnÃ¥ +Ġnå PLEX Ġsei foto .toJson -å¤ļ +多 ĠKlein Ġlibc Ġminers @@ -42190,12 +42190,12 @@ ridden etz ĉfirst Ġmilestone -æĹł -ÑĥÑī +无 +ущ (success < -Ġ"}Ċ +Ġ"} timezone Ġeer maxcdn @@ -43284,9 +43284,9 @@ _reverse ĠWinds ')));Ċ Ġcongest -ģı +ğı Ġprolonged -è¿Ļ +这 ĠCrossAxisAlignment LEEP ĠVALID @@ -43313,8 +43313,8 @@ HomePage completion Ġsupplying YPES -ável -åζ +ável +制 (click \Contracts /questions @@ -43322,11 +43322,11 @@ YPES AMS .mesh Ġ'");čĊ +>");Ċ dropIfExists ĠBeg _HAL @@ -43699,7 +43699,7 @@ _HAL Ġinstitute veis Ġfft -Ãģ +Á Ġzoekt analy ĠHomeland @@ -43756,7 +43756,7 @@ contro |max _guid levation -наÑı +ная (undefined ĠDDR Ġshootings @@ -43765,7 +43765,7 @@ ENDOR Ġaveraging Ġgreeted Ġtheaters -ое +ое ĠdB Ġgst Ġdefinite @@ -43805,7 +43805,7 @@ Las _PROJECT ucceeded olu -ÄŁi +ği Ġpersonalities Ġreshape Ġenclosed @@ -43813,7 +43813,7 @@ olu Ġtutorials Ġexploded _DIRECTORY -åĨħ容 +内容 Ġcanon Ġrecognise PAD @@ -43836,7 +43836,7 @@ qua Ġdwell ossa Ġrewarded -ий +ий (topic _partition Ġ__________________ @@ -43844,7 +43844,7 @@ Keywords ĠFranco Lite Ġnaken -Ġза +Ġза OBJECT Ġcrafts ĠSwap @@ -43864,12 +43864,12 @@ Upgrade -second Ġneph .pres -ìŀħ +입 .seq Ġpadded "? jl -ãĥ¬ +レ ')R -üssen +üssen efs Ġuncover Ġlud @@ -43980,7 +43980,7 @@ youtube :Int ĠHindi ĠCAT -Ġع +Ġع rar omore /per @@ -43991,7 +43991,7 @@ omore ĠEF rounded ĠPlatinum -ĠвÑģе +Ġвсе .coords .Device /item @@ -44020,13 +44020,13 @@ _MANAGER Ġinfring ĠERA _party -Ñij +ё Ġinici _Request Ġmiracle ĠcancelButton Spy -ató +ató Ġpolish ĠNicole .displayName @@ -44035,7 +44035,7 @@ ató RouterModule Ġstared IDER -ÑĥнкÑĨи +ункци Ġnota $arr pecified @@ -44057,13 +44057,13 @@ Binder +", .att ĠEthi -Ġcódigo +Ġcódigo ='\ .lines (Of -å°Ĩ +将 missible -Ġvé +Ġvé Ġacoustic Ġcrafting nit @@ -44078,7 +44078,7 @@ _wr Ġdns ĠReferences Ġundertaken -Ġkøbenhavn +Ġkøbenhavn Ġchai ĠCroat _Log @@ -44089,14 +44089,14 @@ _med Ġcostumes ĠRequires affle -çĬ¶æĢģ +状态 -Semit elaide -еÑĤод +етод Ġpestic Ġdra DOCUMENT -Ġ...čĊ +Ġ... }`}Ċ ĠAuction ĠDock @@ -44119,8 +44119,8 @@ BERS Ġinstallations .Async Ġrays -=âĢĿ -;ččĊ +=” +;Ċ .crypto _dbg ĠEnumerable @@ -44156,7 +44156,7 @@ _adjust (glm StatusBar filepath -?âĢĻ +?’ Ġdetective Ġunserer ĠTibet @@ -44175,9 +44175,9 @@ allax Ġlungs -current ĠBooking -åĪĹ表 +列表 Ġenjoyment -र +र JA typed .Btn @@ -44224,7 +44224,7 @@ Quant Ġsibling Ġcass -vous -öt +öt _PATTERN _SECTION estimated @@ -44288,7 +44288,7 @@ spaces Ġsegundo _strlen .Firebase -å¤Ħ +处 Ġmentioning \( ĠValve @@ -44321,14 +44321,14 @@ cookies Ġfichier Ġenforced ABB -noÅĽci +ności _ALLOW Ġrecruited Ġexpenditure -night ĠassertNotNull _execute -Ġد +Ġد INDEX _FMT Ġrescued @@ -44356,9 +44356,9 @@ _until Ġhypoth cheduling translator -ĠÐľ +ĠМ Rom -ãĢijĊĊ +】ĊĊ ĠXamarin Ġviolating .anchor @@ -44370,7 +44370,7 @@ ADVERTISEMENT Ġblond ĠPAT .glob -Ġè¾ĵ +Ġ输 Ġsplitting Ġunsubscribe Ġatmospheric @@ -44398,7 +44398,7 @@ coholic ĠassertFalse ĠPatrol ensem -ÅĤÄħ +łą ¨¡ WIDTH ĠRescue @@ -44414,7 +44414,7 @@ _BODY Ġroi cust (tc -ï¼ģ");Ċ +!");Ċ Ġfestivals Ġperformers Ġclimbed @@ -44481,7 +44481,7 @@ Typed statusCode Ġ()) ĠMW -Ġмож +Ġмож ROSS .buf Ġfairy @@ -44516,7 +44516,7 @@ etical ĠForums ĠCharacters _met -Ġìĭľ +Ġ시 Ġkings achie ĠLambda @@ -44529,7 +44529,7 @@ andex ĠHip ĠPrincip StartDate -ĠãĢĮ +Ġ「 tres Ġ&# .MaxValue @@ -44556,7 +44556,7 @@ tracking _BIND ITOR -urlencoded -ĠÑħ +Ġх ĠTrinity Ġtraps Ġ|- @@ -44634,7 +44634,7 @@ calls %@", respect (mp -é«ĺ +高 -if Ġcushion obot @@ -44653,7 +44653,7 @@ feat pies NotBlank (term -ÈĽi +ți _Params .normalize Bullet @@ -44713,7 +44713,7 @@ _backup .getResource Ġdefinitive .EditText -ĠsÃŃ +Ġsí .CONT ĠPLAYER .cards @@ -44727,10 +44727,10 @@ WebDriver å£ ĠNF _clip -åŃIJ +子 Ġinteracting .tmp -Ġ'''ĊĊ +Ġ''' Ġdee Ġfrost "]))Ċ @@ -44791,7 +44791,7 @@ William ĠDocuments .Xaml Ġbatches -éģĵ +道 ĠReleased Tail COOKIE @@ -44814,14 +44814,14 @@ itary _rotation Repositories Ġtart -ĠÑģв +Ġсв Ġmappings èª Cu Cycle Ġbun ĉlua -ãĥī +ド Ġ((! Ġcollectively ĠCond @@ -44869,7 +44869,7 @@ dig Ġapproximate _dirs liga -ÅĤad +ład Ġkindness Ġcontre ĠEVERY @@ -44914,7 +44914,7 @@ _bins ĠOUR rieb Pros -ĠwiÄĻ +Ġwię "d Ġasyncio zeigen @@ -44926,13 +44926,13 @@ Chinese Ġunsuccessful ĠSeahawks ORG -竳 +章 Ġprofessionally ĠCoupon -åŃĹæ®µ +字段 Convention Ġpolym -æīĭ +手 Ġsalvation Ġengineered ĠWrest @@ -44954,7 +44954,7 @@ pac efe ĠBuddha --------------Ċ -åºĵ +库 (forKey Ġlumin Ġ(? @@ -44986,12 +44986,12 @@ _popup ĠDirectors {: [R -ĠÑįлеменÑĤ +Ġэлемент Ġplat Ġdirecting -ä¸ī +三 ĠGilbert -â̦.ĊĊ +….ĊĊ .qml Ġthereafter Ġdisposition @@ -45075,10 +45075,10 @@ ooled .comments Ġchin ).* -Ġили +Ġили tgl udos -ĠdÃŃas +Ġdías chai .program Ġpsz @@ -45131,13 +45131,13 @@ yclopedia Ġê³ =# Ġlectures -âĢľIn +“In Ġ!_ Ġhb ĠVendor Recently _notes -æıIJ示 +提示 "My HeadersHeight _SO @@ -45179,7 +45179,7 @@ EPS izoph exual IRD -ä»İ +从 Ġlith Ġsanitize Ġfeminine @@ -45237,7 +45237,7 @@ archical _quote Ġfoolish Ġcomprising -Ġоп +Ġоп -selected vf maid @@ -45276,7 +45276,7 @@ CLIENT [z Ġbothered ĠBBQ -ças +ças _examples _FIN ĠwhiteColor @@ -45293,7 +45293,7 @@ jsx ĠMohammed .Job -toggler -ĠполÑĮзоваÑĤ +Ġпользоват ardon Ġnewborn Ġnaval @@ -45308,7 +45308,7 @@ Flight -sk ĠElle _exports -ĠÑı +Ġя ĠIH izophren Ġíģ @@ -45318,7 +45318,7 @@ _primary Ġsystemic Ġdiferentes INCT -Ġ''ĊĊ +Ġ'' $q WidgetItem clide @@ -45329,9 +45329,9 @@ agrid ĠMongoDB inte Ġapprent -ÂŃing +­ing .Db -ĠÃĤ +Ġ hammer ='';Ċ Ġbrokers @@ -45347,14 +45347,14 @@ _band .readFileSync ====== .rx -?čĊ +?Ċ Ġmetaphor Ti conte Ġdebit Ġcontempt CppType -æĶ¯ +支 FormField ratio osopher @@ -45391,7 +45391,7 @@ asan _moves -margin Ġingen -³³³ + Ġprojet Ġotra Ġbras @@ -45424,13 +45424,13 @@ accion nas Ġtrouver NONE -"},čĊ +"},Ċ Ġftp WithIdentifier polate FileInfo Ġpursued -ĠĠĠĠčĊĠĠĠĠčĊ +ĠĠĠĠĠĠĠĠ DESCRIPTION }*/Ċ FromNib @@ -45447,7 +45447,7 @@ imson /library (_ -ÐĴÑĭ -.*;čĊ +Вы +.*;Ċ =j -cor Son @@ -45624,14 +45624,14 @@ endDate Ġencontrar _skill houette -!čĊ +!Ċ .weather Ġemphasized -å®¶ -ĠÑģпиÑģ +家 +Ġспис ĠCompiler (android -ĠâĢº +Ġ› .turn Ġsuppression _calls @@ -45640,7 +45640,7 @@ _calls .hex ĠBills ĠRSA -ÏĤ +ς ĠEscape ementia Ġfrontend @@ -45708,7 +45708,7 @@ imized .orientation .compareTo Ġmassaggi -ĠìľĦ +Ġ위 Ġelbow Ġantioxid undreds @@ -45736,7 +45736,7 @@ Guess Ġaccumulation Ġsimulated ĠDrivers -Ġdés +Ġdés curring Ġelephant Ġadvertised @@ -45746,7 +45746,7 @@ SHIFT Ġanc Ġwardrobe Ingredients -Ġ||čĊ +Ġ|| ippy Ġantibiotics avings @@ -45757,7 +45757,7 @@ avings removed orderby Ġcres -ocê +ocê Ġpym ĠCircular @index @@ -45783,7 +45783,7 @@ _fire ĠBEST Ġstressful _PE -æĹ¥æľŁ +日期 ĠDataFrame ĉInteger _Print @@ -45805,7 +45805,7 @@ fontsize ĠDepression Ġacne ***ĊĊ -ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ .contents ynth ĠStraight @@ -45835,9 +45835,9 @@ upa (( survey Ġíĺ @@ -47462,7 +47462,7 @@ _prepare stin ĠHeath .PrimaryKey -ĠâĨIJ +Ġ← ĠLocalDateTime Ġcooperative Learning @@ -47532,7 +47532,7 @@ _changes (Task _subset ĠTRANS -åĬĽ +力 ĠScout -popup Ġsmoked @@ -47578,7 +47578,7 @@ GAME ĠBrock Ġoccupy Ġdecorations -ánd +ánd Ġcot Ġparan Disk @@ -47604,7 +47604,7 @@ ypress //{Ċ Ġsyll PTR -åŃĺåľ¨ +存在 Ġdidnt Mailer Ġacademics @@ -47630,17 +47630,17 @@ ettle slashes -pound recht -ات +ات .onClick Ġnord -ständ +ständ _when UTERS icc Ġcapsule ĠWid Marc -ุ +ุ rored UGE LOUD @@ -47683,12 +47683,12 @@ Kevin Ġannounces ])* reservation -Ġæķ° +Ġ数 Ġprejudice ĠStringComparison Ġbeard -win -ĠSão +ĠSão ĉms jal ĠEarn @@ -47721,10 +47721,10 @@ _tf UIImageView populate bab -ĠÏĥ +Ġσ [++ Ġopioid -Ġ##Ċ +Ġ## dtype ĠStarts ('/') @@ -47733,14 +47733,14 @@ dtype Ġredundant ĠEssential Ġscrapy -Ġим +Ġим acl Ġcrear ĠBend Ġrelieve -room wife -ĠvÃł +Ġvà ĠQPoint Ġquasi ĠmethodName @@ -47756,16 +47756,16 @@ Located Ġhypert Ġdj ĠUserInfo -ĠåĪĽå»º +Ġ创建 \xb (sim -Ġ==Ċ +Ġ== Ġstaging Ġdrastically -åѦ +学 lords .less -ведиÑĤе +ведите ĠBucket ĠMam .term @@ -47778,7 +47778,7 @@ precio itat Lex _infos -İ +İ .other VELO Ġponder @@ -47807,25 +47807,25 @@ _fact .chunk Ġlent Ġaller -Ġà¤ķ +Ġक _idle Ġadmissions JSONArray Ġvibration .helpers -å¤ĸ +外 Ġhen john -ĠìĥĿ +Ġ생 Ġjudgement Ġgeen terra ^{ ĠIz -Ġcâ +Ġcâ instances Ġthreatens -Ġmüssen +Ġmüssen KindOfClass Ġstorytelling _demo @@ -47834,10 +47834,10 @@ Privacy hift ĠYi esor -íķł +할 ensitivity .Writer -à¸Ĥ +ข District .getJSONObject Impro @@ -47870,17 +47870,17 @@ Calculate Ġdislike ĠPreferences _EXTERNAL -è°ĥ +调 Ġdodge -æľįåĬ¡ +服务 .names .drawImage _prom uckland Ġ<$> -ız +ız /site -项 +项 rophe Ġcompelled Ġlaptops @@ -47917,7 +47917,7 @@ urai _major Ġentirety ingerprint -ços +ços /account ĉright ursos @@ -47928,7 +47928,7 @@ _INSERT EdgeInsets Ġcolonies .IM -ĉĠĉ +Ġ ROAD CCCC placing @@ -47953,7 +47953,7 @@ Decorator shield ressive .did -请è¾ĵåħ¥ +请输入 Ġshutter Dam Ġparenting @@ -47983,7 +47983,7 @@ Ho Ġcaster ĠclientId Ġpdb -ëıĦ +도 itic ĠGameState ĠnewItem @@ -48001,7 +48001,7 @@ _VECTOR Ġfrontier Ġlacked JUST -ĠÑįÑĤ +Ġэт Ġtint ĠMystery dateTime @@ -48062,7 +48062,7 @@ constraint terminate ĠBelgian assium -Ġ]čĊ +Ġ] Systems ousedown .bus @@ -48076,7 +48076,7 @@ daily ĠCoding (destination #$ -ujÄħ +ują Ġemergence _para _INCLUDE @@ -48104,7 +48104,7 @@ lies ĠMeaning ĠOFFSET ensing -ĠfrÃ¥n +Ġfrån .localStorage Ġë© ({});Ċ @@ -48113,12 +48113,12 @@ decoder Ġdismant Ir Ġinsurg -Ġ'':Ċ -.âĢĿĊ +Ġ'': +.”Ċ Ġbrunette .assets _NETWORK -à¸Ĭ +ช nym _Source \Tests @@ -48182,7 +48182,7 @@ IDDLE ĠLon Extras Transient -веÑĢ +вер /module Ġendurance _tex @@ -48243,7 +48243,7 @@ Conversion $/,Ċ [o ĠConservatives -ÏĢ +π lates _Exception Ġmeilleur @@ -48269,8 +48269,8 @@ Blocking Sin (author Ġcortex -'){čĊ -ï¼īï¼Į +'){Ċ +), Ġdumped ĠShut ĠKeyEvent @@ -48289,7 +48289,7 @@ Bitcoin Anyway ayette Ġ'[' -Ãłnh +ành mgr Ġcorrelated Ġnause @@ -48307,7 +48307,7 @@ Placeholder =text ĠManaging ocalypse -åĮĹ +北 _mag fld âij @@ -48333,10 +48333,10 @@ CAM _sentence Ġbites ĠonFailure -ĠâĪĪ +Ġ∈ Kim .gender -Ġλ +Ġλ Ġ[. "]); landing @@ -48385,11 +48385,11 @@ SETTING ĠAttach /run .tabs -ĠogsÃ¥ +Ġogså Brown .DATE Ġfos -åŃĹ符 +字符 Wood -three herited @@ -48406,7 +48406,7 @@ _digit (ne budget CSI -ĠìķĦ +Ġ아 ĠASP GroupId _COUNTER @@ -48418,18 +48418,18 @@ Sharper ĠFriendly ulet -command -ĠÐł +ĠР cycles ĠWaste Ġtapped ĉBuffer -âĢĶin -ĠĊĠĠĊ +—in +ĠĠĠ ĠIdeal ĠCandy _Syntax -êt -ìĿĮ +êt +음 above ĠNazis Ġfst @@ -48441,7 +48441,7 @@ wik ĠDeserialize ourg .attrib -ï¼ļĊĊ +:ĊĊ ĠWins .eql Ryan @@ -48489,7 +48489,7 @@ _animation Ġtha taboola ĠTHC -ÃŃculo +ículo Ġglowing Ġhonors bstract @@ -48500,8 +48500,8 @@ ITES /Desktop ĉglm Ġzinc -ática -Ġ<<Ċ +ática +Ġ<< VML ĠUnlimited vre @@ -48514,18 +48514,18 @@ travel Firestore Ġemailed _FLASH -ĠfÃ¥r -âĺħâĺħ +Ġfår +★★ Ġ:] Hum .reserve -üm +üm Ġkostenlose ĠSCP utan ĠGore Ġchats -/>čĊ +/>Ċ .getResources Ġlump _consts @@ -48604,7 +48604,7 @@ _FREQ (verbose Ġlongitud ĠCharter -ê·¸ +그 Ġbundles .ignore umbo @@ -48620,7 +48620,7 @@ _Checked Ġfax ĠGust itchens -Ġ))ĊĊ +Ġ)) Ġremarkably /XML -remove @@ -48642,7 +48642,7 @@ coeff ĠHealing Ġordin !), -Ġ'',čĊ +Ġ'', (md ĠSask čĊ -Ġrá +Ġ=> +Ġrá Ġblunt ĠImageIcon ifik @@ -48680,17 +48680,17 @@ $app Ġmedio Ġgranting Ġtslint -ĠMö +ĠMö (figsize Ġhurricane Ġlifes -ĠÃĦ +ĠÄ rocessing _standard -option '))) Ġvacant -å·¥ +工 ĠHollow handleChange Ġdivider @@ -48719,7 +48719,7 @@ pections them -hop Ġscreenshots -Ġ/*!Ċ +Ġ/*! Ġconversions Ġnormalization (configuration @@ -48736,7 +48736,7 @@ itre Ġshoppers Ġdisadvantage Ġliking -ç¬ij +笑 Ġunderstandable SEE Ġhoy @@ -48744,7 +48744,7 @@ SEE Ġconfer Ġnowrap ĠVern -,čĊčĊ +,ĊĊ imestep LayoutManager à· @@ -48753,7 +48753,7 @@ PLETED Japan Ġinduce Ġå¯ -озв +озв _ENDPOINT .horizontal Ġaccelerated @@ -48797,7 +48797,7 @@ licate forces .extra .authenticate -вод +вод ¡° ĠforControlEvents Ġsenha @@ -48805,13 +48805,13 @@ forces Ġminist ĠPreference ĠTelegraph -Ñĥп +уп strpos Ġillnesses Ġpigs ĠgetIntent Sol -Ġ¡ +Ġ¡ (cpu [prop screens @@ -48826,7 +48826,7 @@ GroupBox /comments Ġnumbered Ġbroadcasting -çĽij +监 .nativeElement .mu ĠupdatedAt @@ -48875,7 +48875,7 @@ MARY mens motion Ġsampled -âĢľThat +“That iday quipment getInt @@ -48886,23 +48886,23 @@ uned Ġ})( mmm ĠRising -ä»» +任 Ġunemployed xfa .follow -ĉĉĉĉĠĠĠĠĠĠ +ĠĠĠĠĠĠ slt .Phone Ġknives Ġeve onClick -]))čĊ +]))Ċ ĠWitness ĉNS ĠEOS ĠStefan ĠPriest -âĢĶwhich +—which GetString .By Ġupstairs @@ -48919,20 +48919,20 @@ odor Ġdecir ĠnewName +. -缸 +相 igslist ĠGithub Ġsuccessive racial Ġenviron -éªĮè¯ģ +验证 Ġredirected TOTAL Ġgrabbing ĠLance Ġforfe _CB -å¾® +微 Elapsed _way (DialogInterface @@ -48949,7 +48949,7 @@ _shell Ġbeginner ("%. (tool -Ġнов +Ġнов :init (API ĠMorrison @@ -48983,7 +48983,7 @@ dbcTemplate Ġacclaimed Histor Ġmeses -über +über ĠRenew Ġgras ĠEk @@ -49024,7 +49024,7 @@ Provides Ġnutrient .Timestamp IZATION -åĨĮ +册 Ġfats ĠXxx ctica @@ -49041,16 +49041,16 @@ hb Portal ĠBread .which -ÂŃt +­t asInstanceOf Ġjobject ĉlength _MT -;">čĊ +;">Ċ _EXIST Ġmaternal REL -Ġê²½ìļ° +Ġ경우 hee Ġlayouts ĠLap @@ -49080,7 +49080,7 @@ spir Tips ĠIoT ĠGan -èģĶ +联 Ġbiases Ġconsultants pled @@ -49088,7 +49088,7 @@ _ht associated ],ĊĊ Ġdelightful -ĠÑĤек +Ġтек Helvetica (load -expand @@ -49104,12 +49104,12 @@ _registration Ġrestoring Ġunreal OVER -ĉĊĉĊĉĊ +ĉĉĉ ATS _probe Ġdivisor .updateDynamic -å¹³ +平 Produces stamp .jboss @@ -49121,7 +49121,7 @@ Martin ĠPassed clarations hel -аÑĩ +ач ĉcopy -bin zan @@ -49138,14 +49138,14 @@ outlined HexString +c .Public -ẩ +ẩ Ġconveyor ĠEB Ġselects Ġknocking ĠCec IBUTES -owaÄĩ +ować gatsby *v entropy @@ -49183,13 +49183,13 @@ ummies _growth Ġaktiv Ġgrouping -å¢ŀ +增 _truth -åIJ¬ +听 todos iset TexCoord -ätt +ätt ĠZur roys _MAGIC @@ -49204,7 +49204,7 @@ Lu Ġfick undles _loaded -ие +ие Poll ritic ELY @@ -49215,7 +49215,7 @@ ELY scrollView Ġcommunist /problems -}čĊčĊčĊčĊ +}ĊĊĊĊ ,o Ġudp Ġobese @@ -49236,7 +49236,7 @@ getContext (rel ĠBrotherhood )`Ċ -è§£ +解 .Information OutOfRangeException ĠSek @@ -49268,7 +49268,7 @@ idenav ĠdateFormat .joda veys -Ġ).ĊĊ +Ġ). Ġcareg ĠParallel _translation @@ -49289,7 +49289,7 @@ Upon ĉmove (Un Ġqr -׾ +ל _beta Ġskies ĉme @@ -49321,7 +49321,7 @@ skins Ġdansk ĠPrinceton acist -Ġ());Ċ +Ġ()); tracks imonial adecimal @@ -49334,7 +49334,7 @@ cantidad Ġseekers Ġplausible tier -еж +еж Ġrapper ĠMana ĠHttpStatusCode @@ -49350,7 +49350,7 @@ Instagram TERM ĠCbd ĠParagraph -Ġtravés +Ġtravés Ġconstructing Ġswal Ġpige @@ -49365,15 +49365,15 @@ imas ĠAo ĠPerez ĠDAL -Ġëĭ¤ +Ġ다 Ġdivis StoryboardSegue ĠModify -ĠÃľber +ĠÜber _OVERRIDE .pem untos -Ġespañ +Ġespañ Ġ{? ĠPAY _ipv @@ -49386,7 +49386,7 @@ _Reg -Javadoc ĉload ĠLikewise -اÙħ +ام UNE .sem xcb @@ -49399,16 +49399,16 @@ _sleep Ġcue ĠQByteArray Ġcorrupted -ĠDé +ĠDé Ġimped GetName Ġinaccurate Ġsober -ее +ее Ġbarcode --){Ċ inki -Ġép +Ġép Ġdri ĠALT >>>>>>>> @@ -49448,7 +49448,7 @@ _regular .tabControl Ġpuppet Ġutilization -Ġâĸł +Ġ■ Ġsucces Ġlamps _proj @@ -49501,7 +49501,7 @@ XS Ġinfectious ĠMons _LOOP -Ġzurück +Ġzurück Ġobtener /repos Vel @@ -49515,9 +49515,9 @@ recursive _chip ominated ĠNit -âĢĶto +—to ĠBuddh -омеÑĢ +омер ĠMAG ĠCHE _den @@ -49541,13 +49541,13 @@ igel ,W ADS (panel -ì²´ +체 itating .palette Ġmosquito Ġtego (parseInt -Ġdespués +Ġdespués promise Ġwij typescript @@ -49565,7 +49565,7 @@ _thresh ĠFreeman ,DB _rw -çŃī +等 Ub _statistics ="">< @@ -49579,7 +49579,7 @@ ylko Forest Ġheadset Ġgallon -ÑĢем +рем Ġwithdrawn ĠCandidate Ġmelting @@ -49591,7 +49591,7 @@ mime Ġthirst $return memberof -еб +еб ĠHttpServletRequest (ob _Result @@ -49616,7 +49616,7 @@ chedules Ġparity ĉdest ĠDoors -čĊĉčĊ +ĉ _dimension Ġaload .StoredProcedure @@ -49662,7 +49662,7 @@ _iso updates halb udiante -ë¡Ŀ +록 Ġnaive ĠPeg ĠLounge @@ -49705,8 +49705,8 @@ ottenham ascular Ġtrailers ĠCLOSE -ами -âĢĻai +ами +’ai ĠGain wor Ġplanner @@ -49717,7 +49717,7 @@ xlabel HF Viol .BASELINE -еÑĤÑģÑı +ется ĠRotate Ġtxn :bold @@ -49732,7 +49732,7 @@ their hecy getActiveSheet .clients -ãģį +き _hide [word Cb @@ -49754,7 +49754,7 @@ atif ĠFeld _BINARY itous -à¹Ħ +ไ Ġflashing -sided Ġcontradiction @@ -49775,7 +49775,7 @@ Fall LinearLayout /photos Ġfeather -Ġ|čĊ +Ġ| Downloads .StartsWith Ġ//# @@ -49800,7 +49800,7 @@ _ATTACH ĠSolomon jin ografia -öl +öl _design culated ĠLuna @@ -49826,7 +49826,7 @@ itational Ġadvancement ouro ĠPredicate -å¾Ĺ +得 eria ĠPierce orio @@ -49843,7 +49843,7 @@ _SENSOR ĠReach _decoder .Exp -ĠÑĤак +Ġтак pill ,Q ĠGrill @@ -49853,7 +49853,7 @@ pill Ġmileage Ġecological ]]);Ċ -ĠÂŃ +Ġ­ subplot acad ĠTrying @@ -49884,7 +49884,7 @@ Som .non Ġ'). ĠgetView -ạn +ạn prus Matthew Ġsia @@ -49915,7 +49915,7 @@ ypass -more JD ĠBurns -'>čĊ +'>Ċ .Dependency .QueryString .Owner @@ -49978,14 +49978,14 @@ MLS Ġstimulate Partition Ġmun -óm +óm erala -account .Binary -cé +cé Ġseize connections -ĠĊĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠ ĠDiagnostic VISIBLE ĠRuns @@ -50005,9 +50005,9 @@ arrison shaw blood AJOR -æĽ´æĸ° +更新 ĠMuse -æĶ¶ +收 Ġretaining otte Ġmosque @@ -50040,7 +50040,7 @@ tera ĠRack ĠcurrentIndex Ġallen -Ġç͍æĪ· +Ġ用户 ĠCubs [X _SEQ @@ -50097,7 +50097,7 @@ licher ĠSimmons Taking ĠClaims -Ġdifférent +Ġdifférent ActivityResult Ġsns éĢīæĭ @@ -50121,7 +50121,7 @@ ovich uish (bot Ġgallons -ĠRé +ĠRé ĠSaid ĠSTDMETHODCALLTYPE aising @@ -50177,7 +50177,7 @@ getPath ĠGlobals ')}}' kinson -Ġкол +Ġкол ognitive _li Ġimminent @@ -50459,7 +50459,7 @@ KK ĠEugene _PWM roi -ĠâĹı +Ġ● ĠHamburg .Must Ġaxe @@ -50468,11 +50468,11 @@ enef ĠSpecies ĠStress Ġawhile -ĠбÑĥд +Ġбуд Ġwithstand ĠDecoder _inventory -Ġ{ččĊ +Ġ{ Ġtgt Ġrailroad WASHINGTON @@ -50481,7 +50481,7 @@ NST -phone ,U Ġexercising -ụ +ụ _PIXEL avors iterated @@ -50507,9 +50507,9 @@ _Ent _ini ĠEuropeans ĠBelle -åij½ +命 )[' -åºĶ +应 ĠUseful .reference ()", @@ -50523,9 +50523,9 @@ _LAYER wk ĠNoise ###ĊĊ -Ġpréc +Ġpréc otle -ÑĤе +те auf ibal Ġconquer @@ -50536,7 +50536,7 @@ OAD ĠFI .fixture Ġterse -ĠĠĠĠĉĉĉĉ +ĠĠĠĠ Ġsanctuary ugi ĠComparator @@ -50616,7 +50616,7 @@ _TCP ĠOx _CHO Ġprostituerte -Ġvä +Ġvä Ġsito Ġconstituents ĠContinued @@ -50634,7 +50634,7 @@ verification ~= .hp Iterable -Ñĭе +ые atori Ġctr Rx @@ -50643,19 +50643,19 @@ dag .pin Ġpseud Ġinvo -ÑģÑĤÑĢ +стр _pix -为空 +为空 Ġsworn -âĢĶor +—or _registry Ġdisasters ĠROI -ĠâĢķ +Ġ― aktu forest beiten -âĢĶI +—I ueva egt Ġspikes @@ -50672,7 +50672,7 @@ ROOM ĠPASSWORD Cookies .El -á»Ń +ử ĠBert Ġhashed icester @@ -50684,7 +50684,7 @@ otope -Americ ĠMatthews URAL -âĢľ, +“, Summer fos _CONTAINER @@ -50693,8 +50693,8 @@ _ACK _disp _Re Ġfacile -аÑĪ -ĠìķĬ +аш +Ġ않 Ġeben Ġsprink ĠQuint @@ -50719,7 +50719,7 @@ HEL (asset Ġhvor FileSystem -<>();čĊ +<>();Ċ ocoder ĠCannon )x @@ -50744,21 +50744,21 @@ $IFn atsu Josh Equality -Ġ}()Ċ +Ġ}() _less ĠRatio ĠCats ĠStern Monster Ġmercury -ühr +ühr Ġplusieurs .deserialize scopy .False )animated ĠExperts -Ġ""){Ċ +Ġ""){ .When seealso .unpack @@ -50788,7 +50788,7 @@ EDIUM ĠHistoric _holder ĠMarines -Ġtä +Ġtä .Light quirer asonry @@ -50797,7 +50797,7 @@ divider _fb restricted ĠEverybody -Não +Não Ġknot ĠTwitch Ġhallway @@ -50811,7 +50811,7 @@ played Ġbatting _dl Ġcomedian -Ġév +Ġév ĠDEM ĠEden :white @@ -50826,7 +50826,7 @@ nz _BGR ValidateAntiForgeryToken _air -âĢľWhen +“When Ġglfw ĠConversation _TOTAL @@ -50835,11 +50835,11 @@ _TOTAL Ġiterable ĠPASS Ġadvertise -Ġmöglich +Ġmöglich /train ĠVolkswagen Ġcreepy -Ġ")čĊ +Ġ") QUENCE Ġaltar Ġedits @@ -50871,13 +50871,13 @@ _usb ĠMartha ĠCatholics ĠMond -обÑĭ +обы _absolute Ġashamed ponsors tal Ġsadness -Ġpuò +Ġpuò Fade -preview ĠRequests @@ -50916,7 +50916,7 @@ aille FLOAT Ġmak Ġgcc -âķIJâķIJ +══ ("~/ SCRIPTOR Ġtonnes @@ -50928,17 +50928,17 @@ Pred .githubusercontent (print ĠHole -çľĭ +看 adget Ġprompts Ġgenetically ĠHod Ġvertically _controls -ÑģÑĤан -"){čĊ +стан +"){Ċ $title -Ġ}),ĊĊ +Ġ}), Ġstatewide ĠCorrespond ĠAttr @@ -50948,7 +50948,7 @@ ElementType Ġfamilia (article Ġblat -ÂłĊ +Ċ ĠglGet ĠReceiver Ġ%- @@ -50968,9 +50968,9 @@ assist Ġchatting ĠMant Ġ%@ -Ġ"");ĊĊ +Ġ""); Ġdgv -Ġíķ¨ +Ġ함 .repeat _Message Ġadvisers @@ -50982,7 +50982,7 @@ Misc Ġtrimmed ĠAck VertexAttrib -ç´¢ +索 uates .mysql Ġdestin @@ -50994,7 +50994,7 @@ _AREA __*/ []( ĠsignIn -Äij +đ xr ahir .firestore @@ -51016,8 +51016,8 @@ _alignment .Args dob Ġvn -âĨĴ -Ġdü +→ +Ġdü ĠXM Ġalumni Ġhone @@ -51052,7 +51052,7 @@ _define Sorted 'in Logs -á»ĩn +ện Ġnylon Dump Imagine @@ -51075,7 +51075,7 @@ bursement Province ĠApproved ()<< -ória +ória usch ĠJenny arrants @@ -51091,7 +51091,7 @@ etting Ġgoddess ripe Ġmuscular -ĉĉĉĉĉĉĉĉĠ +Ġ ĠHugo Ġmejores loid @@ -51107,7 +51107,7 @@ BOOLE ahi -END Ġiff -ób +ób ĠBruno rowsable ĠPoison @@ -51122,8 +51122,8 @@ urtles eful $value fed -ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĊ -èµĦ +ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ +资 (cm Ġvulnerabilities Ġ[(' @@ -51133,13 +51133,13 @@ entieth Ġpraying Claims Ġkaufen -né +né Ġpoisoning collections ĠinitState ĠSeverity Ġcontention -ĠĊĉĊ +Ġ .controllers structured ictim @@ -51152,7 +51152,7 @@ Produto .multi Ġgrape beg -æŁ¥è¯¢ +查询 Ġquartz ĠRomance ĠMidwest @@ -51161,7 +51161,7 @@ beg icont .unshift otre -Ġún +Ġún ipple Ġsuburb uali @@ -51172,9 +51172,9 @@ Voice IDA ĉpost ptoms -vé +vé ĠIngredients -öff +öff .operator Ġ<<= lastic @@ -51187,7 +51187,7 @@ _READY nowledge Ġappended ungan -âĢĻen +’en ĠLoren publisher ĠMG @@ -51222,15 +51222,15 @@ INavigationController Ġlawful Ġlore ĠLoads -ĠÑģозд +Ġсозд .promise ĠFaces .Platform .getLocation Ġtroubling -ĠvÃŃdeo +Ġvídeo ĠFeaturing -产 +产 qed ĠonBind Ġtoddler @@ -51251,13 +51251,13 @@ _excel ĠLuxury _KEYS .npy -ů +ů Ġforehead -β +β Ġendangered /the pipeline -ű +ű neo Explore SpecWarn @@ -51306,13 +51306,13 @@ _sat ĠGHz LONG ();čĊčĊ +>();ĊĊ ĠSims ĠDates ĉConnection @@ -51642,7 +51642,7 @@ NUMBER ĠGlad sla ĠReb -еÑģÑĤво +ество arbon /controllers Slots @@ -51650,7 +51650,7 @@ Slots FULL uire @student -à¹īà¸Ń +้อ Translator Ġpreferably chemistry @@ -51668,7 +51668,7 @@ iges kov examples __(" -Ġкак +Ġкак ĠBoris (dx spr @@ -51676,7 +51676,7 @@ spr atoon ĠHarley icamente -âĸĪâĸĪâĸĪâĸĪ +████ evity usher .VisualStudio @@ -51720,7 +51720,7 @@ creative Ġmadness OrNil Ġhin -Åĵ +œ .GetKey _console "Our @@ -51743,7 +51743,7 @@ _ACCEPT Ġforc ĠFrau Ġthresh -ĠÏĢ +Ġπ (BASE _Open Wunused @@ -51764,10 +51764,10 @@ $params .Warning Ġneutrality zhou -ÑĢаÑī +ращ akter ĠConstructors -ÃĵN +ÓN ĠProgressive ĠBurger Ġincurred @@ -51782,7 +51782,7 @@ _Block _employee ĠPepper laughter -ãĥĸ +ブ '];?> ='. (rename @@ -51793,7 +51793,7 @@ _gap xampp OMIC Ġpedido -Ġdévelop +Ġdévelop __(/*! _od were @@ -51812,10 +51812,10 @@ ueil central ĠABOUT Ġincorporating -Ġ-----------------------------------------------------------------------------Ċ +Ġ----------------------------------------------------------------------------- _widgets ĠsystemFontOfSize -ört +ört /jpeg ĠSMTP (browser @@ -51825,11 +51825,11 @@ _AVAILABLE Ġincorporates /android yx -å¸ĥ +布 _lab Ġleaking ĠHint -ünchen +ünchen .Scale Ġfireworks ĠlParam @@ -51874,7 +51874,7 @@ bh Ġtreffen Ġunderline _nums -íķľëĭ¤ +한다 )v usize Ġdisappearance @@ -51915,7 +51915,7 @@ squeeze ptune _FRONT mh -ĠìĥĿìĦ± +Ġ생성 RunWith Ġturnout siblings @@ -51935,7 +51935,7 @@ Expanded -ro ĠWorldwide ĠCork -ól +ól Lim Ġdenn Pretty @@ -51944,11 +51944,11 @@ Triangle Featured (Common _eff -Ġ""čĊ -Ỽi +Ġ"" +ới _LINEAR ĠRica -Ġcafé +Ġcafé Ġappell Ġniveau Ġ&, @@ -51988,7 +51988,7 @@ becue Ġchaotic Ġani ĠAnnie -ưá»Ŀ +ườ .dx disconnect Ġarchived @@ -52052,7 +52052,7 @@ Entr ĉdst ":- .mon -Ġ(ĊĊ +Ġ( Ġcapita ĠinitComponents Ġswords @@ -52076,7 +52076,7 @@ _partial ĉperror ĠReligious -"+ -ĉĉĉĠĠĠĠĠĠĠĠĠĠĠ +ĠĠĠĠĠĠĠĠĠĠĠ ĠSecrets (normal ACES @@ -52126,8 +52126,8 @@ orde Ġ'"+ Ġpumping ĠClement -ÃĥO -åİŁ +ÃO +原 >n ĠstrSql jdbc @@ -52155,13 +52155,13 @@ HM (mouse ĠReSharper -routing -ĠØ´ +Ġش Ġjointly ĠFamil // ĠXV -à§į +্ Ġseja .visual kker @@ -52654,13 +52654,13 @@ yyval NSObject Ġescaping ĠNullable -Ġhä +Ġhä want Eliminar ĠCLLocation ĠreuseIdentifier BufferSize -ÃŁer +ßer ĠAsked ']],Ċ Ġshields @@ -52700,9 +52700,9 @@ _xlabel (gulp ĠButtons ĠBroker -çĽijåIJ¬ +监听 $email -ÙIJ +ِ Ġclassics compose (bs @@ -52718,13 +52718,13 @@ afort Ġroyalty serializer ieux -ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĊ +ĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ execution ĠviewController Ġrepro .pe Ġcapitalize -åĩ» +击 Ġtunnels .DATA pirit @@ -52798,8 +52798,8 @@ nice (lista à± ployment -ãģ¾ãģŁ -好 +また +好 subst ']][' abol @@ -52851,7 +52851,7 @@ arend .createTextNode ĠRAW Ġinflux -çī© +物 Tok -board Recording @@ -52883,7 +52883,7 @@ natural :w .safe Ġtowels -áºŃt +ật .gsub ë£ inqu @@ -52954,10 +52954,10 @@ Aud ĠKiller .getAbsolutePath _caps -Å« +ū Ġsubstrate .assertIn -ìķĦ +아 Ġthyroid ĠDeluxe Ġfactorial @@ -53020,7 +53020,7 @@ UCCEEDED Ġhyster esthes distinct -Ãły +ày ĠCombo ĉsf ĠâĬ @@ -53035,10 +53035,10 @@ Dur ĠRELEASE -dollar .Commit -Ġkhông +Ġkhông Ġlaunder .=" -Ġæĸĩ +Ġ文 Ġbye .GetKeyDown Ġgio @@ -53052,14 +53052,14 @@ reuse Ġjerseys _MP patibility -Ġ设置 +Ġ设置 Ġreplacements Ġprecedence Ġbuffered .bs _GREEN brain -ách +ách availability ĠETF Ġfret @@ -53087,7 +53087,7 @@ _WEB Ġdrm Ġcolumnist Markup -ĠaquÃŃ +Ġaquí ĠDiane Ġcw ĠTick @@ -53102,7 +53102,7 @@ InBackground Iteration defaultValue attention -ĠÑĢабоÑĤ +Ġработ Ġwaiver Ġproduit ĠGradient @@ -53151,7 +53151,7 @@ posium ĠScripture bbb uchs -ä¸įèĥ½ +不能 .BigDecimal sizes _solver @@ -53159,7 +53159,7 @@ _From _joint Ġpathlib Ġgears -ĠÑĦоÑĢм +Ġформ Ġconceal Ġdifferentiate NN Ġtand @@ -53584,10 +53584,10 @@ bis ĠBroken ĉfs ĠmView -аÑĨии +ации -facebook Ġcaches -ãĢĤãĢĤĊĊ +。。ĊĊ ĠORM ĠDistrib ĠSceneManager @@ -53611,12 +53611,12 @@ edido apollo Ġutmost openssl -ĠhÃ¥ +Ġhå ('& .Standard Ġdistraction ifax -ĠëķĮ +Ġ때 those ispens vak @@ -53636,16 +53636,16 @@ _All Ġuninstall Ġfluor liquid -Ġlá +Ġlá Ġfrightening adan ĠAUT Ġtattoos Ġpropagation .translation -ÐŁÑĢ +Пр _scheduler -ãĢĤâĢľ +。“ Ġcairo ĠHttpClientModule ĠNDP @@ -53678,7 +53678,7 @@ Attrs ĠCurve LAST ĠSCRIPT -ê³¼ +과 Malloc .groupby ĠLeslie @@ -53708,7 +53708,7 @@ _FACTOR .fb Ġounce _saved -Ġر +Ġر Ġdeeds ĠDolphins Ġbuen @@ -53749,7 +53749,7 @@ YouTube NCY Club Ein ---čĊ +--Ċ Ġconstrained ETwitter YG @@ -53770,7 +53770,7 @@ StartPosition Ġproblemas _INTERRUPT ĠSTORE -模 +模 iliated ĠRPM [temp @@ -53813,7 +53813,7 @@ atk _mb .Div Ġendeavor -Ġ(£ +Ġ(£ Ġclutter Ġurgency Ġinstructors @@ -53824,7 +53824,7 @@ cem .ft Stephen Ron -ãģĻãĤĭ +する sci ĠAtmos Ġcatering @@ -53837,7 +53837,7 @@ xdf .showToast OOT -result -Ìģ +́ Ġghosts ĠBuen ĠRider @@ -53893,12 +53893,12 @@ abras ĠBey getClient eken -Ġ'''čĊ +Ġ''' Wiki (HttpStatus Stretch ĠGest -Ġíķĺ +Ġ하 Ġentitlement Ġdoen blogs @@ -53906,7 +53906,7 @@ blogs "Oh ĠSummon ĠBackbone -Ġgü +Ġgü getColumn ĠWINAPI ĉva @@ -53919,24 +53919,24 @@ elsinki _Per flies Ġincompet -Ġjuż +Ġjuż ()% -Ġ---Ċ +Ġ--- umas ĠOlder Ġdisputed _REQUIRE .matmul unken -ä¹ĭ -ãģĭãĤī +之 +から Ġttl underscore ĠPatricia Ġtaper Ġseiner Ġsaya -åı° +台 ieri .secret Ġxor @@ -53947,13 +53947,13 @@ ieri Ġdavid oulos ĠPetersburg -Ġ"",čĊ +Ġ"", shelf -water -byte -ĠобÑĬекÑĤ +Ġобъект Ġstirring -ìĹ´ +열 Ġcompt ĠPotential RAFT @@ -53972,13 +53972,13 @@ _BT renched cors (itemView -ĠgÃ¥ +Ġgå .Contact ViewChild indsay configs Duplicate -â̦I +…I zyst (todo .RemoveAt @@ -53993,7 +53993,7 @@ Lee ĠYun Ġunderwent icom -Ġ''){Ċ +Ġ''){ -pol flammatory Mutation @@ -54019,8 +54019,8 @@ ESSAGES ĉraise oultry ĉmodule -æĺ¾ç¤º -nÃŃ +显示 +ní Ġyrs Ġphysic -platform @@ -54031,13 +54031,13 @@ nÃŃ ĠIncontri Scenario Amb -Ġpremière +Ġpremière /articles ĠMajority CLUSIVE onor -ĠhabÃŃa -å·ŀ +Ġhabía +州 Ġmidi ĠLac .findIndex @@ -54068,7 +54068,7 @@ MYSQL