|
| 1 | +import scala.quoted.* |
| 2 | + |
| 3 | +object Repro: |
| 4 | + inline def singletonValues[A]: List[A] = ${ singletonValuesImpl[A] } |
| 5 | + |
| 6 | + private def singletonValuesImpl[A: Type](using Quotes): Expr[List[A]] = |
| 7 | + import quotes.reflect.* |
| 8 | + |
| 9 | + def collectValues[T: Type]: List[Expr[T]] = |
| 10 | + val tpe = TypeRepr.of[T] |
| 11 | + tpe.dealias match |
| 12 | + case o: OrType => unionValues[T](o) |
| 13 | + case _ if tpe.isSingleton => singletonValue[T](tpe) :: Nil |
| 14 | + case _ => |
| 15 | + val sym = tpe.typeSymbol |
| 16 | + if sym.flags.is(Flags.Sealed) then sumValues[T](sym) |
| 17 | + else if sym.flags.is(Flags.Case) && sym.caseFields.nonEmpty then productValues[T](tpe, sym) |
| 18 | + else |
| 19 | + report.errorAndAbort( |
| 20 | + s"Cannot derive values for ${tpe.show}. Supported types: singleton types, " + |
| 21 | + "enums, sealed traits/classes, union types, and case classes/tuples whose fields are all enumerable." |
| 22 | + ) |
| 23 | + |
| 24 | + def singletonValue[T: Type](tpe: TypeRepr): Expr[T] = |
| 25 | + tpe.asType match |
| 26 | + case '[t] => |
| 27 | + Expr.summon[ValueOf[t]] match |
| 28 | + case Some(vo) => '{ $vo.value }.asExprOf[T] |
| 29 | + case None => report.errorAndAbort(s"Cannot determine value for singleton type: ${tpe.show}") |
| 30 | + |
| 31 | + def unionValues[T: Type](orType: OrType): List[Expr[T]] = |
| 32 | + def extract(tpe: TypeRepr): List[Expr[T]] = |
| 33 | + tpe.dealias match |
| 34 | + case o: OrType => extract(o.left) ++ extract(o.right) |
| 35 | + case s if s.isSingleton => singletonValue[T](s) :: Nil |
| 36 | + case other => report.errorAndAbort(s"Unsupported type in union: ${other.show}.") |
| 37 | + extract(orType) |
| 38 | + |
| 39 | + def sumValues[T: Type](sym: Symbol): List[Expr[T]] = |
| 40 | + sym.children.flatMap { child => |
| 41 | + if child.isTerm then |
| 42 | + child.termRef.asType match |
| 43 | + case '[t] => |
| 44 | + Expr.summon[ValueOf[t]] match |
| 45 | + case Some(vo) => '{ $vo.value }.asExprOf[T] :: Nil |
| 46 | + case None => report.errorAndAbort(s"Cannot get value for: ${child.name}") |
| 47 | + else |
| 48 | + child.typeRef.asType match |
| 49 | + case '[c] => collectValues[c].map(_.asExprOf[T]) |
| 50 | + } |
| 51 | + |
| 52 | + def productValues[T: Type](tpe: TypeRepr, sym: Symbol): List[Expr[T]] = |
| 53 | + val constructorParams = sym.primaryConstructor.paramSymss.flatten.filter(_.isTerm) |
| 54 | + val fieldTypes = constructorParams.map(f => tpe.memberType(f).widen.dealias) |
| 55 | + val fieldValueExprs: List[List[Term]] = fieldTypes.map { ft => |
| 56 | + ft.asType match |
| 57 | + case '[f] => collectValues[f].map(_.asTerm) |
| 58 | + } |
| 59 | + val cartesian = fieldValueExprs.foldRight(List(List.empty[Term])) { (vals, acc) => |
| 60 | + for v <- vals; rest <- acc yield v :: rest |
| 61 | + } |
| 62 | + cartesian.map { args => |
| 63 | + val companion = Ref(sym.companionModule) |
| 64 | + val applyMethod = sym.companionModule.methodMember("apply").head |
| 65 | + val typeParams = applyMethod.paramSymss.flatten.filter(_.isTypeParam) |
| 66 | + if typeParams.nonEmpty then Select.overloaded(companion, "apply", fieldTypes, args).asExprOf[T] |
| 67 | + else Select.overloaded(companion, "apply", Nil, args).asExprOf[T] |
| 68 | + } |
| 69 | + |
| 70 | + Expr.ofList(collectValues[A]) |
0 commit comments