Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ abstract class TaskHandler {
record.env = task.getEnvironmentStr()
record.executorName = task.processor.executor.getName()
record.containerMeta = task.containerMeta()
record.accelerator = task.config.getAccelerator()?.request
record.accelerator_type = task.config.getAccelerator()?.type

if( isCompleted() ) {
record.error_action = task.errorAction?.toString()
Expand Down
13 changes: 10 additions & 3 deletions modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ class TraceRecord implements Serializable {
vol_ctxt: 'num', // -- /proc/$pid/status field 'voluntary_ctxt_switches'
inv_ctxt: 'num', // -- /proc/$pid/status field 'nonvoluntary_ctxt_switches'
hostname: 'str',
cpu_model: 'str',
accelerator: 'num',
accelerator_type: 'str'
Comment on lines -107 to -108
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields denoted the requested accelerators. Now that you are moving them into the allocated resources, will they denote the accelerators allocated by sched?

cpu_model: 'str'
]

static public Map<String,Closure<String>> FORMATTER = [
Expand All @@ -125,6 +123,7 @@ class TraceRecord implements Serializable {
transient private ContainerMeta containerMeta
transient private Integer numSpotInterruptions
transient private String logStreamId
transient private Map<String,Object> resourceAllocation
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not blocking, but the naming is a bit awkward. how about resourcesAllocated ?


/**
* Convert the given value to a string
Expand Down Expand Up @@ -638,4 +637,12 @@ class TraceRecord implements Serializable {
void setContainerMeta(ContainerMeta meta) {
this.containerMeta = meta
}

Map<String,Object> getResourceAllocation() {
return resourceAllocation
}

void setResourceAllocation(Map<String,Object> value) {
this.resourceAllocation = value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ class TaskHandlerTest extends Specification {
cpus: 2,
time: '1 hour',
disk: '100 GB',
memory: '4 GB',
accelerator: [request: 3, type: 'v100']
memory: '4 GB'
]
def task = new TaskRun(id: new TaskId(100), workDir: folder, name:'task1', exitStatus: 127, config: config )
task.metaClass.getHashLog = { "5d5d7ds" }
Expand Down Expand Up @@ -101,8 +100,6 @@ class TaskHandlerTest extends Specification {
trace.memory == MemoryUnit.of('4 GB').toBytes()
trace.disk == MemoryUnit.of('100 GB').toBytes()
trace.env == 'FOO=hola\nBAR=mundo\nAWS_SECRET=[secure]\n'
trace.accelerator == 3
trace.accelerator_type == 'v100'

// check get method
trace.getFmtStr('%cpu') == '1.0%'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,6 @@ class TraceRecordTest extends Specification {
record.cpus = 4
record.time = 3_600_000L
record.memory = 1024L * 1024L * 1024L * 8L
record.accelerator = 3
record.accelerator_type = 'v100'

when:
def json = new JsonSlurper().parseText(record.renderJson().toString())
Expand All @@ -263,8 +261,6 @@ class TraceRecordTest extends Specification {
json.cpus == '4'
json.time == '1h'
json.memory == '8 GB'
json.accelerator == '3'
json.accelerator_type == 'v100'

}

Expand Down
2 changes: 1 addition & 1 deletion plugins/nf-seqera/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dependencies {
compileOnly project(':nextflow')
compileOnly 'org.slf4j:slf4j-api:2.0.17'
compileOnly 'org.pf4j:pf4j:3.14.1'
api 'io.seqera:sched-client:0.41.0-SNAPSHOT'
api 'io.seqera:sched-client:0.46.0-SNAPSHOT'

testImplementation(testFixtures(project(":nextflow")))
testImplementation "org.apache.groovy:groovy:4.0.30"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,45 @@ class SeqeraTaskHandler extends TaskHandler implements FusionAwareTask {
return cachedTaskState?.getId()
}

/**
* Get the allocated resources for this task from the last task attempt.
* Falls back to the resource requirement from the task state if no attempts exist.
*
* @return a map of allocated resource fields, or null if not available
*/
protected Map<String,Object> getResourceAllocation() {
if (!cachedTaskState)
return null

def resources = null
final attempts = cachedTaskState.getAttempts()
if (attempts && !attempts.isEmpty()) {
resources = attempts.get(attempts.size() - 1).getResources()
}
if (!resources) {
resources = cachedTaskState.getResourceAllocation()
}
if (!resources)
return null

final result = new LinkedHashMap<String,Object>()
if (resources.getCpuShares() != null)
result.put('cpuShares', resources.getCpuShares())
if (resources.getMemoryMiB() != null)
result.put('memoryMiB', resources.getMemoryMiB())
if (resources.getAcceleratorCount() != null)
result.put('acceleratorCount', resources.getAcceleratorCount())
if (resources.getAcceleratorType() != null)
result.put('acceleratorType', resources.getAcceleratorType().toString())
if (resources.getAcceleratorName() != null)
result.put('acceleratorName', resources.getAcceleratorName())
if (resources.getTime() != null)
result.put('time', resources.getTime())
return result.isEmpty() ? null : result
}

protected Long getGrantedTime() {
final time = cachedTaskState?.getResourceRequirement()?.getTime()
final time = cachedTaskState?.getResourceAllocation()?.getTime()
return time != null ? Duration.of(time).toMillis() : task.config.getTime()?.toMillis()
}

Expand All @@ -375,6 +412,7 @@ class SeqeraTaskHandler extends TaskHandler implements FusionAwareTask {
result.machineInfo = getMachineInfo()
result.numSpotInterruptions = getNumSpotInterruptions()
result.logStreamId = getLogStreamId()
result.resourceAllocation = getResourceAllocation()
// Override executor name to include cloud backend for cost tracking
result.executorName = "${SeqeraExecutor.SEQERA}/aws"
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,13 @@ class SeqeraTaskHandlerTest extends Specification {
.type('m5.large')
.zone('us-east-1a')
.priceModel(SchedPriceModel.SPOT)
def resources = new ResourceRequirement().cpuShares(2048).memoryMiB(4096)
def attempt = new TaskAttempt()
.index(1)
.nativeId('arn:aws:ecs:us-east-1:123:task/abc')
.status(SchedTaskStatus.SUCCEEDED)
.machineInfo(machineInfo)
.resources(resources)
handler.cachedTaskState = new SchedTaskState()
.id('tsk-xyz789')
.attempts([attempt])
Expand All @@ -252,6 +254,7 @@ class SeqeraTaskHandlerTest extends Specification {
trace.getMachineInfo().type == 'm5.large'
trace.getNumSpotInterruptions() == 2
trace.getLogStreamId() == 'log-stream-xyz'
trace.getResourceAllocation() == [cpuShares: 2048, memoryMiB: 4096]
trace.getExecutorName() == 'seqera/aws'
}

Expand Down Expand Up @@ -579,11 +582,92 @@ class SeqeraTaskHandlerTest extends Specification {
result == [MY_VAR: 'my_val', SHARED_KEY: 'fusion_val']
}

def 'should return null for getResourceAllocation when cachedTaskState is null'() {
given:
def handler = createHandler()

expect:
handler.getResourceAllocation() == null
}

def 'should return allocated resources from last task attempt'() {
given:
def handler = createHandler()
def resources = new ResourceRequirement()
.cpuShares(2048)
.memoryMiB(4096)
def attempt = new TaskAttempt()
.index(1)
.nativeId('task-1')
.status(SchedTaskStatus.SUCCEEDED)
.resources(resources)
handler.cachedTaskState = new SchedTaskState().attempts([attempt])

when:
def result = handler.getResourceAllocation()

then:
result != null
result.cpuShares == 2048
result.memoryMiB == 4096
}

def 'should use last attempt resources when multiple attempts exist'() {
given:
def handler = createHandler()
def resources1 = new ResourceRequirement().cpuShares(1024).memoryMiB(2048)
def resources2 = new ResourceRequirement().cpuShares(4096).memoryMiB(8192)
def attempt1 = new TaskAttempt().index(1).nativeId('task-1').status(SchedTaskStatus.FAILED).resources(resources1)
def attempt2 = new TaskAttempt().index(2).nativeId('task-2').status(SchedTaskStatus.SUCCEEDED).resources(resources2)
handler.cachedTaskState = new SchedTaskState().attempts([attempt1, attempt2])

when:
def result = handler.getResourceAllocation()

then:
result.cpuShares == 4096
result.memoryMiB == 8192
}

def 'should fallback to taskState resourceRequirement when no attempts'() {
given:
def handler = createHandler()
def resources = new ResourceRequirement().cpuShares(1024).memoryMiB(2048).time('1h')
handler.cachedTaskState = new SchedTaskState()
.attempts([])
.resourceAllocation(resources)

when:
def result = handler.getResourceAllocation()

then:
result.cpuShares == 1024
result.memoryMiB == 2048
result.time == '1h'
}

def 'should fallback to taskState resourceRequirement when last attempt has no resources'() {
given:
def handler = createHandler()
def attempt = new TaskAttempt().index(1).nativeId('task-1').status(SchedTaskStatus.SUCCEEDED)
def resources = new ResourceRequirement().cpuShares(512).memoryMiB(1024)
handler.cachedTaskState = new SchedTaskState()
.attempts([attempt])
.resourceAllocation(resources)

when:
def result = handler.getResourceAllocation()

then:
result.cpuShares == 512
result.memoryMiB == 1024
}

def 'should return granted time from resource requirement'() {
given:
def handler = createHandler()
handler.cachedTaskState = new SchedTaskState()
.resourceRequirement(new ResourceRequirement().time('2h'))
.resourceAllocation(new ResourceRequirement().time('2h'))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ResourceRequirement be renamed to something like ResourceRequest ? Seems like here it is being used to model things other than hard requirements


expect:
handler.getGrantedTime() == Duration.of('2h').toMillis()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ class TowerClient implements TraceObserverV2 {
record.priceModel = trace.getMachineInfo()?.priceModel?.toString()
record.numSpotInterruptions = trace.getNumSpotInterruptions()
record.logStreamId = trace.getLogStreamId()
record.resourceAllocation = trace.getResourceAllocation()

return record
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ class TowerClientTest extends Specification {
req.tasks[0].logStreamId == 'arn:aws:logs:us-east-1:123456789:log-group:/ecs/task:log-stream:abc123'
}

def 'should include accelerator request in task map'() {
def 'should include resourceAllocation in task map'() {
given:
def client = Spy(new TowerClient())
client.getWorkflowProgress(true) >> new WorkflowProgress()
Expand All @@ -681,18 +681,16 @@ class TowerClientTest extends Specification {
cpus: 1,
submit: now-2000,
start: now-1000,
complete: now,
accelerator: 2,
acceleratorType: 'v100'
complete: now
])
trace.setResourceAllocation([cpuShares: 2048, memoryMiB: 4096, time: '1h'])

when:
def req = client.makeTasksReq([trace])

then:
req.tasks.size() == 1
req.tasks[0].accelerator == 2
req.tasks[0].acceleratorType == 'v100'
req.tasks[0].resourceAllocation == [cpuShares: 2048, memoryMiB: 4096, time: '1h']
}

def 'should return error response on http request timeout' () {
Expand Down
Loading