Skip to content

Commit a53c640

Browse files
fix BlockMatrixAgg lowering
1 parent e2a420c commit a53c640

File tree

3 files changed

+116
-101
lines changed

3 files changed

+116
-101
lines changed

hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -746,13 +746,13 @@ case class BlockMatrixAgg(
746746
axesToSumOut match {
747747
case IndexedSeq(0, 1) => BlockMatrixSparsity.dense
748748
case IndexedSeq(0) => // col vector result; agg over row
749-
BlockMatrixSparsity.constructFromShapeAndFunction(child.typ.nRowBlocks, 1) { (i, _) =>
750-
(0 until child.typ.nColBlocks).exists(j => child.typ.hasBlock(i -> j))
751-
}
752-
case IndexedSeq(1) => // row vector result; agg over col
753749
BlockMatrixSparsity.constructFromShapeAndFunction(1, child.typ.nColBlocks) { (_, j) =>
754750
(0 until child.typ.nRowBlocks).exists(i => child.typ.hasBlock(i -> j))
755751
}
752+
case IndexedSeq(1) => // row vector result; agg over col
753+
BlockMatrixSparsity.constructFromShapeAndFunction(child.typ.nRowBlocks, 1) { (i, _) =>
754+
(0 until child.typ.nColBlocks).exists(j => child.typ.hasBlock(i -> j))
755+
}
756756
}
757757
} else BlockMatrixSparsity.dense
758758

hail/src/main/scala/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala

