Skip to content

Commit abb98b0

Browse files
committed
Building LLM into next action decider.
1 parent 881a0ee commit abb98b0

File tree

4 files changed

+156
-71
lines changed

4 files changed

+156
-71
lines changed

ihmc-high-level-behaviors/src/libgdx/java/us/ihmc/rdx/ui/behavior/tree/RDXBehaviorTreeRootNode.java

+5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import imgui.ImGui;
55
import us.ihmc.behaviors.behaviorTree.BehaviorTreeRootNodeDefinition;
66
import us.ihmc.behaviors.behaviorTree.BehaviorTreeRootNodeState;
7+
import us.ihmc.behaviors.reasoning.BehaviorTreeLLMEncoding;
78
import us.ihmc.communication.crdt.CRDTInfo;
9+
import us.ihmc.log.LogTools;
810
import us.ihmc.rdx.imgui.ImBooleanWrapper;
911
import us.ihmc.rdx.imgui.ImGuiTools;
1012
import us.ihmc.rdx.imgui.ImGuiUniqueLabelMap;
@@ -85,6 +87,9 @@ public void renderContextMenuItems()
8587
{
8688
super.renderContextMenuItems();
8789

90+
if (ImGui.menuItem(labels.get("Print LLM Encoding")))
91+
LogTools.info("LLM Encoding:%n%s".formatted(BehaviorTreeLLMEncoding.encode(state)));
92+
8893
if (ImGui.menuItem(labels.get("Render Progress Using Plots"), null, progressWidgetsManager.getRenderAsPlots()))
8994
progressWidgetsManager.setRenderAsPlots(!progressWidgetsManager.getRenderAsPlots());
9095
}

ihmc-high-level-behaviors/src/main/java/us/ihmc/behaviors/behaviorTree/BehaviorTreeRootNodeExecutor.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import gnu.trove.map.hash.TLongObjectHashMap;
44
import org.apache.logging.log4j.Level;
5+
import us.ihmc.behaviors.reasoning.BehaviorTreeNextActionReasoning;
56
import us.ihmc.behaviors.sequence.ActionNodeExecutor;
67
import us.ihmc.behaviors.sequence.ActionNodeState;
78
import us.ihmc.behaviors.sequence.FallbackNodeExecutor;
@@ -23,6 +24,7 @@ public class BehaviorTreeRootNodeExecutor extends BehaviorTreeNodeExecutor<Behav
2324
private final List<LeafNodeExecutor<?, ?>> failedLeaves = new ArrayList<>();
2425
private final List<LeafNodeExecutor<?, ?>> successfulLeaves = new ArrayList<>();
2526
private final List<LeafNodeExecutor<?, ?>> failedLeavesWithoutFallback = new ArrayList<>();
27+
private final BehaviorTreeNextActionReasoning nextActionReasoning = new BehaviorTreeNextActionReasoning();
2628

2729
public BehaviorTreeRootNodeExecutor(long id, CRDTInfo crdtInfo, WorkspaceResourceDirectory saveFileDirectory)
2830
{
@@ -241,7 +243,9 @@ private void executeNextLeaf()
241243
leafToExecute.update();
242244
leafToExecute.triggerExecution();
243245
currentlyExecutingLeaves.add(leafToExecute);
244-
state.stepForwardNextExecutionIndex();
246+
int nextExecutionIndex = nextActionReasoning.queryNextLeafToExecuteIndex(state);
247+
state.setExecutionNextIndex(nextExecutionIndex);
248+
// state.stepForwardNextExecutionIndex();
245249
}
246250

