@@ -1293,107 +1293,110 @@ object LowerBlockMatrixIR {
12931293
12941294 bmir match {
12951295
1296- case BlockMatrixAgg (child, axesToSumOut) =>
1297- val loweredChild = lower(child)
1296+ case BlockMatrixAgg (child@ _ , axesToSumOut) =>
1297+ // val loweredChild = lower(child)
12981298 axesToSumOut match {
12991299 case IndexedSeq (0 , 1 ) =>
1300- val summedChild = loweredChild.mapBody { (ctx, body) =>
1301- NDArrayReshape (
1302- NDArrayAgg (body, IndexedSeq (0 , 1 )),
1303- MakeTuple .ordered(FastSeq (I64 (1 ), I64 (1 ))),
1304- ErrorIDs .NO_ERROR ,
1305- )
1306- }
1307- val summedChildType = BlockMatrixType (
1308- child.typ.elementType,
1309- IndexedSeq [Long ](child.typ.nRowBlocks, child.typ.nColBlocks),
1310- child.typ.nRowBlocks == 1 ,
1311- 1 ,
1312- BlockMatrixSparsity .dense,
1313- )
1314- val res = NDArrayAgg (
1315- summedChild.collectLocal(summedChildType, " block_matrix_agg" ),
1316- IndexedSeq [Int ](0 , 1 ),
1317- )
1318- new BlockMatrixStage (summedChild.broadcastVals, TStruct .empty) {
1319- override def blockContext (idx : (Int , Int )): IR = makestruct()
1320- override def blockBody (ctxRef : Ref ): IR =
1321- NDArrayReshape (res, MakeTuple .ordered(FastSeq (I64 (1L ), I64 (1L ))), ErrorIDs .NO_ERROR )
1322- }
1300+ throw new NotImplementedError (" BlockMatrixAgg (0, 1)" )
1301+ // val summedChild = loweredChild.mapBody { (ctx, body) =>
1302+ // NDArrayReshape(
1303+ // NDArrayAgg(body, IndexedSeq(0, 1)),
1304+ // MakeTuple.ordered(FastSeq(I64(1), I64(1))),
1305+ // ErrorIDs.NO_ERROR,
1306+ // )
1307+ // }
1308+ // val summedChildType = BlockMatrixType(
1309+ // child.typ.elementType,
1310+ // IndexedSeq[Long](child.typ.nRowBlocks, child.typ.nColBlocks),
1311+ // child.typ.nRowBlocks == 1,
1312+ // 1,
1313+ // BlockMatrixSparsity.dense,
1314+ // )
1315+ // val res = NDArrayAgg(
1316+ // summedChild.collectLocal(summedChildType, "block_matrix_agg"),
1317+ // IndexedSeq[Int](0, 1),
1318+ // )
1319+ // new BlockMatrixStage(summedChild.broadcastVals, TStruct.empty) {
1320+ // override def blockContext(idx: (Int, Int)): IR = makestruct()
1321+ // override def blockBody(ctxRef: Ref): IR =
1322+ // NDArrayReshape(res, MakeTuple.ordered(FastSeq(I64(1L), I64(1L))), ErrorIDs.NO_ERROR)
1323+ // }
13231324 case IndexedSeq (0 ) => // Number of rows goes to 1. Number of cols remains the same.
1324- new BlockMatrixStage (loweredChild.broadcastVals, TArray (loweredChild.ctxType)) {
1325- override def blockContext (idx : (Int , Int )): IR = {
1326- val (row, col) = idx
1327- assert(row == 0 , s " Asked for idx $idx" )
1328- MakeArray (
1329- (0 until child.typ.nRowBlocks).map(childRow =>
1330- loweredChild.blockContext((childRow, col))
1331- ),
1332- TArray (loweredChild.ctxType),
1333- )
1334- }
1335- override def blockBody (ctxRef : Ref ): IR = {
1336- val summedChildBlocks = mapIR(ToStream (ctxRef)) { singleChildCtx =>
1337- bindIR(NDArrayAgg (loweredChild.blockBody(singleChildCtx), axesToSumOut))(
1338- aggedND =>
1339- NDArrayReshape (
1340- aggedND,
1341- MakeTuple .ordered(FastSeq (
1342- I64 (1 ),
1343- GetTupleElement (NDArrayShape (aggedND), 0 ),
1344- )),
1345- ErrorIDs .NO_ERROR ,
1346- )
1347- )
1348- }
1349- val aggVar = freshName()
1350- StreamAgg (
1351- summedChildBlocks,
1352- aggVar,
1353- ApplyAggOp (NDArraySum ())(Ref (
1354- aggVar,
1355- summedChildBlocks.typ.asInstanceOf [TStream ].elementType,
1356- )),
1357- )
1358- }
1359- }
1325+ throw new NotImplementedError (" BlockMatrixAgg (0)" )
1326+ // new BlockMatrixStage(loweredChild.broadcastVals, TArray(loweredChild.ctxType)) {
1327+ // override def blockContext(idx: (Int, Int)): IR = {
1328+ // val (row, col) = idx
1329+ // assert(row == 0, s"Asked for idx $idx")
1330+ // MakeArray(
1331+ // (0 until child.typ.nRowBlocks).map(childRow =>
1332+ // loweredChild.blockContext((childRow, col))
1333+ // ),
1334+ // TArray(loweredChild.ctxType),
1335+ // )
1336+ // }
1337+ // override def blockBody(ctxRef: Ref): IR = {
1338+ // val summedChildBlocks = mapIR(ToStream(ctxRef)) { singleChildCtx =>
1339+ // bindIR(NDArrayAgg(loweredChild.blockBody(singleChildCtx), axesToSumOut))(
1340+ // aggedND =>
1341+ // NDArrayReshape(
1342+ // aggedND,
1343+ // MakeTuple.ordered(FastSeq(
1344+ // I64(1),
1345+ // GetTupleElement(NDArrayShape(aggedND), 0),
1346+ // )),
1347+ // ErrorIDs.NO_ERROR,
1348+ // )
1349+ // )
1350+ // }
1351+ // val aggVar = freshName()
1352+ // StreamAgg(
1353+ // summedChildBlocks,
1354+ // aggVar,
1355+ // ApplyAggOp(NDArraySum())(Ref(
1356+ // aggVar,
1357+ // summedChildBlocks.typ.asInstanceOf[TStream].elementType,
1358+ // )),
1359+ // )
1360+ // }
1361+ // }
13601362 case IndexedSeq (1 ) => // Number of cols goes to 1. Number of rows remains the same.
1361- new BlockMatrixStage (loweredChild.broadcastVals, TArray (loweredChild.ctxType)) {
1362- override def blockContext (idx : (Int , Int )): IR = {
1363- val (row, col) = idx
1364- assert(col == 0 , s " Asked for idx $idx" )
1365- MakeArray (
1366- (0 until child.typ.nColBlocks).map(childCol =>
1367- loweredChild.blockContext((row, childCol))
1368- ),
1369- TArray (loweredChild.ctxType),
1370- )
1371- }
1372- override def blockBody (ctxRef : Ref ): IR = {
1373- val summedChildBlocks = mapIR(ToStream (ctxRef)) { singleChildCtx =>
1374- bindIR(NDArrayAgg (loweredChild.blockBody(singleChildCtx), axesToSumOut)) {
1375- aggedND =>
1376- NDArrayReshape (
1377- aggedND,
1378- MakeTuple (FastSeq (
1379- 0 -> GetTupleElement (NDArrayShape (aggedND), 0 ),
1380- 1 -> I64 (1 ),
1381- )),
1382- ErrorIDs .NO_ERROR ,
1383- )
1384- }
1385- }
1386- val aggVar = freshName()
1387- StreamAgg (
1388- summedChildBlocks,
1389- aggVar,
1390- ApplyAggOp (NDArraySum ())(Ref (
1391- aggVar,
1392- summedChildBlocks.typ.asInstanceOf [TStream ].elementType,
1393- )),
1394- )
1395- }
1396- }
1363+ throw new NotImplementedError (" BlockMatrixAgg (0)" )
1364+ // new BlockMatrixStage(loweredChild.broadcastVals, TArray(loweredChild.ctxType)) {
1365+ // override def blockContext(idx: (Int, Int)): IR = {
1366+ // val (row, col) = idx
1367+ // assert(col == 0, s"Asked for idx $idx")
1368+ // MakeArray(
1369+ // (0 until child.typ.nColBlocks).map(childCol =>
1370+ // loweredChild.blockContext((row, childCol))
1371+ // ),
1372+ // TArray(loweredChild.ctxType),
1373+ // )
1374+ // }
1375+ // override def blockBody(ctxRef: Ref): IR = {
1376+ // val summedChildBlocks = mapIR(ToStream(ctxRef)) { singleChildCtx =>
1377+ // bindIR(NDArrayAgg(loweredChild.blockBody(singleChildCtx), axesToSumOut)) {
1378+ // aggedND =>
1379+ // NDArrayReshape(
1380+ // aggedND,
1381+ // MakeTuple(FastSeq(
1382+ // 0 -> GetTupleElement(NDArrayShape(aggedND), 0),
1383+ // 1 -> I64(1),
1384+ // )),
1385+ // ErrorIDs.NO_ERROR,
1386+ // )
1387+ // }
1388+ // }
1389+ // val aggVar = freshName()
1390+ // StreamAgg(
1391+ // summedChildBlocks,
1392+ // aggVar,
1393+ // ApplyAggOp(NDArraySum())(Ref(
1394+ // aggVar,
1395+ // summedChildBlocks.typ.asInstanceOf[TStream].elementType,
1396+ // )),
1397+ // )
1398+ // }
1399+ // }
13971400 }
13981401
13991402 case x @ BlockMatrixSlice (
0 commit comments