@@ -16,19 +16,22 @@ use std::env;
16
16
use std:: sync:: LazyLock ;
17
17
18
18
use base_db:: SourceDatabaseFileInputExt as _;
19
+ use either:: Either ;
19
20
use expect_test:: Expect ;
20
21
use hir_def:: {
21
22
body:: { Body , BodySourceMap } ,
22
23
db:: DefDatabase ,
23
24
hir:: { ExprId , Pat , PatId } ,
24
25
item_scope:: ItemScope ,
25
26
nameres:: DefMap ,
26
- src:: HasSource ,
27
- AssocItemId , DefWithBodyId , HasModule , LocalModuleId , Lookup , ModuleDefId , SyntheticSyntax ,
27
+ src:: { HasChildSource , HasSource } ,
28
+ AdtId , AssocItemId , DefWithBodyId , FieldId , HasModule , LocalModuleId , Lookup , ModuleDefId ,
29
+ SyntheticSyntax ,
28
30
} ;
29
31
use hir_expand:: { db:: ExpandDatabase , FileRange , InFile } ;
30
32
use itertools:: Itertools ;
31
33
use rustc_hash:: FxHashMap ;
34
+ use span:: TextSize ;
32
35
use stdx:: format_to;
33
36
use syntax:: {
34
37
ast:: { self , AstNode , HasName } ,
@@ -132,14 +135,40 @@ fn check_impl(
132
135
None => continue ,
133
136
} ;
134
137
let def_map = module. def_map ( & db) ;
135
- visit_module ( & db, & def_map, module. local_id , & mut |it| {
136
- defs. push ( match it {
137
- ModuleDefId :: FunctionId ( it) => it. into ( ) ,
138
- ModuleDefId :: EnumVariantId ( it) => it. into ( ) ,
139
- ModuleDefId :: ConstId ( it) => it. into ( ) ,
140
- ModuleDefId :: StaticId ( it) => it. into ( ) ,
141
- _ => return ,
142
- } )
138
+ visit_module ( & db, & def_map, module. local_id , & mut |it| match it {
139
+ ModuleDefId :: FunctionId ( it) => defs. push ( it. into ( ) ) ,
140
+ ModuleDefId :: EnumVariantId ( it) => {
141
+ defs. push ( it. into ( ) ) ;
142
+ let variant_id = it. into ( ) ;
143
+ let vd = db. variant_data ( variant_id) ;
144
+ defs. extend ( vd. fields ( ) . iter ( ) . filter_map ( |( local_id, fd) | {
145
+ if fd. has_default {
146
+ let field = FieldId { parent : variant_id, local_id, has_default : true } ;
147
+ Some ( DefWithBodyId :: FieldId ( field) )
148
+ } else {
149
+ None
150
+ }
151
+ } ) ) ;
152
+ }
153
+ ModuleDefId :: ConstId ( it) => defs. push ( it. into ( ) ) ,
154
+ ModuleDefId :: StaticId ( it) => defs. push ( it. into ( ) ) ,
155
+ ModuleDefId :: AdtId ( it) => {
156
+ let variant_id = match it {
157
+ AdtId :: StructId ( it) => it. into ( ) ,
158
+ AdtId :: UnionId ( it) => it. into ( ) ,
159
+ AdtId :: EnumId ( _) => return ,
160
+ } ;
161
+ let vd = db. variant_data ( variant_id) ;
162
+ defs. extend ( vd. fields ( ) . iter ( ) . filter_map ( |( local_id, fd) | {
163
+ if fd. has_default {
164
+ let field = FieldId { parent : variant_id, local_id, has_default : true } ;
165
+ Some ( DefWithBodyId :: FieldId ( field) )
166
+ } else {
167
+ None
168
+ }
169
+ } ) ) ;
170
+ }
171
+ _ => { }
143
172
} ) ;
144
173
}
145
174
defs. sort_by_key ( |def| match def {
@@ -160,12 +189,20 @@ fn check_impl(
160
189
loc. source ( & db) . value . syntax ( ) . text_range ( ) . start ( )
161
190
}
162
191
DefWithBodyId :: InTypeConstId ( it) => it. source ( & db) . syntax ( ) . text_range ( ) . start ( ) ,
163
- DefWithBodyId :: FieldId ( _) => unreachable ! ( ) ,
192
+ DefWithBodyId :: FieldId ( it) => {
193
+ let cs = it. parent . child_source ( & db) ;
194
+ match cs. value . get ( it. local_id ) {
195
+ Some ( Either :: Left ( it) ) => it. syntax ( ) . text_range ( ) . start ( ) ,
196
+ Some ( Either :: Right ( it) ) => it. syntax ( ) . text_range ( ) . end ( ) ,
197
+ None => TextSize :: new ( u32:: MAX ) ,
198
+ }
199
+ }
164
200
} ) ;
165
201
let mut unexpected_type_mismatches = String :: new ( ) ;
166
202
for def in defs {
167
203
let ( body, body_source_map) = db. body_with_source_map ( def) ;
168
204
let inference_result = db. infer ( def) ;
205
+ dbg ! ( & inference_result) ;
169
206
170
207
for ( pat, mut ty) in inference_result. type_of_pat . iter ( ) {
171
208
if let Pat :: Bind { id, .. } = body. pats [ pat] {
@@ -389,14 +426,40 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
389
426
let def_map = module. def_map ( & db) ;
390
427
391
428
let mut defs: Vec < DefWithBodyId > = Vec :: new ( ) ;
392
- visit_module ( & db, & def_map, module. local_id , & mut |it| {
393
- defs. push ( match it {
394
- ModuleDefId :: FunctionId ( it) => it. into ( ) ,
395
- ModuleDefId :: EnumVariantId ( it) => it. into ( ) ,
396
- ModuleDefId :: ConstId ( it) => it. into ( ) ,
397
- ModuleDefId :: StaticId ( it) => it. into ( ) ,
398
- _ => return ,
399
- } )
429
+ visit_module ( & db, & def_map, module. local_id , & mut |it| match it {
430
+ ModuleDefId :: FunctionId ( it) => defs. push ( it. into ( ) ) ,
431
+ ModuleDefId :: EnumVariantId ( it) => {
432
+ defs. push ( it. into ( ) ) ;
433
+ let variant_id = it. into ( ) ;
434
+ let vd = db. variant_data ( variant_id) ;
435
+ defs. extend ( vd. fields ( ) . iter ( ) . filter_map ( |( local_id, fd) | {
436
+ if fd. has_default {
437
+ let field = FieldId { parent : variant_id, local_id, has_default : true } ;
438
+ Some ( DefWithBodyId :: FieldId ( field) )
439
+ } else {
440
+ None
441
+ }
442
+ } ) ) ;
443
+ }
444
+ ModuleDefId :: ConstId ( it) => defs. push ( it. into ( ) ) ,
445
+ ModuleDefId :: StaticId ( it) => defs. push ( it. into ( ) ) ,
446
+ ModuleDefId :: AdtId ( it) => {
447
+ let variant_id = match it {
448
+ AdtId :: StructId ( it) => it. into ( ) ,
449
+ AdtId :: UnionId ( it) => it. into ( ) ,
450
+ AdtId :: EnumId ( _) => return ,
451
+ } ;
452
+ let vd = db. variant_data ( variant_id) ;
453
+ defs. extend ( vd. fields ( ) . iter ( ) . filter_map ( |( local_id, fd) | {
454
+ if fd. has_default {
455
+ let field = FieldId { parent : variant_id, local_id, has_default : true } ;
456
+ Some ( DefWithBodyId :: FieldId ( field) )
457
+ } else {
458
+ None
459
+ }
460
+ } ) ) ;
461
+ }
462
+ _ => { }
400
463
} ) ;
401
464
defs. sort_by_key ( |def| match def {
402
465
DefWithBodyId :: FunctionId ( it) => {
@@ -416,7 +479,14 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
416
479
loc. source ( & db) . value . syntax ( ) . text_range ( ) . start ( )
417
480
}
418
481
DefWithBodyId :: InTypeConstId ( it) => it. source ( & db) . syntax ( ) . text_range ( ) . start ( ) ,
419
- DefWithBodyId :: FieldId ( _) => unreachable ! ( ) ,
482
+ DefWithBodyId :: FieldId ( it) => {
483
+ let cs = it. parent . child_source ( & db) ;
484
+ match cs. value . get ( it. local_id ) {
485
+ Some ( Either :: Left ( it) ) => it. syntax ( ) . text_range ( ) . start ( ) ,
486
+ Some ( Either :: Right ( it) ) => it. syntax ( ) . text_range ( ) . end ( ) ,
487
+ None => TextSize :: new ( u32:: MAX ) ,
488
+ }
489
+ }
420
490
} ) ;
421
491
for def in defs {
422
492
let ( body, source_map) = db. body_with_source_map ( def) ;
@@ -477,7 +547,7 @@ pub(crate) fn visit_module(
477
547
let body = db. body ( it. into ( ) ) ;
478
548
visit_body ( db, & body, cb) ;
479
549
}
480
- ModuleDefId :: AdtId ( hir_def :: AdtId :: EnumId ( it) ) => {
550
+ ModuleDefId :: AdtId ( AdtId :: EnumId ( it) ) => {
481
551
db. enum_data ( it) . variants . iter ( ) . for_each ( |& ( it, _) | {
482
552
let body = db. body ( it. into ( ) ) ;
483
553
cb ( it. into ( ) ) ;
0 commit comments