Skip to content

Commit f2cade4

Browse files
darkcoderrisesshivaji-kharse
authored andcommitted
perf(vector): Improve hnsw by sharding vectors
1 parent 9d4cb77 commit f2cade4

File tree

17 files changed

+1369
-71
lines changed

17 files changed

+1369
-71
lines changed

.github/workflows/ci-dgraph-vector-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
dgraph-vector-tests:
2727
if: github.event.pull_request.draft == false
2828
runs-on: warp-ubuntu-latest-x64-4x
29-
timeout-minutes: 30
29+
timeout-minutes: 120
3030
steps:
3131
- uses: actions/checkout@v5
3232
- name: Set up Go

.vscode/launch.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"--security",
1616
"whitelist=0.0.0.0/0;"
1717
],
18-
"showLog": true
18+
"showLog": false
1919
},
2020
{
2121
"name": "Zero",
@@ -25,7 +25,7 @@
2525
"program": "${workspaceRoot}/dgraph/",
2626
"env": {},
2727
"args": ["zero"],
28-
"showLog": true
28+
"showLog": false
2929
},
3030
{
3131
"name": "AlphaACL",

posting/index.go

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ import (
3333
"github.com/hypermodeinc/dgraph/v25/schema"
3434
"github.com/hypermodeinc/dgraph/v25/tok"
3535
"github.com/hypermodeinc/dgraph/v25/tok/hnsw"
36+
tokIndex "github.com/hypermodeinc/dgraph/v25/tok/index"
37+
"github.com/hypermodeinc/dgraph/v25/tok/kmeans"
38+
3639
"github.com/hypermodeinc/dgraph/v25/types"
3740
"github.com/hypermodeinc/dgraph/v25/x"
3841
)
@@ -1412,6 +1415,284 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) {
14121415
return prefixes, nil
14131416
}
14141417

