diff --git a/posting/list.go b/posting/list.go index 34f75880c93..3f7c51254c6 100644 --- a/posting/list.go +++ b/posting/list.go @@ -1709,45 +1709,87 @@ func (l *List) Uids(opt ListOptions) (*pb.List, error) { if opt.First == 0 { opt.First = math.MaxInt32 } - // Pre-assign length to make it faster. - l.RLock() - // Use approximate length for initial capacity. - res := make([]uint64, 0, l.mutationMap.len()+codec.ApproxLen(l.plist.Pack)) - out := &pb.List{} - if l.mutationMap.len() == 0 && opt.Intersect != nil && len(l.plist.Splits) == 0 { - if opt.ReadTs < l.minTs { - l.RUnlock() - return out, errors.Wrapf(ErrTsTooOld, "While reading UIDs") + + getUidList := func() (*pb.List, error, bool) { + // Pre-assign length to make it faster. + l.RLock() + defer l.RUnlock() + // Use approximate length for initial capacity. + res := make([]uint64, 0, l.ApproxLen()) + out := &pb.List{} + + if l.mutationMap.len() == 0 && opt.Intersect != nil && len(l.plist.Splits) == 0 { + if opt.ReadTs < l.minTs { + return out, errors.Wrapf(ErrTsTooOld, "While reading UIDs"), false + } + algo.IntersectCompressedWith(l.plist.Pack, opt.AfterUid, opt.Intersect, out) + return out, nil, false } - algo.IntersectCompressedWith(l.plist.Pack, opt.AfterUid, opt.Intersect, out) - l.RUnlock() - return out, nil - } - err := l.iterate(opt.ReadTs, opt.AfterUid, func(p *pb.Posting) error { - if p.PostingType == pb.Posting_REF { - res = append(res, p.Uid) - if opt.First < 0 { - // We need the last N. - // TODO: This could be optimized by only considering some of the last UidBlocks. - if len(res) > -opt.First { - res = res[1:] + // If we need to intersect and the number of elements are small, in that case it's better to + // just check each item is present or not. + if opt.Intersect != nil && len(opt.Intersect.Uids) < l.ApproxLen() { + // Cache the iterator as it makes the search space smaller each time. + var pitr pIterator + for _, uid := range opt.Intersect.Uids { + ok, _, err := l.findPostingWithItr(opt.ReadTs, uid, pitr) + if err != nil { + return nil, err, false + } + if ok { + res = append(res, uid) } - } else if len(res) > opt.First { - return ErrStopIteration } + + out.Uids = res + return out, nil, false } - return nil - }) - l.RUnlock() - if err != nil { - return out, errors.Wrapf(err, "cannot retrieve UIDs from list with key %s", - hex.EncodeToString(l.key)) + + // If we are going to iterate over the list, in that case we only need to read between min and max + // of opt.Intersect. + var uidMin, uidMax uint64 = 0, 0 + if opt.Intersect != nil && len(opt.Intersect.Uids) > 0 { + uidMin = opt.Intersect.Uids[0] + uidMax = opt.Intersect.Uids[len(opt.Intersect.Uids)-1] + } + + err := l.iterate(opt.ReadTs, opt.AfterUid, func(p *pb.Posting) error { + if p.PostingType == pb.Posting_REF { + if p.Uid < uidMin { + return nil + } + if p.Uid > uidMax && uidMax > 0 { + return ErrStopIteration + } + res = append(res, p.Uid) + + if opt.First < 0 { + // We need the last N. + // TODO: This could be optimized by only considering some of the last UidBlocks. + if len(res) > -opt.First { + res = res[1:] + } + } else if len(res) > opt.First { + return ErrStopIteration + } + } + return nil + }) + if err != nil { + return out, errors.Wrapf(err, "cannot retrieve UIDs from list with key %s", + hex.EncodeToString(l.key)), false + } + out.Uids = res + return out, nil, true } // Do The intersection here as it's optimized. - out.Uids = res - lenBefore := len(res) + out, err, applyIntersectWith := getUidList() + if err != nil || !applyIntersectWith { + return out, err + } + + lenBefore := len(out.Uids) if opt.Intersect != nil { algo.IntersectWith(out, opt.Intersect, out) } @@ -2045,7 +2087,7 @@ func (l *List) FindPosting(readTs uint64, uid uint64) (found bool, pos *pb.Posti return l.findPosting(readTs, uid) } -func (l *List) findPosting(readTs uint64, uid uint64) (found bool, pos *pb.Posting, err error) { +func (l *List) findPostingWithItr(readTs uint64, uid uint64, pitr pIterator) (found bool, pos *pb.Posting, err error) { // Iterate starts iterating after the given argument, so we pass UID - 1 // TODO Find what happens when uid = math.MaxUint64 searchFurther, pos := l.mutationMap.findPosting(readTs, uid) @@ -2056,7 +2098,6 @@ func (l *List) findPosting(readTs uint64, uid uint64) (found bool, pos *pb.Posti return false, nil, nil } - var pitr pIterator err = pitr.seek(l, uid-1, 0) if err != nil { return false, nil, errors.Wrapf(err, @@ -2080,6 +2121,11 @@ func (l *List) findPosting(readTs uint64, uid uint64) (found bool, pos *pb.Posti return false, nil, nil } +func (l *List) findPosting(readTs uint64, uid uint64) (found bool, pos *pb.Posting, err error) { + var pitr pIterator + return l.findPostingWithItr(readTs, uid, pitr) +} + // Facets gives facets for the posting representing value. func (l *List) Facets(readTs uint64, param *pb.FacetParams, langs []string, listType bool) ([]*pb.Facets, error) {