Skip to content

Commit 183d70f

Browse files
committed
Add support for defining top-level main functions with nested commands
Previously this was not supported since the macro needs to handle this case specially. The reason is bit subtle, but basically boils down differences in how top-level classes and methods are encoded: - classes are declared in a package - methods are declared in a synthetically-generated 'package class' (presumably, because java bytecode supports only top-level class declarations) This means that if a main method looks up annotated command classes in `this`, then the lookup paths need to be handled differently depending on whether or not the main method is top-level. There are also some differences in instantiating the command class.
1 parent be77793 commit 183d70f

File tree

4 files changed

+58
-38
lines changed

4 files changed

+58
-38
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 0.19.1
4+
5+
- Add support for defining top-level main functions with nested commands.
6+
37
## 0.19.0
48

59
- Add error handling and output printing to the annotation API.

argparse/sandbox/src/example.scala

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,21 @@ object api
88

99
override def defaultHelpFlags = Seq("-h")
1010

11-
object app:
12-
13-
/**
14-
* Hello world
15-
*
16-
* @param base the base value
17-
*/
11+
/**
12+
* Hello world
13+
*
14+
* @param base the base value
15+
*/
16+
@api.command()
17+
class wrapper(base: Int):
1818
@api.command()
19-
class wrapper(base: Int):
20-
@api.command()
21-
def add(x: Int) = println(base + x)
19+
def add(x: Int) = println(base + x)
2220

23-
/** A nested command */
21+
/** A nested command */
22+
@api.command()
23+
def nested(y: Int = 2) = foo(y)
24+
class foo(y: Int):
2425
@api.command()
25-
def nested(y: Int = 2) = foo(y)
26-
class foo(y: Int):
27-
@api.command()
28-
def ok() = println("ok")
26+
def ok() = println("ok")
2927

30-
def main(args: Array[String]): Unit = argparse.main(this, args)
28+
def main(args: Array[String]): Unit = argparse.main(this, args)

argparse/src-3/argparse/core/Command.scala

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,17 @@ object CommandMacros:
4343
import qctx.reflect.*
4444
val CommandAnnot = TypeRepr.of[MacroApi#command]
4545

46-
val methods = TypeRepr.of[Container].typeSymbol.memberMethods
47-
val classes = TypeRepr.of[Container].typeSymbol.memberTypes.filter(_.isClassDef)
46+
val containerTpe = TypeRepr.of[Container].typeSymbol
47+
48+
// need to handle the case where `findImpl` is called from a top-level
49+
// function, but commands are wrapped in a top-level annotated class
50+
val isTopLevel = containerTpe.owner.isPackageDef && containerTpe.name.endsWith("$package$")
51+
52+
val methods = containerTpe.memberMethods
53+
val classes = if isTopLevel then
54+
containerTpe.owner.memberTypes.filter(_.isClassDef)
55+
else
56+
containerTpe.memberTypes.filter(_.isClassDef)
4857

4958
for
5059
sym <- (methods ++ classes)
@@ -65,7 +74,7 @@ object CommandMacros:
6574

6675
apiType.asType match
6776
case '[t] if TypeRepr.of[t] <:< TypeRepr.of[MacroApi] =>
68-
makeCommand[Container, MacroApi](api.asExprOf[MacroApi], method, sym.name, doc)
77+
makeCommand[Container, MacroApi](api.asExprOf[MacroApi], method, sym.name, doc, isTopLevel)
6978
case '[t] =>
7079
report.error(s"wrong API ${Type.show[t]}")
7180
'{???}
@@ -112,7 +121,8 @@ object CommandMacros:
112121
api: Expr[Api],
113122
method: qctx.reflect.Symbol,
114123
name: String, // name is separate because method name is not always representative (e.g. if method is class constructor)
115-
doc: DocComment
124+
doc: DocComment,
125+
isTopLevel: Boolean
116126
): Expr[Command[Container]] =
117127
import qctx.reflect.*
118128

@@ -165,7 +175,7 @@ object CommandMacros:
165175
Implicits.search(readerType) match
166176
case iss: ImplicitSearchSuccess => iss.tree
167177
case other =>
168-
report.error( s"No ${readerType.show} available for parameter ${param.name}.", param.pos.get)
178+
report.error(s"No ${readerType.show} available for parameter ${param.name}.", param.pos.get)
169179
'{???}.asTerm
170180

171181
paramTpe match
@@ -251,18 +261,28 @@ object CommandMacros:
251261
}
252262

253263
def callOrInstantiate() = try
254-
val outer = instance()
255264
val results = args.map(_.map(_()))
256265
${
257-
if method.isClassConstructor then
266+
if method.isClassConstructor && isTopLevel then
258267
call(using qctx)(
259-
New(TypeSelect('{outer}.asTerm, rtpe.typeSymbol.name)).select(method),
268+
New(method.tree.asInstanceOf[DefDef].returnTpt).select(method),
260269
ptpes,
261270
'results
262271
).asExpr
272+
else if method.isClassConstructor && !isTopLevel then
273+
'{
274+
val outer = instance()
275+
${
276+
call(using qctx)(
277+
New(TypeSelect('{outer}.asTerm, rtpe.typeSymbol.name)).select(method),
278+
ptpes,
279+
'results
280+
).asExpr
281+
}
282+
}
263283
else
264284
call(using qctx)(
265-
Select('{outer}.asTerm, method),
285+
Select('{instance()}.asTerm, method),
266286
ptpes,
267287
'results
268288
).asExpr
Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
import argparse.default as ap
22

3-
object Main:
3+
@ap.command()
4+
class app():
45

56
@ap.command()
6-
class app:
7+
def version() = println(s"v1000")
78

8-
@ap.command()
9-
def version() = println(s"v1000")
9+
@ap.command()
10+
class op(factor: Double = 1.0):
1011

1112
@ap.command()
12-
class op(factor: Double = 1.0):
13+
def showFactor() = println(s"the factor is $factor")
1314

14-
@ap.command()
15-
def showFactor() = println(s"the factor is $factor")
16-
17-
@ap.command()
18-
def multiply(operand: Double) = println(s"result is: ${factor * operand}")
15+
@ap.command()
16+
def multiply(operand: Double) = println(s"result is: ${factor * operand}")
1917

20-
// this is boilerplate for now; it will become obsolete once macro-annotations
21-
// are released
22-
def main(args: Array[String]) = argparse.main(this, args)
18+
// this is boilerplate for now; it will become obsolete once macro-annotations
19+
// are released
20+
def main(args: Array[String]): Unit = argparse.main(this, args)

0 commit comments

Comments
 (0)