Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import ai.koog.agents.core.dsl.extension.replaceHistoryWithTLDR
import ai.koog.agents.core.prompt.Prompts.selectRelevantTools
import ai.koog.agents.core.tools.ToolDescriptor
import ai.koog.agents.core.tools.annotations.LLMDescription
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.llm.LLModel
import ai.koog.prompt.params.LLMParams
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredRequest
import ai.koog.prompt.structure.StructuredRequestConfig
import ai.koog.prompt.structure.json.JsonStructure
Expand Down Expand Up @@ -137,8 +137,8 @@ public open class AIAgentSubgraph<TInput, TOutput>(
examples = listOf(SelectedTools(listOf()), SelectedTools(tools.map { it.name }.take(3))),
),
),
fixingParser = toolSelectionStrategy.fixingParser,
)
),
fixingParser = toolSelectionStrategy.fixingParser,
).getOrThrow()

prompt = initialPrompt
Expand All @@ -157,7 +157,15 @@ public open class AIAgentSubgraph<TInput, TOutput>(
*/
@OptIn(InternalAgentsApi::class, DetachedPromptExecutorAPI::class, ExperimentalUuidApi::class)
override suspend fun execute(context: AIAgentGraphContextBase, input: TInput): TOutput? =
withContext(NodeInfoContextElement(Uuid.random().toString(), getNodeInfoElement()?.id, name, input, inputType)) {
withContext(
NodeInfoContextElement(
Uuid.random().toString(),
getNodeInfoElement()?.id,
name,
input,
inputType
)
) {
val newTools = selectTools(context)

// Copy inner context with new tools, model and LLM params.
Expand Down Expand Up @@ -201,7 +209,14 @@ public open class AIAgentSubgraph<TInput, TOutput>(
}

runIfNonRootContext(context) {
pipeline.onSubgraphExecutionCompleted(this@AIAgentSubgraph, innerContext, input, inputType, result, outputType)
pipeline.onSubgraphExecutionCompleted(
this@AIAgentSubgraph,
innerContext,
input,
inputType,
result,
outputType
)
}

result
Expand Down Expand Up @@ -284,7 +299,10 @@ public open class AIAgentSubgraph<TInput, TOutput>(
* effectively skipping execution for root contexts.
*/
@OptIn(InternalAgentsApi::class)
private suspend fun runIfNonRootContext(context: AIAgentGraphContextBase, action: suspend AIAgentGraphContextBase.() -> Unit) {
private suspend fun runIfNonRootContext(
context: AIAgentGraphContextBase,
action: suspend AIAgentGraphContextBase.() -> Unit
) {
if (context.parentContext == null) return
action(context)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ import ai.koog.agents.core.utils.ActiveProperty
import ai.koog.prompt.dsl.ModerationResult
import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.executor.model.PromptExecutor
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.executor.model.executeStructured
import ai.koog.prompt.executor.model.parseResponseToStructuredResponse
import ai.koog.prompt.llm.LLModel
import ai.koog.prompt.message.LLMChoice
import ai.koog.prompt.message.Message
import ai.koog.prompt.params.LLMParams
import ai.koog.prompt.streaming.StreamFrame
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredRequestConfig
import ai.koog.prompt.structure.StructuredResponse
import ai.koog.prompt.structure.executeStructured
import ai.koog.prompt.structure.parseResponseToStructuredResponse
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.KSerializer
import kotlinx.serialization.serializer
Expand Down Expand Up @@ -264,6 +264,7 @@ public sealed class AIAgentLLMSession(
*/
public open suspend fun <T> requestLLMStructured(
config: StructuredRequestConfig<T>,
fixingParser: StructureFixingParser? = null
): Result<StructuredResponse<T>> {
validateSession()

Expand All @@ -273,6 +274,7 @@ public sealed class AIAgentLLMSession(
prompt = preparedPrompt,
model = model,
config = config,
fixingParser = fixingParser
)
}

Expand Down Expand Up @@ -346,8 +348,9 @@ public sealed class AIAgentLLMSession(
*/
public suspend fun <T> parseResponseToStructuredResponse(
response: Message.Assistant,
config: StructuredRequestConfig<T>
): StructuredResponse<T> = executor.parseResponseToStructuredResponse(response, config, model)
config: StructuredRequestConfig<T>,
fixingParser: StructureFixingParser? = null
): StructuredResponse<T> = executor.parseResponseToStructuredResponse(response, config, model, fixingParser)

/**
* Sends a request to the language model, potentially receiving multiple choices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.dsl.PromptBuilder
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.executor.model.PromptExecutor
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.llm.LLModel
import ai.koog.prompt.message.Message
import ai.koog.prompt.params.LLMParams
import ai.koog.prompt.streaming.StreamFrame
import ai.koog.prompt.structure.StructureDefinition
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredRequestConfig
import ai.koog.prompt.structure.StructuredResponse
import kotlinx.coroutines.flow.Flow
Expand Down Expand Up @@ -444,8 +444,9 @@ public class AIAgentLLMWriteSession internal constructor(
*/
override suspend fun <T> requestLLMStructured(
config: StructuredRequestConfig<T>,
fixingParser: StructureFixingParser?
): Result<StructuredResponse<T>> {
return super.requestLLMStructured(config).also {
return super.requestLLMStructured(config, fixingParser).also {
it.onSuccess { response ->
appendPrompt {
message(response.message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import ai.koog.agents.core.tools.Tool
import ai.koog.agents.core.tools.ToolArgs
import ai.koog.agents.core.tools.ToolDescriptor
import ai.koog.agents.core.tools.ToolResult
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.message.Message
import ai.koog.prompt.streaming.StreamFrame
import ai.koog.prompt.structure.StructureDefinition
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredResponse
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.serializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ import ai.koog.agents.core.tools.ToolDescriptor
import ai.koog.prompt.dsl.ModerationResult
import ai.koog.prompt.dsl.PromptBuilder
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.llm.LLModel
import ai.koog.prompt.message.Message
import ai.koog.prompt.streaming.StreamFrame
import ai.koog.prompt.streaming.toMessageResponses
import ai.koog.prompt.structure.StructureDefinition
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredOutputPrompts.appendStructuredOutputInstructions
import ai.koog.prompt.structure.StructuredRequestConfig
import ai.koog.prompt.structure.StructuredResponse
import kotlinx.coroutines.flow.Flow
Expand Down Expand Up @@ -199,14 +200,15 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMModerateMessage(
public inline fun <reified T> AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestStructured(
name: String? = null,
config: StructuredRequestConfig<T>,
fixingParser: StructureFixingParser? = null,
): AIAgentNodeDelegate<String, Result<StructuredResponse<T>>> =
node(name) { message ->
llm.writeSession {
appendPrompt {
user(message)
}

requestLLMStructured(config)
requestLLMStructured(config, fixingParser)
}
}

Expand Down Expand Up @@ -540,7 +542,7 @@ public inline fun <reified TInput, T> AIAgentSubgraphBuilderBase<*, *>.nodeSetSt
): AIAgentNodeDelegate<TInput, TInput> =
node(name) { message ->
llm.writeSession {
prompt = config.updatePrompt(model, prompt)
prompt = appendStructuredOutputInstructions(prompt, config.structuredRequest(model))
message
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import ai.koog.agents.core.dsl.extension.onMultipleToolCalls
import ai.koog.agents.core.dsl.extension.onToolCall
import ai.koog.agents.core.environment.ReceivedToolResult
import ai.koog.agents.core.environment.result
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.message.Message
import ai.koog.prompt.structure.StructuredRequestConfig

Expand Down Expand Up @@ -184,9 +185,11 @@ public fun reActStrategy(
*/
public inline fun <reified Output> structuredOutputWithToolsStrategy(
config: StructuredRequestConfig<Output>,
fixingParser: StructureFixingParser? = null,
parallelTools: Boolean = false
): AIAgentGraphStrategy<String, Output> = structuredOutputWithToolsStrategy(
config,
fixingParser,
parallelTools
) { it }

Expand All @@ -204,6 +207,7 @@ public inline fun <reified Output> structuredOutputWithToolsStrategy(
*/
public inline fun <reified Input, reified Output> structuredOutputWithToolsStrategy(
config: StructuredRequestConfig<Output>,
fixingParser: StructureFixingParser? = null,
parallelTools: Boolean = false,
noinline transform: suspend AIAgentGraphContextBase.(input: Input) -> String
): AIAgentGraphStrategy<Input, Output> = strategy<Input, Output>("structured_output_with_tools_strategy") {
Expand All @@ -214,7 +218,7 @@ public inline fun <reified Input, reified Output> structuredOutputWithToolsStrat
val sendToolResult by nodeLLMSendMultipleToolResults()
val transformToStructuredOutput by node<Message.Assistant, Output> { response ->
llm.writeSession {
parseResponseToStructuredResponse(response, config).data
parseResponseToStructuredResponse(response, config, fixingParser).data
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
import ai.koog.agents.core.tools.annotations.LLMDescription
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.executor.clients.openai.OpenAIModels
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.llm.LLModel
import ai.koog.prompt.message.Message
import ai.koog.prompt.structure.StructureFixingParser
import kotlinx.serialization.Serializable

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import ai.koog.agents.ext.tool.AskUser
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.executor.clients.openai.OpenAIModels
import ai.koog.prompt.executor.llms.all.simpleOpenAIExecutor
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.executor.model.StructureFixingParser
import kotlinx.coroutines.runBlocking

suspend fun main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import ai.koog.prompt.executor.clients.google.structure.GoogleBasicJsonSchemaGen
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
import ai.koog.prompt.executor.clients.openai.base.structure.OpenAIBasicJsonSchemaGenerator
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.llm.LLMProvider
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredRequest
import ai.koog.prompt.structure.StructuredRequestConfig
import ai.koog.prompt.structure.json.JsonStructure
Expand Down Expand Up @@ -212,13 +212,12 @@ suspend fun main() {

// Fallback manual structured output mode, via explicit prompting with additional message, not native model support
default = StructuredRequest.Manual(genericWeatherStructure),

// Helper parser to attempt a fix if a malformed output is produced.
fixingParser = StructureFixingParser(
model = AnthropicModels.Haiku_3_5,
retries = 2,
),
)
),
// Helper parser to attempt a fix if a malformed output is produced.
fixingParser = StructureFixingParser(
model = AnthropicModels.Haiku_3_5,
retries = 2,
),
)

nodeStart then prepareRequest then getStructuredForecast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import ai.koog.prompt.executor.clients.google.structure.GoogleStandardJsonSchema
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
import ai.koog.prompt.executor.clients.openai.base.structure.OpenAIStandardJsonSchemaGenerator
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.llm.LLMProvider
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredRequest
import ai.koog.prompt.structure.StructuredRequestConfig
import ai.koog.prompt.structure.json.JsonStructure
Expand Down Expand Up @@ -81,13 +81,13 @@ suspend fun main() {

// Fallback manual structured output mode, via explicit prompting with additional message, not native model support
default = StructuredRequest.Manual(genericWeatherStructure),
),

// Helper parser to attempt a fix if a malformed output is produced.
fixingParser = StructureFixingParser(
model = AnthropicModels.Haiku_3_5,
retries = 2,
),
)
// Helper parser to attempt a fix if a malformed output is produced.
fixingParser = StructureFixingParser(
model = AnthropicModels.Haiku_3_5,
retries = 2,
),
)

nodeStart then prepareRequest then getStructuredForecast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
import ai.koog.prompt.executor.clients.openai.OpenAIModels
import ai.koog.prompt.executor.clients.openai.base.structure.OpenAIStandardJsonSchemaGenerator
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.llm.LLMProvider
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredRequest
import ai.koog.prompt.structure.StructuredRequestConfig
import ai.koog.prompt.structure.json.JsonStructure
Expand Down Expand Up @@ -71,16 +71,17 @@ suspend fun main() {

// Fallback manual structured output mode, via explicit prompting with additional message, not native model support
default = StructuredRequest.Manual(genericWeatherStructure),
)

// Helper parser to attempt a fix if a malformed output is produced.
fixingParser = StructureFixingParser(
model = AnthropicModels.Haiku_3_5,
retries = 2,
),
// Helper parser to attempt a fix if a malformed output is produced.
val fixingParser = StructureFixingParser(
model = AnthropicModels.Haiku_3_5,
retries = 2,
)

val agentStrategy = structuredOutputWithToolsStrategy<FullWeatherForecastRequest, FullWeatherForecast>(
config
config,
fixingParser
) { request ->
text {
+"Requesting forecast for"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import ai.koog.prompt.executor.clients.google.GoogleModels
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
import ai.koog.prompt.executor.clients.openai.OpenAIModels
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.llm.LLMProvider
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.text.text
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package ai.koog.integration.tests.utils.structuredOutput

import ai.koog.agents.core.tools.annotations.LLMDescription
import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.executor.model.StructureFixingParser
import ai.koog.prompt.llm.LLModel
import ai.koog.prompt.structure.RegisteredStandardJsonSchemaGenerators
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredRequest
import ai.koog.prompt.structure.StructuredRequestConfig
import ai.koog.prompt.structure.StructuredResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ kotlin {
dependencies {
api(project(":prompt:prompt-model"))
api(project(":agents:agents-tools"))
api(project(":prompt:prompt-structure"))
api(project(":prompt:prompt-executor:prompt-executor-model"))
api(libs.kotlinx.coroutines.core)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ public class DashscopeLLMClient(

private companion object {
private val staticLogger = KotlinLogging.logger { }

init {
// On class load register custom OpenAI JSON schema generators for structured output.
registerOpenAIJsonSchemaGenerators(LLMProvider.Alibaba)
}
}

override fun llmProvider(): LLMProvider = LLMProvider.Alibaba
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ public class DeepSeekLLMClient(

private companion object {
private val staticLogger = KotlinLogging.logger { }

init {
// On class load register custom OpenAI JSON schema generators for structured output.
registerOpenAIJsonSchemaGenerators(LLMProvider.DeepSeek)
}
}

/**
Expand Down
Loading
Loading