Skip to content

Commit 7b49503

Browse files
committed
Stronger model
1 parent d1140fa commit 7b49503

5 files changed

Lines changed: 426 additions & 66 deletions

File tree

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,48 @@
11
package com.embabel.grouper;
22

3+
import com.embabel.agent.api.common.autonomy.AgentInvocation;
4+
import com.embabel.agent.config.models.OpenAiModels;
35
import com.embabel.agent.core.AgentPlatform;
6+
import com.embabel.agent.core.Verbosity;
7+
import com.embabel.common.ai.model.LlmOptions;
8+
import com.embabel.grouper.agent.Domain;
9+
import com.embabel.grouper.agent.PromptedParticipant;
410
import org.springframework.shell.standard.ShellComponent;
11+
import org.springframework.shell.standard.ShellMethod;
12+
13+
import java.util.LinkedList;
14+
import java.util.List;
515

616
@ShellComponent
717
record GrouperShell(AgentPlatform agentPlatform) {
818

9-
// @ShellMethod("Demo")
10-
// String demo() {
11-
// // Illustrate calling an agent programmatically,
12-
// // as most often occurs in real applications.
13-
// var reviewedStory = AgentInvocation
14-
// .create(agentPlatform, WriteAndReviewAgent.ReviewedStory.class)
15-
// .invoke(new UserInput("Tell me a story about caterpillars"));
16-
// return reviewedStory.getContent();
17-
// }
19+
@ShellMethod("Demo")
20+
String demo() {
21+
var participants = new LinkedList<Domain.Participant>();
22+
var promptedParticipant = new PromptedParticipant(
23+
"Alice",
24+
LlmOptions.withModel(OpenAiModels.GPT_41_MINI),
25+
"""
26+
You are a 15 year old girl who lives in Richmond and loves Taylor Swift and tennis
27+
"""
28+
);
29+
participants.add(promptedParticipant);
30+
var focusGroup = new Domain.FocusGroup(participants);
31+
32+
var positioning = new Domain.Positioning(List.of(
33+
new Domain.MessageTest(
34+
new Domain.Message("nosmoke", "smoking is bad", "To deter the participant from wanting to smoke"),
35+
"Don't smoke as it will kill you",
36+
"Smoking is uncool",
37+
"Smoking will give you cancer",
38+
"Boys won't want to kiss you if you stink of cigarette smoke")
39+
));
40+
41+
var focusGroupRun = AgentInvocation.builder(agentPlatform)
42+
.options(o -> o.verbosity(new Verbosity(true, false, false, false)))
43+
.build(Domain.FocusGroupRun.class)
44+
.invoke(focusGroup, participants, positioning);
45+
return focusGroupRun.toString();
46+
}
1847

1948
}

src/main/java/com/embabel/grouper/agent/Domain.java

Lines changed: 133 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,56 @@
11
package com.embabel.grouper.agent;
22

3+
import com.embabel.common.ai.model.LlmOptions;
34
import com.embabel.common.ai.prompt.PromptContributor;
45

56
import java.time.Instant;
6-
import java.util.HashMap;
7-
import java.util.LinkedList;
8-
import java.util.List;
9-
import java.util.Map;
7+
import java.util.*;
108

