@@ -33,6 +33,9 @@ import (
3333 "github.com/hypermodeinc/dgraph/v25/schema"
3434 "github.com/hypermodeinc/dgraph/v25/tok"
3535 "github.com/hypermodeinc/dgraph/v25/tok/hnsw"
36+ tokIndex "github.com/hypermodeinc/dgraph/v25/tok/index"
37+ "github.com/hypermodeinc/dgraph/v25/tok/kmeans"
38+
3639 "github.com/hypermodeinc/dgraph/v25/types"
3740 "github.com/hypermodeinc/dgraph/v25/x"
3841)
@@ -1412,6 +1415,284 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) {
14121415 return prefixes , nil
14131416}
14141417
1418+ func rebuildVectorIndex (ctx context.Context , factorySpecs []* tok.FactoryCreateSpec , rb * IndexRebuild ) error {
1419+ pk := x.ParsedKey {Attr : rb .Attr }
1420+
1421+ indexer , err := factorySpecs [0 ].CreateIndex (pk .Attr )
1422+ if err != nil {
1423+ return err
1424+ }
1425+
1426+ dimension := indexer .Dimension ()
1427+ // If dimension is -1, it means that the dimension is not set through options in case of partitioned hnsw.
1428+ if dimension == - 1 {
1429+ numVectorsToCheck := 100
1430+ lenFreq := make (map [int ]int , numVectorsToCheck )
1431+ maxFreq := 0
1432+ MemLayerInstance .IterateDisk (ctx , IterateDiskArgs {
1433+ Prefix : pk .DataPrefix (),
1434+ ReadTs : rb .StartTs ,
1435+ AllVersions : false ,
1436+ Reverse : false ,
1437+ CheckInclusion : func (uid uint64 ) error {
1438+ return nil
1439+ },
1440+ Function : func (l * List , pk x.ParsedKey ) error {
1441+ val , err := l .Value (rb .StartTs )
1442+ if err != nil {
1443+ return err
1444+ }
1445+ inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1446+ lenFreq [len (inVec )] += 1
1447+ if lenFreq [len (inVec )] > maxFreq {
1448+ maxFreq = lenFreq [len (inVec )]
1449+ dimension = len (inVec )
1450+ }
1451+ numVectorsToCheck -= 1
1452+ if numVectorsToCheck <= 0 {
1453+ return ErrStopIteration
1454+ }
1455+ return nil
1456+ },
1457+ StartKey : x .DataKey (rb .Attr , 0 ),
1458+ })
1459+
1460+ indexer .SetDimension (rb .CurrentSchema , dimension )
1461+ }
1462+
1463+ fmt .Println ("Selecting vector dimension to be:" , dimension )
1464+
1465+ norm := rebuilder {attr : rb .Attr , prefix : pk .DataPrefix (), startTs : rb .StartTs }
1466+ norm .fn = func (uid uint64 , pl * List , txn * Txn ) ([]* pb.DirectedEdge , error ) {
1467+ val , err := pl .Value (rb .StartTs )
1468+ if err != nil {
1469+ return nil , err
1470+ }
1471+ if val .Tid == types .VFloatID {
1472+ return nil , nil
1473+ }
1474+
1475+ // Convert to VFloatID and persist as binary bytes.
1476+ sv , err := types .Convert (val , types .VFloatID )
1477+ if err != nil {
1478+ return nil , err
1479+ }
1480+ b := types .ValueForType (types .BinaryID )
1481+ if err = types .Marshal (sv , & b ); err != nil {
1482+ return nil , err
1483+ }
1484+
1485+ edge := & pb.DirectedEdge {
1486+ Attr : rb .Attr ,
1487+ Entity : uid ,
1488+ Value : b .Value .([]byte ),
1489+ ValueType : types .VFloatID .Enum (),
1490+ }
1491+ inKey := x .DataKey (edge .Attr , uid )
1492+ p , err := txn .Get (inKey )
1493+ if err != nil {
1494+ return []* pb.DirectedEdge {}, err
1495+ }
1496+
1497+ if err := p .addMutation (ctx , txn , edge ); err != nil {
1498+ return []* pb.DirectedEdge {}, err
1499+ }
1500+ return nil , nil
1501+ }
1502+
1503+ if err := norm .RunWithoutTemp (ctx ); err != nil {
1504+ return err
1505+ }
1506+
1507+ count := 0
1508+
1509+ if indexer .NumSeedVectors () > 0 {
1510+ err := MemLayerInstance .IterateDisk (ctx , IterateDiskArgs {
1511+ Prefix : pk .DataPrefix (),
1512+ ReadTs : rb .StartTs ,
1513+ AllVersions : false ,
1514+ Reverse : false ,
1515+ CheckInclusion : func (uid uint64 ) error {
1516+ return nil
1517+ },
1518+ Function : func (l * List , pk x.ParsedKey ) error {
1519+ val , err := l .Value (rb .StartTs )
1520+ if err != nil {
1521+ return err
1522+ }
1523+
1524+ if val .Tid != types .VFloatID {
1525+ // Here, we convert the defaultID type vector into vfloat.
1526+ sv , err := types .Convert (val , types .VFloatID )
1527+ if err != nil {
1528+ return err
1529+ }
1530+ b := types .ValueForType (types .BinaryID )
1531+ if err = types .Marshal (sv , & b ); err != nil {
1532+ return err
1533+ }
1534+
1535+ val .Value = b .Value
1536+ val .Tid = types .VFloatID
1537+ }
1538+
1539+ inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1540+ if len (inVec ) != dimension {
1541+ return fmt .Errorf ("vector dimension mismatch expected dimension %d but got %d" , dimension , len (inVec ))
1542+ }
1543+ count += 1
1544+ indexer .AddSeedVector (inVec )
1545+ if count == indexer .NumSeedVectors () {
1546+ return ErrStopIteration
1547+ }
1548+ return nil
1549+ },
1550+ StartKey : x .DataKey (rb .Attr , 0 ),
1551+ })
1552+ if err != nil {
1553+ return err
1554+ }
1555+ }
1556+
1557+ txns := make ([]* Txn , indexer .NumThreads ())
1558+ for i := range txns {
1559+ txns [i ] = NewTxn (rb .StartTs )
1560+ }
1561+ caches := make ([]tokIndex.CacheType , indexer .NumThreads ())
1562+ for i := range caches {
1563+ caches [i ] = hnsw .NewTxnCache (NewViTxn (txns [i ]), rb .StartTs )
1564+ }
1565+
1566+ if count < indexer .NumSeedVectors () {
1567+ indexer .SetNumPasses (0 )
1568+ }
1569+
1570+ for pass_idx := range indexer .NumBuildPasses () {
1571+ fmt .Println ("Building pass" , pass_idx )
1572+
1573+ indexer .StartBuild (caches )
1574+
1575+ builder := rebuilder {attr : rb .Attr , prefix : pk .DataPrefix (), startTs : rb .StartTs }
1576+ builder .fn = func (uid uint64 , pl * List , txn * Txn ) ([]* pb.DirectedEdge , error ) {
1577+ val , err := pl .Value (rb .StartTs )
1578+ if err != nil {
1579+ return []* pb.DirectedEdge {}, err
1580+ }
1581+
1582+ inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1583+ if len (inVec ) != dimension {
1584+ return []* pb.DirectedEdge {}, nil
1585+ }
1586+ indexer .BuildInsert (ctx , uid , inVec )
1587+ return []* pb.DirectedEdge {}, nil
1588+ }
1589+
1590+ err := builder .RunWithoutTemp (ctx )
1591+ if err != nil {
1592+ return err
1593+ }
1594+
1595+ indexer .EndBuild ()
1596+ }
1597+
1598+ centroids := indexer .GetCentroids ()
1599+
1600+ if centroids != nil {
1601+ txn := NewTxn (rb .StartTs )
1602+
1603+ bCentroids , err := json .Marshal (centroids )
1604+ if err != nil {
1605+ return err
1606+ }
1607+
1608+ if err := addCentroidInDB (ctx , rb .Attr , bCentroids , txn ); err != nil {
1609+ return err
1610+ }
1611+ txn .Update ()
1612+ writer := NewTxnWriter (pstore )
1613+ if err := txn .CommitToDisk (writer , rb .StartTs ); err != nil {
1614+ return err
1615+ }
1616+ }
1617+
1618+ numIndexPasses := indexer .NumIndexPasses ()
1619+
1620+ if count < indexer .NumSeedVectors () {
1621+ numIndexPasses = 1
1622+ }
1623+
1624+ for pass_idx := range numIndexPasses {
1625+ fmt .Println ("Indexing pass" , pass_idx )
1626+
1627+ indexer .StartBuild (caches )
1628+
1629+ builder := rebuilder {attr : rb .Attr , prefix : pk .DataPrefix (), startTs : rb .StartTs }
1630+ builder .fn = func (uid uint64 , pl * List , txn * Txn ) ([]* pb.DirectedEdge , error ) {
1631+ val , err := pl .Value (rb .StartTs )
1632+ if err != nil {
1633+ return []* pb.DirectedEdge {}, err
1634+ }
1635+
1636+ inVec := types .BytesAsFloatArray (val .Value .([]byte ))
1637+ if len (inVec ) != dimension && centroids != nil {
1638+ if pass_idx == 0 {
1639+ glog .Warningf ("Skipping vector with invalid dimension uid: %d, dimension: %d" , uid , len (inVec ))
1640+ }
1641+ return []* pb.DirectedEdge {}, nil
1642+ }
1643+
1644+ indexer .BuildInsert (ctx , uid , inVec )
1645+
1646+ return []* pb.DirectedEdge {}, nil
1647+ }
1648+
1649+ err := builder .RunWithoutTemp (ctx )
1650+ if err != nil {
1651+ return err
1652+ }
1653+
1654+ for _ , idx := range indexer .EndBuild () {
1655+ txns [idx ].Update ()
1656+ writer := NewTxnWriter (pstore )
1657+
1658+ x .ExponentialRetry (int (x .Config .MaxRetries ),
1659+ 20 * time .Millisecond , func () error {
1660+ err := txns [idx ].CommitToDisk (writer , rb .StartTs )
1661+ if err == badger .ErrBannedKey {
1662+ glog .Errorf ("Error while writing to banned namespace." )
1663+ return nil
1664+ }
1665+ return err
1666+ })
1667+
1668+ txns [idx ].cache .plists = nil
1669+ txns [idx ] = nil
1670+ }
1671+ }
1672+
1673+ return nil
1674+ }
1675+
1676+ func addCentroidInDB (ctx context.Context , attr string , vec []byte , txn * Txn ) error {
1677+ indexCountAttr := hnsw .ConcatStrings (attr , kmeans .CentroidPrefix )
1678+ countKey := x .DataKey (indexCountAttr , 1 )
1679+ pl , err := txn .Get (countKey )
1680+ if err != nil {
1681+ return err
1682+ }
1683+
1684+ edge := & pb.DirectedEdge {
1685+ Entity : 1 ,
1686+ Attr : indexCountAttr ,
1687+ Value : vec ,
1688+ ValueType : pb .Posting_ValType (12 ),
1689+ }
1690+ if err := pl .addMutation (ctx , txn , edge ); err != nil {
1691+ return err
1692+ }
1693+ return nil
1694+ }
1695+
14151696// rebuildTokIndex rebuilds index for a given attribute.
14161697// We commit mutations with startTs and ignore the errors.
14171698func rebuildTokIndex (ctx context.Context , rb * IndexRebuild ) error {
@@ -1443,6 +1724,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
14431724 }
14441725
14451726 runForVectors := (len (factorySpecs ) != 0 )
1727+ if runForVectors {
1728+ return rebuildVectorIndex (ctx , factorySpecs , rb )
1729+ }
14461730
14471731 pk := x.ParsedKey {Attr : rb .Attr }
14481732 builder := rebuilder {attr : rb .Attr , prefix : pk .DataPrefix (), startTs : rb .StartTs }
0 commit comments