Skip to content

Commit 397ae19

Browse files
committed
Add JSON Schema integration tests
1 parent 5ebb768 commit 397ae19

File tree

1 file changed

+248
-22
lines changed

1 file changed

+248
-22
lines changed

FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift

Lines changed: 248 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,34 @@ struct SchemaTests {
7373
#expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)")
7474
}
7575

76+
@Test(arguments: InstanceConfig.allConfigs)
77+
func generateContentJSONSchemaItems(_ config: InstanceConfig) async throws {
78+
let model = FirebaseAI.componentInstance(config).generativeModel(
79+
modelName: ModelNames.gemini2_5_FlashLite,
80+
generationConfig: GenerationConfig(
81+
responseMIMEType: "application/json",
82+
responseJSONSchema: [
83+
"type": .string("array"),
84+
"description": .string("A list of city names"),
85+
"items": .object([
86+
"type": .string("string"),
87+
"description": .string("The name of the city"),
88+
]),
89+
"minItems": .number(3),
90+
"maxItems": .number(5),
91+
]
92+
),
93+
safetySettings: safetySettings
94+
)
95+
let prompt = "What are the biggest cities in Canada?"
96+
let response = try await model.generateContent(prompt)
97+
let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
98+
let jsonData = try #require(text.data(using: .utf8))
99+
let decodedJSON = try JSONDecoder().decode([String].self, from: jsonData)
100+
#expect(decodedJSON.count >= 3, "Expected at least 3 cities, but got \(decodedJSON.count)")
101+
#expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)")
102+
}
103+
76104
@Test(arguments: InstanceConfig.allConfigs)
77105
func generateContentSchemaNumberRange(_ config: InstanceConfig) async throws {
78106
let model = FirebaseAI.componentInstance(config).generativeModel(
@@ -96,14 +124,41 @@ struct SchemaTests {
96124
#expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)")
97125
}
98126

127+
@Test(arguments: InstanceConfig.allConfigs)
128+
func generateContentJSONSchemaNumberRange(_ config: InstanceConfig) async throws {
129+
let model = FirebaseAI.componentInstance(config).generativeModel(
130+
modelName: ModelNames.gemini2_5_FlashLite,
131+
generationConfig: GenerationConfig(
132+
responseMIMEType: "application/json",
133+
responseJSONSchema: [
134+
"type": .string("integer"),
135+
"description": .string("A number"),
136+
"minimum": .number(110),
137+
"maximum": .number(120),
138+
]
139+
),
140+
safetySettings: safetySettings
141+
)
142+
let prompt = "Give me a number"
143+
144+
let response = try await model.generateContent(prompt)
145+
146+
let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
147+
let jsonData = try #require(text.data(using: .utf8))
148+
let decodedNumber = try JSONDecoder().decode(Double.self, from: jsonData)
149+
#expect(decodedNumber >= 110.0, "Expected a number >= 110, but got \(decodedNumber)")
150+
#expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)")
151+
}
152+
153+
private struct ProductInfo: Codable {
154+
let productName: String
155+
let rating: Int
156+
let price: Double
157+
let salePrice: Float
158+
}
159+
99160
@Test(arguments: InstanceConfig.allConfigs)
100161
func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws {
101-
struct ProductInfo: Codable {
102-
let productName: String
103-
let rating: Int // Will correspond to .integer in schema
104-
let price: Double // Will correspond to .double in schema
105-
let salePrice: Float // Will correspond to .float in schema
106-
}
107162
let model = FirebaseAI.componentInstance(config).generativeModel(
108163
modelName: ModelNames.gemini2FlashLite,
109164
generationConfig: GenerationConfig(
@@ -150,28 +205,95 @@ struct SchemaTests {
150205
}
151206

152207
@Test(arguments: InstanceConfig.allConfigs)
153-
func generateContentAnyOfSchema(_ config: InstanceConfig) async throws {
154-
struct MailingAddress: Decodable {
155-
let streetAddress: String
156-
let city: String
208+
func generateContentJSONSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws {
209+
let model = FirebaseAI.componentInstance(config).generativeModel(
210+
modelName: ModelNames.gemini2_5_FlashLite,
211+
generationConfig: GenerationConfig(
212+
responseMIMEType: "application/json",
213+
responseJSONSchema: [
214+
"type": .string("object"),
215+
"title": .string("ProductInfo"),
216+
"properties": .object([
217+
"productName": .object([
218+
"type": .string("string"),
219+
"description": .string("The name of the product"),
220+
]),
221+
"price": .object([
222+
"type": .string("number"),
223+
"description": .string("A price"),
224+
"minimum": .number(10.00),
225+
"maximum": .number(120.00),
226+
]),
227+
"salePrice": .object([
228+
"type": .string("number"),
229+
"description": .string("A sale price"),
230+
"minimum": .number(5.00),
231+
"maximum": .number(90.00),
232+
]),
233+
"rating": .object([
234+
"type": .string("integer"),
235+
"description": .string("A rating"),
236+
"minimum": .number(1),
237+
"maximum": .number(5),
238+
]),
239+
]),
240+
"required": .array([
241+
.string("productName"),
242+
.string("price"),
243+
.string("salePrice"),
244+
.string("rating"),
245+
]),
246+
"propertyOrdering": .array([
247+
.string("salePrice"),
248+
.string("rating"),
249+
.string("price"),
250+
.string("productName"),
251+
]),
252+
]
253+
),
254+
safetySettings: safetySettings
255+
)
256+
let prompt = "Describe a premium wireless headphone, including a user rating and price."
157257

158-
// Canadian-specific
159-
let province: String?
160-
let postalCode: String?
258+
let response = try await model.generateContent(prompt)
161259

162-
// U.S.-specific
163-
let state: String?
164-
let zipCode: String?
260+
let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
261+
let jsonData = try #require(text.data(using: .utf8))
262+
let decodedProduct = try JSONDecoder().decode(ProductInfo.self, from: jsonData)
263+
let price = decodedProduct.price
264+
let salePrice = decodedProduct.salePrice
265+
let rating = decodedProduct.rating
266+
#expect(price >= 10.0, "Expected a price >= 10.00, but got \(price)")
267+
#expect(price <= 120.0, "Expected a price <= 120.00, but got \(price)")
268+
#expect(salePrice >= 5.0, "Expected a salePrice >= 5.00, but got \(salePrice)")
269+
#expect(salePrice <= 90.0, "Expected a salePrice <= 90.00, but got \(salePrice)")
270+
#expect(rating >= 1, "Expected a rating >= 1, but got \(rating)")
271+
#expect(rating <= 5, "Expected a rating <= 5, but got \(rating)")
272+
}
273+
274+
private struct MailingAddress: Decodable {
275+
let streetAddress: String
276+
let city: String
277+
278+
// Canadian-specific
279+
let province: String?
280+
let postalCode: String?
165281

166-
var isCanadian: Bool {
167-
return province != nil && postalCode != nil && state == nil && zipCode == nil
168-
}
282+
// U.S.-specific
283+
let state: String?
284+
let zipCode: String?
169285

170-
var isAmerican: Bool {
171-
return province == nil && postalCode == nil && state != nil && zipCode != nil
172-
}
286+
var isCanadian: Bool {
287+
return province != nil && postalCode != nil && state == nil && zipCode == nil
173288
}
174289

290+
var isAmerican: Bool {
291+
return province == nil && postalCode == nil && state != nil && zipCode != nil
292+
}
293+
}
294+
295+
@Test(arguments: InstanceConfig.allConfigs)
296+
func generateContentAnyOfSchema(_ config: InstanceConfig) async throws {
175297
let streetSchema = Schema.string(description:
176298
"The civic number and street name, for example, '123 Main Street'.")
177299
let citySchema = Schema.string(description: "The name of the city.")
@@ -232,4 +354,108 @@ struct SchemaTests {
232354
"Expected Canadian Queen's University address, got \(queensAddress)."
233355
)
234356
}
357+
358+
@Test(arguments: InstanceConfig.allConfigs)
359+
func generateContentAnyOfJSONSchema(_ config: InstanceConfig) async throws {
360+
let streetSchema: JSONValue = .object([
361+
"type": .string("string"),
362+
"description": .string("The civic number and street name, for example, '123 Main Street'."),
363+
])
364+
let citySchema: JSONValue = .object([
365+
"type": .string("string"),
366+
"description": .string("The name of the city."),
367+
])
368+
let canadianAddressSchema: JSONObject = [
369+
"type": .string("object"),
370+
"description": .string("A Canadian mailing address"),
371+
"properties": .object([
372+
"streetAddress": streetSchema,
373+
"city": citySchema,
374+
"province": .object([
375+
"type": .string("string"),
376+
"description": .string(
377+
"The 2-letter province or territory code, for example, 'ON', 'QC', or 'NU'."
378+
),
379+
]),
380+
"postalCode": .object([
381+
"type": .string("string"),
382+
"description": .string("The postal code, for example, 'A1A 1A1'."),
383+
]),
384+
]),
385+
"required": .array([
386+
.string("streetAddress"),
387+
.string("city"),
388+
.string("province"),
389+
.string("postalCode"),
390+
]),
391+
]
392+
let americanAddressSchema: JSONObject = [
393+
"type": .string("object"),
394+
"description": .string("A U.S. mailing address"),
395+
"properties": .object([
396+
"streetAddress": streetSchema,
397+
"city": citySchema,
398+
"state": .object([
399+
"type": .string("string"),
400+
"description": .string(
401+
"The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'."
402+
),
403+
]),
404+
"zipCode": .object([
405+
"type": .string("string"),
406+
"description": .string("The 5-digit ZIP code, for example, '12345'."),
407+
]),
408+
]),
409+
"required": .array([
410+
.string("streetAddress"),
411+
.string("city"),
412+
.string("state"),
413+
.string("zipCode"),
414+
]),
415+
]
416+
let model = FirebaseAI.componentInstance(config).generativeModel(
417+
modelName: ModelNames.gemini2_5_Flash,
418+
generationConfig: GenerationConfig(
419+
temperature: 0.0,
420+
topP: 0.0,
421+
topK: 1,
422+
responseMIMEType: "application/json",
423+
responseJSONSchema: [
424+
"type": .string("array"),
425+
"items": .object([
426+
"anyOf": .array([
427+
.object(canadianAddressSchema),
428+
.object(americanAddressSchema),
429+
]),
430+
]),
431+
]
432+
),
433+
safetySettings: safetySettings
434+
)
435+
let prompt = """
436+
What are the mailing addresses for the University of Waterloo, UC Berkeley and Queen's U?
437+
"""
438+
439+
let response = try await model.generateContent(prompt)
440+
441+
let text = try #require(response.text)
442+
let jsonData = try #require(text.data(using: .utf8))
443+
let decodedAddresses = try JSONDecoder().decode([MailingAddress].self, from: jsonData)
444+
try #require(decodedAddresses.count == 3, "Expected 3 JSON addresses, got \(text).")
445+
let waterlooAddress = decodedAddresses[0]
446+
#expect(
447+
waterlooAddress.isCanadian,
448+
"Expected Canadian University of Waterloo address, got \(waterlooAddress)."
449+
)
450+
let berkeleyAddress = decodedAddresses[1]
451+
#expect(
452+
berkeleyAddress.isAmerican,
453+
"Expected American UC Berkeley address, got \(berkeleyAddress)."
454+
)
455+
let queensAddress = decodedAddresses[2]
456+
#expect(
457+
queensAddress.isCanadian,
458+
"Expected Canadian Queen's University address, got \(queensAddress)."
459+
)
460+
}
235461
}

0 commit comments

Comments
 (0)