@@ -50,6 +50,10 @@ pub fn token_tree_to_expr(tt: &tt::Subtree) -> Result<TreeArc<ast::Expr>, Expand
50
50
let token_source = SubtreeTokenSource :: new ( tt) ;
51
51
let mut tree_sink = TtTreeSink :: new ( token_source. querier ( ) ) ;
52
52
ra_parser:: parse_expr ( & token_source, & mut tree_sink) ;
53
+ if tree_sink. roots . len ( ) != 1 {
54
+ return Err ( ExpandError :: ConversionError ) ;
55
+ }
56
+
53
57
let syntax = tree_sink. inner . finish ( ) ;
54
58
ast:: Expr :: cast ( & syntax)
55
59
. map ( |m| m. to_owned ( ) )
@@ -61,6 +65,10 @@ pub fn token_tree_to_pat(tt: &tt::Subtree) -> Result<TreeArc<ast::Pat>, ExpandEr
61
65
let token_source = SubtreeTokenSource :: new ( tt) ;
62
66
let mut tree_sink = TtTreeSink :: new ( token_source. querier ( ) ) ;
63
67
ra_parser:: parse_pat ( & token_source, & mut tree_sink) ;
68
+ if tree_sink. roots . len ( ) != 1 {
69
+ return Err ( ExpandError :: ConversionError ) ;
70
+ }
71
+
64
72
let syntax = tree_sink. inner . finish ( ) ;
65
73
ast:: Pat :: cast ( & syntax) . map ( |m| m. to_owned ( ) ) . ok_or_else ( || ExpandError :: ConversionError )
66
74
}
@@ -70,6 +78,9 @@ pub fn token_tree_to_ty(tt: &tt::Subtree) -> Result<TreeArc<ast::TypeRef>, Expan
70
78
let token_source = SubtreeTokenSource :: new ( tt) ;
71
79
let mut tree_sink = TtTreeSink :: new ( token_source. querier ( ) ) ;
72
80
ra_parser:: parse_ty ( & token_source, & mut tree_sink) ;
81
+ if tree_sink. roots . len ( ) != 1 {
82
+ return Err ( ExpandError :: ConversionError ) ;
83
+ }
73
84
let syntax = tree_sink. inner . finish ( ) ;
74
85
ast:: TypeRef :: cast ( & syntax) . map ( |m| m. to_owned ( ) ) . ok_or_else ( || ExpandError :: ConversionError )
75
86
}
@@ -81,6 +92,9 @@ pub fn token_tree_to_macro_stmts(
81
92
let token_source = SubtreeTokenSource :: new ( tt) ;
82
93
let mut tree_sink = TtTreeSink :: new ( token_source. querier ( ) ) ;
83
94
ra_parser:: parse_macro_stmts ( & token_source, & mut tree_sink) ;
95
+ if tree_sink. roots . len ( ) != 1 {
96
+ return Err ( ExpandError :: ConversionError ) ;
97
+ }
84
98
let syntax = tree_sink. inner . finish ( ) ;
85
99
ast:: MacroStmts :: cast ( & syntax) . map ( |m| m. to_owned ( ) ) . ok_or_else ( || ExpandError :: ConversionError )
86
100
}
@@ -92,6 +106,9 @@ pub fn token_tree_to_macro_items(
92
106
let token_source = SubtreeTokenSource :: new ( tt) ;
93
107
let mut tree_sink = TtTreeSink :: new ( token_source. querier ( ) ) ;
94
108
ra_parser:: parse_macro_items ( & token_source, & mut tree_sink) ;
109
+ if tree_sink. roots . len ( ) != 1 {
110
+ return Err ( ExpandError :: ConversionError ) ;
111
+ }
95
112
let syntax = tree_sink. inner . finish ( ) ;
96
113
ast:: MacroItems :: cast ( & syntax) . map ( |m| m. to_owned ( ) ) . ok_or_else ( || ExpandError :: ConversionError )
97
114
}
@@ -268,6 +285,10 @@ struct TtTreeSink<'a, Q: Querier> {
268
285
text_pos : TextUnit ,
269
286
token_pos : usize ,
270
287
inner : SyntaxTreeBuilder ,
288
+
289
+ // Number of roots
290
+ // Use for detect ill-form tree which is not single root
291
+ roots : smallvec:: SmallVec < [ usize ; 1 ] > ,
271
292
}
272
293
273
294
impl < ' a , Q : Querier > TtTreeSink < ' a , Q > {
@@ -278,6 +299,7 @@ impl<'a, Q: Querier> TtTreeSink<'a, Q> {
278
299
text_pos : 0 . into ( ) ,
279
300
token_pos : 0 ,
280
301
inner : SyntaxTreeBuilder :: default ( ) ,
302
+ roots : smallvec:: SmallVec :: new ( ) ,
281
303
}
282
304
}
283
305
}
@@ -323,10 +345,16 @@ impl<'a, Q: Querier> TreeSink for TtTreeSink<'a, Q> {
323
345
324
346
fn start_node ( & mut self , kind : SyntaxKind ) {
325
347
self . inner . start_node ( kind) ;
348
+
349
+ match self . roots . last_mut ( ) {
350
+ None | Some ( 0 ) => self . roots . push ( 1 ) ,
351
+ Some ( ref mut n) => * * n += 1 ,
352
+ } ;
326
353
}
327
354
328
355
fn finish_node ( & mut self ) {
329
356
self . inner . finish_node ( ) ;
357
+ * self . roots . last_mut ( ) . unwrap ( ) -= 1 ;
330
358
}
331
359
332
360
fn error ( & mut self , error : ParseError ) {
@@ -375,4 +403,22 @@ mod tests {
375
403
assert_eq ! ( query. token( 2 + 15 + 3 ) . 1 , "\" rust1\" " ) ;
376
404
assert_eq ! ( query. token( 2 + 15 + 3 ) . 0 , STRING ) ;
377
405
}
406
+
407
+ #[ test]
408
+ fn stmts_token_trees_to_expr_is_err ( ) {
409
+ let rules = create_rules (
410
+ r#"
411
+ macro_rules! stmts {
412
+ () => {
413
+ let a = 0;
414
+ let b = 0;
415
+ let c = 0;
416
+ let d = 0;
417
+ }
418
+ }
419
+ "# ,
420
+ ) ;
421
+ let expansion = expand ( & rules, "stmts!()" ) ;
422
+ assert ! ( token_tree_to_expr( & expansion) . is_err( ) ) ;
423
+ }
378
424
}
0 commit comments