Skip to content

Commit 4d96e4f

Browse files
authored
feat: improve searchtools (#146)
1 parent 5c022be commit 4d96e4f

File tree

3 files changed

+50
-38
lines changed

3 files changed

+50
-38
lines changed

examples/http_client/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func main() {
125125
}
126126

127127
// Discover tools
128-
tools, err := client.SearchTools("", 10)
128+
tools, err := client.SearchTools("http", 10)
129129
if err != nil {
130130
log.Fatalf("search: %v", err)
131131
}

src/tag/tag_search.go

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func NewTagSearchStrategy(repo ToolRepository, descriptionWeight float64) *TagSe
3030
// SearchTools returns tools ordered by relevance to the query, using explicit tags and description keywords.
3131
func (s *TagSearchStrategy) SearchTools(ctx context.Context, query string, limit int) ([]Tool, error) {
3232
// Normalize query
33-
queryLower := strings.ToLower(query)
33+
queryLower := strings.ToLower(strings.TrimSpace(query))
3434
words := s.wordRegex.FindAllString(queryLower, -1)
3535
queryWordSet := make(map[string]struct{}, len(words))
3636
for _, w := range words {
@@ -43,57 +43,74 @@ func (s *TagSearchStrategy) SearchTools(ctx context.Context, query string, limit
4343
return nil, err
4444
}
4545

46-
// SUTCP each tool
47-
type sUTCPdTool struct {
48-
t Tool
49-
sUTCP float64
46+
// Compute SUTCP score for each tool
47+
type scoredTool struct {
48+
tool Tool
49+
score float64
5050
}
51-
var sUTCPd []sUTCPdTool
51+
var scored []scoredTool
52+
5253
for _, t := range tools {
53-
var sUTCP float64
54+
var score float64
5455

55-
// SUTCP from tags
56+
// Match against tags
5657
for _, tag := range t.Tags {
5758
tagLower := strings.ToLower(tag)
59+
60+
// Direct substring match
5861
if strings.Contains(queryLower, tagLower) {
59-
sUTCP += 1.0
62+
score += 1.0
6063
}
61-
// Partial matches on tag words
64+
65+
// Word-level overlap
6266
tagWords := s.wordRegex.FindAllString(tagLower, -1)
6367
for _, w := range tagWords {
6468
if _, ok := queryWordSet[w]; ok {
65-
sUTCP += s.descriptionWeight
69+
score += s.descriptionWeight
6670
}
6771
}
6872
}
6973

70-
// SUTCP from description
74+
// Match against description
7175
if t.Description != "" {
7276
descWords := s.wordRegex.FindAllString(strings.ToLower(t.Description), -1)
7377
for _, w := range descWords {
7478
if len(w) > 2 {
7579
if _, ok := queryWordSet[w]; ok {
76-
sUTCP += s.descriptionWeight
80+
score += s.descriptionWeight
7781
}
7882
}
7983
}
8084
}
8185

82-
sUTCPd = append(sUTCPd, sUTCPdTool{t: t, sUTCP: sUTCP})
86+
scored = append(scored, scoredTool{tool: t, score: score})
8387
}
8488

85-
// Sort by descending sUTCP
86-
sort.Slice(sUTCPd, func(i, j int) bool {
87-
return sUTCPd[i].sUTCP > sUTCPd[j].sUTCP
89+
// Sort descending by score
90+
sort.Slice(scored, func(i, j int) bool {
91+
return scored[i].score > scored[j].score
8892
})
8993

90-
// Return up to limit
94+
// Collect only positive matches
9195
var result []Tool
92-
for i, st := range sUTCPd {
93-
if i >= limit {
94-
break
96+
for _, st := range scored {
97+
if st.score > 0 {
98+
result = append(result, st.tool)
99+
if len(result) >= limit {
100+
break
101+
}
95102
}
96-
result = append(result, st.t)
97103
}
104+
105+
// If no matches, fallback to top N (for discoverability)
106+
if len(result) == 0 && len(scored) > 0 {
107+
for i, st := range scored {
108+
if i >= limit {
109+
break
110+
}
111+
result = append(result, st.tool)
112+
}
113+
}
114+
98115
return result, nil
99116
}

utcp_client.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -517,26 +517,21 @@ func (c *UtcpClient) CallTool(
517517
return fn(ctx, args)
518518
}
519519

520-
func (c *UtcpClient) SearchTools(query string, limit int) ([]Tool, error) {
521-
tools, err := c.searchStrategy.SearchTools(context.Background(), query, limit)
520+
func (c *UtcpClient) SearchTools(providerPrefix string, limit int) ([]Tool, error) {
521+
all, err := c.toolRepository.GetTools(context.Background())
522522
if err != nil {
523523
return nil, err
524524
}
525-
526-
// Convert []*Tool to []Tool if needed
527-
result := make([]Tool, len(tools))
528-
for i, tool := range tools {
529-
switch t := any(tool).(type) {
530-
case Tool:
531-
result[i] = t
532-
case *Tool:
533-
result[i] = *t
534-
default:
535-
// fallback (shouldn't happen)
536-
result[i] = Tool{}
525+
var filtered []Tool
526+
for _, t := range all {
527+
if strings.HasPrefix(t.Name, providerPrefix+".") {
528+
filtered = append(filtered, t)
537529
}
538530
}
539-
return result, nil
531+
if len(filtered) == 0 {
532+
return nil, fmt.Errorf("no tools found for provider %q", providerPrefix)
533+
}
534+
return filtered, nil
540535
}
541536

542537
// ----- variable substitution src -----

0 commit comments

Comments
 (0)