Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[FLINK-33532][network] Move the serialization of ShuffleDescr… #24441

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import java.io.IOException;
import java.io.InputStream;
import java.util.Optional;

/** BlobWriter is used to upload data to the BLOB store. */
public interface BlobWriter {
Expand Down Expand Up @@ -103,24 +102,22 @@ static <T> Either<SerializedValue<T>, PermanentBlobKey> tryOffload(
if (serializedValue.getByteArray().length < blobWriter.getMinOffloadingSize()) {
return Either.Left(serializedValue);
} else {
return offloadWithException(serializedValue, jobId, blobWriter)
.map(Either::<SerializedValue<T>, PermanentBlobKey>Right)
.orElse(Either.Left(serializedValue));
return offloadWithException(serializedValue, jobId, blobWriter);
}
}

static <T> Optional<PermanentBlobKey> offloadWithException(
static <T> Either<SerializedValue<T>, PermanentBlobKey> offloadWithException(
SerializedValue<T> serializedValue, JobID jobId, BlobWriter blobWriter) {
Preconditions.checkNotNull(serializedValue);
Preconditions.checkNotNull(jobId);
Preconditions.checkNotNull(blobWriter);
try {
final PermanentBlobKey permanentBlobKey =
blobWriter.putPermanent(jobId, serializedValue.getByteArray());
return Optional.of(permanentBlobKey);
return Either.Right(permanentBlobKey);
} catch (IOException e) {
LOG.warn("Failed to offload value for job {} to BLOB store.", jobId, e);
return Optional.empty();
return Either.Left(serializedValue);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void serializeShuffleDescriptors(
new ShuffleDescriptorGroup(
toBeSerialized.toArray(new ShuffleDescriptorAndIndex[0]));
MaybeOffloaded<ShuffleDescriptorGroup> serializedShuffleDescriptorGroup =
shuffleDescriptorSerializer.trySerializeAndOffloadShuffleDescriptor(
shuffleDescriptorSerializer.serializeAndTryOffloadShuffleDescriptor(
shuffleDescriptorGroup, numConsumers);

toBeSerialized.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.apache.flink.runtime.blob.PermanentBlobKey;
import org.apache.flink.runtime.blob.PermanentBlobService;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup;
Expand Down Expand Up @@ -98,7 +98,9 @@ public InputGateDeploymentDescriptor(
new IndexRange(consumedSubpartitionIndex, consumedSubpartitionIndex),
inputChannels.length,
Collections.singletonList(
new NonOffloadedRaw<>(new ShuffleDescriptorGroup(inputChannels))));
new NonOffloaded<>(
CompressedSerializedValue.fromObject(
new ShuffleDescriptorGroup(inputChannels)))));
}

public InputGateDeploymentDescriptor(
Expand Down Expand Up @@ -145,14 +147,18 @@ public ShuffleDescriptor[] getShuffleDescriptors() {
// This is only for testing scenarios, in a production environment we always call
// tryLoadAndDeserializeShuffleDescriptors to deserialize ShuffleDescriptors first.
inputChannels = new ShuffleDescriptor[numberOfInputChannels];
for (MaybeOffloaded<ShuffleDescriptorGroup> rawShuffleDescriptors :
serializedInputChannels) {
checkState(
rawShuffleDescriptors instanceof NonOffloadedRaw,
"Trying to work with offloaded serialized shuffle descriptors.");
NonOffloadedRaw<ShuffleDescriptorGroup> nonOffloadedRawValue =
(NonOffloadedRaw<ShuffleDescriptorGroup>) rawShuffleDescriptors;
putOrReplaceShuffleDescriptors(nonOffloadedRawValue.value);
try {
for (MaybeOffloaded<ShuffleDescriptorGroup> serializedShuffleDescriptors :
serializedInputChannels) {
checkState(
serializedShuffleDescriptors instanceof NonOffloaded,
"Trying to work with offloaded serialized shuffle descriptors.");
NonOffloaded<ShuffleDescriptorGroup> nonOffloadedSerializedValue =
(NonOffloaded<ShuffleDescriptorGroup>) serializedShuffleDescriptors;
tryDeserializeShuffleDescriptorGroup(nonOffloadedSerializedValue);
}
} catch (ClassNotFoundException | IOException e) {
throw new RuntimeException("Could not deserialize shuffle descriptors.", e);
}
}
return inputChannels;
Expand Down Expand Up @@ -207,12 +213,21 @@ private void tryLoadAndDeserializeShuffleDescriptorGroup(
}
putOrReplaceShuffleDescriptors(shuffleDescriptorGroup);
} else {
NonOffloadedRaw<ShuffleDescriptorGroup> nonOffloadedSerializedValue =
(NonOffloadedRaw<ShuffleDescriptorGroup>) serializedShuffleDescriptors;
putOrReplaceShuffleDescriptors(nonOffloadedSerializedValue.value);
NonOffloaded<ShuffleDescriptorGroup> nonOffloadedSerializedValue =
(NonOffloaded<ShuffleDescriptorGroup>) serializedShuffleDescriptors;
tryDeserializeShuffleDescriptorGroup(nonOffloadedSerializedValue);
}
}

private void tryDeserializeShuffleDescriptorGroup(
NonOffloaded<ShuffleDescriptorGroup> nonOffloadedShuffleDescriptorGroup)
throws IOException, ClassNotFoundException {
ShuffleDescriptorGroup shuffleDescriptorGroup =
nonOffloadedShuffleDescriptorGroup.serializedValue.deserializeValue(
getClass().getClassLoader());
putOrReplaceShuffleDescriptors(shuffleDescriptorGroup);
}

private void putOrReplaceShuffleDescriptors(ShuffleDescriptorGroup shuffleDescriptorGroup) {
for (ShuffleDescriptorAndIndex shuffleDescriptorAndIndex :
shuffleDescriptorGroup.getShuffleDescriptors()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,6 @@ public NonOffloaded(SerializedValue<T> serializedValue) {
}
}

/**
* The raw value that is not offloaded to the {@link org.apache.flink.runtime.blob.BlobServer}.
*
* @param <T> type of the raw value
*/
public static class NonOffloadedRaw<T> extends MaybeOffloaded<T> {
private static final long serialVersionUID = 1L;

/** The raw value. */
public T value;

@SuppressWarnings("unused")
public NonOffloadedRaw() {}

public NonOffloadedRaw(T value) {
this.value = Preconditions.checkNotNull(value);
}
}

/**
* Reference to a serialized value that was offloaded to the {@link
* org.apache.flink.runtime.blob.BlobServer}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,11 @@
import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor;
import org.apache.flink.types.Either;
import org.apache.flink.util.CompressedSerializedValue;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.SerializedValue;

import javax.annotation.Nullable;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -452,7 +449,7 @@ public int getIndex() {
public static class ShuffleDescriptorGroup implements Serializable {
private static final long serialVersionUID = 1L;

private ShuffleDescriptorAndIndex[] shuffleDescriptors;
private final ShuffleDescriptorAndIndex[] shuffleDescriptors;

public ShuffleDescriptorGroup(ShuffleDescriptorAndIndex[] shuffleDescriptors) {
this.shuffleDescriptors = checkNotNull(shuffleDescriptors);
Expand All @@ -461,31 +458,19 @@ public ShuffleDescriptorGroup(ShuffleDescriptorAndIndex[] shuffleDescriptors) {
public ShuffleDescriptorAndIndex[] getShuffleDescriptors() {
return shuffleDescriptors;
}

private void writeObject(ObjectOutputStream oos) throws IOException {
byte[] bytes = InstantiationUtil.serializeObjectAndCompress(shuffleDescriptors);
oos.writeObject(bytes);
}

private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException {
byte[] bytes = (byte[]) ois.readObject();
shuffleDescriptors =
InstantiationUtil.decompressAndDeserializeObject(
bytes, ClassLoader.getSystemClassLoader());
}
}

/** Offload shuffle descriptors. */
/** Serialize shuffle descriptors. */
interface ShuffleDescriptorSerializer {
/**
* Try to serialize and offload shuffle descriptors.
* Serialize and try offload shuffle descriptors.
*
* @param shuffleDescriptorGroup to serialize and offload
* @param shuffleDescriptorGroup to serialize
* @param numConsumer consumers number of these shuffle descriptors, it means how many times
* serialized shuffle descriptor should be sent
* @return offloaded serialized or non-offloaded raw shuffle descriptors
* @return offloaded or non-offloaded serialized shuffle descriptors
*/
MaybeOffloaded<ShuffleDescriptorGroup> trySerializeAndOffloadShuffleDescriptor(
MaybeOffloaded<ShuffleDescriptorGroup> serializeAndTryOffloadShuffleDescriptor(
ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException;
}

Expand All @@ -502,24 +487,25 @@ public DefaultShuffleDescriptorSerializer(
}

@Override
public MaybeOffloaded<ShuffleDescriptorGroup> trySerializeAndOffloadShuffleDescriptor(
public MaybeOffloaded<ShuffleDescriptorGroup> serializeAndTryOffloadShuffleDescriptor(
ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException {

final Either<ShuffleDescriptorGroup, PermanentBlobKey> rawValueOrBlobKey =
shouldOffload(shuffleDescriptorGroup.getShuffleDescriptors(), numConsumer)
? BlobWriter.offloadWithException(
CompressedSerializedValue.fromObject(
shuffleDescriptorGroup),
jobID,
blobWriter)
.map(Either::<ShuffleDescriptorGroup, PermanentBlobKey>Right)
.orElse(Either.Left(shuffleDescriptorGroup))
: Either.Left(shuffleDescriptorGroup);

if (rawValueOrBlobKey.isLeft()) {
return new TaskDeploymentDescriptor.NonOffloadedRaw<>(rawValueOrBlobKey.left());
final CompressedSerializedValue<ShuffleDescriptorGroup> compressedSerializedValue =
CompressedSerializedValue.fromObject(shuffleDescriptorGroup);

final Either<SerializedValue<ShuffleDescriptorGroup>, PermanentBlobKey>
serializedValueOrBlobKey =
shouldOffload(
shuffleDescriptorGroup.getShuffleDescriptors(),
numConsumer)
? BlobWriter.offloadWithException(
compressedSerializedValue, jobID, blobWriter)
: Either.Left(compressedSerializedValue);

if (serializedValueOrBlobKey.isLeft()) {
return new TaskDeploymentDescriptor.NonOffloaded<>(serializedValueOrBlobKey.left());
} else {
return new TaskDeploymentDescriptor.Offloaded<>(rawValueOrBlobKey.right());
return new TaskDeploymentDescriptor.Offloaded<>(serializedValueOrBlobKey.right());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
Expand All @@ -39,10 +39,12 @@
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.CompressedSerializedValue;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -88,7 +90,7 @@ void testCreateAndGet() throws Exception {
assertThat(cachedShuffleDescriptors.getAllSerializedShuffleDescriptorGroups()).hasSize(1);
MaybeOffloaded<ShuffleDescriptorGroup> maybeOffloadedShuffleDescriptor =
cachedShuffleDescriptors.getAllSerializedShuffleDescriptorGroups().get(0);
assertNonOffloadedRawShuffleDescriptorAndIndexEquals(
assertNonOffloadedShuffleDescriptorAndIndexEquals(
maybeOffloadedShuffleDescriptor,
Collections.singletonList(shuffleDescriptor),
Collections.singletonList(0));
Expand Down Expand Up @@ -142,22 +144,26 @@ void testMarkPartitionFinishAndSerialize() throws Exception {
intermediateResultPartition2,
TaskDeploymentDescriptorFactory.PartitionLocationConstraint.MUST_BE_KNOWN,
false);
assertNonOffloadedRawShuffleDescriptorAndIndexEquals(
assertNonOffloadedShuffleDescriptorAndIndexEquals(
maybeOffloaded,
Arrays.asList(expectedShuffleDescriptor1, expectedShuffleDescriptor2),
Arrays.asList(0, 1));
}

private void assertNonOffloadedRawShuffleDescriptorAndIndexEquals(
private void assertNonOffloadedShuffleDescriptorAndIndexEquals(
MaybeOffloaded<ShuffleDescriptorGroup> maybeOffloaded,
List<ShuffleDescriptor> expectedDescriptors,
List<Integer> expectedIndices) {
List<Integer> expectedIndices)
throws Exception {
assertThat(expectedDescriptors).hasSameSizeAs(expectedIndices);
assertThat(maybeOffloaded).isInstanceOf(NonOffloadedRaw.class);
NonOffloadedRaw<ShuffleDescriptorGroup> nonOffloadedRaw =
(NonOffloadedRaw<ShuffleDescriptorGroup>) maybeOffloaded;
assertThat(maybeOffloaded).isInstanceOf(NonOffloaded.class);
NonOffloaded<ShuffleDescriptorGroup> nonOffloaded =
(NonOffloaded<ShuffleDescriptorGroup>) maybeOffloaded;
ShuffleDescriptorAndIndex[] shuffleDescriptorAndIndices =
nonOffloadedRaw.value.getShuffleDescriptors();
nonOffloaded
.serializedValue
.deserializeValue(getClass().getClassLoader())
.getShuffleDescriptors();
assertThat(shuffleDescriptorAndIndices).hasSameSizeAs(expectedDescriptors);
for (int i = 0; i < shuffleDescriptorAndIndices.length; i++) {
assertThat(shuffleDescriptorAndIndices[i].getIndex()).isEqualTo(expectedIndices.get(i));
Expand Down Expand Up @@ -212,9 +218,9 @@ private static class TestingShuffleDescriptorSerializer
implements TaskDeploymentDescriptorFactory.ShuffleDescriptorSerializer {

@Override
public MaybeOffloaded<ShuffleDescriptorGroup> trySerializeAndOffloadShuffleDescriptor(
ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) {
return new NonOffloadedRaw<>(shuffleDescriptorGroup);
public MaybeOffloaded<ShuffleDescriptorGroup> serializeAndTryOffloadShuffleDescriptor(
ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException {
return new NonOffloaded<>(CompressedSerializedValue.fromObject(shuffleDescriptorGroup));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.blob.TestingBlobWriter;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup;
Expand All @@ -47,8 +47,11 @@ public static ShuffleDescriptor[] deserializeShuffleDescriptors(
int maxIndex = 0;
for (MaybeOffloaded<ShuffleDescriptorGroup> sd : maybeOffloaded) {
ShuffleDescriptorGroup shuffleDescriptorGroup;
if (sd instanceof NonOffloadedRaw) {
shuffleDescriptorGroup = ((NonOffloadedRaw<ShuffleDescriptorGroup>) sd).value;
if (sd instanceof NonOffloaded) {
shuffleDescriptorGroup =
((NonOffloaded<ShuffleDescriptorGroup>) sd)
.serializedValue.deserializeValue(
ClassLoader.getSystemClassLoader());

} else {
final CompressedSerializedValue<ShuffleDescriptorGroup> compressedSerializedValue =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor;
import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor;
import org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder;
import org.apache.flink.util.CompressedSerializedValue;

import org.apache.flink.shaded.guava31.com.google.common.io.Closer;

Expand Down Expand Up @@ -1312,8 +1313,9 @@ partitionIds[2], createExecutionAttemptId())),
subpartitionIndexRange,
channelDescs.length,
Collections.singletonList(
new TaskDeploymentDescriptor.NonOffloadedRaw<>(
new ShuffleDescriptorGroup(channelDescs))));
new TaskDeploymentDescriptor.NonOffloaded<>(
CompressedSerializedValue.fromObject(
new ShuffleDescriptorGroup(channelDescs)))));

final TaskMetricGroup taskMetricGroup =
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup();
Expand Down