diff --git a/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala b/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala index 68375e91..37b8981f 100644 --- a/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala +++ b/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala @@ -730,21 +730,24 @@ package """ :+ packageName :+ """ getFunctionMapping(signature, returnType, ScalaCompat(None)) private[compiler] def getFunctionMapping( - signature: String, + signature: String, // can also contain type params like [A, B, C](...) returnType: String, sc: ScalaCompat - ): (String, String, String) = { + ): (String, String, String) = { // renderCall, f, templateType - val params: List[List[Term.Param]] = + val (tparams, params): (List[Type.Param], List[List[Term.Param]]) = try { val dialect = Dialect.current.withAllowGivenUsing(true) val input = Input.String(s"object FT { def signature$signature }") val obj = implicitly[Parse[Stat]].apply(input, dialect).get.asInstanceOf[Defn.Object] val templ = obj.templ val defdef = templ.body.stats.head.asInstanceOf[Decl.Def] - defdef.paramClauseGroups.headOption.map(_.paramClauses.map(_.values)).getOrElse(Nil) + ( + defdef.paramClauseGroups.headOption.map(_.tparamClause.values).getOrElse(Nil), + defdef.paramClauseGroups.headOption.map(_.paramClauses.map(_.values)).getOrElse(Nil) + ) } catch { - case e: ParseException => Nil + case e: ParseException => (Nil, Nil) } def filterType(p: Term.Param) = @@ -782,7 +785,10 @@ package """ :+ packageName :+ """ }.mkString } - val renderCall = "def render%s: %s = apply%s".format( + val typeParams = if (tparams.isEmpty) "" else tparams.mkString("[", ",", "]") + + val renderCall = "def render%s%s: %s = apply%s%s".format( + typeParams, "(" + params.flatten .map { case p @ ByNameParam(_, paramType) => p.name.toString + ":" + paramType @@ -790,28 +796,34 @@ package """ :+ packageName :+ """ } .mkString(",") + ")", returnType, + typeParams, applyArgs ) - val f = "def f:%s = %s => apply%s".format( + val f = "def f%s:%s = %s => apply%s%s".format( + typeParams, functionType, params.map(group => "(" + group.map(_.name.toString).mkString(",") + ")").mkString(" => "), + typeParams, applyArgs ) - val templateType = sc.valueOrEmptyIfScala3Exceeding22Params( - params.flatten.size, - "_root_.play.twirl.api.Template%s[%s%s]".format( - params.flatten.size, - params.flatten - .map { - case ByNameParam(_, paramType) => paramType - case p => filterType(p) - } - .mkString(","), - (if (params.flatten.isEmpty) "" else ",") + returnType - ) - ) + val templateType = + if (tparams.isEmpty) + sc.valueOrEmptyIfScala3Exceeding22Params( + params.flatten.size, + "_root_.play.twirl.api.Template%s[%s%s]".format( + params.flatten.size, + params.flatten + .map { + case ByNameParam(_, paramType) => paramType + case p => filterType(p) + } + .mkString(","), + (if (params.flatten.isEmpty) "" else ",") + returnType + ) + ) + else "" // If type params -> no TemplateN trait. Too many possibilities. TODO: Generate TemplateN on the fly? (renderCall, f, templateType) } diff --git a/parser/src/main/scala/play/twirl/parser/TwirlParser.scala b/parser/src/main/scala/play/twirl/parser/TwirlParser.scala index 336e999f..85073385 100644 --- a/parser/src/main/scala/play/twirl/parser/TwirlParser.scala +++ b/parser/src/main/scala/play/twirl/parser/TwirlParser.scala @@ -1254,8 +1254,8 @@ class TwirlParser(val shouldParseInclusiveDot: Boolean) { * Parse the template arguments, if they exist */ private def maybeTemplateArgs(): Option[PosString] = { - if (check("@(")) { - input.regress(1) + val reset = input.offset() + def parseArgs(): Option[PosString] = { val p = input.offset() val args = templateArgs() if (args != null) { @@ -1265,6 +1265,30 @@ class TwirlParser(val shouldParseInclusiveDot: Boolean) { } else { None } + } + if (check("@(")) { + input.regress(1) + parseArgs() + } else if (check("@[")) { + input.regress(1) + Option(squareBrackets()) match { + case Some(value) if value.replaceAll("\\s", "") == "[]" => + input.regressTo(reset) // don't consume @ + error(s"identifier expected but ']' found", reset) // TODO: really reset hier? + None + case Some(types) => + parseArgs() match { + case Some(value) => Some(position(PosString(types + value.str), reset)) + case None => + val result = Some(position(PosString(types + "()"), reset)) + check("\n") + result + } + case None => + input.regressTo(reset) // don't consume @ + error(s"Type parameter(s) expected", reset) // TODO: really reset hier? + None + } } else None }