Lines changed: 100 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,107 +1293,110 @@ object LowerBlockMatrixIR {
12931293

12941294
bmir match {
12951295

1296-
case BlockMatrixAgg(child, axesToSumOut) =>
1297-
val loweredChild = lower(child)
1296+
case BlockMatrixAgg(child@_, axesToSumOut) =>
1297+
// val loweredChild = lower(child)
12981298
axesToSumOut match {
12991299
case IndexedSeq(0, 1) =>
1300-
val summedChild = loweredChild.mapBody { (ctx, body) =>
1301-
NDArrayReshape(
1302-
NDArrayAgg(body, IndexedSeq(0, 1)),
1303-
MakeTuple.ordered(FastSeq(I64(1), I64(1))),
1304-
ErrorIDs.NO_ERROR,
1305-
)
1306-
}
1307-
val summedChildType = BlockMatrixType(
1308-
child.typ.elementType,
1309-
IndexedSeq[Long](child.typ.nRowBlocks, child.typ.nColBlocks),
1310-
child.typ.nRowBlocks == 1,
1311-
1,
1312-
BlockMatrixSparsity.dense,
1313-
)
1314-
val res = NDArrayAgg(
1315-
summedChild.collectLocal(summedChildType, "block_matrix_agg"),
1316-
IndexedSeq[Int](0, 1),
1317-
)
1318-
new BlockMatrixStage(summedChild.broadcastVals, TStruct.empty) {
1319-
override def blockContext(idx: (Int, Int)): IR = makestruct()
1320-
override def blockBody(ctxRef: Ref): IR =
1321-
NDArrayReshape(res, MakeTuple.ordered(FastSeq(I64(1L), I64(1L))), ErrorIDs.NO_ERROR)
1322-
}
1300+
throw new NotImplementedError("BlockMatrixAgg (0, 1)")
1301+
// val summedChild = loweredChild.mapBody { (ctx, body) =>
1302+
// NDArrayReshape(
1303+
// NDArrayAgg(body, IndexedSeq(0, 1)),
1304+
// MakeTuple.ordered(FastSeq(I64(1), I64(1))),
1305+
// ErrorIDs.NO_ERROR,
1306+
// )
1307+
// }
1308+
// val summedChildType = BlockMatrixType(
1309+
// child.typ.elementType,
1310+
// IndexedSeq[Long](child.typ.nRowBlocks, child.typ.nColBlocks),
1311+
// child.typ.nRowBlocks == 1,
1312+
// 1,
1313+
// BlockMatrixSparsity.dense,
1314+
// )
1315+
// val res = NDArrayAgg(
1316+
// summedChild.collectLocal(summedChildType, "block_matrix_agg"),
1317+
// IndexedSeq[Int](0, 1),
1318+
// )
1319+
// new BlockMatrixStage(summedChild.broadcastVals, TStruct.empty) {
1320+
// override def blockContext(idx: (Int, Int)): IR = makestruct()
1321+
// override def blockBody(ctxRef: Ref): IR =
1322+
// NDArrayReshape(res, MakeTuple.ordered(FastSeq(I64(1L), I64(1L))), ErrorIDs.NO_ERROR)
1323+
// }
13231324
case IndexedSeq(0) => // Number of rows goes to 1. Number of cols remains the same.
1324-
new BlockMatrixStage(loweredChild.broadcastVals, TArray(loweredChild.ctxType)) {
1325-
override def blockContext(idx: (Int, Int)): IR = {
1326-
val (row, col) = idx
1327-
assert(row == 0, s"Asked for idx $idx")
1328-
MakeArray(
1329-
(0 until child.typ.nRowBlocks).map(childRow =>
1330-
loweredChild.blockContext((childRow, col))
1331-
),
1332-
TArray(loweredChild.ctxType),
1333-
)
1334-
}
1335-
override def blockBody(ctxRef: Ref): IR = {
1336-
val summedChildBlocks = mapIR(ToStream(ctxRef)) { singleChildCtx =>
1337-
bindIR(NDArrayAgg(loweredChild.blockBody(singleChildCtx), axesToSumOut))(
1338-
aggedND =>
1339-
NDArrayReshape(
1340-
aggedND,
1341-
MakeTuple.ordered(FastSeq(
1342-
I64(1),
1343-
GetTupleElement(NDArrayShape(aggedND), 0),
1344-
)),
1345-
ErrorIDs.NO_ERROR,
1346-
)
1347-
)
1348-
}
1349-
val aggVar = freshName()
1350-
StreamAgg(
1351-
summedChildBlocks,
1352-
aggVar,
1353-
ApplyAggOp(NDArraySum())(Ref(
1354-
aggVar,
1355-
summedChildBlocks.typ.asInstanceOf[TStream].elementType,
1356-
)),
1357-
)
1358-
}
1359-
}
1325+
throw new NotImplementedError("BlockMatrixAgg (0)")
1326+
// new BlockMatrixStage(loweredChild.broadcastVals, TArray(loweredChild.ctxType)) {
1327+
// override def blockContext(idx: (Int, Int)): IR = {
1328+
// val (row, col) = idx
1329+
// assert(row == 0, s"Asked for idx $idx")
1330+
// MakeArray(
1331+
// (0 until child.typ.nRowBlocks).map(childRow =>
1332+
// loweredChild.blockContext((childRow, col))
1333+
// ),
1334+
// TArray(loweredChild.ctxType),
1335+
// )
1336+
// }
1337+
// override def blockBody(ctxRef: Ref): IR = {
1338+
// val summedChildBlocks = mapIR(ToStream(ctxRef)) { singleChildCtx =>
1339+
// bindIR(NDArrayAgg(loweredChild.blockBody(singleChildCtx), axesToSumOut))(
1340+
// aggedND =>
1341+
// NDArrayReshape(
1342+
// aggedND,
1343+
// MakeTuple.ordered(FastSeq(
1344+
// I64(1),
1345+
// GetTupleElement(NDArrayShape(aggedND), 0),
1346+
// )),
1347+
// ErrorIDs.NO_ERROR,
1348+
// )
1349+
// )
1350+
// }
1351+
// val aggVar = freshName()
1352+
// StreamAgg(
1353+
// summedChildBlocks,
1354+
// aggVar,
1355+
// ApplyAggOp(NDArraySum())(Ref(
1356+
// aggVar,
1357+
// summedChildBlocks.typ.asInstanceOf[TStream].elementType,
1358+
// )),
1359+
// )
1360+
// }
1361+
// }
13601362
case IndexedSeq(1) => // Number of cols goes to 1. Number of rows remains the same.
1361-
new BlockMatrixStage(loweredChild.broadcastVals, TArray(loweredChild.ctxType)) {
1362-
override def blockContext(idx: (Int, Int)): IR = {
1363-
val (row, col) = idx
1364-
assert(col == 0, s"Asked for idx $idx")
1365-
MakeArray(
1366-
(0 until child.typ.nColBlocks).map(childCol =>
1367-
loweredChild.blockContext((row, childCol))
1368-
),
1369-
TArray(loweredChild.ctxType),
1370-
)
1371-
}
1372-
override def blockBody(ctxRef: Ref): IR = {
1373-
val summedChildBlocks = mapIR(ToStream(ctxRef)) { singleChildCtx =>
1374-
bindIR(NDArrayAgg(loweredChild.blockBody(singleChildCtx), axesToSumOut)) {
1375-
aggedND =>
1376-
NDArrayReshape(
1377-
aggedND,
1378-
MakeTuple(FastSeq(
1379-
0 -> GetTupleElement(NDArrayShape(aggedND), 0),
1380-
1 -> I64(1),
1381-
)),
1382-
ErrorIDs.NO_ERROR,
1383-
)
1384-
}
1385-
}
1386-
val aggVar = freshName()
1387-
StreamAgg(
1388-
summedChildBlocks,
1389-
aggVar,
1390-
ApplyAggOp(NDArraySum())(Ref(
1391-
aggVar,
1392-
summedChildBlocks.typ.asInstanceOf[TStream].elementType,
1393-
)),
1394-
)
1395-
}
1396-
}
1363+
throw new NotImplementedError("BlockMatrixAgg (0)")
1364+
// new BlockMatrixStage(loweredChild.broadcastVals, TArray(loweredChild.ctxType)) {
1365+
// override def blockContext(idx: (Int, Int)): IR = {
1366+
// val (row, col) = idx
1367+
// assert(col == 0, s"Asked for idx $idx")
1368+
// MakeArray(
1369+
// (0 until child.typ.nColBlocks).map(childCol =>
1370+
// loweredChild.blockContext((row, childCol))
1371+
// ),
1372+
// TArray(loweredChild.ctxType),
1373+
// )
1374+
// }
1375+
// override def blockBody(ctxRef: Ref): IR = {
1376+
// val summedChildBlocks = mapIR(ToStream(ctxRef)) { singleChildCtx =>
1377+
// bindIR(NDArrayAgg(loweredChild.blockBody(singleChildCtx), axesToSumOut)) {
1378+
// aggedND =>
1379+
// NDArrayReshape(
1380+
// aggedND,
1381+
// MakeTuple(FastSeq(
1382+
// 0 -> GetTupleElement(NDArrayShape(aggedND), 0),
1383+
// 1 -> I64(1),
1384+
// )),
1385+
// ErrorIDs.NO_ERROR,
1386+
// )
1387+
// }
1388+
// }
1389+
// val aggVar = freshName()
1390+
// StreamAgg(
1391+
// summedChildBlocks,
1392+
// aggVar,
1393+
// ApplyAggOp(NDArraySum())(Ref(
1394+
// aggVar,
1395+
// summedChildBlocks.typ.asInstanceOf[TStream].elementType,
1396+
// )),
1397+
// )
1398+
// }
1399+
// }
13971400
}
13981401

13991402
case x @ BlockMatrixSlice(

hail/src/test/scala/is/hail/expr/ir/BlockMatrixIRSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,18 @@ class BlockMatrixIRSuite extends HailSuite {
249249
)
250250
}
251251

252+
@Test def foo(): Unit = {
253+
var bm: BlockMatrixIR = BlockMatrixRead(BlockMatrixNativeReader(fs, "tmp.bm"))
254+
bm = BlockMatrixAgg(bm, FastSeq(0))
255+
assertEvalsTo(
256+
BlockMatrixWrite(
257+
bm,
258+
BlockMatrixNativeWriter("tmp2.bm", true, true, false),
259+
),
260+
(),
261+
)
262+
}
263+
252264
@Test def readWriteBlockMatrix(): Unit = {
253265
val original = "src/test/resources/blockmatrix_example/0"
254266
val expected = BlockMatrix.read(ctx.fs, original).toBreezeMatrix()

0 commit comments

Comments
 (0)