247251
private boolean shouldExecuteNextLeaf()
@@ -299,7 +303,15 @@ public boolean isEndOfSequence()
299303
{
300304
return state.getExecutionNextIndex() >= orderedLeaves.size();
301305
}
302-
306+
307+
@Override
308+
public void destroy()
309+
{
310+
super.destroy();
311+
312+
nextActionReasoning.destroy();
313+
}
314+
303315
public TLongObjectHashMap<BehaviorTreeNodeExecutor<?, ?>> getIDToNodeMap()
304316
{
305317
return idToNodeMap;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package us.ihmc.behaviors.reasoning;
2+
3+
import us.ihmc.behaviors.behaviorTree.BehaviorTreeNodeState;
4+
import us.ihmc.behaviors.behaviorTree.BehaviorTreeRootNodeState;
5+
import us.ihmc.behaviors.sequence.ActionSequenceState;
6+
import us.ihmc.behaviors.sequence.LeafNodeState;
7+
import us.ihmc.log.LogTools;
8+
9+
public class BehaviorTreeLLMEncoding
10+
{
11+
public static String encode(BehaviorTreeRootNodeState rootNode)
12+
{
13+
StringBuilder builder = new StringBuilder();
14+
15+
builder.append("nodes: [\n");
16+
17+
encodeTree(rootNode, builder, 0);
18+
19+
builder.append(" ],%nstate: { execution_next_index: %d }".formatted(rootNode.getExecutionNextIndex()));
20+
21+
return builder.toString();
22+
}
23+
24+
private static void encodeTree(BehaviorTreeNodeState<?> node, StringBuilder builder, int indent)
25+
{
26+
builder.append("\t".repeat(indent));
27+
28+
if (node instanceof LeafNodeState<?> leafNode)
29+
{
30+
builder.append("{ type: leaf, index: %d, is_executing: %b, failed: %b, can_execute: %b }"
31+
.formatted(leafNode.getLeafIndex(),
32+
leafNode.getIsExecuting(),
33+
leafNode.getFailed(),
34+
leafNode.getCanExecute()));
35+
}
36+
else if (node instanceof ActionSequenceState sequenceNode)
37+
{
38+
builder.append("{ type: sequence, children: [\n");
39+
40+
for (BehaviorTreeNodeState<?> child : node.getChildren())
41+
{
42+
encodeTree(child, builder, indent + 1);
43+
builder.append("\n");
44+
}
45+
46+
builder.append("\t".repeat(indent));
47+
48+
builder.append("]");
49+
builder.append(" }");
50+
}
51+
else
52+
{
53+
LogTools.error("Implement node type: " + node.getClass().getSimpleName());
54+
55+
for (BehaviorTreeNodeState<?> child : node.getChildren())
56+
{
57+
encodeTree(child, builder, indent);
58+
}
59+
}
60+
61+
}
62+
}

ihmc-high-level-behaviors/src/main/java/us/ihmc/behaviors/reasoning/BehaviorTreeNextActionReasoning.java

+75-69
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import de.kherud.llama.LlamaModel;
55
import de.kherud.llama.ModelParameters;
66
import de.kherud.llama.args.MiroStat;
7+
import us.ihmc.behaviors.behaviorTree.BehaviorTreeRootNodeState;
78
import us.ihmc.commons.time.Stopwatch;
89
import us.ihmc.log.LogTools;
910
import us.ihmc.tools.IHMCCommonPaths;
@@ -13,79 +14,66 @@ public class BehaviorTreeNextActionReasoning
1314
private static final String SYSTEM = """
1415
<|start_header_id|>system<|end_header_id|>
1516
You are a reasoning system that decides the next action to execute in a tree-based robotic system.
16-
The current tree and state is given by:
17-
{
18-
"nodes": [
19-
{"id": int, "type": string, "children": [ ]}
20-
],
21-
"state": {
22-
"currently_executing": int
23-
"is_done": bool
24-
}
25-
}
26-
There are two node types: Action and Sequence.
27-
An Action node is the only type of node that can be executed.
28-
A Sequence node can have children. When one child of an action sequence node is done, the next one in the list of children should be executed.
29-
Please consider which node is best to execute next and output only the node ID number of that action.
17+
The following is a schema for how the tree will be represented for a query.
18+
There is a tree of nodes, where each node's type can be leaf or sequence.
19+
A sequence node has 0 or more children nodes.
20+
A leaf node does not have any children.
21+
The leaves are depth-first ordered and their position in this ordering is given by the index field.
22+
Each leaf node also has boolean fields for whether it is currently executing, has failed, and can execute.
23+
The state portion of the scheme gives the global state of the tree.
24+
The state has a field called execution next index, which is the index of the next node to execute.
25+
nodes: [
26+
{ type: sequence, children: [
27+
{ type: leaf, index: int, is_executing: bool, failed: bool, can_execute: bool } }
28+
] } ],
29+
state: { execution_next_index: int }
30+
A sequence node defines the order of execution of the children as one after the other.
31+
The next node to execute should be the one after the last one that is executing.
32+
If no node's are executing, the next node to execute should remain unchanged.
33+
Your task is to decide the next left to execute by providing its index.
3034
<|eot_id|>
3135
<|start_header_id|>user<|end_header_id|>
32-
{
33-
"nodes": [
34-
{"id": 001, "type": "Sequence, "children": [
35-
{"id": 002, "type": "Action"},
36-
{"id": 005, "type": "Action"},
37-
{"id": 020, "type": "Action"},
38-
{"id": 004, "type": "Action"},
39-
{"id": 056, "type": "Action"}
40-
]}
41-
],
42-
"state": {
43-
"currently_executing": 002,
44-
"is_done": true
45-
}
46-
}
36+
nodes: [
37+
{ type: sequence, children: [
38+
{ type: leaf, index: 0, is_executing: false, failed: false, can_execute: true } }
39+
{ type: leaf, index: 1, is_executing: false, failed: false, can_execute: true } }
40+
{ type: leaf, index: 2, is_executing: false, failed: false, can_execute: true } }
41+
{ type: leaf, index: 3, is_executing: false, failed: false, can_execute: true } }
42+
{ type: leaf, index: 4, is_executing: false, failed: false, can_execute: true } }
43+
] } ],
44+
state: { execution_next_index: 0 }
4745
<|eot_id|>
4846
<|start_header_id|>assistant<|end_header_id|>
49-
005
47+
0
5048
<|eot_id|>
5149
<|start_header_id|>user<|end_header_id|>
52-
{
53-
"nodes": [
54-
{"id": 001, "type": "Sequence, "children": [
55-
{"id": 002, "type": "Action"},
56-
{"id": 005, "type": "Action"},
57-
{"id": 020, "type": "Action"},
58-
{"id": 004, "type": "Action"},
59-
{"id": 056, "type": "Action"}
60-
]}
61-
],
62-
"state": {
63-
"currently_executing": 005,
64-
"is_done": true
65-
}
66-
}
50+
nodes: [
51+
{ type: sequence, children: [
52+
{ type: leaf, index: 0, is_executing: true, failed: false, can_execute: true } }
53+
{ type: leaf, index: 1, is_executing: false, failed: false, can_execute: true } }
54+
{ type: leaf, index: 2, is_executing: false, failed: false, can_execute: true } }
55+
{ type: leaf, index: 3, is_executing: false, failed: false, can_execute: true } }
56+
{ type: leaf, index: 4, is_executing: false, failed: false, can_execute: true } }
57+
] } ],
58+
state: { execution_next_index: 0 }
6759
<|eot_id|>
6860
<|start_header_id|>assistant<|end_header_id|>
69-
020
61+
1
7062
<|eot_id|>
7163
<|start_header_id|>user<|end_header_id|>
72-
{
73-
"nodes": [
74-
{"id": 001, "type": "Sequence, "children": [
75-
{"id": 002, "type": "Action"},
76-
{"id": 005, "type": "Action"},
77-
{"id": 020, "type": "Action"},
78-
{"id": 004, "type": "Action"},
79-
{"id": 056, "type": "Action"}
80-
]}
81-
],
82-
"state": {
83-
"currently_executing": 020,
84-
"is_done": true
85-
}
86-
}
64+
nodes: [
65+
{ type: sequence, children: [
66+
{ type: leaf, index: 0, is_executing: false, failed: false, can_execute: true } }
67+
{ type: leaf, index: 1, is_executing: false, failed: false, can_execute: true } }
68+
{ type: leaf, index: 2, is_executing: true, failed: false, can_execute: true } }
69+
{ type: leaf, index: 3, is_executing: false, failed: false, can_execute: true } }
70+
{ type: leaf, index: 4, is_executing: false, failed: false, can_execute: true } }
71+
] } ],
72+
state: { execution_next_index: 2 }
8773
<|eot_id|>
8874
<|start_header_id|>assistant<|end_header_id|>
75+
3
76+
<|eot_id|>
8977
""";
9078

9179

@@ -105,12 +93,21 @@ public BehaviorTreeNextActionReasoning()
10593
model = new LlamaModel(modelParams);
10694
}
10795

