Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glebashnik/feed field generator #32842

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b954ab2
wip to support generate indexing statement
glebashnik Oct 28, 2024
9461c66
Initial support for generate in indexing language
glebashnik Oct 29, 2024
98073fe
draft generate expression
glebashnik Nov 5, 2024
7a293db
Refactoring. LocallLLM and OpenAI implement Generator.
glebashnik Nov 5, 2024
82ff3a9
Fix build
lesters Nov 14, 2024
baccc99
Remove dependency on document in linguistics
lesters Nov 14, 2024
ac9357b
Fix tests
lesters Nov 14, 2024
8d91388
Set input and output types for generate expression
lesters Nov 14, 2024
8e56319
Merge branch 'master' into glebashnik/feed-field-generator
lesters Nov 14, 2024
420b943
Update GenerateExpression after merge with master branch
lesters Nov 14, 2024
cc0218a
Wire in generators to indexing processor
lesters Nov 20, 2024
261bfd6
Resolve conflicts with master
lesters Nov 20, 2024
69baaa4
Use Prompt insteaad of String in Generators
lesters Nov 21, 2024
0f0d6a9
Merge branch 'master' into glebashnik/feed-field-generator
lesters Nov 29, 2024
f4869ec
ConfigurableLanguageModel implements Generator
lesters Nov 29, 2024
40b4a19
Renamed Generator interface to TextGenerator, added support for array…
glebashnik Jan 3, 2025
7790a0a
Improved input/output type inference for generate expression.
glebashnik Jan 3, 2025
e75e566
Improved input/output type error messages
glebashnik Jan 3, 2025
46e3add
Descriptive (future-proof) names for text generator component classes.
glebashnik Jan 3, 2025
fb7aed2
Added prompt template to LanguageModelTextGenerator
glebashnik Jan 3, 2025
a5d466e
Added max length to LanguageModelTextGenerator
glebashnik Jan 6, 2025
9d87bbe
Added tests with a tiny LLM, fixed issue with inference parameters.
glebashnik Jan 6, 2025
74b89fb
Merging with master
glebashnik Jan 6, 2025
63b12fa
Fixed forgotten changes in jj and abi-spec after Generator to TextGen…
glebashnik Jan 6, 2025
69d0653
Renamed Generator to TextGenerator in IndexingParser.ccc
glebashnik Jan 6, 2025
85d7bc5
Added LanguageModelTextGenerator to PlatformBundles
glebashnik Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions config-model-api/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -1812,8 +1812,8 @@
"public final java.lang.String toString()",
"public final int hashCode()",
"public final boolean equals(java.lang.Object)",
"public java.lang.String name()",
"public java.lang.String id()"
"public java.lang.String id()",
"public java.lang.String name()"
],
"fields" : [ ]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.yahoo.documentmodel.TemporaryUnknownType;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.TextGenerator;
import com.yahoo.language.simple.SimpleLinguistics;
import com.yahoo.schema.Index;
import com.yahoo.schema.Schema;
Expand Down Expand Up @@ -399,12 +400,13 @@ public boolean hasSingleAttribute() {

/** Parse an indexing expression which will use the simple linguistics implementation suitable for testing */
public void parseIndexingScript(String schemaName, String script) {
parseIndexingScript(schemaName, script, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
parseIndexingScript(schemaName, script, new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), TextGenerator.throwsOnUse.asMap());
}

public void parseIndexingScript(String schemaName, String script, Linguistics linguistics, Map<String, Embedder> embedders) {
public void parseIndexingScript(String schemaName, String script, Linguistics linguistics,
Map<String, Embedder> embedders, Map<String, TextGenerator> generators) {
try {
ScriptParserContext config = new ScriptParserContext(linguistics, embedders);
ScriptParserContext config = new ScriptParserContext(linguistics, embedders, generators);
config.setInputStream(new IndexingInput(script));
setIndexingScript(schemaName, ScriptExpression.newInstance(config));
} catch (ParseException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.TextGenerator;
import com.yahoo.language.simple.SimpleLinguistics;
import com.yahoo.schema.document.SDField;
import com.yahoo.schema.parser.ParseException;
Expand Down Expand Up @@ -34,13 +35,14 @@ public void apply(String schemaName, SDField field) {

/** Creates an indexing operation which will use the simple linguistics implementation suitable for testing */
public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine) throws ParseException {
return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse.asMap(),
TextGenerator.throwsOnUse.asMap());
}

public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine,
Linguistics linguistics, Map<String, Embedder> embedders)
throws ParseException {
ScriptParserContext config = new ScriptParserContext(linguistics, embedders);
public static IndexingOperation fromStream(
SimpleCharStream input, boolean multiLine, Linguistics linguistics, Map<String, Embedder> embedders,
Map<String, TextGenerator> generators) throws ParseException {
ScriptParserContext config = new ScriptParserContext(linguistics, embedders, generators);
config.setAnnotatorConfig(new AnnotatorConfig());
config.setInputStream(input);
ScriptExpression exp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ public ApplicationContainerCluster(TreeConfigProducer<?> parent, String configSu

addSimpleComponent("com.yahoo.language.provider.DefaultLinguisticsProvider");
addSimpleComponent("com.yahoo.language.provider.DefaultEmbedderProvider");
addSimpleComponent("com.yahoo.language.provider.DefaultGeneratorProvider");
addSimpleComponent("com.yahoo.container.jdisc.SecretStoreProvider");
addSimpleComponent("com.yahoo.container.jdisc.CertificateStoreProvider");
addSimpleComponent("com.yahoo.container.jdisc.AthenzIdentityProviderProvider");
Expand Down
13 changes: 6 additions & 7 deletions config-model/src/main/javacc/SchemaParser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,11 @@ import com.yahoo.schema.document.MatchAlgorithm;
import com.yahoo.schema.document.HnswIndexParams;
import com.yahoo.schema.document.Sorting;
import com.yahoo.schema.document.Stemming;
import com.yahoo.schema.document.SDField;
import com.yahoo.schema.FeatureNames;
import com.yahoo.schema.fieldoperation.IndexingOperation;
import com.yahoo.search.schema.RankProfile.InputType;
import com.yahoo.searchlib.rankingexpression.FeatureList;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
Expand All @@ -49,7 +46,6 @@ import java.util.Optional;
import java.util.Map;
import java.util.List;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.logging.Level;

/**
Expand Down Expand Up @@ -82,7 +78,7 @@ public class SchemaParser {
*/
@SuppressWarnings("deprecation")
private IndexingOperation newIndexingOperation(boolean multiline) throws ParseException {
return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap());
}

/**
Expand All @@ -91,13 +87,15 @@ public class SchemaParser {
* @param multiline Whether or not to allow multi-line expressions.
* @param linguistics What to use for tokenizing.
*/
private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
private IndexingOperation newIndexingOperation(
boolean multiline, Linguistics linguistics, Map<String, Embedder> embedders,
Map<String, Generator> generators) throws ParseException {
SimpleCharStream input = (SimpleCharStream)token_source.input_stream;
if (token.next != null) {
input.backup(token.next.image.length());
}
try {
return IndexingOperation.fromStream(input, multiline, linguistics, embedders);
return IndexingOperation.fromStream(input, multiline, linguistics, embedders, generators);
} finally {
token.next = null;
jj_ntk = -1;
Expand All @@ -121,6 +119,7 @@ public class SchemaParser {
token_source.input_stream.getBeginColumn() + ".").initCause(e);
}
}
}}
}

PARSER_END(SchemaParser)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.language.provider;

import com.yahoo.component.annotation.Inject;
import com.yahoo.container.di.componentgraph.Provider;
import com.yahoo.language.process.TextGenerator;

/**
* Provides the default generator implementation if no generator component has been explicitly configured
* (dependency injection will fall back to providers if no components of the requested type is found).
*
* @author lesters
*/
@SuppressWarnings("unused") // Injected
public class DefaultGeneratorProvider implements Provider<TextGenerator> {

@Inject
public DefaultGeneratorProvider() { }

@Override
public TextGenerator get() { return TextGenerator.throwsOnUse; }

@Override
public void deconstruct() {}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.TextGenerator;
import com.yahoo.language.provider.DefaultEmbedderProvider;
import com.yahoo.language.provider.DefaultGeneratorProvider;
import com.yahoo.vespa.configdefinition.IlscriptsConfig;
import com.yahoo.vespa.indexinglanguage.AdapterFactory;
import com.yahoo.vespa.indexinglanguage.SimpleAdapterFactory;
Expand Down Expand Up @@ -58,9 +60,12 @@ public Expression selectExpression(DocumentType documentType, String fieldName)
public IndexingProcessor(DocumentTypeManager documentTypeManager,
IlscriptsConfig ilscriptsConfig,
Linguistics linguistics,
ComponentRegistry<Embedder> embedders) {
ComponentRegistry<Embedder> embedders,
ComponentRegistry<TextGenerator> generators) {
this.documentTypeManager = documentTypeManager;
scriptManager = new ScriptManager(this.documentTypeManager, ilscriptsConfig, linguistics, toMap(embedders));
Map<String, Embedder> embedderMap = toMap(embedders, DefaultEmbedderProvider.class);
Map<String, TextGenerator> generatorMap = toMap(generators, DefaultGeneratorProvider.class);
scriptManager = new ScriptManager(this.documentTypeManager, ilscriptsConfig, linguistics, embedderMap, generatorMap);
adapterFactory = new SimpleAdapterFactory(new ExpressionSelector());
}

Expand Down Expand Up @@ -132,11 +137,11 @@ private void processRemove(DocumentRemove input, List<DocumentOperation> out) {
out.add(input);
}

private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
var map = embedders.allComponentsById().entrySet().stream()
.collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue));
private <T> Map<String, T> toMap(ComponentRegistry<T> registry, Class<?> defaultProviderClass) {
var map = registry.allComponentsById().entrySet().stream()
.collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue));
if (map.size() > 1) {
map.remove(DefaultEmbedderProvider.class.getName());
map.remove(defaultProviderClass.getName());
// Ideally, this should be handled by dependency injection, however for now this workaround is necessary.
}
return map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.yahoo.language.Linguistics;

