Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use std::sync::Arc;

use emmylua_parser::{
LuaAst, LuaAstNode, LuaDocAttributeType, LuaDocBinaryType, LuaDocConditionalType,
LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericDecl, LuaDocGenericType,
LuaDocIndexAccessType, LuaDocInferType, LuaDocMappedType, LuaDocMultiLineUnionType,
LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, LuaDocUnaryType,
LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator,
LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericDecl, LuaDocGenericDeclList,
LuaDocGenericType, LuaDocIndexAccessType, LuaDocInferType, LuaDocMappedType,
LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType,
LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator,
LuaTypeUnaryOperator, LuaVarExpr,
};
use internment::ArcIntern;
Expand Down Expand Up @@ -469,6 +469,10 @@ fn infer_unary_type(analyzer: &mut DocAnalyzer, unary_type: &LuaDocUnaryType) ->
}

fn infer_func_type(analyzer: &mut DocAnalyzer, func: &LuaDocFuncType) -> LuaType {
if let Some(generic_list) = func.get_generic_decl_list() {
register_inline_func_generics(analyzer, func, generic_list);
}

let mut params_result = Vec::new();
for param in func.get_params() {
let name = if let Some(param) = param.get_name_token() {
Expand Down Expand Up @@ -544,6 +548,33 @@ fn infer_func_type(analyzer: &mut DocAnalyzer, func: &LuaDocFuncType) -> LuaType
)
}

fn register_inline_func_generics(
analyzer: &mut DocAnalyzer,
func: &LuaDocFuncType,
generic_list: LuaDocGenericDeclList,
) {
let mut generics = Vec::new();
for param in generic_list.get_generic_decl() {
let Some(name_token) = param.get_name_token() else {
continue;
};

let constraint = param.get_type().map(|ty| infer_type(analyzer, ty));
generics.push(GenericParam::new(
SmolStr::new(name_token.get_name_text()),
constraint,
None,
));
}
if generics.is_empty() {
return;
}

analyzer
.generic_index
.add_generic_scope(vec![func.get_range()], generics, true);
}

fn get_colon_define(analyzer: &mut DocAnalyzer) -> Option<bool> {
let owner = analyzer.comment.get_owner()?;
if let LuaAst::LuaFuncStat(func_stat) = owner {
Expand Down
47 changes: 47 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/generic_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,4 +631,51 @@ mod test {
"#,
));
}

#[test]
fn test_issue_846() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@alias Parameters<T extends function> T extends (fun(...: infer P): any) and P or never

---@param x number
---@param y number
---@return number
function pow(x, y) end

---@generic F
---@param f F
---@return Parameters<F>
function return_params(f) end
"#,
);
assert!(ws.check_code_for(
DiagnosticCode::ParamTypeMismatch,
r#"
result = return_params(pow)
"#,
));
let result_ty = ws.expr_ty("result");
assert_eq!(ws.humanize_type(result_ty), "(number,number)");
}

#[test]
fn test_overload() {
let mut ws = VirtualWorkspace::new();

assert!(ws.check_code_for(
DiagnosticCode::ParamTypeMismatch,
r#"
---@class Expect
---@overload fun<T>(actual: T): T
local expect = {}

result = expect("")
"#,
));
let result_ty = ws.expr_ty("result");
assert_eq!(ws.humanize_type(result_ty), "string");
}
}
7 changes: 5 additions & 2 deletions crates/emmylua_code_analysis/src/db_index/type/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,11 @@ impl LuaFunctionType {
{
return false;
}

semantic_model.type_check(owner_type, t).is_ok()
if semantic_model.type_check(owner_type, t).is_ok() {
return true;
}
// 如果名称是`self`, 则做更宽泛的检查
name == "self" && semantic_model.type_check(t, owner_type).is_ok()
}
None => name == "self",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ fn check_table_expr_content(
let Some(member_key) = semantic_model.get_member_key(&field_key) else {
continue;
};

let source_type = match semantic_model.infer_member_type(table_type, &member_key) {
Ok(typ) => typ,
Err(_) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -966,4 +966,26 @@ return t
"#
));
}

#[test]
fn test_object_table() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias A {[string]: string}

---@param matchers A
function name(matchers)
end
"#,
);
assert!(!ws.check_code_for(
DiagnosticCode::AssignTypeMismatch,
r#"
name({
toBe = 1,
})
"#
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1355,4 +1355,67 @@ mod test {
"#,
));
}

#[test]
fn test_fix_issue_844() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias Tester fun(customTesters: Tester[]): boolean?

---@generic V
---@param t V[]
---@return fun(tbl: any):int, V
function ipairs(t) end
"#,
);
assert!(ws.check_code_for(
DiagnosticCode::ParamTypeMismatch,
r#"
---@param newTesters Tester[]
local function addMatchers(newTesters)
for _, tester in ipairs(newTesters) do
end
end
"#
));
}

#[test]
fn test_pairs_1() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@param value string
function aaaa(value)
end

---@generic K, V
---@param t {[K]: V} | V[]
---@return fun(tbl: any):K, V
function pairs(t) end
"#,
);
assert!(!ws.check_code_for(
DiagnosticCode::ParamTypeMismatch,
r#"
---@type {[string]: number}
local matchers = {}
for _, matcher in pairs(matchers) do
aaaa(matcher)
end
"#
));
assert!(!ws.check_code_for(
DiagnosticCode::ParamTypeMismatch,
r#"
---@alias MatchersObject {[string]: number}
---@type MatchersObject
local matchers = {}
for _, matcher in pairs(matchers) do
aaaa(matcher)
end
"#
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub fn instantiate_func_generic(
if let Some(type_list) = call_expr.get_call_generic_type_list() {
apply_call_generic_type_list(db, file_id, &mut context, &type_list);
} else {
// 没有指定泛型, 从调用参数中推断
infer_generic_types_from_call(
db,
&mut context,
Expand Down Expand Up @@ -160,7 +161,6 @@ fn infer_generic_types_from_call(
}

let arg_type = infer_expr(db, context.cache, call_arg_expr.clone())?;

match (func_param_type, &arg_type) {
(LuaType::Variadic(variadic), _) => {
let mut arg_types = vec![];
Expand Down
Loading