119
public abstract class Domain {
1210

1311
/**
1412
* A message to be evaluated
1513
*
16-
* @param id id of the message, in case we have variants
17-
* @param content content of this instance of the message
14+
* @param id id of the message, in case we have variants
15+
* @param detail detail of this instance of the message
16+
* @param objective objective
1817
*/
1918
public record Message(
2019
String id,
21-
String content) {
20+
String detail,
21+
String objective) {
2222
}
2323

2424
/**
25-
* Map from id to Message
25+
* Expression of a message
2626
*
27-
* @param messaging
27+
* @param message
28+
* @param expression
2829
*/
29-
public record Positioning(Map<String, List<Message>> messaging) {
30+
public record MessageExpression(
31+
Message message,
32+
String expression
33+
) {
34+
}
35+
36+
public record MessageTest(
37+
Message message,
38+
List<MessageExpression> expressions
39+
) {
40+
41+
public MessageTest(Message message, String... expressions) {
42+
this(message,
43+
Arrays.stream(expressions).map(e -> new MessageExpression(message, e)).toList());
44+
}
45+
}
46+
47+
/**
48+
* Map from name to Message
49+
* This allows us to test multiple variants of the same
50+
* message name
51+
*
52+
*/
53+
public record Positioning(List<MessageTest> messageTests) {
3054
}
3155

3256
public enum Gender {
@@ -40,15 +64,17 @@ public interface Participant extends PromptContributor {
4064

4165
String name();
4266

43-
/***
44-
* A participant may run on multiple models
45-
*/
46-
String model();
67+
LlmOptions llm();
4768
}
4869

4970
public record FocusGroup(
50-
List<Participant> participants,
51-
Instant timestamp
71+
List<Participant> participants
72+
) {
73+
}
74+
75+
public record FocusGroupSubmission(
76+
FocusGroup focusGroup,
77+
Positioning positioning
5278
) {
5379
}
5480

@@ -67,11 +93,17 @@ public record Reaction(
6793
) {
6894
}
6995

96+
// public record MessageSubmission(
97+
// Message message,
98+
// Participant participant
99+
// ) {
100+
// }
101+
70102
/**
71103
* Reaction of one participant to a given message
72104
*/
73105
public record SpecificReaction(
74-
Message message,
106+
MessageExpression message,
75107
Participant participant,
76108
Reaction reaction,
77109
Instant timestamp
@@ -87,30 +119,111 @@ public record MessageScore(
87119
) {
88120
}
89121

122+
/**
123+
* Represents a combination of a participant and a message expression
124+
*/
125+
public record ParticipantMessagePresentation(
126+
Participant participant,
127+
MessageExpression messageExpression
128+
) {
129+
}
130+
131+
/**
132+
* Matrix tracking completion status for all participant/message expression combinations
133+
*/
134+
public record CompletionMatrix(
135+
Map<ParticipantMessagePresentation, Boolean> completionStatus
136+
) {
137+
public boolean isComplete() {
138+
return completionStatus.values().stream().allMatch(Boolean::booleanValue);
139+
}
140+
141+
public boolean hasReaction(Participant participant, MessageExpression messageExpression) {
142+
return completionStatus.getOrDefault(
143+
new ParticipantMessagePresentation(participant, messageExpression),
144+
false
145+
);
146+
}
147+
148+
public List<ParticipantMessagePresentation> getAllCombinations() {
149+
return List.copyOf(completionStatus.keySet());
150+
}
151+
152+
public List<ParticipantMessagePresentation> getCompletedCombinations() {
153+
return completionStatus.entrySet().stream()
154+
.filter(Map.Entry::getValue)
155+
.map(Map.Entry::getKey)
156+
.toList();
157+
}
158+
159+
public List<ParticipantMessagePresentation> getIncompleteCombinations() {
160+
return completionStatus.entrySet().stream()
161+
.filter(e -> !e.getValue())
162+
.map(Map.Entry::getKey)
163+
.toList();
164+
}
165+
}
166+
167+
/**
168+
* Built up as we return results
169+
*/
90170
public static class FocusGroupRun {
91171

92172
public final FocusGroup focusGroup;
93173

174+
public final Positioning positioning;
175+
94176
private final Map<Participant, List<SpecificReaction>> reactionsByParticipant = new HashMap<>();
95177

96-
public FocusGroupRun(FocusGroup focusGroup) {
178+
private final Map<ParticipantMessagePresentation, Boolean> matrixData = new HashMap<>();
179+
180+
public FocusGroupRun(
181+
FocusGroup focusGroup,
182+
Positioning positioning) {
97183
this.focusGroup = focusGroup;
184+
this.positioning = positioning;
185+
186+
// Initialize matrix with all combinations set to false
187+
var allMessageExpressions = positioning.messageTests().stream()
188+
.flatMap(mt -> mt.expressions().stream())
189+
.toList();
190+
191+
for (var participant : focusGroup.participants()) {
192+
for (MessageExpression messageExpression : allMessageExpressions) {
193+
matrixData.put(new ParticipantMessagePresentation(participant, messageExpression), false);
194+
}
195+
}
196+
}
197+
198+
public CompletionMatrix getCompletionMatrix() {
199+
return new CompletionMatrix(Map.copyOf(matrixData));
200+
}
201+
202+
public boolean isComplete() {
203+
return getCompletionMatrix().isComplete();
98204
}
99205

100206
public void record(SpecificReaction reaction) {
101207
reactionsByParticipant
102208
.computeIfAbsent(reaction.participant(), k -> new LinkedList<>())
103209
.add(reaction);
210+
211+
// Update matrix
212+
ParticipantMessagePresentation key = new ParticipantMessagePresentation(
213+
reaction.participant(),
214+
reaction.message()
215+
);
216+
matrixData.put(key, true);
104217
}
105218

106219
public List<SpecificReaction> getReactionsForParticipant(Participant participant) {
107220
return reactionsByParticipant.getOrDefault(participant, List.of());
108221
}
109222

110-
public MessageScore getAverageScoreForMessage(Message message) {
223+
public MessageScore getAverageScoreForMessage(MessageExpression messageExpression) {
111224
var reactions = reactionsByParticipant.values().stream()
112225
.flatMap(List::stream)
113-
.filter(r -> r.message().equals(message))
226+
.filter(r -> r.message().equals(messageExpression))
114227
.toList();
115228

116229
long count = reactions.size();
Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,78 @@
11
package com.embabel.grouper.agent;
22

3+
import com.embabel.agent.api.annotation.AchievesGoal;
4+
import com.embabel.agent.api.annotation.Action;
35
import com.embabel.agent.api.annotation.Agent;
6+
import com.embabel.agent.api.annotation.Condition;
7+
import com.embabel.agent.api.common.OperationContext;
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
10+
11+
import java.time.Instant;
412

513
@Agent(description = "Simulate a focus group")
614
class Grouper {
715

8-
// @Action
16+
private final Logger logger = LoggerFactory.getLogger(Grouper.class);
17+
18+
@Action
19+
Domain.FocusGroupRun createFocusGroupRun(
20+
Domain.FocusGroup focusGroup,
21+
Domain.Positioning positioning
22+
) {
23+
return new Domain.FocusGroupRun(focusGroup, positioning);
24+
}
25+
26+
@Action(post = {"done"})
27+
Domain.FocusGroupRun testMessages(
28+
Domain.FocusGroupRun focusGroupRun,
29+
OperationContext operationContext
30+
) {
31+
var combos = focusGroupRun.getCompletionMatrix().getAllCombinations();
32+
operationContext.parallelMap(
33+
combos,
34+
15,
35+
participantMessagePresentation ->
36+
testMessageExpressionWithParticipant(participantMessagePresentation, focusGroupRun, operationContext)
37+
);
38+
return focusGroupRun;
39+
}
40+
41+
Domain.SpecificReaction testMessageExpressionWithParticipant(
42+
Domain.ParticipantMessagePresentation messageTest,
43+
Domain.FocusGroupRun focusGroupRun,
44+
OperationContext operationContext) {
45+
var reaction = operationContext.ai()
46+
.withLlm(messageTest.participant().llm())
47+
.withPromptContributor(messageTest.participant())
48+
.creating(Domain.Reaction.class)
49+
.fromPrompt("""
50+
React to the following message given your persona:
51+
52+
<message>%s</message>
53+
54+
Assess in terms of whether it would produce the following objective in your mind:
55+
<objective>%s</objective>
56+
""".formatted(messageTest.messageExpression().expression(), messageTest.messageExpression().message().objective()));
57+
logger.info("Reaction of {} was {}", messageTest.participant(), reaction);
58+
return new Domain.SpecificReaction(
59+
messageTest.messageExpression(),
60+
messageTest.participant(),
61+
reaction,
62+
Instant.now()
63+
);
64+
}
65+
66+
@Condition
67+
boolean done(Domain.FocusGroupRun focusGroupRun) {
68+
return focusGroupRun.isComplete();
969

10-
// TODO must allow multiple LLMs
70+
}
1171

12-
// Total score?
72+
@Action(pre = {"done"})
73+
@AchievesGoal(description = "Focus group has considered positioning")
74+
Domain.FocusGroupRun results(Domain.FocusGroupRun focusGroupRun) {
75+
return focusGroupRun;
76+
}
1377

1478
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package com.embabel.grouper.agent;
2+
3+
import com.embabel.common.ai.model.LlmOptions;
4+
import org.jetbrains.annotations.NotNull;
5+
6+
/**
7+
* The identity is solely responsible for the participant's contribution.
8+
*/
9+
public record PromptedParticipant(
10+
String name,
11+
LlmOptions llm,
12+
String identity
13+
) implements Domain.Participant {
14+
15+
@NotNull
16+
@Override
17+
public String contribution() {
18+
return """
19+
NAME: %s
20+
IDENTITY:
21+
%s
22+
""".formatted(name, identity);
23+
}
24+
}

0 commit comments

Comments
 (0)