@@ -327,22 +327,28 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
327327// Broadcasts input tensor based on the broadcastToShape.
328328LogicalResult torch_to_linalg::broadcastToGivenShape (
329329 Operation *op, PatternRewriter &rewriter, Value input,
330- SmallVector<Value> broadcastToShape, Value &result ,
331- SmallVector<bool > useBroadcastToShape) {
330+ SmallVector<Value> broadcastToShape, RankedTensorType broadcastType ,
331+ Value &result, SmallVector<bool > useBroadcastToShape) {
332332 RankedTensorType inputType = input.getType ().cast <RankedTensorType>();
333+ int64_t inputRank = inputType.getRank ();
334+ int64_t outputRank = broadcastToShape.size ();
335+ ArrayRef<int64_t > outputShape = broadcastType.getShape ();
333336 SmallVector<int64_t > inputShape =
334337 makeShapeTorchCompatible (inputType.getShape ());
335- if (broadcastToShape. size () < inputShape. size () ) {
338+ if (outputRank < inputRank ) {
336339 return rewriter.notifyMatchFailure (
337340 op, " invalid shape: broadcastToShape size must not be smaller than the "
338341 " size of the input shape" );
339342 }
340343
341344 Type elementType = inputType.getElementType ();
342345 Location loc = op->getLoc ();
343- SmallVector<Value > outShape;
346+ SmallVector<OpFoldResult > outShape;
344347 bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes (rewriter);
345348
349+ // Vector indicating broadcasted status when assuming strict symbolic shapes.
350+ SmallVector<bool > broadcastedStatus;
351+
346352 // Create affine map and shapes for tensor initialization.
347353 SmallVector<AffineExpr> outExpr;
348354 Value zero =
@@ -351,10 +357,39 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
351357 rewriter.create <arith::ConstantOp>(loc, rewriter.getIndexAttr (0 ));
352358 Value oneIndex =
353359 rewriter.create <arith::ConstantOp>(loc, rewriter.getIndexAttr (1 ));
354- size_t diff = broadcastToShape.size () - inputShape.size ();
355- for (size_t i = 0 ; i < broadcastToShape.size (); i++) {
360+ size_t diff = outputRank - inputRank;
361+ bool hasDynamicNumpyBroadcast = false ;
362+ for (size_t i = 0 , e = outputRank; i < e; i++) {
356363 Value shapeValue = broadcastToShape[i];
357364 size_t j = i - diff;
365+ bool isDynamic = i >= diff && inputShape[j] == kUnknownSize ;
366+
367+ // Inherit static output shapes if present.
368+ if (outputShape[i] != ShapedType::kDynamic ) {
369+ outShape.push_back (rewriter.getIndexAttr (outputShape[i]));
370+ if (i < diff) {
371+ if (outputShape[i] < 0 ) {
372+ return rewriter.notifyMatchFailure (
373+ op, " invalid shape: negative values not allowed in new broadcast "
374+ " dimensions" );
375+ }
376+ continue ;
377+ }
378+ if (isDynamic) {
379+ hasDynamicNumpyBroadcast = true ;
380+ } else if (inputShape[j] != outputShape[i] && inputShape[j] != 1 ) {
381+ return rewriter.notifyMatchFailure (
382+ op, " invalid shape: static mismatch in input and output broadcast "
383+ " shapes" );
384+ }
385+
386+ // If strict symbolic shapes are assumed and the input shape is dynamic,
387+ // we can assume that dim is not broadcasted.
388+ broadcastedStatus.push_back (inputShape[j] != outputShape[i] &&
389+ !isDynamic);
390+ continue ;
391+ }
392+
358393 if (i < diff) {
359394 if (!elideDynamicBroadcastCheck) {
360395 Value isValid = rewriter.create <arith::CmpIOp>(
@@ -374,24 +409,80 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
374409 Value select = rewriter.create <arith::SelectOp>(
375410 loc, isNegative, oneIndex, castIntToIndex (rewriter, loc, shapeValue));
376411 outShape.push_back (select);
377- } else {
378- // Case of dynamic input dimension wherein the shape to broadcast will
379- // yield us the dimension size of the output.
380- Value dim = getDimOp (rewriter, loc, input, j);
381- if (!useBroadcastToShape.empty ()) {
382- if (useBroadcastToShape[i])
383- dim = castIntToIndex (rewriter, loc, broadcastToShape[j]);
412+ broadcastedStatus.push_back (true );
413+ continue ;
414+ }
415+
416+ // Case of dynamic input dimension wherein the shape to broadcast will
417+ // yield us the dimension size of the output.
418+ Value dim;
419+ if (!useBroadcastToShape.empty () && useBroadcastToShape[j]) {
420+ dim = castIntToIndex (rewriter, loc, broadcastToShape[i]);
421+ if (isDynamic) {
422+ hasDynamicNumpyBroadcast = true ;
384423 }
385- outShape.push_back (dim);
424+ if (!elideDynamicBroadcastCheck) {
425+ Value isValid = rewriter.create <arith::CmpIOp>(
426+ loc, arith::CmpIPredicate::sge, shapeValue, zero);
427+ rewriter.create <cf::AssertOp>(
428+ loc, isValid,
429+ rewriter.getStringAttr (
430+ " unimplemented: dynamic negative broadcast sizes" ));
431+ }
432+ } else {
433+ dim = getDimOp (rewriter, loc, input, j);
386434 }
435+ // We can safely assume this dimension is not broadcasted with strict
436+ // symbols.
437+ broadcastedStatus.push_back (false );
438+ outShape.push_back (dim);
387439 }
388440
389- Value outTensor = rewriter.create <tensor::EmptyOp>(
390- loc, getAsOpFoldResult (outShape), elementType);
441+ Value outTensor =
442+ rewriter.create <tensor::EmptyOp>(loc, outShape, elementType);
443+
444+ // If we know there are no ? -> ? broadcasted dims, or we are assuming
445+ // strict symbols, we can safely use standard linalg style broadcasting
446+ // semantics.
447+ if (!hasDynamicNumpyBroadcast || elideDynamicBroadcastCheck) {
448+ // If no dims are broadcasted and the rank doesn't change, we can just fold
449+ // the op away entirely.
450+ if (!llvm::any_of (broadcastedStatus, [](bool b) { return b; }) &&
451+ inputRank == outputRank) {
452+ result = rewriter.create <tensor::CastOp>(loc, outTensor.getType (), input);
453+ return success ();
454+ }
455+
456+ SmallVector<AffineExpr> inputExprs;
457+ for (int64_t i = 0 , e = inputRank; i < e; ++i) {
458+ if (broadcastedStatus[i]) {
459+ inputExprs.push_back (rewriter.getAffineConstantExpr (0 ));
460+ continue ;
461+ }
462+ inputExprs.push_back (rewriter.getAffineDimExpr (i + diff));
463+ }
464+
465+ SmallVector<AffineMap> indexingMaps = {
466+ AffineMap::get (outputRank, 0 , inputExprs, rewriter.getContext ()),
467+ rewriter.getMultiDimIdentityMap (outputRank)};
468+ SmallVector<utils::IteratorType> iteratorTypes (
469+ outputRank, utils::IteratorType::parallel);
470+ result = rewriter
471+ .create <linalg::GenericOp>(
472+ loc, outTensor.getType (), input, outTensor, indexingMaps,
473+ iteratorTypes,
474+ [&](OpBuilder &b, Location loc, ValueRange args) {
475+ b.create <linalg::YieldOp>(loc, args[0 ]);
476+ })
477+ .getResult (0 );
478+ return success ();
479+ }
391480
481+ // Fall back to numpy-style dynamic broadcasting in the form of a single
482+ // linalg op.
392483 SmallVector<AffineMap> indexingMaps = {
393- rewriter.getMultiDimIdentityMap (broadcastToShape. size () )};
394- SmallVector<utils::IteratorType> iteratorTypes (broadcastToShape. size () ,
484+ rewriter.getMultiDimIdentityMap (outputRank )};
485+ SmallVector<utils::IteratorType> iteratorTypes (outputRank ,
395486 utils::IteratorType::parallel);
396487 result = rewriter
397488 .create <linalg::GenericOp>(
@@ -402,7 +493,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
402493 // would be used to extract values from the input tensor
403494 // later on.
404495 SmallVector<Value> loopIndices;
405- for (size_t i = 0 ; i < broadcastToShape. size () ; ++i) {
496+ for (size_t i = 0 , e = outputRank ; i < e ; ++i) {
406497 if (i < diff)
407498 continue ;
408499 loopIndices.push_back (b.create <linalg::IndexOp>(loc, i));
@@ -411,7 +502,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
411502 // the i-th input dimension is not 1, else it contains a
412503 // zero index.
413504 SmallVector<Value> inputIndicesToExtract;
414- for (size_t i = 0 , n = inputShape. size () ; i < n; i++) {
505+ for (size_t i = 0 , n = inputRank ; i < n; i++) {
415506 if (inputShape[i] == 1 ) {
416507 inputIndicesToExtract.push_back (zeroIndex);
417508 } else {
0 commit comments