import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.TextGenerator;
import com.yahoo.vespa.configdefinition.IlscriptsConfig;
import com.yahoo.vespa.indexinglanguage.ScriptParserContext;
import com.yahoo.vespa.indexinglanguage.expressions.InputExpression;
Expand All @@ -31,9 +32,9 @@ public class ScriptManager {
private final DocumentTypeManager documentTypeManager;

public ScriptManager(DocumentTypeManager documentTypeManager, IlscriptsConfig config, Linguistics linguistics,
Map<String, Embedder> embedders) {
Map<String, Embedder> embedders, Map<String, TextGenerator> generators) {
this.documentTypeManager = documentTypeManager;
documentFieldScripts = createScriptsMap(documentTypeManager, config, linguistics, embedders);
documentFieldScripts = createScriptsMap(documentTypeManager, config, linguistics, embedders, generators);
}

private Map<String, DocumentScript> getScripts(DocumentType inputType) {
Expand Down Expand Up @@ -70,9 +71,10 @@ public DocumentScript getScript(DocumentType inputType, String inputFieldName) {
private static Map<String, Map<String, DocumentScript>> createScriptsMap(DocumentTypeManager documentTypes,
IlscriptsConfig config,
Linguistics linguistics,
Map<String, Embedder> embedders) {
Map<String, Embedder> embedders,
Map<String, TextGenerator> generators) {
Map<String, Map<String, DocumentScript>> documentFieldScripts = new HashMap<>(config.ilscript().size());
ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders);
ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders, generators);
parserContext.getAnnotatorConfig().setMaxTermOccurrences(config.maxtermoccurrences());
parserContext.getAnnotatorConfig().setMaxTokenizeLength(config.fieldmatchmaxlength());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ private static IndexingProcessor newProcessor(String configId) {
return new IndexingProcessor(new DocumentTypeManager(ConfigGetter.getConfig(DocumentmanagerConfig.class, configId)),
ConfigGetter.getConfig(IlscriptsConfig.class, configId),
new SimpleLinguistics(),
new ComponentRegistry<>(),
new ComponentRegistry<>());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.yahoo.document.DocumentType;
import com.yahoo.document.DocumentTypeManager;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.TextGenerator;
import com.yahoo.vespa.configdefinition.IlscriptsConfig;
import org.junit.Test;

Expand All @@ -27,7 +28,7 @@ public void requireThatScriptsAreAppliedToSubType() {
IlscriptsConfig.Builder config = new IlscriptsConfig.Builder();
config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newssummary")
.content("input title | index title"));
ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap());
ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap(), TextGenerator.throwsOnUse.asMap());
assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newsarticle")));
assertNull(scriptMgr.getScript(new DocumentType("unknown")));
}
Expand All @@ -41,22 +42,22 @@ public void requireThatScriptsAreAppliedToSuperType() {
IlscriptsConfig.Builder config = new IlscriptsConfig.Builder();
config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newsarticle")
.content("input title | index title"));
ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap());
ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap(), TextGenerator.throwsOnUse.asMap());
assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newssummary")));
assertNull(scriptMgr.getScript(new DocumentType("unknown")));
}

