Skip to content

Commit

Permalink
wip: workflow spi [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
octonato committed Nov 29, 2024
1 parent a215300 commit ff3feba
Show file tree
Hide file tree
Showing 9 changed files with 317 additions and 470 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
package akka.javasdk.testkit;

import akka.actor.typed.ActorSystem;
import akka.annotation.InternalApi;
import akka.http.javadsl.Http;
import akka.http.javadsl.model.HttpRequest;
import akka.javasdk.DependencyProvider;
import akka.javasdk.Metadata;
import akka.javasdk.client.ComponentClient;
import akka.javasdk.http.HttpClient;
import akka.javasdk.http.HttpClientProvider;
import akka.javasdk.impl.ApplicationConfig;
import akka.javasdk.impl.ErrorHandling;
import akka.javasdk.impl.JsonMessageCodec;
import akka.javasdk.impl.MessageCodec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public WorkflowDef<TransferState> definition() {
.addStep(deposit);
}

public Effect<Message> startTransfer(Transfer transfer) {
public Effect<Message> startTransfer(Transfer transfer) {vl
if (transfer.amount <= 0.0) {
return effects().reply(new Message("Transfer amount should be greater than zero"));
} else {
Expand Down
34 changes: 0 additions & 34 deletions akka-javasdk/src/main/java/akka/javasdk/workflow/StepBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,40 +60,6 @@ public <Output> AsyncCallStepBuilder<Void, Output> asyncCall(Supplier<Completion
}


public static class CallStepBuilder<Input, DefCallInput, DefCallOutput> {

final private String name;

final private Class<Input> callInputClass;
/* callFactory builds the DeferredCall that will be passed to runtime for execution */
final private Function<Input, DeferredCall<DefCallInput, DefCallOutput>> callFunc;

public CallStepBuilder(String name, Class<Input> callInputClass, Function<Input, DeferredCall<DefCallInput, DefCallOutput>> callFunc) {
this.name = name;
this.callInputClass = callInputClass;
this.callFunc = callFunc;
}

/**
* Transition to the next step based on the result of the step call.
* <p>
* The {@link Function} passed to this method should receive the return type of the step call and return
* an {@link Workflow.Effect.TransitionalEffect} describing the next step to transition to.
* <p>
* When defining the Effect, you can update the workflow state and indicate the next step to transition to.
* This can be another step, or a pause or end of the workflow.
* <p>
* When transition to another step, you can also pass an input parameter to the next step.
*
* @param transitionInputClass Input class for transition.
* @param transitionFunc Function that transform the action result to a {@link Workflow.Effect.TransitionalEffect}
* @return CallStep
*/
public Workflow.CallStep<Input, DefCallInput, DefCallOutput, ?> andThen(Class<DefCallOutput> transitionInputClass, Function<DefCallOutput, Workflow.Effect.TransitionalEffect<Void>> transitionFunc) {
return new Workflow.CallStep<>(name, callInputClass, callFunc, transitionInputClass, transitionFunc);
}
}

public static class AsyncCallStepBuilder<CallInput, CallOutput> {

final private String name;
Expand Down
162 changes: 91 additions & 71 deletions akka-javasdk/src/main/scala/akka/javasdk/impl/SdkRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import java.util.concurrent.CompletionStage
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.jdk.CollectionConverters._
import scala.jdk.FutureConverters._
import scala.jdk.OptionConverters.RichOptional
import scala.reflect.ClassTag
import scala.util.control.NonFatal

Expand All @@ -21,15 +23,18 @@ import akka.annotation.InternalApi
import akka.http.scaladsl.model.headers.RawHeader
import akka.javasdk.BuildInfo
import akka.javasdk.DependencyProvider
import akka.javasdk.JwtClaims
import akka.javasdk.Principals
import akka.javasdk.ServiceSetup
import akka.javasdk.Tracing
import akka.javasdk.annotations.ComponentId
import akka.javasdk.annotations.Setup
import akka.javasdk.annotations.http.HttpEndpoint
import akka.javasdk.client.ComponentClient
import akka.javasdk.consumer.Consumer
import akka.javasdk.eventsourcedentity.EventSourcedEntity
import akka.javasdk.eventsourcedentity.EventSourcedEntityContext
import akka.javasdk.http.AbstractHttpEndpoint
import akka.javasdk.http.HttpClientProvider
import akka.javasdk.http.RequestContext
import akka.javasdk.impl.Sdk.StartupContext
Expand All @@ -42,10 +47,13 @@ import akka.javasdk.impl.consumer.ConsumerService
import akka.javasdk.impl.eventsourcedentity.EventSourcedEntitiesImpl
import akka.javasdk.impl.eventsourcedentity.EventSourcedEntityService
import akka.javasdk.impl.http.HttpClientProviderImpl
import akka.javasdk.impl.http.JwtClaimsImpl
import akka.javasdk.impl.keyvalueentity.KeyValueEntitiesImpl
import akka.javasdk.impl.keyvalueentity.KeyValueEntityService
import akka.javasdk.impl.reflection.Reflect
import akka.javasdk.impl.reflection.Reflect.Syntax.AnnotatedElementOps
import akka.javasdk.impl.telemetry.SpanTracingImpl
import akka.javasdk.impl.telemetry.TraceInstrumentation
import akka.javasdk.impl.timedaction.TimedActionService
import akka.javasdk.impl.timer.TimerSchedulerImpl
import akka.javasdk.impl.view.ViewService
Expand All @@ -59,13 +67,8 @@ import akka.javasdk.timer.TimerScheduler
import akka.javasdk.view.View
import akka.javasdk.workflow.Workflow
import akka.javasdk.workflow.WorkflowContext
import akka.javasdk.JwtClaims
import akka.javasdk.http.AbstractHttpEndpoint
import akka.javasdk.Tracing
import akka.javasdk.impl.http.JwtClaimsImpl
import akka.javasdk.impl.telemetry.SpanTracingImpl
import akka.javasdk.impl.telemetry.TraceInstrumentation
import akka.runtime.sdk.spi.ComponentClients
import akka.runtime.sdk.spi.EventSourcedEntityDescriptor
import akka.runtime.sdk.spi.HttpEndpointConstructionContext
import akka.runtime.sdk.spi.HttpEndpointDescriptor
import akka.runtime.sdk.spi.RemoteIdentification
Expand All @@ -74,7 +77,9 @@ import akka.runtime.sdk.spi.SpiDevModeSettings
import akka.runtime.sdk.spi.SpiEventingSupportSettings
import akka.runtime.sdk.spi.SpiMockedEventingSettings
import akka.runtime.sdk.spi.SpiSettings
import akka.runtime.sdk.spi.SpiWorkflow
import akka.runtime.sdk.spi.StartContext
import akka.runtime.sdk.spi.WorkflowDescriptor
import akka.stream.Materializer
import com.google.protobuf.Descriptors
import com.typesafe.config.Config
Expand Down Expand Up @@ -308,27 +313,30 @@ private final class Sdk(
private val componentFactories: Map[Descriptors.ServiceDescriptor, Service] = componentClasses
.filter(hasComponentId)
.foldLeft(Map[Descriptors.ServiceDescriptor, Service]()) { (factories, clz) =>
val service = if (classOf[TimedAction].isAssignableFrom(clz)) {
logger.debug(s"Registering TimedAction [${clz.getName}]")
timedActionService(clz.asInstanceOf[Class[TimedAction]])
} else if (classOf[Consumer].isAssignableFrom(clz)) {
logger.debug(s"Registering Consumer [${clz.getName}]")
consumerService(clz.asInstanceOf[Class[Consumer]])
} else if (classOf[EventSourcedEntity[_, _]].isAssignableFrom(clz)) {
logger.debug(s"Registering EventSourcedEntity [${clz.getName}]")
eventSourcedEntityService(clz.asInstanceOf[Class[EventSourcedEntity[Nothing, Nothing]]])
} else if (classOf[Workflow[_]].isAssignableFrom(clz)) {
logger.debug(s"Registering Workflow [${clz.getName}]")
workflowService(clz.asInstanceOf[Class[Workflow[Nothing]]])
} else if (classOf[KeyValueEntity[_]].isAssignableFrom(clz)) {
logger.debug(s"Registering KeyValueEntity [${clz.getName}]")
keyValueEntityService(clz.asInstanceOf[Class[KeyValueEntity[Nothing]]])
} else if (Reflect.isView(clz)) {
logger.debug(s"Registering View [${clz.getName}]")
viewService(clz.asInstanceOf[Class[View]])
} else throw new IllegalArgumentException(s"Component class of unknown component type [$clz]")

factories.updated(service.descriptor, service)

val serviceOpt =
if (classOf[TimedAction].isAssignableFrom(clz)) {
logger.debug(s"Registering TimedAction [${clz.getName}]")
Some(timedActionService(clz.asInstanceOf[Class[TimedAction]]))
} else if (classOf[Consumer].isAssignableFrom(clz)) {
logger.debug(s"Registering Consumer [${clz.getName}]")
Some(consumerService(clz.asInstanceOf[Class[Consumer]]))
} else if (classOf[EventSourcedEntity[_, _]].isAssignableFrom(clz)) {
logger.debug(s"Registering EventSourcedEntity [${clz.getName}]")
Some(eventSourcedEntityService(clz.asInstanceOf[Class[EventSourcedEntity[Nothing, Nothing]]]))
} else if (classOf[Workflow[_]].isAssignableFrom(clz)) {
None
} else if (classOf[KeyValueEntity[_]].isAssignableFrom(clz)) {
logger.debug(s"Registering KeyValueEntity [${clz.getName}]")
Some(keyValueEntityService(clz.asInstanceOf[Class[KeyValueEntity[Nothing]]]))
} else if (Reflect.isView(clz)) {
logger.debug(s"Registering View [${clz.getName}]")
Some(viewService(clz.asInstanceOf[Class[View]]))
} else throw new IllegalArgumentException(s"Component class of unknown component type [$clz]")

serviceOpt
.map(service => factories.updated(service.descriptor, service))
.getOrElse(factories)
}

private def hasComponentId(clz: Class[_]): Boolean = {
Expand All @@ -337,7 +345,7 @@ private final class Sdk(
} else {
//additional check to skip logging for endpoints
if (!clz.hasAnnotation[HttpEndpoint]) {
//this could happened when we remove the @ComponentId annotation from the class,
//this could happen when we remove the @ComponentId annotation from the class,
//the file descriptor generated by annotation processor might still have this class entry,
//for instance when working with IDE and incremental compilation (without clean)
logger.warn("Ignoring component [{}] as it does not have the @ComponentId annotation", clz.getName)
Expand All @@ -353,6 +361,55 @@ private final class Sdk(
HttpEndpointDescriptorFactory(httpEndpointClass, httpEndpointFactory(httpEndpointClass))
}

private val workflowDescriptors: Seq[WorkflowDescriptor] = componentClasses
.filter(hasComponentId)
.filter(Reflect.isWorkflow)
.map {
case clz if Reflect.isWorkflow(clz) =>
val componentId = clz.getAnnotation(classOf[ComponentId]).value
new WorkflowDescriptor(
componentId,
id => workflowInstanceFactory(id, clz.asInstanceOf[Class[Workflow[Nothing]]]))
}

private def workflowInstanceFactory[S, W <: Workflow[S]](workflowId: String, clz: Class[W]): SpiWorkflow = {
logger.debug(s"Registering Workflow [${clz.getName}]")
new WorkflowImpl[S, W](
workflowId,
clz,
messageCodec,
timerClient = runtimeComponentClients.timerClient,
sdkExecutionContext,
sdkTracerFactory,
{ context =>

val workflow = wiredInstance(clz) {
sideEffectingComponentInjects(None).orElse {
// remember to update component type API doc and docs if changing the set of injectables
case p if p == classOf[WorkflowContext] => context
}
}

// FIXME pull this inline setup stuff out of SdkRunner and into some workflow class
val workflowStateType: Class[_] = Reflect.workflowStateType[S, W](workflow)
messageCodec.registerTypeHints(workflowStateType)

workflow
.definition()
.getSteps
.asScala
.flatMap {
case asyncCallStep: Workflow.AsyncCallStep[_, _, _] =>
List(asyncCallStep.callInputClass, asyncCallStep.transitionInputClass)
case callStep: Workflow.CallStep[_, _, _, _] =>
List(callStep.callInputClass, callStep.transitionInputClass)
}
.foreach(messageCodec.registerTypeHints)

workflow
})
}

private val eventSourcedEntityDescriptors =
componentClasses
.filter(hasComponentId)
Expand Down Expand Up @@ -412,7 +469,6 @@ private final class Sdk(
var eventSourcedEntitiesEndpoint: Option[EventSourcedEntities] = None
var valueEntitiesEndpoint: Option[ValueEntities] = None
var viewsEndpoint: Option[Views] = None
var workflowEntitiesEndpoint: Option[WorkflowEntities] = None

val classicSystem = system.classicSystem

Expand Down Expand Up @@ -453,15 +509,9 @@ private final class Sdk(
valueEntitiesEndpoint = Some(
new KeyValueEntitiesImpl(classicSystem, entityServices, sdkSettings, sdkDispatcherName, sdkTracerFactory))

case (serviceClass, workflowServices: Map[String, WorkflowService[_, _]] @unchecked)
case (serviceClass, _: Map[String, WorkflowService[_, _]] @unchecked)
if serviceClass == classOf[WorkflowService[_, _]] =>
workflowEntitiesEndpoint = Some(
new WorkflowImpl(
workflowServices,
runtimeComponentClients.timerClient,
sdkExecutionContext,
sdkDispatcherName,
sdkTracerFactory))
//ignore

case (serviceClass, _: Map[String, TimedActionService[_]] @unchecked)
if serviceClass == classOf[TimedActionService[_]] =>
Expand Down Expand Up @@ -535,12 +585,15 @@ private final class Sdk(
Sdk.this.eventSourcedEntityDescriptors
override def valueEntities: Option[ValueEntities] = valueEntitiesEndpoint
override def views: Option[Views] = viewsEndpoint
override def workflowEntities: Option[WorkflowEntities] = workflowEntitiesEndpoint
override def workflowEntities: Option[WorkflowEntities] = None
override def httpEndpointDescriptors: Seq[HttpEndpointDescriptor] =
Sdk.this.httpEndpointDescriptors

override def timedActionsDescriptors: Seq[TimedActionDescriptor] =
Sdk.this.timedActionDescriptors

override def workflowDescriptors: Seq[WorkflowDescriptor] = Sdk.this.workflowDescriptors

}
}

Expand All @@ -550,39 +603,6 @@ private final class Sdk(
private def consumerService[A <: Consumer](clz: Class[A]): ConsumerService[A] =
new ConsumerService[A](clz, messageCodec, () => wiredInstance(clz)(sideEffectingComponentInjects(None)))

private def workflowService[S, W <: Workflow[S]](clz: Class[W]): WorkflowService[S, W] = {
new WorkflowService[S, W](
clz,
messageCodec,
{ context =>

val workflow = wiredInstance(clz) {
sideEffectingComponentInjects(None).orElse {
// remember to update component type API doc and docs if changing the set of injectables
case p if p == classOf[WorkflowContext] => context
}
}

// FIXME pull this inline setup stuff out of SdkRunner and into some workflow class
val workflowStateType: Class[S] = Reflect.workflowStateType(workflow)
messageCodec.registerTypeHints(workflowStateType)

workflow
.definition()
.getSteps
.asScala
.flatMap {
case asyncCallStep: Workflow.AsyncCallStep[_, _, _] =>
List(asyncCallStep.callInputClass, asyncCallStep.transitionInputClass)
case callStep: Workflow.CallStep[_, _, _, _] =>
List(callStep.callInputClass, callStep.transitionInputClass)
}
.foreach(messageCodec.registerTypeHints)

workflow
})
}

private def eventSourcedEntityService[S, E, ES <: EventSourcedEntity[S, E]](
clz: Class[ES]): EventSourcedEntityService[S, E, ES] =
EventSourcedEntityService(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import com.google.protobuf.any.{ Any => ScalaPbAny }
*/
@InternalApi
class ReflectiveWorkflowRouter[S, W <: Workflow[S]](
override protected val workflow: W,
override val workflow: W,
commandHandlers: Map[String, CommandHandler])
extends WorkflowRouter[S, W](workflow) {

Expand All @@ -38,7 +38,7 @@ class ReflectiveWorkflowRouter[S, W <: Workflow[S]](

val scalaPbAnyCommand = command.asInstanceOf[ScalaPbAny]
if (AnySupport.isJson(scalaPbAnyCommand)) {
// special cased component client calls, lets json commands trough all the way
// special cased component client calls, lets json commands through all the way
val methodInvoker = commandHandler.getSingleNameInvoker()
val deserializedCommand =
CommandSerialization.deserializeComponentClientCommand(methodInvoker.method, scalaPbAnyCommand)
Expand Down
Loading

0 comments on commit ff3feba

Please sign in to comment.