1
1
//! Builtin derives.
2
2
3
3
use base_db:: { CrateOrigin , LangCrateOrigin } ;
4
+ use either:: Either ;
4
5
use tracing:: debug;
5
6
6
7
use crate :: tt:: { self , TokenId } ;
7
8
use syntax:: {
8
- ast:: { self , AstNode , HasGenericParams , HasModuleItem , HasName } ,
9
+ ast:: { self , AstNode , HasGenericParams , HasModuleItem , HasName , HasTypeBounds } ,
9
10
match_ast,
10
11
} ;
11
12
@@ -60,8 +61,11 @@ pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander>
60
61
61
62
struct BasicAdtInfo {
62
63
name : tt:: Ident ,
63
- /// `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
64
- param_types : Vec < Option < tt:: Subtree > > ,
64
+ /// first field is the name, and
65
+ /// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
66
+ /// third fields is where bounds, if any
67
+ param_types : Vec < ( tt:: Subtree , Option < tt:: Subtree > , Option < tt:: Subtree > ) > ,
68
+ field_types : Vec < tt:: Subtree > ,
65
69
}
66
70
67
71
fn parse_adt ( tt : & tt:: Subtree ) -> Result < BasicAdtInfo , ExpandError > {
@@ -75,17 +79,34 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
75
79
ExpandError :: Other ( "no item found" . into ( ) )
76
80
} ) ?;
77
81
let node = item. syntax ( ) ;
78
- let ( name, params) = match_ast ! {
82
+ let ( name, params, fields ) = match_ast ! {
79
83
match node {
80
- ast:: Struct ( it) => ( it. name( ) , it. generic_param_list( ) ) ,
81
- ast:: Enum ( it) => ( it. name( ) , it. generic_param_list( ) ) ,
82
- ast:: Union ( it) => ( it. name( ) , it. generic_param_list( ) ) ,
84
+ ast:: Struct ( it) => {
85
+ ( it. name( ) , it. generic_param_list( ) , it. field_list( ) . into_iter( ) . collect:: <Vec <_>>( ) )
86
+ } ,
87
+ ast:: Enum ( it) => ( it. name( ) , it. generic_param_list( ) , it. variant_list( ) . into_iter( ) . flat_map( |x| x. variants( ) ) . filter_map( |x| x. field_list( ) ) . collect( ) ) ,
88
+ ast:: Union ( it) => ( it. name( ) , it. generic_param_list( ) , it. record_field_list( ) . into_iter( ) . map( |x| ast:: FieldList :: RecordFieldList ( x) ) . collect( ) ) ,
83
89
_ => {
84
90
debug!( "unexpected node is {:?}" , node) ;
85
91
return Err ( ExpandError :: Other ( "expected struct, enum or union" . into( ) ) )
86
92
} ,
87
93
}
88
94
} ;
95
+ let field_types = fields
96
+ . into_iter ( )
97
+ . flat_map ( |f| match f {
98
+ ast:: FieldList :: RecordFieldList ( x) => Either :: Left (
99
+ x. fields ( )
100
+ . filter_map ( |x| x. ty ( ) )
101
+ . map ( |x| mbe:: syntax_node_to_token_tree ( x. syntax ( ) ) . 0 ) ,
102
+ ) ,
103
+ ast:: FieldList :: TupleFieldList ( x) => Either :: Right (
104
+ x. fields ( )
105
+ . filter_map ( |x| x. ty ( ) )
106
+ . map ( |x| mbe:: syntax_node_to_token_tree ( x. syntax ( ) ) . 0 ) ,
107
+ ) ,
108
+ } )
109
+ . collect :: < Vec < _ > > ( ) ;
89
110
let name = name. ok_or_else ( || {
90
111
debug ! ( "parsed item has no name" ) ;
91
112
ExpandError :: Other ( "missing name" . into ( ) )
@@ -97,35 +118,46 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
97
118
. into_iter ( )
98
119
. flat_map ( |param_list| param_list. type_or_const_params ( ) )
99
120
. map ( |param| {
100
- if let ast:: TypeOrConstParam :: Const ( param) = param {
121
+ let name = param
122
+ . name ( )
123
+ . map ( |x| mbe:: syntax_node_to_token_tree ( x. syntax ( ) ) . 0 )
124
+ . unwrap_or_else ( tt:: Subtree :: empty) ;
125
+ let bounds = match & param {
126
+ ast:: TypeOrConstParam :: Type ( x) => {
127
+ x. type_bound_list ( ) . map ( |x| mbe:: syntax_node_to_token_tree ( x. syntax ( ) ) . 0 )
128
+ }
129
+ ast:: TypeOrConstParam :: Const ( _) => None ,
130
+ } ;
131
+ let ty = if let ast:: TypeOrConstParam :: Const ( param) = param {
101
132
let ty = param
102
133
. ty ( )
103
134
. map ( |ty| mbe:: syntax_node_to_token_tree ( ty. syntax ( ) ) . 0 )
104
135
. unwrap_or_else ( tt:: Subtree :: empty) ;
105
136
Some ( ty)
106
137
} else {
107
138
None
108
- }
139
+ } ;
140
+ ( name, ty, bounds)
109
141
} )
110
142
. collect ( ) ;
111
- Ok ( BasicAdtInfo { name : name_token, param_types } )
143
+ Ok ( BasicAdtInfo { name : name_token, param_types, field_types } )
112
144
}
113
145
114
146
fn expand_simple_derive ( tt : & tt:: Subtree , trait_path : tt:: Subtree ) -> ExpandResult < tt:: Subtree > {
115
147
let info = match parse_adt ( tt) {
116
148
Ok ( info) => info,
117
149
Err ( e) => return ExpandResult :: with_err ( tt:: Subtree :: empty ( ) , e) ,
118
150
} ;
151
+ let mut where_block = vec ! [ ] ;
119
152
let ( params, args) : ( Vec < _ > , Vec < _ > ) = info
120
153
. param_types
121
154
. into_iter ( )
122
- . enumerate ( )
123
- . map ( |( idx, param_ty) | {
124
- let ident = tt:: Leaf :: Ident ( tt:: Ident {
125
- span : tt:: TokenId :: unspecified ( ) ,
126
- text : format ! ( "T{idx}" ) . into ( ) ,
127
- } ) ;
155
+ . map ( |( ident, param_ty, bound) | {
128
156
let ident_ = ident. clone ( ) ;
157
+ if let Some ( b) = bound {
158
+ let ident = ident. clone ( ) ;
159
+ where_block. push ( quote ! { #ident : #b , } ) ;
160
+ }
129
161
if let Some ( ty) = param_ty {
130
162
( quote ! { const #ident : #ty , } , quote ! { #ident_ , } )
131
163
} else {
@@ -134,9 +166,16 @@ fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResu
134
166
}
135
167
} )
136
168
. unzip ( ) ;
169
+
170
+ where_block. extend ( info. field_types . iter ( ) . map ( |x| {
171
+ let x = x. clone ( ) ;
172
+ let bound = trait_path. clone ( ) ;
173
+ quote ! { #x : #bound , }
174
+ } ) ) ;
175
+
137
176
let name = info. name ;
138
177
let expanded = quote ! {
139
- impl < ##params > #trait_path for #name < ##args > { }
178
+ impl < ##params > #trait_path for #name < ##args > where ##where_block { }
140
179
} ;
141
180
ExpandResult :: ok ( expanded)
142
181
}
0 commit comments