Skip to content

Commit 45e7904

Browse files
authored
Add support for content parts and image URLs in AI Guard (#10449)
1 parent 02cc483 commit 45e7904

File tree

8 files changed

+811
-28
lines changed

8 files changed

+811
-28
lines changed

communication/src/main/java/datadog/communication/serialization/custom/aiguard/MessageWriter.java

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,51 @@ public void write(
1414
final AIGuard.Message value, final Writable writable, final EncodingCache encodingCache) {
1515
final int[] size = {0};
1616
final boolean hasRole = isNotBlank(value.getRole(), size);
17-
final boolean hasContent = isNotBlank(value.getContent(), size);
1817
final boolean hasToolCallId = isNotBlank(value.getToolCallId(), size);
1918
final boolean hasToolCalls = isNotEmpty(value.getToolCalls(), size);
19+
20+
final boolean hasContentParts = isNotEmpty(value.getContentParts(), size);
21+
final boolean hasContentString = !hasContentParts && isNotBlank(value.getContent(), size);
22+
2023
writable.startMap(size[0]);
2124
writeString(hasRole, "role", value.getRole(), writable, encodingCache);
22-
writeString(hasContent, "content", value.getContent(), writable, encodingCache);
25+
26+
if (hasContentParts) {
27+
writeContentParts("content", value.getContentParts(), writable, encodingCache);
28+
} else {
29+
writeString(hasContentString, "content", value.getContent(), writable, encodingCache);
30+
}
31+
2332
writeString(hasToolCallId, "tool_call_id", value.getToolCallId(), writable, encodingCache);
2433
writeToolCallArray(hasToolCalls, "tool_calls", value.getToolCalls(), writable, encodingCache);
2534
}
2635

36+
private static void writeContentParts(
37+
final String key,
38+
final List<AIGuard.ContentPart> contentParts,
39+
final Writable writable,
40+
final EncodingCache encodingCache) {
41+
writable.writeString(key, encodingCache);
42+
writable.startArray(contentParts.size());
43+
44+
for (final AIGuard.ContentPart part : contentParts) {
45+
writable.startMap(2);
46+
47+
writable.writeString("type", encodingCache);
48+
writable.writeString(part.getType().toString(), encodingCache);
49+
50+
if (part.getType() == AIGuard.ContentPart.Type.TEXT) {
51+
writable.writeString("text", encodingCache);
52+
writable.writeString(part.getText(), encodingCache);
53+
} else if (part.getType() == AIGuard.ContentPart.Type.IMAGE_URL) {
54+
writable.writeString("image_url", encodingCache);
55+
writable.startMap(1);
56+
writable.writeString("url", encodingCache);
57+
writable.writeString(part.getImageUrl().getUrl(), encodingCache);
58+
}
59+
}
60+
}
61+
2762
private static void writeString(
2863
final boolean present,
2964
final String key,

communication/src/test/groovy/datadog/communication/serialization/aiguard/MessageWriterTest.groovy

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,127 @@ class MessageWriterTest extends DDSpecification {
116116
private static String asString(final Value value) {
117117
return value.asStringValue().asString()
118118
}
119+
120+
void 'test write message with text content parts'() {
121+
given:
122+
final message = AIGuard.Message.message('user', [
123+
AIGuard.ContentPart.text('Hello world')
124+
])
125+
126+
when:
127+
writer.writeObject(message, encodingCache)
128+
129+
then:
130+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
131+
final value = asStringKeyMap(unpacker.unpackValue())
132+
value.size() == 2
133+
asString(value.role) == 'user'
134+
135+
final contentParts = value.content.asArrayValue().list()
136+
contentParts.size() == 1
137+
138+
final part = asStringKeyMap(contentParts[0])
139+
asString(part.type) == 'text'
140+
asString(part.text) == 'Hello world'
141+
}
142+
}
143+
144+
void 'test write message with image_url content parts'() {
145+
given:
146+
final message = AIGuard.Message.message('user', [
147+
AIGuard.ContentPart.imageUrl('https://example.com/image.jpg')
148+
])
149+
150+
when:
151+
writer.writeObject(message, encodingCache)
152+
153+
then:
154+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
155+
final value = asStringKeyMap(unpacker.unpackValue())
156+
value.size() == 2
157+
asString(value.role) == 'user'
158+
159+
final contentParts = value.content.asArrayValue().list()
160+
contentParts.size() == 1
161+
162+
final part = asStringKeyMap(contentParts[0])
163+
asString(part.type) == 'image_url'
164+
165+
final imageUrl = asStringKeyMap(part.image_url)
166+
asString(imageUrl.url) == 'https://example.com/image.jpg'
167+
}
168+
}
169+
170+
void 'test write message with mixed content parts'() {
171+
given:
172+
final message = AIGuard.Message.message('user', [
173+
AIGuard.ContentPart.text('Describe this:'),
174+
AIGuard.ContentPart.imageUrl('https://example.com/image.jpg'),
175+
AIGuard.ContentPart.text('What is it?')
176+
])
177+
178+
when:
179+
writer.writeObject(message, encodingCache)
180+
181+
then:
182+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
183+
final value = asStringKeyMap(unpacker.unpackValue())
184+
value.size() == 2
185+
asString(value.role) == 'user'
186+
187+
final contentParts = value.content.asArrayValue().list()
188+
contentParts.size() == 3
189+
190+
final part1 = asStringKeyMap(contentParts[0])
191+
asString(part1.type) == 'text'
192+
asString(part1.text) == 'Describe this:'
193+
194+
final part2 = asStringKeyMap(contentParts[1])
195+
asString(part2.type) == 'image_url'
196+
final imageUrl = asStringKeyMap(part2.image_url)
197+
asString(imageUrl.url) == 'https://example.com/image.jpg'
198+
199+
final part3 = asStringKeyMap(contentParts[2])
200+
asString(part3.type) == 'text'
201+
asString(part3.text) == 'What is it?'
202+
}
203+
}
204+
205+
void 'test content parts type serializes as string not integer'() {
206+
given:
207+
final message = AIGuard.Message.message('user', [
208+
AIGuard.ContentPart.text('Test')
209+
])
210+
211+
when:
212+
writer.writeObject(message, encodingCache)
213+
214+
then:
215+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
216+
final value = asStringKeyMap(unpacker.unpackValue())
217+
final contentParts = value.content.asArrayValue().list()
218+
final part = asStringKeyMap(contentParts[0])
219+
220+
// Verify type is a string value, not an integer
221+
part.type.isStringValue()
222+
!part.type.isIntegerValue()
223+
asString(part.type) == 'text'
224+
}
225+
}
226+
227+
void 'test backward compatibility with string content'() {
228+
given:
229+
final message = AIGuard.Message.message('user', 'Plain text message')
230+
231+
when:
232+
writer.writeObject(message, encodingCache)
233+
234+
then:
235+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
236+
final value = asStringValueMap(unpacker.unpackValue())
237+
value.size() == 2
238+
value.role == 'user'
239+
value.content == 'Plain text message'
240+
}
241+
}
119242
}

dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import datadog.trace.api.aiguard.AIGuard.AIGuardAbortError;
1717
import datadog.trace.api.aiguard.AIGuard.AIGuardClientError;
1818
import datadog.trace.api.aiguard.AIGuard.Action;
19+
import datadog.trace.api.aiguard.AIGuard.ContentPart;
1920
import datadog.trace.api.aiguard.AIGuard.Evaluation;
2021
import datadog.trace.api.aiguard.AIGuard.Message;
2122
import datadog.trace.api.aiguard.AIGuard.Options;
@@ -136,16 +137,37 @@ private static List<Message> messagesForMetaStruct(List<Message> messages) {
136137
boolean contentTruncated = false;
137138
for (int i = messages.size() - size; i < messages.size(); i++) {
138139
final Message source = messages.get(i);
139-
String content = source.getContent();
140-
if (content != null && content.length() > maxContent) {
141-
contentTruncated = true;
142-
content = content.substring(0, maxContent);
143-
}
140+
144141
List<ToolCall> toolCalls = source.getToolCalls();
145142
if (toolCalls != null) {
146143
toolCalls = new ArrayList<>(toolCalls);
147144
}
148-
result.add(new Message(source.getRole(), content, toolCalls, source.getToolCallId()));
145+
146+
List<ContentPart> contentParts = source.getContentParts();
147+
if (contentParts != null) {
148+
final List<ContentPart> truncatedParts = new ArrayList<>(contentParts.size());
149+
for (final ContentPart part : contentParts) {
150+
if (part.getType() == ContentPart.Type.TEXT
151+
&& part.getText() != null
152+
&& part.getText().length() > maxContent) {
153+
contentTruncated = true;
154+
final String text = part.getText().substring(0, maxContent);
155+
truncatedParts.add(ContentPart.text(text));
156+
} else {
157+
truncatedParts.add(part);
158+
}
159+
}
160+
161+
result.add(
162+
new Message(source.getRole(), truncatedParts, toolCalls, source.getToolCallId()));
163+
} else {
164+
String content = source.getContent();
165+
if (content != null && content.length() > maxContent) {
166+
contentTruncated = true;
167+
content = content.substring(0, maxContent);
168+
}
169+
result.add(new Message(source.getRole(), content, toolCalls, source.getToolCallId()));
170+
}
149171
}
150172
if (contentTruncated) {
151173
WafMetricCollector.get().aiGuardTruncated(CONTENT);
@@ -333,12 +355,45 @@ public Message fromJson(JsonReader reader) throws IOException {
333355
public void toJson(final JsonWriter writer, final Message value) throws IOException {
334356
writer.beginObject();
335357
writeValue(writer, "role", value.getRole());
336-
writeValue(writer, "content", value.getContent());
358+
359+
if (value.getContentParts() != null) {
360+
writeContentParts(writer, "content", value.getContentParts());
361+
} else {
362+
writeValue(writer, "content", value.getContent());
363+
}
364+
337365
writeArray(writer, "tool_calls", value.getToolCalls());
338366
writeValue(writer, "tool_call_id", value.getToolCallId());
339367
writer.endObject();
340368
}
341369

370+
private void writeContentParts(
371+
final JsonWriter writer, final String name, final List<ContentPart> contentParts)
372+
throws IOException {
373+
writer.name(name);
374+
writer.beginArray();
375+
for (final ContentPart part : contentParts) {
376+
writer.beginObject();
377+
378+
writer.name("type");
379+
writer.value(part.getType().toString());
380+
381+
if (part.getType() == ContentPart.Type.TEXT) {
382+
writer.name("text");
383+
writer.value(part.getText());
384+
} else if (part.getType() == ContentPart.Type.IMAGE_URL) {
385+
writer.name("image_url");
386+
writer.beginObject();
387+
writer.name("url");
388+
writer.value(part.getImageUrl().getUrl());
389+
writer.endObject();
390+
}
391+
392+
writer.endObject();
393+
}
394+
writer.endArray();
395+
}
396+
342397
private void writeValue(final JsonWriter writer, final String name, final Object value)
343398
throws IOException {
344399
if (value != null) {

0 commit comments

Comments
 (0)