@Test
public void requireThatEmptyConfigurationDoesNotThrow() {
var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg");
ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap());
ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap(), TextGenerator.throwsOnUse.asMap());
assertNull(scriptMgr.getScript(new DocumentType("unknown")));
}

@Test
public void requireThatUnknownDocumentTypeReturnsNull() {
var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg");
ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap());
ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap(), TextGenerator.throwsOnUse.asMap());
for (Iterator<DocumentType> it = typeMgr.documentTypeIterator(); it.hasNext(); ) {
assertNull(scriptMgr.getScript(it.next()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ private static <T extends Expression> T parse(ScriptParserContext context, Parse
parser.setDefaultFieldName(context.getDefaultFieldName());
parser.setLinguistics(context.getLinguistcs());
parser.setEmbedders(context.getEmbedders());
parser.setGenerators(context.getGenerators());

try {
return method.call(parser);
} catch (ParseException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.TextGenerator;
import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig;
import com.yahoo.vespa.indexinglanguage.parser.CharStream;

Expand All @@ -17,12 +18,14 @@ public class ScriptParserContext {
private AnnotatorConfig annotatorConfig = new AnnotatorConfig();
private Linguistics linguistics;
private final Map<String, Embedder> embedders;
private final Map<String, TextGenerator> generators;
private String defaultFieldName = null;
private CharStream inputStream = null;

public ScriptParserContext(Linguistics linguistics, Map<String, Embedder> embedders) {
public ScriptParserContext(Linguistics linguistics, Map<String, Embedder> embedders, Map<String, TextGenerator> generators) {
this.linguistics = linguistics;
this.embedders = embedders;
this.generators = generators;
}

public AnnotatorConfig getAnnotatorConfig() {
Expand All @@ -46,6 +49,9 @@ public ScriptParserContext setLinguistics(Linguistics linguistics) {
public Map<String, Embedder> getEmbedders() {
return Collections.unmodifiableMap(embedders);
}

public Map<String, TextGenerator> getGenerators() { return Collections.unmodifiableMap(generators);
}

public String getDefaultFieldName() {
return defaultFieldName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.TextGenerator;
import com.yahoo.language.simple.SimpleLinguistics;
import com.yahoo.vespa.indexinglanguage.*;
import com.yahoo.vespa.indexinglanguage.parser.IndexingInput;
Expand Down Expand Up @@ -301,7 +302,11 @@ public static Expression fromString(String expression) throws ParseException {
}

public static Expression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression)));
return newInstance(new ScriptParserContext(linguistics, embedders, Map.of()).setInputStream(new IndexingInput(expression)));
}

public static Expression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders, Map<String, TextGenerator> generators) throws ParseException {
return newInstance(new ScriptParserContext(linguistics, embedders, generators).setInputStream(new IndexingInput(expression)));
}

public static Expression newInstance(ScriptParserContext context) throws ParseException {
Expand Down
Loading
Loading