Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add support for int keys in maps #1180

Draft
wants to merge 8 commits into
base: canary
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use crate::{
};
use anyhow::Result;
use baml_types::{
BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue,
TypeValue,
BamlMap, BamlMapKey, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType,
LiteralValue, TypeValue,
};
pub use to_baml_arg::ArgCoercer;

Expand Down Expand Up @@ -186,7 +186,7 @@ impl IRHelper for IntermediateRepr {
if let Ok(baml_arg) =
coerce_settings.coerce_arg(self, param_type, param_value, &mut scope)
{
baml_arg_map.insert(param_name.to_string(), baml_arg);
baml_arg_map.insert(BamlMapKey::string(param_name), baml_arg);
}
} else {
// Check if the parameter is optional.
Expand Down Expand Up @@ -291,14 +291,15 @@ impl IRHelper for IntermediateRepr {
if !map_type.is_subtype_of(&field_type) {
anyhow::bail!("Could not unify {:?} with {:?}", map_type, field_type);
} else {
let mapped_fields: BamlMap<String, BamlValueWithMeta<FieldType>> =
pairs
let mapped_fields: BamlMap<BamlMapKey, BamlValueWithMeta<FieldType>> =
pairs
.into_iter()
.map(|(key, val)| {
let sub_value = self.distribute_type(val, item_type.clone())?;
let sub_value =
self.distribute_type(val, item_type.clone())?;
Ok((key, sub_value))
})
.collect::<anyhow::Result<BamlMap<String,BamlValueWithMeta<FieldType>>>>()?;
.collect::<anyhow::Result<_>>()?;
Ok(BamlValueWithMeta::Map(mapped_fields, field_type))
}
}
Expand Down Expand Up @@ -515,7 +516,11 @@ mod tests {
}

fn mk_map_1() -> BamlValue {
BamlValue::Map(vec![("a".to_string(), mk_int(1))].into_iter().collect())
BamlValue::Map(
vec![(BamlMapKey::string("a"), mk_int(1))]
.into_iter()
.collect(),
)
}

fn mk_ir() -> IntermediateRepr {
Expand Down Expand Up @@ -557,7 +562,7 @@ mod tests {
#[test]
fn infer_map_map() {
let my_map_map = BamlValue::Map(
vec![("map_a".to_string(), mk_map_1())]
vec![(BamlMapKey::string("map_a"), mk_map_1())]
.into_iter()
.collect(),
);
Expand Down Expand Up @@ -629,12 +634,12 @@ mod tests {
let map_1 = BamlValue::Map(
vec![
(
"1_string".to_string(),
BamlMapKey::string("1_string"),
BamlValue::String("1_string_value".to_string()),
),
("1_int".to_string(), mk_int(1)),
(BamlMapKey::string("1_int"), mk_int(1)),
(
"1_foo".to_string(),
BamlMapKey::string("1_foo"),
BamlValue::Class(
"Foo".to_string(),
vec![
Expand Down Expand Up @@ -719,9 +724,9 @@ mod tests {
// The compound value we want to test.
let map = BamlValue::Map(
vec![
("a".to_string(), list_1.clone()),
("b".to_string(), list_1),
("c".to_string(), list_2),
(BamlMapKey::string("a"), list_1.clone()),
(BamlMapKey::string("b"), list_1),
(BamlMapKey::string("c"), list_2),
]
.into_iter()
.collect(),
Expand Down
79 changes: 69 additions & 10 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use baml_types::{
BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue,
TypeValue,
BamlMap, BamlMapKey, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType,
LiteralValue, TypeValue,
};
use core::result::Result;
use std::path::PathBuf;
use std::{collections::VecDeque, path::PathBuf};

use crate::ir::IntermediateRepr;

Expand Down Expand Up @@ -84,7 +84,10 @@ impl ArgCoercer {
};

for key in kv.keys() {
if !vec!["file", "media_type"].contains(&key.as_str()) {
if !["file", "media_type"]
.map(BamlMapKey::string)
.contains(&key)
{
scope.push_error(format!(
"Invalid property `{}` on file {}: `media_type` is the only supported property",
key,
Expand Down Expand Up @@ -118,7 +121,7 @@ impl ArgCoercer {
None => None,
};
for key in kv.keys() {
if !vec!["url", "media_type"].contains(&key.as_str()) {
if !["url", "media_type"].map(BamlMapKey::string).contains(&key) {
scope.push_error(format!(
"Invalid property `{}` on url {}: `media_type` is the only supported property",
key,
Expand All @@ -143,7 +146,10 @@ impl ArgCoercer {
None => None,
};
for key in kv.keys() {
if !vec!["base64", "media_type"].contains(&key.as_str()) {
if !["base64", "media_type"]
.map(BamlMapKey::string)
.contains(&key)
{
scope.push_error(format!(
"Invalid property `{}` on base64 {}: `media_type` is the only supported property",
key,
Expand Down Expand Up @@ -215,7 +221,7 @@ impl ArgCoercer {
}),
(FieldType::Class(name), _) => match value {
BamlValue::Class(n, _) if n == name => Ok(value.clone()),
BamlValue::Class(_, obj) | BamlValue::Map(obj) => match ir.find_class(name) {
BamlValue::Class(_, obj) /*BamlValue::Map(obj)*/ => match ir.find_class(name) {
Ok(c) => {
let mut fields = BamlMap::new();

Expand Down Expand Up @@ -285,20 +291,73 @@ impl ArgCoercer {
(FieldType::Map(k, v), _) => {
if let BamlValue::Map(kv) = value {
let mut map = BamlMap::new();
let mut failed_parsing_int_err = None;

let mut is_union_of_literal_ints = false;

// TODO: Can we avoid this loop here? Won't hit performance
// by a lot unless the user defines a giant union.
if let FieldType::Union(items) = k.as_ref() {
let mut found_types_other_than_literal_ints = false;
let mut queue = VecDeque::from_iter(items.iter());
while let Some(item) = queue.pop_front() {
match item {
FieldType::Literal(LiteralValue::Int(_)) => continue,
FieldType::Union(nested) => queue.extend(nested.iter()),
_ => {
found_types_other_than_literal_ints = true;
break;
}
}
}
if !found_types_other_than_literal_ints {
is_union_of_literal_ints = true;
}
}

for (key, value) in kv {
scope.push("<key>".to_string());
let k = self.coerce_arg(ir, k, &BamlValue::String(key.clone()), scope);

let target_baml_key = if matches!(**k, FieldType::Primitive(TypeValue::Int))
|| is_union_of_literal_ints
{
let BamlMapKey::String(str_int) = key else {
todo!();
};

match str_int.parse::<i64>() {
Ok(i) => BamlValue::Int(i),
Err(e) => {
failed_parsing_int_err = Some(key);
// Stop here and let the code below return
// the error.
break;
}
}
} else {
BamlValue::String(key.to_string())
};

let coerced_key = self.coerce_arg(ir, k, &target_baml_key, scope);
scope.pop(false);

if k.is_ok() {
if coerced_key.is_ok() {
scope.push(key.to_string());
if let Ok(v) = self.coerce_arg(ir, v, value, scope) {
map.insert(key.clone(), v);
}
scope.pop(false);
}
}
Ok(BamlValue::Map(map))

if let Some(failed_int) = failed_parsing_int_err {
scope.push_error(format!(
"Expected int for map with int keys, got `{failed_int}`"
));
Err(())
} else {
Ok(BamlValue::Map(map))
}
} else {
scope.push_error(format!("Expected map, got `{}`", value));
Err(())
Expand Down
7 changes: 5 additions & 2 deletions engine/baml-lib/baml-core/src/ir/walker.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use baml_types::BamlValue;
use baml_types::{BamlMapKey, BamlValue};
use indexmap::IndexMap;

use internal_baml_parser_database::RetryPolicyStrategy;
Expand Down Expand Up @@ -237,7 +237,10 @@ impl Expression {
Expression::Map(m) => {
let mut map = baml_types::BamlMap::new();
for (k, v) in m {
map.insert(k.as_string_value(env_values)?, v.normalize(env_values)?);
map.insert(
BamlMapKey::String(k.as_string_value(env_values)?),
v.normalize(env_values)?,
);
}
Ok(BamlValue::Map(map))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,42 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) {
// Literal string key.
FieldType::Literal(FieldArity::Required, LiteralValue::String(_), ..) => {}

// Literal string union.
// Literal int key.
FieldType::Literal(FieldArity::Required, LiteralValue::Int(_), ..) => {}

// Literal union.
FieldType::Union(FieldArity::Required, items, ..) => {
let mut queue = VecDeque::from_iter(items.iter());

// Little hack to keep track of data types in the union with
// a single pass and no allocations. Unions that contain
// literals of different types are not allowed as map keys.
//
// TODO: Same code is used at `coerce_map` function in
// baml-lib/jsonish/src/deserializer/coercer/coerce_map.rs
//
// Should figure out how to reuse this.
let mut literal_types_found = [0, 0, 0];
let [strings, ints, bools] = &mut literal_types_found;

while let Some(item) = queue.pop_front() {
match item {
// Ok, literal string.
FieldType::Literal(
FieldArity::Required,
LiteralValue::String(_),
..,
) => {}
) => *strings += 1,

// Ok, literal int.
FieldType::Literal(FieldArity::Required, LiteralValue::Int(_), ..) => {
*ints += 1
}

// Ok, literal bool.
FieldType::Literal(FieldArity::Required, LiteralValue::Bool(_), ..) => {
*bools += 1
}

// Nested union, "recurse" but it's iterative.
FieldType::Union(FieldArity::Required, nested, ..) => {
Expand All @@ -101,6 +125,13 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) {
}
}
}

if literal_types_found.iter().filter(|&&t| t > 0).count() > 1 {
ctx.push_error(DatamodelError::new_validation_error(
"Unions in map keys may only contain literals of the same type.",
kv_types.0.span().clone(),
));
}
}

other => {
Expand Down
Loading
Loading