Home > Blockchain >  Scala 3 : Finding functions with the given annotation
Scala 3 : Finding functions with the given annotation

Time:02-12

For Scala 3 macros, does anyone know of a way to find all functions with a given annotation?

For instance:

@fruit
def apple(): Int = ???

@fruit
def banana(): Int = ???

@fruit
def coconut(): Int = ???

@fruit
def durian(): Int = ???

def elephant(): Int = ???

@fruit
def fig(): Int = ???

I would want to find a list of apple, banana, coconut, durian, fig. They could be defined anywhere, but in my case they will all be in a single package.

CodePudding user response:

This solution will extract all the definitions with some annotation from a given package. I will leverage also the compile-time reflection.

This solution will extract all the definitions with some annotations from a given package. I will also leverage the compile-time reflection. So, To solve your problem, we need to divide it in:

  • methods gathering from a package;
  • filter only methods with a given annotation;
  • transform symbols in function application. I suppose that you can pass the package and the annotation (and also the return type) as a type argument. So the macro signature is something like that:
inline def findAllFunction[P, A <: ConstantAnnotation, R]: List[() => R] = 
    ${Implementation.myMacroImpl[P, A, R]()}

The first point is straightforward. we could extract all the methods defined as:

def methodsFromPackage(packageSymbol: Symbol): List[Symbol] =
  packageSymbol.declaredTypes
    .filter(_.isClassDef)
    .flatMap(_.declaredMethods)

The second point is also quite easy. Symbol class has the method hasAnnotation that could be used in this case:

def methodsAnnotatatedWith(
    methods: List[Symbol],
    annotation: Symbol
): List[Symbol] =
  methods.filter(_.hasAnnotation(annotation))

The last point is a little bit challenging. Here we should construct the method call. So we need to create the AST that correspond to the method call. Inspired by this example, we can call definition using Apply. Select and This serve to select the correct method that will be called:

def transformToFunctionApplication(methods: List[Symbol]): Expr[List[() => R]] =
  val appliedDef = methods
    .map(definition => Select(This(definition.owner), definition))
    .map(select => Apply(select, List.empty))
    .map(apply => '{ () => ${ apply.asExprOf[R] } })
  Expr.ofList(appliedDef)

Here I used lamba call, if you want to return directly the value you should change the last two instructions:

def transformToFunctionApplication(methods: List[Symbol]): Expr[List[R]] =
  val appliedDef = methods
    .map(definition => Select(This(definition.owner), definition))
    .map(select => Apply(select, List.empty))
    .map(apply => apply.asExprOf[R])

  Expr.ofList(appliedDef)

To sum up, the all methods could be defined as:

def myMacroImpl[P: Type, A: Type, R: Type]()(using
    Quotes
): Expr[List[() => R]] = {
  import quotes.reflect.*
  val annotation = TypeRepr.of[A].typeSymbol
  val moduleTarget = TypeRepr.of[P].typeSymbol

  def methodsFromPackage(packageSymbol: Symbol): List[Symbol] =
    packageSymbol.declaredTypes
      .filter(_.isClassDef)
      .flatMap(_.declaredMethods)

  def methodsAnnotatatedWith(
      methods: List[Symbol],
      annotation: Symbol
  ): List[Symbol] =
    methods.filter(_.hasAnnotation(annotation))

  def transformToFunctionApplication(
      methods: List[Symbol]
  ): Expr[List[() => R]] =
    val appliedDef = methods
      .map(definition => Select(This(definition.owner), definition))
      .map(select => Apply(select, List.empty))
      .map(apply => '{ () => ${ apply.asExprOf[R] } })
    Expr.ofList(appliedDef)

  val methods = methodsFromPackage(moduleTarget)
  val annotatedMethod = methodsAnnotatatedWith(methods, annotation)
  transformToFunctionApplication(annotatedMethod)
}

Finally, you can use the macro as:

package org.tests
import org.tests.Macros.fruit

package foo {
  @fruit
  def check(): Int = 10
  @fruit
  def other(): Int = 11
}


@main def hello: Unit = 
  println("Hello world!")
  println(Macros.findAllFunction[org.tests.foo, fruit, Int].map(_.apply())) /// List(10, 11)

Scastie

  • Related