108-
public int queryNextLeafToExecuteIndex()
96+
public int queryNextLeafToExecuteIndex(BehaviorTreeRootNodeState rootNode)
97+
{
98+
String treeEncoding = BehaviorTreeLLMEncoding.encode(rootNode);
99+
return queryNextLeafToExecuteIndex(treeEncoding);
100+
}
101+
102+
public int queryNextLeafToExecuteIndex(String treeEncoding)
109103
{
110104
String prompt = SYSTEM;
111-
// prompt += """
112-
// Hello!
113-
// """;
105+
prompt += """
106+
<|start_header_id|>user<|end_header_id|>
107+
%s
108+
<|eot_id|>
109+
<|start_header_id|>assistant<|end_header_id|>
110+
""".formatted(treeEncoding);
114111

115112
InferenceParameters inferParams = new InferenceParameters(prompt);
116113
inferParams.setPenalizeNl(true);
@@ -123,13 +120,12 @@ public int queryNextLeafToExecuteIndex()
123120

124121
String reponse = model.complete(inferParams);
125122

126-
// LogTools.info(prompt + reponse);
127-
//
128-
// LogTools.info("Response: {}", reponse);
123+
LogTools.info(prompt + reponse);
129124

130125
return Integer.parseInt(reponse.trim());
131126
}
132127

128+
// FIXME: Doesn't work yet
133129
public void destroy()
134130
{
135131
model.close();
@@ -142,7 +138,17 @@ public static void main(String[] args)
142138
for (int i = 0; i < 10; i++)
143139
{
144140
Stopwatch stopwatch = new Stopwatch().start();
145-
int leafIndex = reasoning.queryNextLeafToExecuteIndex();
141+
int leafIndex = reasoning.queryNextLeafToExecuteIndex("""
142+
nodes: [
143+
{ type: sequence, children: [
144+
{ type: leaf, index: 0, is_executing: false, failed: false, can_execute: true } }
145+
{ type: leaf, index: 1, is_executing: false, failed: false, can_execute: true } }
146+
{ type: leaf, index: 2, is_executing: false, failed: false, can_execute: true } }
147+
{ type: leaf, index: 3, is_executing: true, failed: false, can_execute: true } }
148+
{ type: leaf, index: 4, is_executing: false, failed: false, can_execute: true } }
149+
] } ],
150+
state: { execution_next_index: 2 }
151+
""");
146152
LogTools.info("Returned {} in {} seconds", leafIndex, stopwatch.totalElapsed());
147153
}
148154

0 commit comments

Comments
 (0)