1418+
func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error {
1419+
pk := x.ParsedKey{Attr: rb.Attr}
1420+
1421+
indexer, err := factorySpecs[0].CreateIndex(pk.Attr)
1422+
if err != nil {
1423+
return err
1424+
}
1425+
1426+
dimension := indexer.Dimension()
1427+
// If dimension is -1, it means that the dimension is not set through options in case of partitioned hnsw.
1428+
if dimension == -1 {
1429+
numVectorsToCheck := 100
1430+
lenFreq := make(map[int]int, numVectorsToCheck)
1431+
maxFreq := 0
1432+
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1433+
Prefix: pk.DataPrefix(),
1434+
ReadTs: rb.StartTs,
1435+
AllVersions: false,
1436+
Reverse: false,
1437+
CheckInclusion: func(uid uint64) error {
1438+
return nil
1439+
},
1440+
Function: func(l *List, pk x.ParsedKey) error {
1441+
val, err := l.Value(rb.StartTs)
1442+
if err != nil {
1443+
return err
1444+
}
1445+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1446+
lenFreq[len(inVec)] += 1
1447+
if lenFreq[len(inVec)] > maxFreq {
1448+
maxFreq = lenFreq[len(inVec)]
1449+
dimension = len(inVec)
1450+
}
1451+
numVectorsToCheck -= 1
1452+
if numVectorsToCheck <= 0 {
1453+
return ErrStopIteration
1454+
}
1455+
return nil
1456+
},
1457+
StartKey: x.DataKey(rb.Attr, 0),
1458+
})
1459+
1460+
indexer.SetDimension(rb.CurrentSchema, dimension)
1461+
}
1462+
1463+
fmt.Println("Selecting vector dimension to be:", dimension)
1464+
1465+
norm := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1466+
norm.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1467+
val, err := pl.Value(rb.StartTs)
1468+
if err != nil {
1469+
return nil, err
1470+
}
1471+
if val.Tid == types.VFloatID {
1472+
return nil, nil
1473+
}
1474+
1475+
// Convert to VFloatID and persist as binary bytes.
1476+
sv, err := types.Convert(val, types.VFloatID)
1477+
if err != nil {
1478+
return nil, err
1479+
}
1480+
b := types.ValueForType(types.BinaryID)
1481+
if err = types.Marshal(sv, &b); err != nil {
1482+
return nil, err
1483+
}
1484+
1485+
edge := &pb.DirectedEdge{
1486+
Attr: rb.Attr,
1487+
Entity: uid,
1488+
Value: b.Value.([]byte),
1489+
ValueType: types.VFloatID.Enum(),
1490+
}
1491+
inKey := x.DataKey(edge.Attr, uid)
1492+
p, err := txn.Get(inKey)
1493+
if err != nil {
1494+
return []*pb.DirectedEdge{}, err
1495+
}
1496+
1497+
if err := p.addMutation(ctx, txn, edge); err != nil {
1498+
return []*pb.DirectedEdge{}, err
1499+
}
1500+
return nil, nil
1501+
}
1502+
1503+
if err := norm.RunWithoutTemp(ctx); err != nil {
1504+
return err
1505+
}
1506+
1507+
count := 0
1508+
1509+
if indexer.NumSeedVectors() > 0 {
1510+
err := MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1511+
Prefix: pk.DataPrefix(),
1512+
ReadTs: rb.StartTs,
1513+
AllVersions: false,
1514+
Reverse: false,
1515+
CheckInclusion: func(uid uint64) error {
1516+
return nil
1517+
},
1518+
Function: func(l *List, pk x.ParsedKey) error {
1519+
val, err := l.Value(rb.StartTs)
1520+
if err != nil {
1521+
return err
1522+
}
1523+
1524+
if val.Tid != types.VFloatID {
1525+
// Here, we convert the defaultID type vector into vfloat.
1526+
sv, err := types.Convert(val, types.VFloatID)
1527+
if err != nil {
1528+
return err
1529+
}
1530+
b := types.ValueForType(types.BinaryID)
1531+
if err = types.Marshal(sv, &b); err != nil {
1532+
return err
1533+
}
1534+
1535+
val.Value = b.Value
1536+
val.Tid = types.VFloatID
1537+
}
1538+
1539+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1540+
if len(inVec) != dimension {
1541+
return fmt.Errorf("vector dimension mismatch expected dimension %d but got %d", dimension, len(inVec))
1542+
}
1543+
count += 1
1544+
indexer.AddSeedVector(inVec)
1545+
if count == indexer.NumSeedVectors() {
1546+
return ErrStopIteration
1547+
}
1548+
return nil
1549+
},
1550+
StartKey: x.DataKey(rb.Attr, 0),
1551+
})
1552+
if err != nil {
1553+
return err
1554+
}
1555+
}
1556+
1557+
txns := make([]*Txn, indexer.NumThreads())
1558+
for i := range txns {
1559+
txns[i] = NewTxn(rb.StartTs)
1560+
}
1561+
caches := make([]tokIndex.CacheType, indexer.NumThreads())
1562+
for i := range caches {
1563+
caches[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
1564+
}
1565+
1566+
if count < indexer.NumSeedVectors() {
1567+
indexer.SetNumPasses(0)
1568+
}
1569+
1570+
for pass_idx := range indexer.NumBuildPasses() {
1571+
fmt.Println("Building pass", pass_idx)
1572+
1573+
indexer.StartBuild(caches)
1574+
1575+
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1576+
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1577+
val, err := pl.Value(rb.StartTs)
1578+
if err != nil {
1579+
return []*pb.DirectedEdge{}, err
1580+
}
1581+
1582+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1583+
if len(inVec) != dimension {
1584+
return []*pb.DirectedEdge{}, nil
1585+
}
1586+
indexer.BuildInsert(ctx, uid, inVec)
1587+
return []*pb.DirectedEdge{}, nil
1588+
}
1589+
1590+
err := builder.RunWithoutTemp(ctx)
1591+
if err != nil {
1592+
return err
1593+
}
1594+
1595+
indexer.EndBuild()
1596+
}
1597+
1598+
centroids := indexer.GetCentroids()
1599+
1600+
if centroids != nil {
1601+
txn := NewTxn(rb.StartTs)
1602+
1603+
bCentroids, err := json.Marshal(centroids)
1604+
if err != nil {
1605+
return err
1606+
}
1607+
1608+
if err := addCentroidInDB(ctx, rb.Attr, bCentroids, txn); err != nil {
1609+
return err
1610+
}
1611+
txn.Update()
1612+
writer := NewTxnWriter(pstore)
1613+
if err := txn.CommitToDisk(writer, rb.StartTs); err != nil {
1614+
return err
1615+
}
1616+
}
1617+
1618+
numIndexPasses := indexer.NumIndexPasses()
1619+
1620+
if count < indexer.NumSeedVectors() {
1621+
numIndexPasses = 1
1622+
}
1623+
1624+
for pass_idx := range numIndexPasses {
1625+
fmt.Println("Indexing pass", pass_idx)
1626+
1627+
indexer.StartBuild(caches)
1628+
1629+
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1630+
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1631+
val, err := pl.Value(rb.StartTs)
1632+
if err != nil {
1633+
return []*pb.DirectedEdge{}, err
1634+
}
1635+
1636+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1637+
if len(inVec) != dimension && centroids != nil {
1638+
if pass_idx == 0 {
1639+
glog.Warningf("Skipping vector with invalid dimension uid: %d, dimension: %d", uid, len(inVec))
1640+
}
1641+
return []*pb.DirectedEdge{}, nil
1642+
}
1643+
1644+
indexer.BuildInsert(ctx, uid, inVec)
1645+
1646+
return []*pb.DirectedEdge{}, nil
1647+
}
1648+
1649+
err := builder.RunWithoutTemp(ctx)
1650+
if err != nil {
1651+
return err
1652+
}
1653+
1654+
for _, idx := range indexer.EndBuild() {
1655+
txns[idx].Update()
1656+
writer := NewTxnWriter(pstore)
1657+
1658+
x.ExponentialRetry(int(x.Config.MaxRetries),
1659+
20*time.Millisecond, func() error {
1660+
err := txns[idx].CommitToDisk(writer, rb.StartTs)
1661+
if err == badger.ErrBannedKey {
1662+
glog.Errorf("Error while writing to banned namespace.")
1663+
return nil
1664+
}
1665+
return err
1666+
})
1667+
1668+
txns[idx].cache.plists = nil
1669+
txns[idx] = nil
1670+
}
1671+
}
1672+
1673+
return nil
1674+
}
1675+
1676+
func addCentroidInDB(ctx context.Context, attr string, vec []byte, txn *Txn) error {
1677+
indexCountAttr := hnsw.ConcatStrings(attr, kmeans.CentroidPrefix)
1678+
countKey := x.DataKey(indexCountAttr, 1)
1679+
pl, err := txn.Get(countKey)
1680+
if err != nil {
1681+
return err
1682+
}
1683+
1684+
edge := &pb.DirectedEdge{
1685+
Entity: 1,
1686+
Attr: indexCountAttr,
1687+
Value: vec,
1688+
ValueType: pb.Posting_ValType(12),
1689+
}
1690+
if err := pl.addMutation(ctx, txn, edge); err != nil {
1691+
return err
1692+
}
1693+
return nil
1694+
}
1695+
14151696
// rebuildTokIndex rebuilds index for a given attribute.
14161697
// We commit mutations with startTs and ignore the errors.
14171698
func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
@@ -1443,6 +1724,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
14431724
}
14441725

14451726
runForVectors := (len(factorySpecs) != 0)
1727+
if runForVectors {
1728+
return rebuildVectorIndex(ctx, factorySpecs, rb)
1729+
}
14461730

14471731
pk := x.ParsedKey{Attr: rb.Attr}
14481732
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}

0 commit comments

Comments
 (0)