Skip to content

Commit c28ac5d

Browse files
committed
fix remove_type
1 parent e732ca9 commit c28ac5d

File tree

4 files changed

+72
-17
lines changed

4 files changed

+72
-17
lines changed

crates/emmylua_code_analysis/src/db_index/type/type_ops/remove_type.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,55 +9,55 @@ pub fn remove_type(db: &DbIndex, source: LuaType, removed_type: LuaType) -> Opti
99
}
1010
}
1111

12-
let source = get_real_type(db, &source).unwrap_or(&source);
12+
let real_type = get_real_type(db, &source).unwrap_or(&source);
1313

1414
match &removed_type {
1515
LuaType::Nil => {
16-
if source.is_nil() {
16+
if real_type.is_nil() {
1717
return None;
1818
}
1919
}
2020
LuaType::Boolean => {
21-
if source.is_boolean() {
21+
if real_type.is_boolean() {
2222
return None;
2323
}
2424
}
2525
LuaType::Integer => {
26-
if source.is_integer() {
26+
if real_type.is_integer() {
2727
return None;
2828
}
2929
}
3030
LuaType::Number => {
31-
if source.is_number() {
31+
if real_type.is_number() {
3232
return None;
3333
}
3434
}
3535
LuaType::String => {
36-
if source.is_string() {
36+
if real_type.is_string() {
3737
return None;
3838
}
3939
}
4040
LuaType::Io => {
41-
if source.is_io() {
41+
if real_type.is_io() {
4242
return None;
4343
}
4444
}
4545
LuaType::Function => {
46-
if source.is_function() {
46+
if real_type.is_function() {
4747
return None;
4848
}
4949
}
5050
LuaType::Thread => {
51-
if source.is_thread() {
51+
if real_type.is_thread() {
5252
return None;
5353
}
5454
}
5555
LuaType::Userdata => {
56-
if source.is_userdata() {
56+
if real_type.is_userdata() {
5757
return None;
5858
}
5959
}
60-
LuaType::Table => match &source {
60+
LuaType::Table => match &real_type {
6161
LuaType::TableConst(_)
6262
| LuaType::Table
6363
| LuaType::Userdata
@@ -74,7 +74,7 @@ pub fn remove_type(db: &DbIndex, source: LuaType, removed_type: LuaType) -> Opti
7474
return Some(source.clone());
7575
}
7676
if type_decl.is_alias() {
77-
if let Some(alias_ref) = get_real_type(db, &source) {
77+
if let Some(alias_ref) = get_real_type(db, &real_type) {
7878
return remove_type(db, alias_ref.clone(), removed_type);
7979
}
8080
}
@@ -91,7 +91,7 @@ pub fn remove_type(db: &DbIndex, source: LuaType, removed_type: LuaType) -> Opti
9191
}
9292
_ => {}
9393
},
94-
LuaType::DocStringConst(s) | LuaType::StringConst(s) => match &source {
94+
LuaType::DocStringConst(s) | LuaType::StringConst(s) => match &real_type {
9595
LuaType::DocStringConst(s2) => {
9696
if s == s2 {
9797
return None;
@@ -104,7 +104,7 @@ pub fn remove_type(db: &DbIndex, source: LuaType, removed_type: LuaType) -> Opti
104104
}
105105
_ => {}
106106
},
107-
LuaType::DocIntegerConst(i) | LuaType::IntegerConst(i) => match &source {
107+
LuaType::DocIntegerConst(i) | LuaType::IntegerConst(i) => match &real_type {
108108
LuaType::DocIntegerConst(i2) => {
109109
if i == i2 {
110110
return None;
@@ -117,7 +117,7 @@ pub fn remove_type(db: &DbIndex, source: LuaType, removed_type: LuaType) -> Opti
117117
}
118118
_ => {}
119119
},
120-
LuaType::DocBooleanConst(b) | LuaType::BooleanConst(b) => match &source {
120+
LuaType::DocBooleanConst(b) | LuaType::BooleanConst(b) => match &real_type {
121121
LuaType::DocBooleanConst(b2) => {
122122
if b == b2 {
123123
return None;
@@ -133,7 +133,7 @@ pub fn remove_type(db: &DbIndex, source: LuaType, removed_type: LuaType) -> Opti
133133
_ => {}
134134
}
135135

136-
if let LuaType::Union(u) = &source {
136+
if let LuaType::Union(u) = &real_type {
137137
let types = u
138138
.into_vec()
139139
.iter()
@@ -144,7 +144,7 @@ pub fn remove_type(db: &DbIndex, source: LuaType, removed_type: LuaType) -> Opti
144144
let types = u
145145
.into_vec()
146146
.iter()
147-
.filter_map(|t| remove_type(db, source.clone(), t.clone()))
147+
.filter_map(|t| remove_type(db, real_type.clone(), t.clone()))
148148
.collect::<Vec<_>>();
149149
return Some(LuaType::from_vec(types));
150150
}

crates/emmylua_code_analysis/src/db_index/type/type_ops/union_type.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,28 @@ pub fn union_type(source: LuaType, target: LuaType) -> LuaType {
4848
LuaType::from_vec(vec![source.clone(), target.clone()])
4949
}
5050
}
51+
(LuaType::MultiLineUnion(left), right) => {
52+
let include = match right {
53+
LuaType::StringConst(v) => {
54+
left.get_unions().iter().any(|(t, _)| match (t, right) {
55+
(LuaType::DocStringConst(a), _) => a == v,
56+
_ => false,
57+
})
58+
}
59+
LuaType::IntegerConst(v) => {
60+
left.get_unions().iter().any(|(t, _)| match (t, right) {
61+
(LuaType::DocIntegerConst(a), _) => a == v,
62+
_ => false,
63+
})
64+
}
65+
_ => false,
66+
};
67+
68+
if include {
69+
return source;
70+
}
71+
LuaType::from_vec(vec![source, target])
72+
}
5173
// union
5274
(LuaType::Union(left), right) if !right.is_union() => {
5375
let left = left.deref().clone();

crates/emmylua_code_analysis/src/db_index/type/types.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,10 @@ impl LuaType {
427427
matches!(self, LuaType::TypeGuard(_))
428428
}
429429

430+
pub fn is_multi_line_union(&self) -> bool {
431+
matches!(self, LuaType::MultiLineUnion(_))
432+
}
433+
430434
pub fn from_vec(types: Vec<LuaType>) -> Self {
431435
return match types.len() {
432436
0 => LuaType::Nil,

crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,4 +1279,33 @@ mod test {
12791279
"#
12801280
));
12811281
}
1282+
1283+
#[test]
1284+
fn test_alias_branch_label_flow() {
1285+
let mut ws = VirtualWorkspace::new();
1286+
ws.def_file(
1287+
"test.lua",
1288+
r#"
1289+
---@alias EditorAttrTypeAlias
1290+
---| 'ATTR_BASE'
1291+
---| 'ATTR_BASE_RATIO'
1292+
---| 'ATTR_ALL_RATIO'
1293+
1294+
---@param attr_element string
1295+
function test(attr_element) end
1296+
"#,
1297+
);
1298+
1299+
assert!(ws.check_code_for_namespace(
1300+
DiagnosticCode::ParamTypeNotMatch,
1301+
r#"
1302+
---@param attr_type EditorAttrTypeAlias
1303+
function add_attr(attr_type)
1304+
if attr_type ~= 'ATTR_BASE' then
1305+
end
1306+
test(attr_type)
1307+
end
1308+
"#
1309+
));
1310+
}
12821311
}

0 commit comments

Comments
 (0)