@@ -7168,8 +7168,17 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
7168
7168
// Remainder = Sum % Constant
7169
7169
// This is based on "Remainder by Summing Digits" from Hacker's Delight.
7170
7170
//
7171
- // For division, we can compute the remainder, subtract it from the dividend,
7172
- // and then multiply by the multiplicative inverse modulo (1 << (BitWidth / 2)).
7171
+ // For division, we can compute the remainder using the algorithm described
7172
+ // above, subtract it from the dividend to get an exact multiple of Constant.
7173
+ // Then multiply that extact multiply by the multiplicative inverse modulo
7174
+ // (1 << (BitWidth / 2)) to get the quotient.
7175
+
7176
+ // If Constant is even, we can shift right the dividend and the divisor by the
7177
+ // number of trailing zeros in Constant before applying the remainder algorithm.
7178
+ // If we're after the quotient, we can subtract this value from the shifted
7179
+ // dividend and multiply by the multiplicative inverse of the shifted divisor.
7180
+ // If we want the remainder, we shift the value left by the number of trailing
7181
+ // zeros and add the bits that were shifted out of the dividend.
7173
7182
bool TargetLowering::expandDIVREMByConstant(SDNode *N,
7174
7183
SmallVectorImpl<SDValue> &Result,
7175
7184
EVT HiLoVT, SelectionDAG &DAG,
@@ -7188,7 +7197,7 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
7188
7197
if (!CN)
7189
7198
return false;
7190
7199
7191
- const APInt & Divisor = CN->getAPIntValue();
7200
+ APInt Divisor = CN->getAPIntValue();
7192
7201
unsigned BitWidth = Divisor.getBitWidth();
7193
7202
unsigned HBitWidth = BitWidth / 2;
7194
7203
assert(VT.getScalarSizeInBits() == BitWidth &&
@@ -7209,12 +7218,20 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
7209
7218
if (DAG.shouldOptForSize())
7210
7219
return false;
7211
7220
7212
- // Early out for 0, 1 or even divisors.
7213
- if (Divisor.ule(1) || Divisor[0] == 0 )
7221
+ // Early out for 0 or 1 divisors.
7222
+ if (Divisor.ule(1))
7214
7223
return false;
7215
7224
7225
+ // If the divisor is even, shift it until it becomes odd.
7226
+ unsigned TrailingZeros = 0;
7227
+ if (!Divisor[0]) {
7228
+ TrailingZeros = Divisor.countTrailingZeros();
7229
+ Divisor.lshrInPlace(TrailingZeros);
7230
+ }
7231
+
7216
7232
SDLoc dl(N);
7217
7233
SDValue Sum;
7234
+ SDValue PartialRem;
7218
7235
7219
7236
// If (1 << HBitWidth) % divisor == 1, we can add the two halves together and
7220
7237
// then add in the carry.
@@ -7229,6 +7246,27 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
7229
7246
DAG.getIntPtrConstant(1, dl));
7230
7247
}
7231
7248
7249
+ // Shift the input by the number of TrailingZeros in the divisor. The
7250
+ // shifted out bits will be added to the remainder later.
7251
+ if (TrailingZeros) {
7252
+ LL = DAG.getNode(
7253
+ ISD::OR, dl, HiLoVT,
7254
+ DAG.getNode(ISD::SRL, dl, HiLoVT, LL,
7255
+ DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl)),
7256
+ DAG.getNode(ISD::SHL, dl, HiLoVT, LH,
7257
+ DAG.getShiftAmountConstant(HBitWidth - TrailingZeros,
7258
+ HiLoVT, dl)));
7259
+ LH = DAG.getNode(ISD::SRL, dl, HiLoVT, LH,
7260
+ DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl));
7261
+
7262
+ // Save the shifted off bits if we need the remainder.
7263
+ if (Opcode != ISD::UDIV) {
7264
+ APInt Mask = APInt::getLowBitsSet(HBitWidth, TrailingZeros);
7265
+ PartialRem = DAG.getNode(ISD::AND, dl, HiLoVT, LL,
7266
+ DAG.getConstant(Mask, dl, HiLoVT));
7267
+ }
7268
+ }
7269
+
7232
7270
// Use addcarry if we can, otherwise use a compare to detect overflow.
7233
7271
EVT SetCCType =
7234
7272
getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), HiLoVT);
@@ -7260,45 +7298,45 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
7260
7298
SDValue RemL =
7261
7299
DAG.getNode(ISD::UREM, dl, HiLoVT, Sum,
7262
7300
DAG.getConstant(Divisor.trunc(HBitWidth), dl, HiLoVT));
7263
- // High half of the remainder is 0.
7264
7301
SDValue RemH = DAG.getConstant(0, dl, HiLoVT);
7265
7302
7266
- // If we only want remainder, we're done.
7267
- if (Opcode == ISD::UREM) {
7268
- Result.push_back(RemL);
7269
- Result.push_back(RemH);
7270
- return true;
7271
- }
7272
-
7273
- // Otherwise, we need to compute the quotient.
7274
-
7275
- // Join the remainder halves.
7276
- SDValue Rem = DAG.getNode(ISD::BUILD_PAIR, dl, VT, RemL, RemH);
7277
-
7278
- // Subtract the remainder from the input.
7279
- SDValue In = DAG.getNode(ISD::SUB, dl, VT, N->getOperand(0), Rem);
7280
-
7281
- // Multiply by the multiplicative inverse of the divisor modulo
7282
- // (1 << BitWidth).
7283
- APInt Mod = APInt::getSignedMinValue(BitWidth + 1);
7284
- APInt MulFactor = Divisor.zext(BitWidth + 1);
7285
- MulFactor = MulFactor.multiplicativeInverse(Mod);
7286
- MulFactor = MulFactor.trunc(BitWidth);
7287
-
7288
- SDValue Quotient =
7289
- DAG.getNode(ISD::MUL, dl, VT, In, DAG.getConstant(MulFactor, dl, VT));
7290
-
7291
- // Split the quotient into low and high parts.
7292
- SDValue QuotL = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, Quotient,
7293
- DAG.getIntPtrConstant(0, dl));
7294
- SDValue QuotH = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, Quotient,
7295
- DAG.getIntPtrConstant(1, dl));
7296
- Result.push_back(QuotL);
7297
- Result.push_back(QuotH);
7298
- // For DIVREM, also return the remainder parts.
7299
- if (Opcode == ISD::UDIVREM) {
7303
+ if (Opcode != ISD::UREM) {
7304
+ // Subtract the remainder from the shifted dividend.
7305
+ SDValue Dividend = DAG.getNode(ISD::BUILD_PAIR, dl, VT, LL, LH);
7306
+ SDValue Rem = DAG.getNode(ISD::BUILD_PAIR, dl, VT, RemL, RemH);
7307
+
7308
+ Dividend = DAG.getNode(ISD::SUB, dl, VT, Dividend, Rem);
7309
+
7310
+ // Multiply by the multiplicative inverse of the divisor modulo
7311
+ // (1 << BitWidth).
7312
+ APInt Mod = APInt::getSignedMinValue(BitWidth + 1);
7313
+ APInt MulFactor = Divisor.zext(BitWidth + 1);
7314
+ MulFactor = MulFactor.multiplicativeInverse(Mod);
7315
+ MulFactor = MulFactor.trunc(BitWidth);
7316
+
7317
+ SDValue Quotient = DAG.getNode(ISD::MUL, dl, VT, Dividend,
7318
+ DAG.getConstant(MulFactor, dl, VT));
7319
+
7320
+ // Split the quotient into low and high parts.
7321
+ SDValue QuotL = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, Quotient,
7322
+ DAG.getIntPtrConstant(0, dl));
7323
+ SDValue QuotH = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, Quotient,
7324
+ DAG.getIntPtrConstant(1, dl));
7325
+ Result.push_back(QuotL);
7326
+ Result.push_back(QuotH);
7327
+ }
7328
+
7329
+ if (Opcode != ISD::UDIV) {
7330
+ // If we shifted the input, shift the remainder left and add the bits we
7331
+ // shifted off the input.
7332
+ if (TrailingZeros) {
7333
+ APInt Mask = APInt::getLowBitsSet(HBitWidth, TrailingZeros);
7334
+ RemL = DAG.getNode(ISD::SHL, dl, HiLoVT, RemL,
7335
+ DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl));
7336
+ RemL = DAG.getNode(ISD::ADD, dl, HiLoVT, RemL, PartialRem);
7337
+ }
7300
7338
Result.push_back(RemL);
7301
- Result.push_back(RemH );
7339
+ Result.push_back(DAG.getConstant(0, dl, HiLoVT) );
7302
7340
}
7303
7341
7304
7342
return true;
0 commit comments