diff --git a/bson/src/main/org/bson/AbstractBsonWriter.java b/bson/src/main/org/bson/AbstractBsonWriter.java index a7cc978f8ba..8a8b238e8af 100644 --- a/bson/src/main/org/bson/AbstractBsonWriter.java +++ b/bson/src/main/org/bson/AbstractBsonWriter.java @@ -748,6 +748,15 @@ protected void throwInvalidState(final String methodName, final State... validSt methodName, validStatesString, state)); } + /** + * {@inheritDoc} + *

+ * The {@link #flush()} method of {@link AbstractBsonWriter} does nothing.

+ */ + @Override + public void flush() { + } + @Override public void close() { closed = true; diff --git a/bson/src/main/org/bson/BSONCallbackAdapter.java b/bson/src/main/org/bson/BSONCallbackAdapter.java index d00a2eaecbf..1d8b5ffe746 100644 --- a/bson/src/main/org/bson/BSONCallbackAdapter.java +++ b/bson/src/main/org/bson/BSONCallbackAdapter.java @@ -39,11 +39,6 @@ protected BSONCallbackAdapter(final BsonWriterSettings settings, final BSONCallb this.bsonCallback = bsonCallback; } - @Override - public void flush() { - //Looks like should be no-op? - } - @Override public void doWriteStartDocument() { BsonContextType contextType = getState() == State.SCOPE_DOCUMENT diff --git a/bson/src/main/org/bson/BsonBinaryWriter.java b/bson/src/main/org/bson/BsonBinaryWriter.java index d9301fd5cb3..e6255ea8478 100644 --- a/bson/src/main/org/bson/BsonBinaryWriter.java +++ b/bson/src/main/org/bson/BsonBinaryWriter.java @@ -108,10 +108,6 @@ public BsonBinaryWriterSettings getBinaryWriterSettings() { return binaryWriterSettings; } - @Override - public void flush() { - } - @Override protected Context getContext() { return (Context) super.getContext(); diff --git a/bson/src/main/org/bson/BsonDocumentWriter.java b/bson/src/main/org/bson/BsonDocumentWriter.java index 7c36a368336..a34188645cd 100644 --- a/bson/src/main/org/bson/BsonDocumentWriter.java +++ b/bson/src/main/org/bson/BsonDocumentWriter.java @@ -194,10 +194,6 @@ public void doWriteUndefined() { write(new BsonUndefined()); } - @Override - public void flush() { - } - @Override protected Context getContext() { return (Context) super.getContext(); diff --git a/bson/src/main/org/bson/io/OutputBuffer.java b/bson/src/main/org/bson/io/OutputBuffer.java index 8793acad9a2..00f88cea706 100644 --- a/bson/src/main/org/bson/io/OutputBuffer.java +++ b/bson/src/main/org/bson/io/OutputBuffer.java @@ -41,6 +41,16 @@ public void write(final byte[] b) { public void close() { } + /** + * {@inheritDoc} + *

+ * The {@link #flush()} method of {@link OutputBuffer} does nothing.

+ */ + @Override + public void flush() throws IOException { + super.flush(); + } + @Override public void write(final byte[] bytes, final int offset, final int length) { writeBytes(bytes, offset, length); diff --git a/driver-core/src/main/com/mongodb/internal/connection/AsyncConnection.java b/driver-core/src/main/com/mongodb/internal/connection/AsyncConnection.java index 2891bc28732..befc0d9aac2 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AsyncConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AsyncConnection.java @@ -50,8 +50,7 @@ void commandAsync(String database, BsonDocument command, FieldNameValidator void commandAsync(String database, BsonDocument command, FieldNameValidator commandFieldNameValidator, @Nullable ReadPreference readPreference, Decoder commandResultDecoder, - OperationContext operationContext, boolean responseExpected, @Nullable SplittablePayload payload, - @Nullable FieldNameValidator payloadFieldNameValidator, SingleResultCallback callback); + OperationContext operationContext, boolean responseExpected, MessageSequences sequences, SingleResultCallback callback); void markAsPinned(Connection.PinningMode pinningMode); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/BsonWriterHelper.java b/driver-core/src/main/com/mongodb/internal/connection/BsonWriterHelper.java index 12c006c7b97..63ccbf62a04 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/BsonWriterHelper.java +++ b/driver-core/src/main/com/mongodb/internal/connection/BsonWriterHelper.java @@ -16,31 +16,56 @@ package com.mongodb.internal.connection; +import com.mongodb.internal.connection.DualMessageSequences.EncodeDocumentsResult; +import com.mongodb.internal.connection.DualMessageSequences.WritersProviderAndLimitsChecker; +import com.mongodb.lang.Nullable; +import org.bson.BsonBinaryWriter; +import org.bson.BsonBinaryWriterSettings; +import org.bson.BsonContextType; import org.bson.BsonDocument; import org.bson.BsonElement; import org.bson.BsonMaximumSizeExceededException; import org.bson.BsonValue; import org.bson.BsonWriter; +import org.bson.BsonWriterSettings; +import org.bson.FieldNameValidator; import org.bson.codecs.BsonValueCodecProvider; -import org.bson.codecs.Codec; +import org.bson.codecs.Encoder; import org.bson.codecs.EncoderContext; import org.bson.codecs.configuration.CodecRegistry; import org.bson.io.BsonOutput; import java.util.List; +import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.connection.DualMessageSequences.WritersProviderAndLimitsChecker.WriteResult.FAIL_LIMIT_EXCEEDED; +import static com.mongodb.internal.connection.DualMessageSequences.WritersProviderAndLimitsChecker.WriteResult.OK_LIMIT_NOT_REACHED; +import static com.mongodb.internal.connection.DualMessageSequences.WritersProviderAndLimitsChecker.WriteResult.OK_LIMIT_REACHED; +import static com.mongodb.internal.connection.MessageSettings.DOCUMENT_HEADROOM_SIZE; import static java.lang.String.format; import static org.bson.codecs.configuration.CodecRegistries.fromProviders; -final class BsonWriterHelper { - private static final int DOCUMENT_HEADROOM = 1024 * 16; +/** + * This class is not part of the public API and may be removed or changed at any time. + */ +public final class BsonWriterHelper { private static final CodecRegistry REGISTRY = fromProviders(new BsonValueCodecProvider()); private static final EncoderContext ENCODER_CONTEXT = EncoderContext.builder().build(); - static void writeElements(final BsonWriter writer, final List bsonElements) { - for (BsonElement bsonElement : bsonElements) { - writer.writeName(bsonElement.getName()); - getCodec(bsonElement.getValue()).encode(writer, bsonElement.getValue(), ENCODER_CONTEXT); + static void appendElementsToDocument( + final BsonOutput bsonOutputWithDocument, + final int documentStartPosition, + @Nullable final List bsonElements) { + if ((bsonElements == null) || bsonElements.isEmpty()) { + return; + } + try (AppendingBsonWriter writer = new AppendingBsonWriter(bsonOutputWithDocument, documentStartPosition)) { + for (BsonElement element : bsonElements) { + String name = element.getName(); + BsonValue value = element.getValue(); + writer.writeName(name); + encodeUsingRegistry(writer, value); + } } } @@ -65,16 +90,86 @@ static void writePayload(final BsonWriter writer, final BsonOutput bsonOutput, f } if (payload.getPosition() == 0) { - throw new BsonMaximumSizeExceededException(format("Payload document size is larger than maximum of %d.", - payloadSettings.getMaxDocumentSize())); + throw createBsonMaximumSizeExceededException(payloadSettings.getMaxDocumentSize()); } } + /** + * @return See {@link DualMessageSequences#encodeDocuments(WritersProviderAndLimitsChecker)}. + */ + static EncodeDocumentsResult writeDocumentsOfDualMessageSequences( + final DualMessageSequences dualMessageSequences, + final int commandDocumentSizeInBytes, + final BsonOutput firstOutput, + final BsonOutput secondOutput, + final MessageSettings messageSettings) { + BsonBinaryWriter firstWriter = createBsonBinaryWriter(firstOutput, dualMessageSequences.getFirstFieldNameValidator(), null); + BsonBinaryWriter secondWriter = createBsonBinaryWriter(secondOutput, dualMessageSequences.getSecondFieldNameValidator(), null); + // the size of operation-agnostic command fields (a.k.a. extra elements) is counted towards `messageOverheadInBytes` + int messageOverheadInBytes = 1000; + int maxSizeInBytes = messageSettings.getMaxMessageSize() - (messageOverheadInBytes + commandDocumentSizeInBytes); + int firstStart = firstOutput.getPosition(); + int secondStart = secondOutput.getPosition(); + int maxBatchCount = messageSettings.getMaxBatchCount(); + return dualMessageSequences.encodeDocuments(writeAction -> { + int firstBeforeWritePosition = firstOutput.getPosition(); + int secondBeforeWritePosition = secondOutput.getPosition(); + int batchCountAfterWrite = writeAction.doAndGetBatchCount(firstWriter, secondWriter); + assertTrue(batchCountAfterWrite <= maxBatchCount); + int writtenSizeInBytes = + firstOutput.getPosition() - firstStart + + secondOutput.getPosition() - secondStart; + if (writtenSizeInBytes < maxSizeInBytes && batchCountAfterWrite < maxBatchCount) { + return OK_LIMIT_NOT_REACHED; + } else if (writtenSizeInBytes > maxSizeInBytes) { + firstOutput.truncateToPosition(firstBeforeWritePosition); + secondOutput.truncateToPosition(secondBeforeWritePosition); + if (batchCountAfterWrite == 1) { + // we have failed to write a single document + throw createBsonMaximumSizeExceededException(messageSettings.getMaxDocumentSize()); + } + return FAIL_LIMIT_EXCEEDED; + } else { + return OK_LIMIT_REACHED; + } + }); + } + + /** + * @param messageSettings Non-{@code null} iff the document size limit must be validated. + */ + static BsonBinaryWriter createBsonBinaryWriter( + final BsonOutput out, + final FieldNameValidator validator, + @Nullable final MessageSettings messageSettings) { + return new BsonBinaryWriter( + new BsonWriterSettings(), + messageSettings == null + ? new BsonBinaryWriterSettings() + : new BsonBinaryWriterSettings(messageSettings.getMaxDocumentSize() + DOCUMENT_HEADROOM_SIZE), + out, + validator); + } + + /** + * Backpatches the document/message/sequence length into the beginning of the document/message/sequence. + * + * @param startPosition The start position of the document/message/sequence in {@code bsonOutput}. + */ + static void backpatchLength(final int startPosition, final BsonOutput bsonOutput) { + int messageLength = bsonOutput.getPosition() - startPosition; + bsonOutput.writeInt32(startPosition, messageLength); + } + + private static BsonMaximumSizeExceededException createBsonMaximumSizeExceededException(final int maxSize) { + return new BsonMaximumSizeExceededException(format("Payload document size is larger than maximum of %d.", maxSize)); + } + private static boolean writeDocument(final BsonWriter writer, final BsonOutput bsonOutput, final MessageSettings settings, final BsonDocument document, final int messageStartPosition, final int batchItemCount, final int maxSplittableDocumentSize) { int currentPosition = bsonOutput.getPosition(); - getCodec(document).encode(writer, document, ENCODER_CONTEXT); + encodeUsingRegistry(writer, document); int messageSize = bsonOutput.getPosition() - messageStartPosition; int documentSize = bsonOutput.getPosition() - currentPosition; if (exceedsLimits(settings, messageSize, documentSize, batchItemCount) @@ -85,16 +180,17 @@ private static boolean writeDocument(final BsonWriter writer, final BsonOutput b return true; } - @SuppressWarnings({"unchecked"}) - private static Codec getCodec(final BsonValue bsonValue) { - return (Codec) REGISTRY.get(bsonValue.getClass()); + static void encodeUsingRegistry(final BsonWriter writer, final BsonValue value) { + @SuppressWarnings("unchecked") + Encoder encoder = (Encoder) REGISTRY.get(value.getClass()); + encoder.encode(writer, value, ENCODER_CONTEXT); } private static MessageSettings getPayloadMessageSettings(final SplittablePayload.Type type, final MessageSettings settings) { MessageSettings payloadMessageSettings = settings; if (type != SplittablePayload.Type.INSERT) { payloadMessageSettings = createMessageSettingsBuilder(settings) - .maxDocumentSize(settings.getMaxDocumentSize() + DOCUMENT_HEADROOM) + .maxDocumentSize(settings.getMaxDocumentSize() + DOCUMENT_HEADROOM_SIZE) .build(); } return payloadMessageSettings; @@ -102,7 +198,7 @@ private static MessageSettings getPayloadMessageSettings(final SplittablePayload private static MessageSettings getDocumentMessageSettings(final MessageSettings settings) { return createMessageSettingsBuilder(settings) - .maxMessageSize(settings.getMaxDocumentSize() + DOCUMENT_HEADROOM) + .maxMessageSize(settings.getMaxDocumentSize() + DOCUMENT_HEADROOM_SIZE) .build(); } @@ -126,8 +222,50 @@ private static boolean exceedsLimits(final MessageSettings settings, final int m return false; } + /** + * A {@link BsonWriter} that allows appending key/value pairs to a document that has been fully written to a {@link BsonOutput}. + */ + private static final class AppendingBsonWriter extends LevelCountingBsonWriter implements AutoCloseable { + private static final int INITIAL_LEVEL = DEFAULT_INITIAL_LEVEL + 1; - private BsonWriterHelper() { + /** + * @param bsonOutputWithDocument A {@link BsonOutput} {@linkplain BsonOutput#getPosition() positioned} + * immediately after the end of the document. + * @param documentStartPosition The {@linkplain BsonOutput#getPosition() position} of the start of the document + * in {@code bsonOutputWithDocument}. + */ + AppendingBsonWriter(final BsonOutput bsonOutputWithDocument, final int documentStartPosition) { + super( + new InternalAppendingBsonBinaryWriter(bsonOutputWithDocument, documentStartPosition), + INITIAL_LEVEL); + } + + @Override + public void writeEndDocument() { + assertTrue(getCurrentLevel() > INITIAL_LEVEL); + super.writeEndDocument(); + } + + @Override + public void close() { + try (InternalAppendingBsonBinaryWriter writer = (InternalAppendingBsonBinaryWriter) getBsonWriter()) { + writer.writeEndDocument(); + } + } + + private static final class InternalAppendingBsonBinaryWriter extends BsonBinaryWriter { + InternalAppendingBsonBinaryWriter(final BsonOutput bsonOutputWithDocument, final int documentStartPosition) { + super(bsonOutputWithDocument); + int documentEndPosition = bsonOutputWithDocument.getPosition(); + int bsonDocumentEndingSize = 1; + int appendFromPosition = documentEndPosition - bsonDocumentEndingSize; + bsonOutputWithDocument.truncateToPosition(appendFromPosition); + setState(State.NAME); + setContext(new Context(null, BsonContextType.DOCUMENT, documentStartPosition)); + } + } } + private BsonWriterHelper() { + } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java index 5cd2000d879..40df1b867fd 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.List; +import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.notNull; /** @@ -52,6 +53,19 @@ public ByteBufferBsonOutput(final BufferProvider bufferProvider) { this.bufferProvider = notNull("bufferProvider", bufferProvider); } + /** + * Creates a new empty {@link ByteBufferBsonOutput.Branch}, + * which gets merged into this {@link ByteBufferBsonOutput} on {@link ByteBufferBsonOutput.Branch#close()} + * by appending its data without copying it. + * If multiple branches are created, they are merged in the order they are {@linkplain ByteBufferBsonOutput.Branch#close() closed}. + * {@linkplain #close() Closing} this {@link ByteBufferBsonOutput} does not {@linkplain ByteBufferBsonOutput.Branch#close() close} the branch. + * + * @return A new {@link ByteBufferBsonOutput.Branch}. + */ + public ByteBufferBsonOutput.Branch branch() { + return new ByteBufferBsonOutput.Branch(this); + } + @Override public void writeBytes(final byte[] bytes, final int offset, final int length) { ensureOpen(); @@ -156,7 +170,9 @@ public int pipe(final OutputStream out) throws IOException { @Override public void truncateToPosition(final int newPosition) { ensureOpen(); - + if (newPosition == position) { + return; + } if (newPosition > position || newPosition < 0) { throw new IllegalArgumentException(); } @@ -174,36 +190,89 @@ public void truncateToPosition(final int newPosition) { position = newPosition; } + /** + * The {@link #flush()} method of {@link ByteBufferBsonOutput} and of its subclasses does nothing.

+ */ + @Override + public final void flush() throws IOException { + } + + /** + * {@inheritDoc} + *

+ * Idempotent.

+ */ @Override public void close() { - for (final ByteBuf cur : bufferList) { - cur.release(); + if (isOpen()) { + for (final ByteBuf cur : bufferList) { + cur.release(); + } + bufferList.clear(); + closed = true; } - bufferList.clear(); - closed = true; } private BufferPositionPair getBufferPositionPair(final int absolutePosition) { int positionInBuffer = absolutePosition; int bufferIndex = 0; - int bufferSize = INITIAL_BUFFER_SIZE; + int bufferSize = bufferList.get(bufferIndex).position(); int startPositionOfBuffer = 0; while (startPositionOfBuffer + bufferSize <= absolutePosition) { bufferIndex++; startPositionOfBuffer += bufferSize; positionInBuffer -= bufferSize; - bufferSize = bufferList.get(bufferIndex).limit(); + bufferSize = bufferList.get(bufferIndex).position(); } return new BufferPositionPair(bufferIndex, positionInBuffer); } private void ensureOpen() { - if (closed) { + if (!isOpen()) { throw new IllegalStateException("The output is closed"); } } + boolean isOpen() { + return !closed; + } + + /** + * @see #branch() + */ + private void merge(final ByteBufferBsonOutput branch) { + assertTrue(branch instanceof ByteBufferBsonOutput.Branch); + branch.bufferList.forEach(ByteBuf::retain); + bufferList.addAll(branch.bufferList); + curBufferIndex += branch.curBufferIndex + 1; + position += branch.position; + } + + public static final class Branch extends ByteBufferBsonOutput { + private final ByteBufferBsonOutput parent; + + private Branch(final ByteBufferBsonOutput parent) { + super(parent.bufferProvider); + this.parent = parent; + } + + /** + * @see #branch() + */ + @Override + public void close() { + if (isOpen()) { + try { + assertTrue(parent.isOpen()); + parent.merge(this); + } finally { + super.close(); + } + } + } + } + private static final class BufferPositionPair { private final int bufferIndex; private int position; diff --git a/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java b/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java index bac2a86e61d..46eabab21bb 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java @@ -23,6 +23,7 @@ import com.mongodb.ServerApi; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.connection.MessageSequences.EmptyMessageSequences; import com.mongodb.internal.session.SessionContext; import com.mongodb.lang.Nullable; import org.bson.BsonArray; @@ -34,7 +35,6 @@ import org.bson.BsonString; import org.bson.ByteBuf; import org.bson.FieldNameValidator; -import org.bson.io.BsonOutput; import java.io.ByteArrayOutputStream; import java.io.UnsupportedEncodingException; @@ -45,12 +45,17 @@ import static com.mongodb.ReadPreference.primary; import static com.mongodb.ReadPreference.primaryPreferred; import static com.mongodb.assertions.Assertions.assertFalse; +import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.assertions.Assertions.fail; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.connection.ClusterConnectionMode.LOAD_BALANCED; import static com.mongodb.connection.ClusterConnectionMode.SINGLE; import static com.mongodb.connection.ServerType.SHARD_ROUTER; import static com.mongodb.connection.ServerType.STANDALONE; +import static com.mongodb.internal.connection.BsonWriterHelper.appendElementsToDocument; +import static com.mongodb.internal.connection.BsonWriterHelper.backpatchLength; +import static com.mongodb.internal.connection.BsonWriterHelper.writeDocumentsOfDualMessageSequences; import static com.mongodb.internal.connection.BsonWriterHelper.writePayload; import static com.mongodb.internal.connection.ByteBufBsonDocument.createList; import static com.mongodb.internal.connection.ByteBufBsonDocument.createOne; @@ -64,43 +69,57 @@ *

This class is not part of the public API and may be removed or changed at any time

*/ public final class CommandMessage extends RequestMessage { + /** + * Specifies that the `OP_MSG` section payload is a BSON document. + */ + private static final byte PAYLOAD_TYPE_0_DOCUMENT = 0; + /** + * Specifies that the `OP_MSG` section payload is a sequence of BSON documents. + */ + private static final byte PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE = 1; + private final MongoNamespace namespace; private final BsonDocument command; private final FieldNameValidator commandFieldNameValidator; private final ReadPreference readPreference; private final boolean exhaustAllowed; - private final SplittablePayload payload; - private final FieldNameValidator payloadFieldNameValidator; + private final MessageSequences sequences; private final boolean responseExpected; + /** + * {@code null} iff either {@link #sequences} is not of the {@link DualMessageSequences} type, + * or it is of that type, but it has not been {@linkplain #encodeMessageBodyWithMetadata(ByteBufferBsonOutput, OperationContext) encoded}. + */ + @Nullable + private Boolean dualMessageSequencesRequireResponse; private final ClusterConnectionMode clusterConnectionMode; private final ServerApi serverApi; CommandMessage(final MongoNamespace namespace, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, final ReadPreference readPreference, final MessageSettings settings, final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { - this(namespace, command, commandFieldNameValidator, readPreference, settings, true, null, null, + this(namespace, command, commandFieldNameValidator, readPreference, settings, true, EmptyMessageSequences.INSTANCE, clusterConnectionMode, serverApi); } CommandMessage(final MongoNamespace namespace, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, final ReadPreference readPreference, final MessageSettings settings, final boolean exhaustAllowed, final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { - this(namespace, command, commandFieldNameValidator, readPreference, settings, true, exhaustAllowed, null, null, + this(namespace, command, commandFieldNameValidator, readPreference, settings, true, exhaustAllowed, EmptyMessageSequences.INSTANCE, clusterConnectionMode, serverApi); } CommandMessage(final MongoNamespace namespace, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, final ReadPreference readPreference, final MessageSettings settings, final boolean responseExpected, - @Nullable final SplittablePayload payload, @Nullable final FieldNameValidator payloadFieldNameValidator, + final MessageSequences sequences, final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { - this(namespace, command, commandFieldNameValidator, readPreference, settings, responseExpected, false, payload, - payloadFieldNameValidator, clusterConnectionMode, serverApi); + this(namespace, command, commandFieldNameValidator, readPreference, settings, responseExpected, false, + sequences, clusterConnectionMode, serverApi); } CommandMessage(final MongoNamespace namespace, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, final ReadPreference readPreference, final MessageSettings settings, final boolean responseExpected, final boolean exhaustAllowed, - @Nullable final SplittablePayload payload, @Nullable final FieldNameValidator payloadFieldNameValidator, + final MessageSequences sequences, final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { super(namespace.getFullName(), getOpCode(settings, clusterConnectionMode, serverApi), settings); this.namespace = namespace; @@ -108,9 +127,9 @@ public final class CommandMessage extends RequestMessage { this.commandFieldNameValidator = commandFieldNameValidator; this.readPreference = readPreference; this.responseExpected = responseExpected; + this.dualMessageSequencesRequireResponse = null; this.exhaustAllowed = exhaustAllowed; - this.payload = payload; - this.payloadFieldNameValidator = payloadFieldNameValidator; + this.sequences = sequences; this.clusterConnectionMode = notNull("clusterConnectionMode", clusterConnectionMode); this.serverApi = serverApi; assertTrue(useOpMsg() || responseExpected); @@ -119,8 +138,8 @@ public final class CommandMessage extends RequestMessage { /** * Create a BsonDocument representing the logical document encoded by an OP_MSG. *

- * The returned document will contain all the fields from the Body (Kind 0) Section, as well as all fields represented by - * OP_MSG Document Sequence (Kind 1) Sections. + * The returned document will contain all the fields from the `PAYLOAD_TYPE_0_DOCUMENT` section, as well as all fields represented by + * `PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE` sections. */ BsonDocument getCommandDocument(final ByteBufferBsonOutput bsonOutput) { List byteBuffers = bsonOutput.getByteBuffers(); @@ -130,14 +149,14 @@ BsonDocument getCommandDocument(final ByteBufferBsonOutput bsonOutput) { byteBuf.position(getEncodingMetadata().getFirstDocumentPosition()); ByteBufBsonDocument byteBufBsonDocument = createOne(byteBuf); - // If true, it means there is at least one Kind 1:Document Sequence in the OP_MSG + // If true, it means there is at least one `PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE` section in the OP_MSG if (byteBuf.hasRemaining()) { BsonDocument commandBsonDocument = byteBufBsonDocument.toBaseBsonDocument(); // Each loop iteration processes one Document Sequence // When there are no more bytes remaining, there are no more Document Sequences while (byteBuf.hasRemaining()) { - // skip reading the payload type, we know it is 1 + // skip reading the payload type, we know it is `PAYLOAD_TYPE_1` byteBuf.position(byteBuf.position() + 1); int sequenceStart = byteBuf.position(); int sequenceSizeInBytes = byteBuf.getInt(); @@ -170,7 +189,7 @@ BsonDocument getCommandDocument(final ByteBufferBsonOutput bsonOutput) { /** * Get the field name from a buffer positioned at the start of the document sequence identifier of an OP_MSG Section of type - * Document Sequence (Kind 1). + * `PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE`. *

* Upon normal completion of the method, the buffer will be positioned at the start of the first BSON object in the sequence. */ @@ -192,7 +211,15 @@ boolean isResponseExpected() { if (responseExpected) { return true; } else { - return payload != null && payload.isOrdered() && payload.hasAnotherSplit(); + if (sequences instanceof SplittablePayload) { + SplittablePayload payload = (SplittablePayload) sequences; + return payload.isOrdered() && payload.hasAnotherSplit(); + } else if (sequences instanceof DualMessageSequences) { + return assertNotNull(dualMessageSequencesRequireResponse); + } else if (!(sequences instanceof EmptyMessageSequences)) { + fail(sequences.toString()); + } + return false; } } @@ -201,47 +228,74 @@ MongoNamespace getNamespace() { } @Override - protected EncodingMetadata encodeMessageBodyWithMetadata(final BsonOutput bsonOutput, final OperationContext operationContext) { + protected EncodingMetadata encodeMessageBodyWithMetadata(final ByteBufferBsonOutput bsonOutput, final OperationContext operationContext) { + int commandStartPosition = useOpMsg() ? writeOpMsg(bsonOutput, operationContext) : writeOpQuery(bsonOutput); + return new EncodingMetadata(commandStartPosition); + } + + @SuppressWarnings("try") + private int writeOpMsg(final ByteBufferBsonOutput bsonOutput, final OperationContext operationContext) { int messageStartPosition = bsonOutput.getPosition() - MESSAGE_PROLOGUE_LENGTH; - int commandStartPosition; - if (useOpMsg()) { - int flagPosition = bsonOutput.getPosition(); - bsonOutput.writeInt32(0); // flag bits - bsonOutput.writeByte(0); // payload type - commandStartPosition = bsonOutput.getPosition(); - - addDocument(command, bsonOutput, commandFieldNameValidator, getExtraElements(operationContext)); - - if (payload != null) { - bsonOutput.writeByte(1); // payload type - int payloadBsonOutputStartPosition = bsonOutput.getPosition(); - bsonOutput.writeInt32(0); // size - bsonOutput.writeCString(payload.getPayloadName()); - writePayload(new BsonBinaryWriter(bsonOutput, payloadFieldNameValidator), bsonOutput, getSettings(), - messageStartPosition, payload, getSettings().getMaxDocumentSize()); - - int payloadBsonOutputLength = bsonOutput.getPosition() - payloadBsonOutputStartPosition; - bsonOutput.writeInt32(payloadBsonOutputStartPosition, payloadBsonOutputLength); + int flagPosition = bsonOutput.getPosition(); + bsonOutput.writeInt32(0); // flag bits + bsonOutput.writeByte(PAYLOAD_TYPE_0_DOCUMENT); + int commandStartPosition = bsonOutput.getPosition(); + List extraElements = getExtraElements(operationContext); + + int commandDocumentSizeInBytes = writeDocument(command, bsonOutput, commandFieldNameValidator); + if (sequences instanceof SplittablePayload) { + appendElementsToDocument(bsonOutput, commandStartPosition, extraElements); + SplittablePayload payload = (SplittablePayload) sequences; + try (FinishOpMsgSectionWithPayloadType1 finishSection = startOpMsgSectionWithPayloadType1( + bsonOutput, payload.getPayloadName())) { + writePayload( + new BsonBinaryWriter(bsonOutput, payload.getFieldNameValidator()), + bsonOutput, getSettings(), messageStartPosition, payload, getSettings().getMaxDocumentSize()); } - - // Write the flag bits - bsonOutput.writeInt32(flagPosition, getOpMsgFlagBits()); + } else if (sequences instanceof DualMessageSequences) { + DualMessageSequences dualMessageSequences = (DualMessageSequences) sequences; + try (ByteBufferBsonOutput.Branch bsonOutputBranch2 = bsonOutput.branch(); + ByteBufferBsonOutput.Branch bsonOutputBranch1 = bsonOutput.branch()) { + DualMessageSequences.EncodeDocumentsResult encodeDocumentsResult; + try (FinishOpMsgSectionWithPayloadType1 finishSection1 = startOpMsgSectionWithPayloadType1( + bsonOutputBranch1, dualMessageSequences.getFirstSequenceId()); + FinishOpMsgSectionWithPayloadType1 finishSection2 = startOpMsgSectionWithPayloadType1( + bsonOutputBranch2, dualMessageSequences.getSecondSequenceId())) { + encodeDocumentsResult = writeDocumentsOfDualMessageSequences( + dualMessageSequences, commandDocumentSizeInBytes, bsonOutputBranch1, + bsonOutputBranch2, getSettings()); + } + dualMessageSequencesRequireResponse = encodeDocumentsResult.isServerResponseRequired(); + extraElements.addAll(encodeDocumentsResult.getExtraElements()); + appendElementsToDocument(bsonOutput, commandStartPosition, extraElements); + } + } else if (sequences instanceof EmptyMessageSequences) { + appendElementsToDocument(bsonOutput, commandStartPosition, extraElements); } else { - bsonOutput.writeInt32(0); - bsonOutput.writeCString(namespace.getFullName()); - bsonOutput.writeInt32(0); - bsonOutput.writeInt32(-1); + fail(sequences.toString()); + } + + // Write the flag bits + bsonOutput.writeInt32(flagPosition, getOpMsgFlagBits()); + return commandStartPosition; + } - commandStartPosition = bsonOutput.getPosition(); + private int writeOpQuery(final ByteBufferBsonOutput bsonOutput) { + bsonOutput.writeInt32(0); + bsonOutput.writeCString(namespace.getFullName()); + bsonOutput.writeInt32(0); + bsonOutput.writeInt32(-1); - List elements = null; - if (serverApi != null) { - elements = new ArrayList<>(3); - addServerApiElements(elements); - } - addDocument(command, bsonOutput, commandFieldNameValidator, elements); + int commandStartPosition = bsonOutput.getPosition(); + + List elements = null; + if (serverApi != null) { + elements = new ArrayList<>(3); + addServerApiElements(elements); } - return new EncodingMetadata(commandStartPosition); + writeDocument(command, bsonOutput, commandFieldNameValidator); + appendElementsToDocument(bsonOutput, commandStartPosition, elements); + return commandStartPosition; } private int getOpMsgFlagBits() { @@ -269,7 +323,7 @@ private List getExtraElements(final OperationContext operationConte SessionContext sessionContext = operationContext.getSessionContext(); TimeoutContext timeoutContext = operationContext.getTimeoutContext(); - List extraElements = new ArrayList<>(); + ArrayList extraElements = new ArrayList<>(); if (!getSettings().isCryptd()) { timeoutContext.runMaxTimeMS(maxTimeMS -> extraElements.add(new BsonElement("maxTimeMS", new BsonInt64(maxTimeMS))) @@ -341,6 +395,19 @@ private void addReadConcernDocument(final List extraElements, final } } + /** + * @param sequenceId The identifier of the sequence contained in the {@code OP_MSG} section to be written. + * @see OP_MSG + */ + private FinishOpMsgSectionWithPayloadType1 startOpMsgSectionWithPayloadType1(final ByteBufferBsonOutput bsonOutput, final String sequenceId) { + bsonOutput.writeByte(PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE); + int sequenceStart = bsonOutput.getPosition(); + // size to be patched back later + bsonOutput.writeInt32(0); + bsonOutput.writeCString(sequenceId); + return () -> backpatchLength(sequenceStart, bsonOutput); + } + private static OpCode getOpCode(final MessageSettings settings, final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { return isServerVersionKnown(settings) || clusterConnectionMode == LOAD_BALANCED || serverApi != null @@ -351,4 +418,9 @@ private static OpCode getOpCode(final MessageSettings settings, final ClusterCon private static boolean isServerVersionKnown(final MessageSettings settings) { return settings.getMaxWireVersion() >= FOUR_DOT_ZERO_WIRE_VERSION; } + + @FunctionalInterface + private interface FinishOpMsgSectionWithPayloadType1 extends AutoCloseable { + void close(); + } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/CommandProtocolImpl.java b/driver-core/src/main/com/mongodb/internal/connection/CommandProtocolImpl.java index de9e0666d40..eb4d6d49516 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CommandProtocolImpl.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CommandProtocolImpl.java @@ -26,17 +26,15 @@ import org.bson.FieldNameValidator; import org.bson.codecs.Decoder; -import static com.mongodb.assertions.Assertions.isTrueArgument; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.connection.ProtocolHelper.getMessageSettings; class CommandProtocolImpl implements CommandProtocol { private final MongoNamespace namespace; private final BsonDocument command; - private final SplittablePayload payload; + private final MessageSequences sequences; private final ReadPreference readPreference; private final FieldNameValidator commandFieldNameValidator; - private final FieldNameValidator payloadFieldNameValidator; private final Decoder commandResultDecoder; private final boolean responseExpected; private final ClusterConnectionMode clusterConnectionMode; @@ -44,8 +42,7 @@ class CommandProtocolImpl implements CommandProtocol { CommandProtocolImpl(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, final boolean responseExpected, - @Nullable final SplittablePayload payload, @Nullable final FieldNameValidator payloadFieldNameValidator, - final ClusterConnectionMode clusterConnectionMode, final OperationContext operationContext) { + final MessageSequences sequences, final ClusterConnectionMode clusterConnectionMode, final OperationContext operationContext) { notNull("database", database); this.namespace = new MongoNamespace(notNull("database", database), MongoNamespace.COMMAND_COLLECTION_NAME); this.command = notNull("command", command); @@ -53,13 +50,9 @@ class CommandProtocolImpl implements CommandProtocol { this.readPreference = readPreference; this.commandResultDecoder = notNull("commandResultDecoder", commandResultDecoder); this.responseExpected = responseExpected; - this.payload = payload; - this.payloadFieldNameValidator = payloadFieldNameValidator; + this.sequences = sequences; this.clusterConnectionMode = notNull("clusterConnectionMode", clusterConnectionMode); this.operationContext = operationContext; - - isTrueArgument("payloadFieldNameValidator cannot be null if there is a payload.", - payload == null || payloadFieldNameValidator != null); } @Nullable @@ -87,13 +80,13 @@ public void executeAsync(final InternalConnection connection, final SingleResult @Override public CommandProtocolImpl withSessionContext(final SessionContext sessionContext) { return new CommandProtocolImpl<>(namespace.getDatabaseName(), command, commandFieldNameValidator, readPreference, - commandResultDecoder, responseExpected, payload, payloadFieldNameValidator, clusterConnectionMode, + commandResultDecoder, responseExpected, sequences, clusterConnectionMode, operationContext.withSessionContext(sessionContext)); } private CommandMessage getCommandMessage(final InternalConnection connection) { return new CommandMessage(namespace, command, commandFieldNameValidator, readPreference, - getMessageSettings(connection.getDescription(), connection.getInitialServerDescription()), responseExpected, payload, - payloadFieldNameValidator, clusterConnectionMode, operationContext.getServerApi()); + getMessageSettings(connection.getDescription(), connection.getInitialServerDescription()), responseExpected, + sequences, clusterConnectionMode, operationContext.getServerApi()); } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/CompressedMessage.java b/driver-core/src/main/com/mongodb/internal/connection/CompressedMessage.java index 9880ef3fb0b..6764135daa1 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CompressedMessage.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CompressedMessage.java @@ -17,7 +17,6 @@ package com.mongodb.internal.connection; import org.bson.ByteBuf; -import org.bson.io.BsonOutput; import java.util.List; @@ -37,7 +36,7 @@ class CompressedMessage extends RequestMessage { } @Override - protected EncodingMetadata encodeMessageBodyWithMetadata(final BsonOutput bsonOutput, final OperationContext operationContext) { + protected EncodingMetadata encodeMessageBodyWithMetadata(final ByteBufferBsonOutput bsonOutput, final OperationContext operationContext) { bsonOutput.writeInt32(wrappedOpcode.getValue()); bsonOutput.writeInt32(getWrappedMessageSize(wrappedMessageBuffers) - MESSAGE_HEADER_LENGTH); bsonOutput.writeByte(compressor.getId()); diff --git a/driver-core/src/main/com/mongodb/internal/connection/Connection.java b/driver-core/src/main/com/mongodb/internal/connection/Connection.java index 95094b240c1..219fb9ae6b9 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Connection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Connection.java @@ -51,7 +51,7 @@ T command(String database, BsonDocument command, FieldNameValidator fieldNam @Nullable T command(String database, BsonDocument command, FieldNameValidator commandFieldNameValidator, @Nullable ReadPreference readPreference, Decoder commandResultDecoder, OperationContext operationContext, - boolean responseExpected, @Nullable SplittablePayload payload, @Nullable FieldNameValidator payloadFieldNameValidator); + boolean responseExpected, MessageSequences sequences); enum PinningMode { diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java index 8f3d0f09fd9..008cdbefcb7 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java @@ -302,9 +302,9 @@ public T command(final String database, final BsonDocument command, final Fi public T command(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, final OperationContext operationContext, final boolean responseExpected, - @Nullable final SplittablePayload payload, @Nullable final FieldNameValidator payloadFieldNameValidator) { + final MessageSequences sequences) { return wrapped.command(database, command, commandFieldNameValidator, readPreference, commandResultDecoder, operationContext, - responseExpected, payload, payloadFieldNameValidator); + responseExpected, sequences); } @Override @@ -364,10 +364,10 @@ public void commandAsync(final String database, final BsonDocument command, @Override public void commandAsync(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, - final OperationContext operationContext, final boolean responseExpected, @Nullable final SplittablePayload payload, - @Nullable final FieldNameValidator payloadFieldNameValidator, final SingleResultCallback callback) { + final OperationContext operationContext, final boolean responseExpected, final MessageSequences sequences, + final SingleResultCallback callback) { wrapped.commandAsync(database, command, commandFieldNameValidator, readPreference, commandResultDecoder, - operationContext, responseExpected, payload, payloadFieldNameValidator, callback); + operationContext, responseExpected, sequences, callback); } @Override diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultServerConnection.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultServerConnection.java index 01d5f587fdc..143ef5b76ae 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultServerConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultServerConnection.java @@ -20,6 +20,7 @@ import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.connection.MessageSequences.EmptyMessageSequences; import com.mongodb.internal.diagnostics.logging.Logger; import com.mongodb.internal.diagnostics.logging.Loggers; import com.mongodb.internal.session.SessionContext; @@ -70,18 +71,17 @@ public ConnectionDescription getDescription() { @Override public T command(final String database, final BsonDocument command, final FieldNameValidator fieldNameValidator, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, final OperationContext operationContext) { - return command(database, command, fieldNameValidator, readPreference, commandResultDecoder, operationContext, true, null, null); + return command(database, command, fieldNameValidator, readPreference, commandResultDecoder, operationContext, true, EmptyMessageSequences.INSTANCE); } @Nullable @Override public T command(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, - final OperationContext operationContext, final boolean responseExpected, - @Nullable final SplittablePayload payload, @Nullable final FieldNameValidator payloadFieldNameValidator) { + final OperationContext operationContext, final boolean responseExpected, final MessageSequences sequences) { return executeProtocol( new CommandProtocolImpl<>(database, command, commandFieldNameValidator, readPreference, commandResultDecoder, - responseExpected, payload, payloadFieldNameValidator, clusterConnectionMode, operationContext), + responseExpected, sequences, clusterConnectionMode, operationContext), operationContext.getSessionContext()); } @@ -90,16 +90,15 @@ public void commandAsync(final String database, final BsonDocument command, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, final OperationContext operationContext, final SingleResultCallback callback) { commandAsync(database, command, fieldNameValidator, readPreference, commandResultDecoder, - operationContext, true, null, null, callback); + operationContext, true, EmptyMessageSequences.INSTANCE, callback); } @Override public void commandAsync(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, final OperationContext operationContext, - final boolean responseExpected, @Nullable final SplittablePayload payload, - @Nullable final FieldNameValidator payloadFieldNameValidator, final SingleResultCallback callback) { + final boolean responseExpected, final MessageSequences sequences, final SingleResultCallback callback) { executeProtocolAsync(new CommandProtocolImpl<>(database, command, commandFieldNameValidator, readPreference, - commandResultDecoder, responseExpected, payload, payloadFieldNameValidator, clusterConnectionMode, operationContext), + commandResultDecoder, responseExpected, sequences, clusterConnectionMode, operationContext), operationContext.getSessionContext(), callback); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/DualMessageSequences.java b/driver-core/src/main/com/mongodb/internal/connection/DualMessageSequences.java new file mode 100644 index 00000000000..0c5a3430c22 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/connection/DualMessageSequences.java @@ -0,0 +1,122 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import org.bson.BsonBinaryWriter; +import org.bson.BsonElement; +import org.bson.FieldNameValidator; + +import java.util.List; + +/** + * Two sequences that may either be coupled or independent. + *

+ * This class is not part of the public API and may be removed or changed at any time.

+ */ +public abstract class DualMessageSequences extends MessageSequences { + + private final String firstSequenceId; + private final FieldNameValidator firstFieldNameValidator; + private final String secondSequenceId; + private final FieldNameValidator secondFieldNameValidator; + + protected DualMessageSequences( + final String firstSequenceId, + final FieldNameValidator firstFieldNameValidator, + final String secondSequenceId, + final FieldNameValidator secondFieldNameValidator) { + this.firstSequenceId = firstSequenceId; + this.firstFieldNameValidator = firstFieldNameValidator; + this.secondSequenceId = secondSequenceId; + this.secondFieldNameValidator = secondFieldNameValidator; + } + + FieldNameValidator getFirstFieldNameValidator() { + return firstFieldNameValidator; + } + + FieldNameValidator getSecondFieldNameValidator() { + return secondFieldNameValidator; + } + + String getFirstSequenceId() { + return firstSequenceId; + } + + String getSecondSequenceId() { + return secondSequenceId; + } + + protected abstract EncodeDocumentsResult encodeDocuments(WritersProviderAndLimitsChecker writersProviderAndLimitsChecker); + + /** + * @see #tryWrite(WriteAction) + */ + public interface WritersProviderAndLimitsChecker { + /** + * Provides writers to the specified {@link WriteAction}, + * {@linkplain WriteAction#doAndGetBatchCount(BsonBinaryWriter, BsonBinaryWriter) executes} it, + * checks the {@linkplain MessageSettings limits}. + *

+ * May be called multiple times per {@link #encodeDocuments(WritersProviderAndLimitsChecker)}.

+ */ + WriteResult tryWrite(WriteAction write); + + /** + * @see #doAndGetBatchCount(BsonBinaryWriter, BsonBinaryWriter) + */ + interface WriteAction { + /** + * Writes documents to the sequences using the provided writers. + * + * @return The resulting batch count since the beginning of {@link #encodeDocuments(WritersProviderAndLimitsChecker)}. + * It is generally allowed to be greater than {@link MessageSettings#getMaxBatchCount()}. + */ + int doAndGetBatchCount(BsonBinaryWriter firstWriter, BsonBinaryWriter secondWriter); + } + + enum WriteResult { + FAIL_LIMIT_EXCEEDED, + OK_LIMIT_REACHED, + OK_LIMIT_NOT_REACHED + } + } + + public static final class EncodeDocumentsResult { + private final boolean serverResponseRequired; + private final List extraElements; + + /** + * @param extraElements See {@link #getExtraElements()}. + */ + public EncodeDocumentsResult(final boolean serverResponseRequired, final List extraElements) { + this.serverResponseRequired = serverResponseRequired; + this.extraElements = extraElements; + } + + boolean isServerResponseRequired() { + return serverResponseRequired; + } + + /** + * {@linkplain BsonElement Key/value pairs} to be added to the document contained in the {@code OP_MSG} section with payload type 0. + */ + List getExtraElements() { + return extraElements; + } + } +} diff --git a/driver-core/src/main/com/mongodb/internal/connection/ElementExtendingBsonWriter.java b/driver-core/src/main/com/mongodb/internal/connection/ElementExtendingBsonWriter.java deleted file mode 100644 index d0ed5234d50..00000000000 --- a/driver-core/src/main/com/mongodb/internal/connection/ElementExtendingBsonWriter.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.internal.connection; - -import org.bson.BsonBinaryWriter; -import org.bson.BsonElement; -import org.bson.BsonReader; - -import java.util.List; - -import static com.mongodb.internal.connection.BsonWriterHelper.writeElements; - -/** - *

This class is not part of the public API and may be removed or changed at any time

- */ -public class ElementExtendingBsonWriter extends LevelCountingBsonWriter { - private final BsonBinaryWriter writer; - private final List extraElements; - - - public ElementExtendingBsonWriter(final BsonBinaryWriter writer, final List extraElements) { - super(writer); - this.writer = writer; - this.extraElements = extraElements; - } - - @Override - public void writeEndDocument() { - if (getCurrentLevel() == 0) { - writeElements(writer, extraElements); - } - super.writeEndDocument(); - } - - @Override - public void pipe(final BsonReader reader) { - if (getCurrentLevel() == -1) { - writer.pipe(reader, extraElements); - } else { - writer.pipe(reader); - } - } -} diff --git a/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java b/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java index 4120dbdfb17..c73a5d4fd86 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java +++ b/driver-core/src/main/com/mongodb/internal/connection/IdHoldingBsonWriter.java @@ -92,11 +92,11 @@ public void writeStartDocument() { @Override public void writeEndDocument() { if (isWritingId()) { - if (getIdBsonWriterCurrentLevel() >= 0) { + if (getIdBsonWriterCurrentLevel() > DEFAULT_INITIAL_LEVEL) { getIdBsonWriter().writeEndDocument(); } - if (getIdBsonWriterCurrentLevel() == -1) { + if (getIdBsonWriterCurrentLevel() == DEFAULT_INITIAL_LEVEL) { if (id != null && id.isJavaScriptWithScope()) { id = new BsonJavaScriptWithScope(id.asJavaScriptWithScope().getCode(), new RawBsonDocument(getBytes())); } else if (id == null) { @@ -105,7 +105,7 @@ public void writeEndDocument() { } } - if (getCurrentLevel() == 0 && id == null) { + if (getCurrentLevel() == DEFAULT_INITIAL_LEVEL + 1 && id == null) { id = fallbackId == null ? new BsonObjectId() : fallbackId; writeObjectId(ID_FIELD_NAME, id.asObjectId().getValue()); } @@ -115,7 +115,7 @@ public void writeEndDocument() { @Override public void writeStartArray() { if (isWritingId()) { - if (getIdBsonWriterCurrentLevel() == -1) { + if (getIdBsonWriterCurrentLevel() == DEFAULT_INITIAL_LEVEL) { idFieldIsAnArray = true; getIdBsonWriter().writeStartDocument(); getIdBsonWriter().writeName(ID_FIELD_NAME); @@ -129,7 +129,7 @@ public void writeStartArray() { public void writeStartArray(final String name) { setCurrentFieldName(name); if (isWritingId()) { - if (getIdBsonWriterCurrentLevel() == -1) { + if (getIdBsonWriterCurrentLevel() == DEFAULT_INITIAL_LEVEL) { getIdBsonWriter().writeStartDocument(); } getIdBsonWriter().writeStartArray(name); @@ -141,7 +141,7 @@ public void writeStartArray(final String name) { public void writeEndArray() { if (isWritingId()) { getIdBsonWriter().writeEndArray(); - if (getIdBsonWriterCurrentLevel() == 0 && idFieldIsAnArray) { + if (getIdBsonWriterCurrentLevel() == DEFAULT_INITIAL_LEVEL + 1 && idFieldIsAnArray) { getIdBsonWriter().writeEndDocument(); id = new RawBsonDocument(getBytes()).get(ID_FIELD_NAME); } @@ -308,7 +308,7 @@ public void writeMinKey() { @Override public void writeName(final String name) { setCurrentFieldName(name); - if (getIdBsonWriterCurrentLevel() >= 0) { + if (getIdBsonWriterCurrentLevel() > DEFAULT_INITIAL_LEVEL) { getIdBsonWriter().writeName(name); } super.writeName(name); @@ -433,13 +433,13 @@ private void setCurrentFieldName(final String name) { } private boolean isWritingId() { - return getIdBsonWriterCurrentLevel() >= 0 || (getCurrentLevel() == 0 && currentFieldName != null + return getIdBsonWriterCurrentLevel() > DEFAULT_INITIAL_LEVEL || (getCurrentLevel() == DEFAULT_INITIAL_LEVEL + 1 && currentFieldName != null && currentFieldName.equals(ID_FIELD_NAME)); } private void addBsonValue(final Supplier value, final Runnable writeValue) { if (isWritingId()) { - if (getIdBsonWriterCurrentLevel() >= 0) { + if (getIdBsonWriterCurrentLevel() > DEFAULT_INITIAL_LEVEL) { writeValue.run(); } else { id = value.get(); @@ -448,7 +448,7 @@ private void addBsonValue(final Supplier value, final Runnable writeV } private int getIdBsonWriterCurrentLevel() { - return idBsonBinaryWriter == null ? -1 : idBsonBinaryWriter.getCurrentLevel(); + return idBsonBinaryWriter == null ? DEFAULT_INITIAL_LEVEL : idBsonBinaryWriter.getCurrentLevel(); } private LevelCountingBsonWriter getIdBsonWriter() { diff --git a/driver-core/src/main/com/mongodb/internal/connection/LevelCountingBsonWriter.java b/driver-core/src/main/com/mongodb/internal/connection/LevelCountingBsonWriter.java index 44889765fbf..3e9d0324bd7 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/LevelCountingBsonWriter.java +++ b/driver-core/src/main/com/mongodb/internal/connection/LevelCountingBsonWriter.java @@ -20,10 +20,21 @@ abstract class LevelCountingBsonWriter extends BsonWriterDecorator { - private int level = -1; + static final int DEFAULT_INITIAL_LEVEL = -1; + + private int level; LevelCountingBsonWriter(final BsonWriter bsonWriter) { + this(bsonWriter, DEFAULT_INITIAL_LEVEL); + } + + /** + * @param initialLevel This parameter allows initializing the {@linkplain #getCurrentLevel() current level} + * with a value different from {@link #DEFAULT_INITIAL_LEVEL}. + */ + LevelCountingBsonWriter(final BsonWriter bsonWriter, final int initialLevel) { super(bsonWriter); + level = initialLevel; } int getCurrentLevel() { diff --git a/driver-core/src/main/com/mongodb/internal/connection/MessageSequences.java b/driver-core/src/main/com/mongodb/internal/connection/MessageSequences.java new file mode 100644 index 00000000000..19600007404 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/connection/MessageSequences.java @@ -0,0 +1,31 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.internal.connection; + +/** + * Zero or more identifiable sequences contained in the {@code OP_MSG} section with payload type 1. + *

+ * This class is not part of the public API and may be removed or changed at any time.

+ * @see OP_MSG + */ +public abstract class MessageSequences { + public static final class EmptyMessageSequences extends MessageSequences { + public static final EmptyMessageSequences INSTANCE = new EmptyMessageSequences(); + + private EmptyMessageSequences() { + } + } +} diff --git a/driver-core/src/main/com/mongodb/internal/connection/MessageSettings.java b/driver-core/src/main/com/mongodb/internal/connection/MessageSettings.java index 7a5734bc140..91fd862ce4b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/MessageSettings.java +++ b/driver-core/src/main/com/mongodb/internal/connection/MessageSettings.java @@ -42,6 +42,13 @@ public final class MessageSettings { * {@code maxWriteBatchSize}. */ private static final int DEFAULT_MAX_BATCH_COUNT = 1000; + /** + * The headroom for documents that are not intended to be stored in a database. + * A command document is an example of such a document. + * This headroom allows a command document to specify a document that is intended to be stored in a database, + * even if the specified document is of the maximum size. + */ + static final int DOCUMENT_HEADROOM_SIZE = 16 * (1 << 10); private final int maxDocumentSize; private final int maxMessageSize; diff --git a/driver-core/src/main/com/mongodb/internal/connection/RequestMessage.java b/driver-core/src/main/com/mongodb/internal/connection/RequestMessage.java index 86e2ebd1dbe..dd09a59f763 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/RequestMessage.java +++ b/driver-core/src/main/com/mongodb/internal/connection/RequestMessage.java @@ -18,24 +18,16 @@ import com.mongodb.lang.Nullable; import org.bson.BsonBinaryWriter; -import org.bson.BsonBinaryWriterSettings; import org.bson.BsonDocument; -import org.bson.BsonElement; -import org.bson.BsonWriter; -import org.bson.BsonWriterSettings; import org.bson.FieldNameValidator; -import org.bson.codecs.BsonValueCodecProvider; -import org.bson.codecs.Codec; -import org.bson.codecs.Encoder; -import org.bson.codecs.EncoderContext; -import org.bson.codecs.configuration.CodecRegistry; import org.bson.io.BsonOutput; -import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import static com.mongodb.assertions.Assertions.notNull; -import static org.bson.codecs.configuration.CodecRegistries.fromProviders; +import static com.mongodb.internal.connection.BsonWriterHelper.backpatchLength; +import static com.mongodb.internal.connection.BsonWriterHelper.createBsonBinaryWriter; +import static com.mongodb.internal.connection.BsonWriterHelper.encodeUsingRegistry; /** * Abstract base class for all MongoDB Wire Protocol request messages. @@ -46,12 +38,6 @@ abstract class RequestMessage { static final int MESSAGE_PROLOGUE_LENGTH = 16; - // Allow an extra 16K to the maximum allowed size of a query or command document, so that, for example, - // a 16M document can be upserted via findAndModify - private static final int DOCUMENT_HEADROOM = 16 * 1024; - - private static final CodecRegistry REGISTRY = fromProviders(new BsonValueCodecProvider()); - private final String collectionName; private final MessageSettings settings; private final int id; @@ -128,12 +114,12 @@ public MessageSettings getSettings() { * @param bsonOutput the output * @param operationContext the session context */ - public void encode(final BsonOutput bsonOutput, final OperationContext operationContext) { + public void encode(final ByteBufferBsonOutput bsonOutput, final OperationContext operationContext) { notNull("operationContext", operationContext); int messageStartPosition = bsonOutput.getPosition(); writeMessagePrologue(bsonOutput); EncodingMetadata encodingMetadata = encodeMessageBodyWithMetadata(bsonOutput, operationContext); - backpatchMessageLength(messageStartPosition, bsonOutput); + backpatchLength(messageStartPosition, bsonOutput); this.encodingMetadata = encodingMetadata; } @@ -165,23 +151,13 @@ protected void writeMessagePrologue(final BsonOutput bsonOutput) { * @param operationContext the session context * @return the encoding metadata */ - protected abstract EncodingMetadata encodeMessageBodyWithMetadata(BsonOutput bsonOutput, OperationContext operationContext); - - protected void addDocument(final BsonDocument document, final BsonOutput bsonOutput, - final FieldNameValidator validator, @Nullable final List extraElements) { - addDocument(document, getCodec(document), EncoderContext.builder().build(), bsonOutput, validator, - settings.getMaxDocumentSize() + DOCUMENT_HEADROOM, extraElements); - } + protected abstract EncodingMetadata encodeMessageBodyWithMetadata(ByteBufferBsonOutput bsonOutput, OperationContext operationContext); - /** - * Backpatches the message length into the beginning of the message. - * - * @param startPosition the start position of the message - * @param bsonOutput the output - */ - protected void backpatchMessageLength(final int startPosition, final BsonOutput bsonOutput) { - int messageLength = bsonOutput.getPosition() - startPosition; - bsonOutput.writeInt32(bsonOutput.getPosition() - messageLength, messageLength); + protected int writeDocument(final BsonDocument document, final BsonOutput bsonOutput, final FieldNameValidator validator) { + BsonBinaryWriter writer = createBsonBinaryWriter(bsonOutput, validator, getSettings()); + int documentStart = bsonOutput.getPosition(); + encodeUsingRegistry(writer, document); + return bsonOutput.getPosition() - documentStart; } /** @@ -192,20 +168,4 @@ protected void backpatchMessageLength(final int startPosition, final BsonOutput protected String getCollectionName() { return collectionName; } - - @SuppressWarnings("unchecked") - Codec getCodec(final BsonDocument document) { - return (Codec) REGISTRY.get(document.getClass()); - } - - private void addDocument(final T obj, final Encoder encoder, final EncoderContext encoderContext, - final BsonOutput bsonOutput, final FieldNameValidator validator, final int maxDocumentSize, - @Nullable final List extraElements) { - BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter(new BsonWriterSettings(), new BsonBinaryWriterSettings(maxDocumentSize), - bsonOutput, validator); - BsonWriter bsonWriter = extraElements == null - ? bsonBinaryWriter - : new ElementExtendingBsonWriter(bsonBinaryWriter, extraElements); - encoder.encode(bsonWriter, obj, encoderContext); - } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java index d628a39238d..b091d654aa7 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java @@ -26,6 +26,7 @@ import org.bson.BsonObjectId; import org.bson.BsonValue; import org.bson.BsonWriter; +import org.bson.FieldNameValidator; import org.bson.codecs.BsonValueCodecProvider; import org.bson.codecs.Codec; import org.bson.codecs.Encoder; @@ -54,8 +55,9 @@ * *

This class is not part of the public API and may be removed or changed at any time

*/ -public final class SplittablePayload { +public final class SplittablePayload extends MessageSequences { private static final CodecRegistry REGISTRY = fromProviders(new BsonValueCodecProvider()); + private final FieldNameValidator fieldNameValidator; private final WriteRequestEncoder writeRequestEncoder = new WriteRequestEncoder(); private final Type payloadType; private final List writeRequestWithIndexes; @@ -94,10 +96,19 @@ public enum Type { * @param payloadType the payload type * @param writeRequestWithIndexes the writeRequests */ - public SplittablePayload(final Type payloadType, final List writeRequestWithIndexes, final boolean ordered) { + public SplittablePayload( + final Type payloadType, + final List writeRequestWithIndexes, + final boolean ordered, + final FieldNameValidator fieldNameValidator) { this.payloadType = notNull("batchType", payloadType); this.writeRequestWithIndexes = notNull("writeRequests", writeRequestWithIndexes); this.ordered = ordered; + this.fieldNameValidator = notNull("fieldNameValidator", fieldNameValidator); + } + + public FieldNameValidator getFieldNameValidator() { + return fieldNameValidator; } /** @@ -175,7 +186,7 @@ boolean isOrdered() { public SplittablePayload getNextSplit() { isTrue("hasAnotherSplit", hasAnotherSplit()); List nextPayLoad = writeRequestWithIndexes.subList(position, writeRequestWithIndexes.size()); - return new SplittablePayload(payloadType, nextPayLoad, ordered); + return new SplittablePayload(payloadType, nextPayLoad, ordered, fieldNameValidator); } /** @@ -204,7 +215,8 @@ public void encode(final BsonWriter writer, final WriteRequestWithIndex writeReq writer, // Reuse `writeRequestDocumentId` if it may have been generated // by `IdHoldingBsonWriter` in a previous attempt. - // If its type is not `BsonObjectId`, we know it could not have been generated. + // If its type is not `BsonObjectId`, which happens only if `_id` was specified by the application, + // we know it could not have been generated. writeRequestDocumentId instanceof BsonObjectId ? writeRequestDocumentId.asObjectId() : null); getCodec(document).encode(idHoldingBsonWriter, document, EncoderContext.builder().isEncodingCollectibleDocument(true).build()); diff --git a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayloadBsonWriter.java b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayloadBsonWriter.java index ecff2c95a0c..e679a3b557c 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayloadBsonWriter.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayloadBsonWriter.java @@ -63,7 +63,7 @@ public void writeStartDocument() { @Override public void writeEndDocument() { - if (getCurrentLevel() == 0 && payload.hasPayload()) { + if (getCurrentLevel() == DEFAULT_INITIAL_LEVEL + 1 && payload.hasPayload()) { writePayloadArray(writer, bsonOutput, settings, messageStartPosition, payload, maxSplittableDocumentSize); } super.writeEndDocument(); diff --git a/driver-core/src/main/com/mongodb/internal/operation/BulkWriteBatch.java b/driver-core/src/main/com/mongodb/internal/operation/BulkWriteBatch.java index 8da0f13e312..1064bee14d3 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/BulkWriteBatch.java +++ b/driver-core/src/main/com/mongodb/internal/operation/BulkWriteBatch.java @@ -154,7 +154,7 @@ private BulkWriteBatch(final MongoNamespace namespace, final ConnectionDescripti this.indexMap = indexMap; this.unprocessed = unprocessedItems; - this.payload = new SplittablePayload(getPayloadType(batchType), payloadItems, ordered); + this.payload = new SplittablePayload(getPayloadType(batchType), payloadItems, ordered, getFieldNameValidator()); this.operationContext = operationContext; this.comment = comment; this.variables = variables; @@ -270,7 +270,7 @@ BulkWriteBatch getNextBatch() { } } - FieldNameValidator getFieldNameValidator() { + private FieldNameValidator getFieldNameValidator() { if (batchType == UPDATE || batchType == REPLACE) { Map rootMap; if (batchType == REPLACE) { diff --git a/driver-core/src/main/com/mongodb/internal/operation/ClientBulkWriteOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ClientBulkWriteOperation.java index b2f6b131723..eb68a5e1983 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ClientBulkWriteOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ClientBulkWriteOperation.java @@ -16,7 +16,6 @@ package com.mongodb.internal.operation; import com.mongodb.ClientBulkWriteException; -import com.mongodb.Function; import com.mongodb.MongoClientSettings; import com.mongodb.MongoCommandException; import com.mongodb.MongoException; @@ -31,27 +30,32 @@ import com.mongodb.bulk.WriteConcernError; import com.mongodb.client.cursor.TimeoutMode; import com.mongodb.client.model.bulk.ClientBulkWriteOptions; -import com.mongodb.client.model.bulk.ClientNamespacedReplaceOneModel; -import com.mongodb.client.model.bulk.ClientNamespacedUpdateOneModel; -import com.mongodb.client.model.bulk.ClientNamespacedWriteModel; import com.mongodb.client.model.bulk.ClientBulkWriteResult; import com.mongodb.client.model.bulk.ClientDeleteResult; import com.mongodb.client.model.bulk.ClientInsertOneResult; +import com.mongodb.client.model.bulk.ClientNamespacedReplaceOneModel; +import com.mongodb.client.model.bulk.ClientNamespacedUpdateOneModel; +import com.mongodb.client.model.bulk.ClientNamespacedWriteModel; import com.mongodb.client.model.bulk.ClientUpdateResult; import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.VisibleForTesting; import com.mongodb.internal.async.function.RetryState; import com.mongodb.internal.binding.ConnectionSource; import com.mongodb.internal.binding.WriteBinding; import com.mongodb.internal.client.model.bulk.AbstractClientDeleteModel; import com.mongodb.internal.client.model.bulk.AbstractClientNamespacedWriteModel; import com.mongodb.internal.client.model.bulk.AbstractClientUpdateModel; +import com.mongodb.internal.client.model.bulk.AcknowledgedSummaryClientBulkWriteResult; +import com.mongodb.internal.client.model.bulk.AcknowledgedVerboseClientBulkWriteResult; import com.mongodb.internal.client.model.bulk.ClientWriteModel; import com.mongodb.internal.client.model.bulk.ConcreteClientBulkWriteOptions; import com.mongodb.internal.client.model.bulk.ConcreteClientDeleteManyModel; import com.mongodb.internal.client.model.bulk.ConcreteClientDeleteOneModel; import com.mongodb.internal.client.model.bulk.ConcreteClientDeleteOptions; +import com.mongodb.internal.client.model.bulk.ConcreteClientDeleteResult; import com.mongodb.internal.client.model.bulk.ConcreteClientInsertOneModel; +import com.mongodb.internal.client.model.bulk.ConcreteClientInsertOneResult; import com.mongodb.internal.client.model.bulk.ConcreteClientNamespacedDeleteManyModel; import com.mongodb.internal.client.model.bulk.ConcreteClientNamespacedDeleteOneModel; import com.mongodb.internal.client.model.bulk.ConcreteClientNamespacedInsertOneModel; @@ -63,26 +67,26 @@ import com.mongodb.internal.client.model.bulk.ConcreteClientUpdateManyModel; import com.mongodb.internal.client.model.bulk.ConcreteClientUpdateOneModel; import com.mongodb.internal.client.model.bulk.ConcreteClientUpdateOptions; -import com.mongodb.internal.client.model.bulk.AcknowledgedSummaryClientBulkWriteResult; -import com.mongodb.internal.client.model.bulk.AcknowledgedVerboseClientBulkWriteResult; -import com.mongodb.internal.client.model.bulk.ConcreteClientDeleteResult; -import com.mongodb.internal.client.model.bulk.ConcreteClientInsertOneResult; import com.mongodb.internal.client.model.bulk.ConcreteClientUpdateResult; import com.mongodb.internal.client.model.bulk.UnacknowledgedClientBulkWriteResult; import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.DualMessageSequences; import com.mongodb.internal.connection.IdHoldingBsonWriter; import com.mongodb.internal.connection.MongoWriteConcernWithResponseException; import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.operation.retry.AttachmentKeys; import com.mongodb.internal.session.SessionContext; -import com.mongodb.internal.validator.MappedFieldNameValidator; import com.mongodb.internal.validator.NoOpFieldNameValidator; import com.mongodb.internal.validator.ReplacingDocumentFieldNameValidator; import com.mongodb.internal.validator.UpdateFieldNameValidator; import com.mongodb.lang.Nullable; import org.bson.BsonArray; +import org.bson.BsonBinaryWriter; +import org.bson.BsonBoolean; import org.bson.BsonDocument; -import org.bson.BsonDocumentWrapper; +import org.bson.BsonElement; +import org.bson.BsonInt32; +import org.bson.BsonInt64; import org.bson.BsonObjectId; import org.bson.BsonValue; import org.bson.BsonWriter; @@ -106,11 +110,15 @@ import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.fail; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PACKAGE; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; +import static com.mongodb.internal.connection.DualMessageSequences.WritersProviderAndLimitsChecker.WriteResult.FAIL_LIMIT_EXCEEDED; +import static com.mongodb.internal.connection.DualMessageSequences.WritersProviderAndLimitsChecker.WriteResult.OK_LIMIT_NOT_REACHED; import static com.mongodb.internal.operation.BulkWriteBatch.logWriteModelDoesNotSupportRetries; +import static com.mongodb.internal.operation.CommandOperationHelper.commandWriteConcern; import static com.mongodb.internal.operation.CommandOperationHelper.initialRetryState; import static com.mongodb.internal.operation.CommandOperationHelper.shouldAttemptToRetryWriteAndAddRetryableLabel; import static com.mongodb.internal.operation.CommandOperationHelper.transformWriteException; -import static com.mongodb.internal.operation.CommandOperationHelper.commandWriteConcern; import static com.mongodb.internal.operation.CommandOperationHelper.validateAndGetEffectiveWriteConcern; import static com.mongodb.internal.operation.OperationHelper.isRetryableWrite; import static com.mongodb.internal.operation.SyncOperationHelper.cursorDocumentToBatchCursor; @@ -118,7 +126,8 @@ import static com.mongodb.internal.operation.SyncOperationHelper.withSourceAndConnection; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; -import static java.util.Collections.singletonMap; +import static java.util.Collections.singletonList; +import static java.util.Optional.ofNullable; import static java.util.Spliterator.IMMUTABLE; import static java.util.Spliterator.ORDERED; import static java.util.Spliterators.spliteratorUnknownSize; @@ -193,8 +202,8 @@ private void executeAllBatches( /** * @return The start model index of the next batch, provided that the operation - * {@linkplain ExhaustiveBulkWriteCommandOkResponse#operationMayContinue(ConcreteClientBulkWriteOptions) may continue} - * and there are unexecuted models left. + * {@linkplain ExhaustiveClientBulkWriteCommandOkResponse#operationMayContinue(ConcreteClientBulkWriteOptions) may continue} + * and there are unexecuted {@linkplain ClientNamespacedWriteModel models} left. */ @Nullable private Integer executeBatch( @@ -203,12 +212,13 @@ private Integer executeBatch( final WriteBinding binding, final ResultAccumulator resultAccumulator) { List unexecutedModels = models.subList(batchStartModelIndex, models.size()); + assertFalse(unexecutedModels.isEmpty()); OperationContext operationContext = binding.getOperationContext(); SessionContext sessionContext = operationContext.getSessionContext(); TimeoutContext timeoutContext = operationContext.getTimeoutContext(); RetryState retryState = initialRetryState(retryWritesSetting, timeoutContext); BatchEncoder batchEncoder = new BatchEncoder(); - Supplier retryingBatchExecutor = decorateWriteWithRetries( + Supplier retryingBatchExecutor = decorateWriteWithRetries( retryState, operationContext, // Each batch re-selects a server and re-checks out a connection because this is simpler, // and it is allowed by https://jira.mongodb.org/browse/DRIVERS-2502. @@ -216,21 +226,21 @@ private Integer executeBatch( // and `ClientSession`, `TransactionContext` are aware of that. () -> withSourceAndConnection(binding::getWriteConnectionSource, true, (connectionSource, connection) -> { ConnectionDescription connectionDescription = connection.getDescription(); - boolean effectiveRetryWrites = isRetryableWrite(retryWritesSetting, effectiveWriteConcern, connectionDescription, sessionContext); + boolean effectiveRetryWrites = isRetryableWrite( + retryWritesSetting, effectiveWriteConcern, connectionDescription, sessionContext); retryState.breakAndThrowIfRetryAnd(() -> !effectiveRetryWrites); resultAccumulator.onNewServerAddress(connectionDescription.getServerAddress()); retryState.attach(AttachmentKeys.maxWireVersion(), connectionDescription.getMaxWireVersion(), true) .attach(AttachmentKeys.commandDescriptionSupplier(), () -> BULK_WRITE_COMMAND_NAME, false); - BsonDocumentWrapper lazilyEncodedBulkWriteCommand = createBulkWriteCommand( - effectiveRetryWrites, effectiveWriteConcern, sessionContext, unexecutedModels, batchEncoder, + ClientBulkWriteCommand bulkWriteCommand = createBulkWriteCommand( + retryState, effectiveRetryWrites, effectiveWriteConcern, sessionContext, unexecutedModels, batchEncoder, () -> retryState.attach(AttachmentKeys.retryableCommandFlag(), true, true)); return executeBulkWriteCommandAndExhaustOkResponse( - retryState, connectionSource, connection, lazilyEncodedBulkWriteCommand, unexecutedModels, - effectiveWriteConcern, operationContext); + retryState, connectionSource, connection, bulkWriteCommand, effectiveWriteConcern, operationContext); }) ); try { - ExhaustiveBulkWriteCommandOkResponse bulkWriteCommandOkResponse = retryingBatchExecutor.get(); + ExhaustiveClientBulkWriteCommandOkResponse bulkWriteCommandOkResponse = retryingBatchExecutor.get(); return resultAccumulator.onBulkWriteCommandOkResponseOrNoResponse( batchStartModelIndex, bulkWriteCommandOkResponse, batchEncoder.intoEncodedBatchInfo()); } catch (MongoWriteConcernWithResponseException mongoWriteConcernWithOkResponseException) { @@ -251,36 +261,34 @@ private Integer executeBatch( /** * @throws MongoWriteConcernWithResponseException This internal exception must be handled to avoid it being observed by an application. - * It {@linkplain MongoWriteConcernWithResponseException#getResponse() bears} the OK response to the {@code lazilyEncodedCommand}, + * It {@linkplain MongoWriteConcernWithResponseException#getResponse() bears} the OK response to the {@code bulkWriteCommand}, * which must be * {@linkplain ResultAccumulator#onBulkWriteCommandOkResponseWithWriteConcernError(int, MongoWriteConcernWithResponseException, BatchEncoder.EncodedBatchInfo) accumulated} * iff this exception is the failed result of retries. */ @Nullable - private ExhaustiveBulkWriteCommandOkResponse executeBulkWriteCommandAndExhaustOkResponse( + private ExhaustiveClientBulkWriteCommandOkResponse executeBulkWriteCommandAndExhaustOkResponse( final RetryState retryState, final ConnectionSource connectionSource, final Connection connection, - final BsonDocumentWrapper lazilyEncodedCommand, - final List unexecutedModels, + final ClientBulkWriteCommand bulkWriteCommand, final WriteConcern effectiveWriteConcern, final OperationContext operationContext) throws MongoWriteConcernWithResponseException { BsonDocument bulkWriteCommandOkResponse = connection.command( "admin", - lazilyEncodedCommand, - FieldNameValidators.createUpdateModsFieldValidator(unexecutedModels), + bulkWriteCommand.getCommandDocument(), + NoOpFieldNameValidator.INSTANCE, null, CommandResultDocumentCodec.create(codecRegistry.get(BsonDocument.class), CommandBatchCursorHelper.FIRST_BATCH), operationContext, effectiveWriteConcern.isAcknowledged(), - null, - null); + bulkWriteCommand.getOpsAndNsInfo()); if (bulkWriteCommandOkResponse == null) { return null; } List> cursorExhaustBatches = doWithRetriesDisabledForCommand(retryState, "getMore", () -> exhaustBulkWriteCommandOkResponseCursor(connectionSource, connection, bulkWriteCommandOkResponse)); - ExhaustiveBulkWriteCommandOkResponse exhaustiveBulkWriteCommandOkResponse = new ExhaustiveBulkWriteCommandOkResponse( + ExhaustiveClientBulkWriteCommandOkResponse exhaustiveBulkWriteCommandOkResponse = new ExhaustiveClientBulkWriteCommandOkResponse( bulkWriteCommandOkResponse, cursorExhaustBatches); // `Connection.command` does not throw `MongoWriteConcernException`, so we have to construct it ourselves MongoWriteConcernException writeConcernException = Exceptions.createWriteConcernException( @@ -313,7 +321,7 @@ private List> exhaustBulkWriteCommandOkResponseCursor( final Connection connection, final BsonDocument response) { int serverDefaultCursorBatchSize = 0; - try (BatchCursor cursor = cursorDocumentToBatchCursor( + try (CommandBatchCursor cursor = cursorDocumentToBatchCursor( TimeoutMode.CURSOR_LIFETIME, response, serverDefaultCursorBatchSize, @@ -321,78 +329,42 @@ private List> exhaustBulkWriteCommandOkResponseCursor( options.getComment().orElse(null), connectionSource, connection)) { + cursor.setCloseWithoutTimeoutReset(true); return stream(spliteratorUnknownSize(cursor, ORDERED | IMMUTABLE), false).collect(toList()); } } - private BsonDocumentWrapper createBulkWriteCommand( + private ClientBulkWriteCommand createBulkWriteCommand( + final RetryState retryState, final boolean effectiveRetryWrites, final WriteConcern effectiveWriteConcern, final SessionContext sessionContext, final List unexecutedModels, final BatchEncoder batchEncoder, - final Runnable ifCommandIsRetryable) { - return new BsonDocumentWrapper<>( - BULK_WRITE_COMMAND_NAME, - new Encoder() { - @Override - public void encode(final BsonWriter writer, final String commandName, final EncoderContext encoderContext) { - batchEncoder.reset(); - writer.writeStartDocument(); - writer.writeInt32(commandName, 1); - writer.writeBoolean("errorsOnly", !options.isVerboseResults()); - writer.writeBoolean("ordered", options.isOrdered()); - options.isBypassDocumentValidation().ifPresent(value -> writer.writeBoolean("bypassDocumentValidation", value)); - options.getComment().ifPresent(value -> { - writer.writeName("comment"); - encodeUsingRegistry(writer, value); - }); - options.getLet().ifPresent(value -> { - writer.writeName("let"); - encodeUsingRegistry(writer, value); - }); - Function modelSupportsRetries = model -> - !(model instanceof ConcreteClientUpdateManyModel || model instanceof ConcreteClientDeleteManyModel); - assertFalse(unexecutedModels.isEmpty()); - LinkedHashMap indexedNamespaces = new LinkedHashMap<>(); - writer.writeStartArray("ops"); - boolean commandIsRetryable = effectiveRetryWrites; - for (int modelIndexInBatch = 0; modelIndexInBatch < unexecutedModels.size(); modelIndexInBatch++) { - AbstractClientNamespacedWriteModel modelWithNamespace = getNamespacedModel(unexecutedModels, modelIndexInBatch); - ClientWriteModel model = modelWithNamespace.getModel(); - if (commandIsRetryable && !modelSupportsRetries.apply(model)) { - commandIsRetryable = false; - logWriteModelDoesNotSupportRetries(); - } - int namespaceIndexInBatch = indexedNamespaces.computeIfAbsent( - modelWithNamespace.getNamespace(), k -> indexedNamespaces.size()); - batchEncoder.encodeWriteModel(writer, model, modelIndexInBatch, namespaceIndexInBatch); - } - writer.writeEndArray(); - writer.writeStartArray("nsInfo"); - indexedNamespaces.keySet().forEach(namespace -> { - writer.writeStartDocument(); - writer.writeString("ns", namespace.getFullName()); - writer.writeEndDocument(); - }); - writer.writeEndArray(); - if (commandIsRetryable) { - batchEncoder.encodeTxnNumber(writer, sessionContext); - ifCommandIsRetryable.run(); - } - commandWriteConcern(effectiveWriteConcern, sessionContext).ifPresent(value -> { - writer.writeName("writeConcern"); - encodeUsingRegistry(writer, value.asDocument()); - }); - writer.writeEndDocument(); - } - - @Override - public Class getEncoderClass() { - throw fail(); - } - } - ); + final Runnable retriesEnabler) { + BsonDocument commandDocument = new BsonDocument(BULK_WRITE_COMMAND_NAME, new BsonInt32(1)) + .append("errorsOnly", BsonBoolean.valueOf(!options.isVerboseResults())) + .append("ordered", BsonBoolean.valueOf(options.isOrdered())); + options.isBypassDocumentValidation().ifPresent(value -> + commandDocument.append("bypassDocumentValidation", BsonBoolean.valueOf(value))); + options.getComment().ifPresent(value -> + commandDocument.append("comment", value)); + options.getLet().ifPresent(let -> + commandDocument.append("let", let.toBsonDocument(BsonDocument.class, codecRegistry))); + commandWriteConcern(effectiveWriteConcern, sessionContext).ifPresent(value-> + commandDocument.append("writeConcern", value.asDocument())); + return new ClientBulkWriteCommand( + commandDocument, + new ClientBulkWriteCommand.OpsAndNsInfo( + effectiveRetryWrites, unexecutedModels, + batchEncoder, + options, + () -> { + retriesEnabler.run(); + return retryState.isFirstAttempt() + ? sessionContext.advanceTransactionNumber() + : sessionContext.getTransactionNumber(); + })); } private void encodeUsingRegistry(final BsonWriter writer, final T value) { @@ -401,8 +373,8 @@ private void encodeUsingRegistry(final BsonWriter writer, final T value) { private void encodeUsingRegistry(final BsonWriter writer, final T value, final EncoderContext encoderContext) { @SuppressWarnings("unchecked") - Encoder collationEncoder = (Encoder) codecRegistry.get(value.getClass()); - collationEncoder.encode(writer, value, encoderContext); + Encoder encoder = (Encoder) codecRegistry.get(value.getClass()); + encoder.encode(writer, value, encoderContext); } private static AbstractClientNamespacedWriteModel getNamespacedModel( @@ -418,7 +390,7 @@ public static Optional serverAddressFromException(@Nullable final } else if (exception instanceof MongoSocketException) { serverAddress = ((MongoSocketException) exception).getServerAddress(); } - return Optional.ofNullable(serverAddress); + return ofNullable(serverAddress); } @Nullable @@ -438,139 +410,7 @@ private static MongoWriteConcernException createWriteConcernException( } } - private static final class FieldNameValidators { - /** - * The server supports only the {@code update} individual write operation in the {@code ops} array field, while the driver supports - * {@link ClientNamespacedUpdateOneModel}, {@link ClientNamespacedUpdateOneModel}, {@link ClientNamespacedReplaceOneModel}. - * The difference between updating and replacing is only in the document specified via the {@code updateMods} field: - *
    - *
  • if the name of the first field starts with {@code '$'}, then the document is interpreted as specifying update operators;
  • - *
  • if the name of the first field does not start with {@code '$'}, then the document is interpreted as a replacement.
  • - *
- */ - private static FieldNameValidator createUpdateModsFieldValidator(final List models) { - return new MappedFieldNameValidator( - NoOpFieldNameValidator.INSTANCE, - singletonMap("ops", new FieldNameValidators.OpsArrayFieldValidator(models))); - } - - static final class OpsArrayFieldValidator implements FieldNameValidator { - private static final Set OPERATION_DISCRIMINATOR_FIELD_NAMES = Stream.of("insert", "update", "delete").collect(toSet()); - - private final List models; - private final ReplacingUpdateModsFieldValidator replacingValidator; - private final UpdatingUpdateModsFieldValidator updatingValidator; - private int currentIndividualOperationIndex; - - OpsArrayFieldValidator(final List models) { - this.models = models; - replacingValidator = new ReplacingUpdateModsFieldValidator(); - updatingValidator = new UpdatingUpdateModsFieldValidator(); - currentIndividualOperationIndex = -1; - } - - @Override - public boolean validate(final String fieldName) { - if (OPERATION_DISCRIMINATOR_FIELD_NAMES.contains(fieldName)) { - currentIndividualOperationIndex++; - } - return true; - } - - @Override - public FieldNameValidator getValidatorForField(final String fieldName) { - if (fieldName.equals("updateMods")) { - return currentIndividualOperationIsReplace() ? replacingValidator.reset() : updatingValidator.reset(); - } - return NoOpFieldNameValidator.INSTANCE; - } - - private boolean currentIndividualOperationIsReplace() { - return getNamespacedModel(models, currentIndividualOperationIndex) instanceof ConcreteClientNamespacedReplaceOneModel; - } - } - - private static final class ReplacingUpdateModsFieldValidator implements FieldNameValidator { - private boolean firstFieldSinceLastReset; - - ReplacingUpdateModsFieldValidator() { - firstFieldSinceLastReset = true; - } - - @Override - public boolean validate(final String fieldName) { - if (firstFieldSinceLastReset) { - // we must validate only the first field, and leave the rest up to the server - firstFieldSinceLastReset = false; - return ReplacingDocumentFieldNameValidator.INSTANCE.validate(fieldName); - } - return true; - } - - @Override - public String getValidationErrorMessage(final String fieldName) { - return ReplacingDocumentFieldNameValidator.INSTANCE.getValidationErrorMessage(fieldName); - } - - @Override - public FieldNameValidator getValidatorForField(final String fieldName) { - return NoOpFieldNameValidator.INSTANCE; - } - - ReplacingUpdateModsFieldValidator reset() { - firstFieldSinceLastReset = true; - return this; - } - } - - private static final class UpdatingUpdateModsFieldValidator implements FieldNameValidator { - private final UpdateFieldNameValidator delegate; - private boolean firstFieldSinceLastReset; - - UpdatingUpdateModsFieldValidator() { - delegate = new UpdateFieldNameValidator(); - firstFieldSinceLastReset = true; - } - - @Override - public boolean validate(final String fieldName) { - if (firstFieldSinceLastReset) { - // we must validate only the first field, and leave the rest up to the server - firstFieldSinceLastReset = false; - return delegate.validate(fieldName); - } - return true; - } - - @Override - public String getValidationErrorMessage(final String fieldName) { - return delegate.getValidationErrorMessage(fieldName); - } - - @Override - public FieldNameValidator getValidatorForField(final String fieldName) { - return NoOpFieldNameValidator.INSTANCE; - } - - @Override - public void start() { - delegate.start(); - } - - @Override - public void end() { - delegate.end(); - } - - UpdatingUpdateModsFieldValidator reset() { - delegate.reset(); - firstFieldSinceLastReset = true; - return this; - } - } - } - - private static final class ExhaustiveBulkWriteCommandOkResponse { + private static final class ExhaustiveClientBulkWriteCommandOkResponse { /** * The number of unsuccessful individual write operations. */ @@ -582,7 +422,7 @@ private static final class ExhaustiveBulkWriteCommandOkResponse { private final int nDeleted; private final List cursorExhaust; - ExhaustiveBulkWriteCommandOkResponse( + ExhaustiveClientBulkWriteCommandOkResponse( final BsonDocument bulkWriteCommandOkResponse, final List> cursorExhaustBatches) { this.nErrors = bulkWriteCommandOkResponse.getInt32("nErrors").getValue(); @@ -656,8 +496,8 @@ private final class ResultAccumulator { */ ClientBulkWriteResult build(@Nullable final MongoException topLevelError, final WriteConcern effectiveWriteConcern) throws MongoException { boolean verboseResultsSetting = options.isVerboseResults(); - boolean haveResponses = false; - boolean haveSuccessfulIndividualOperations = false; + boolean batchResultsHaveResponses = false; + boolean batchResultsHaveInfoAboutSuccessfulIndividualOperations = false; long insertedCount = 0; long upsertedCount = 0; long matchedCount = 0; @@ -670,15 +510,18 @@ ClientBulkWriteResult build(@Nullable final MongoException topLevelError, final Map writeErrors = new HashMap<>(); for (BatchResult batchResult : batchResults) { if (batchResult.hasResponse()) { - haveResponses = true; + batchResultsHaveResponses = true; MongoWriteConcernException writeConcernException = batchResult.getWriteConcernException(); if (writeConcernException != null) { writeConcernErrors.add(writeConcernException.getWriteConcernError()); } int batchStartModelIndex = batchResult.getBatchStartModelIndex(); - ExhaustiveBulkWriteCommandOkResponse response = batchResult.getResponse(); - haveSuccessfulIndividualOperations = haveSuccessfulIndividualOperations - || response.getNErrors() < batchResult.getBatchModelsCount(); + ExhaustiveClientBulkWriteCommandOkResponse response = batchResult.getResponse(); + boolean orderedSetting = options.isOrdered(); + int nErrors = response.getNErrors(); + batchResultsHaveInfoAboutSuccessfulIndividualOperations = batchResultsHaveInfoAboutSuccessfulIndividualOperations + || (orderedSetting && nErrors == 0) + || (!orderedSetting && nErrors < batchResult.getBatchModelsCount()); insertedCount += response.getNInserted(); upsertedCount += response.getNUpserted(); matchedCount += response.getNMatched(); @@ -714,6 +557,8 @@ ClientBulkWriteResult build(@Nullable final MongoException topLevelError, final fail(writeModel.getClass().toString()); } } else { + batchResultsHaveInfoAboutSuccessfulIndividualOperations = batchResultsHaveInfoAboutSuccessfulIndividualOperations + || (orderedSetting && individualOperationIndexInBatch > 0); WriteError individualOperationWriteError = new WriteError( individualOperationResponse.getInt32("code").getValue(), individualOperationResponse.getString("errmsg").getValue(), @@ -733,8 +578,8 @@ ClientBulkWriteResult build(@Nullable final MongoException topLevelError, final } else { return UnacknowledgedClientBulkWriteResult.INSTANCE; } - } else if (haveResponses) { - AcknowledgedSummaryClientBulkWriteResult partialSummaryResult = haveSuccessfulIndividualOperations + } else if (batchResultsHaveResponses) { + AcknowledgedSummaryClientBulkWriteResult partialSummaryResult = batchResultsHaveInfoAboutSuccessfulIndividualOperations ? new AcknowledgedSummaryClientBulkWriteResult(insertedCount, upsertedCount, matchedCount, modifiedCount, deletedCount) : null; throw new ClientBulkWriteException( @@ -758,7 +603,7 @@ void onNewServerAddress(final ServerAddress serverAddress) { Integer onBulkWriteCommandOkResponseOrNoResponse( final int batchStartModelIndex, @Nullable - final ExhaustiveBulkWriteCommandOkResponse response, + final ExhaustiveClientBulkWriteCommandOkResponse response, final BatchEncoder.EncodedBatchInfo encodedBatchInfo) { return onBulkWriteCommandOkResponseOrNoResponse(batchStartModelIndex, response, null, encodedBatchInfo); } @@ -773,7 +618,7 @@ Integer onBulkWriteCommandOkResponseWithWriteConcernError( final BatchEncoder.EncodedBatchInfo encodedBatchInfo) { MongoWriteConcernException writeConcernException = (MongoWriteConcernException) exception.getCause(); onNewServerAddress(writeConcernException.getServerAddress()); - ExhaustiveBulkWriteCommandOkResponse response = (ExhaustiveBulkWriteCommandOkResponse) exception.getResponse(); + ExhaustiveClientBulkWriteCommandOkResponse response = (ExhaustiveClientBulkWriteCommandOkResponse) exception.getResponse(); return onBulkWriteCommandOkResponseOrNoResponse(batchStartModelIndex, response, writeConcernException, encodedBatchInfo); } @@ -784,7 +629,7 @@ Integer onBulkWriteCommandOkResponseWithWriteConcernError( private Integer onBulkWriteCommandOkResponseOrNoResponse( final int batchStartModelIndex, @Nullable - final ExhaustiveBulkWriteCommandOkResponse response, + final ExhaustiveClientBulkWriteCommandOkResponse response, @Nullable final MongoWriteConcernException writeConcernException, final BatchEncoder.EncodedBatchInfo encodedBatchInfo) { @@ -807,18 +652,236 @@ void onBulkWriteCommandErrorWithoutResponse(final MongoException exception) { } } + public static final class ClientBulkWriteCommand { + private final BsonDocument commandDocument; + private final OpsAndNsInfo opsAndNsInfo; + + ClientBulkWriteCommand( + final BsonDocument commandDocument, + final OpsAndNsInfo opsAndNsInfo) { + this.commandDocument = commandDocument; + this.opsAndNsInfo = opsAndNsInfo; + } + + BsonDocument getCommandDocument() { + return commandDocument; + } + + OpsAndNsInfo getOpsAndNsInfo() { + return opsAndNsInfo; + } + + public static final class OpsAndNsInfo extends DualMessageSequences { + private final boolean effectiveRetryWrites; + private final List models; + private final BatchEncoder batchEncoder; + private final ConcreteClientBulkWriteOptions options; + private final Supplier doIfCommandIsRetryableAndAdvanceGetTxnNumber; + + @VisibleForTesting(otherwise = PACKAGE) + public OpsAndNsInfo( + final boolean effectiveRetryWrites, + final List models, + final BatchEncoder batchEncoder, + final ConcreteClientBulkWriteOptions options, + final Supplier doIfCommandIsRetryableAndAdvanceGetTxnNumber) { + super("ops", new OpsFieldNameValidator(models), "nsInfo", NoOpFieldNameValidator.INSTANCE); + this.effectiveRetryWrites = effectiveRetryWrites; + this.models = models; + this.batchEncoder = batchEncoder; + this.options = options; + this.doIfCommandIsRetryableAndAdvanceGetTxnNumber = doIfCommandIsRetryableAndAdvanceGetTxnNumber; + } + + @Override + public EncodeDocumentsResult encodeDocuments(final WritersProviderAndLimitsChecker writersProviderAndLimitsChecker) { + // We must call `batchEncoder.reset` lazily, that is here, and not eagerly before a command retry attempt, + // because a retry attempt may fail before encoding, + // in which case we need the information gathered by `batchEncoder` at a previous attempt. + batchEncoder.reset(); + LinkedHashMap indexedNamespaces = new LinkedHashMap<>(); + WritersProviderAndLimitsChecker.WriteResult writeResult = OK_LIMIT_NOT_REACHED; + boolean commandIsRetryable = effectiveRetryWrites; + int maxModelIndexInBatch = -1; + for (int modelIndexInBatch = 0; modelIndexInBatch < models.size() && writeResult == OK_LIMIT_NOT_REACHED; modelIndexInBatch++) { + AbstractClientNamespacedWriteModel namespacedModel = getNamespacedModel(models, modelIndexInBatch); + MongoNamespace namespace = namespacedModel.getNamespace(); + int indexedNamespacesSizeBeforeCompute = indexedNamespaces.size(); + int namespaceIndexInBatch = indexedNamespaces.computeIfAbsent(namespace, k -> indexedNamespacesSizeBeforeCompute); + boolean writeNewNamespace = indexedNamespaces.size() != indexedNamespacesSizeBeforeCompute; + int finalModelIndexInBatch = modelIndexInBatch; + writeResult = writersProviderAndLimitsChecker.tryWrite((opsWriter, nsInfoWriter) -> { + batchEncoder.encodeWriteModel(opsWriter, namespacedModel.getModel(), finalModelIndexInBatch, namespaceIndexInBatch); + if (writeNewNamespace) { + nsInfoWriter.writeStartDocument(); + nsInfoWriter.writeString("ns", namespace.getFullName()); + nsInfoWriter.writeEndDocument(); + } + return finalModelIndexInBatch + 1; + }); + if (writeResult == FAIL_LIMIT_EXCEEDED) { + batchEncoder.reset(finalModelIndexInBatch); + } else { + maxModelIndexInBatch = finalModelIndexInBatch; + if (commandIsRetryable && doesNotSupportRetries(namespacedModel)) { + commandIsRetryable = false; + logWriteModelDoesNotSupportRetries(); + } + } + } + return new EncodeDocumentsResult( + // we will execute more batches, so we must request a response to maintain the order of individual write operations + options.isOrdered() && maxModelIndexInBatch < models.size() - 1, + commandIsRetryable + ? singletonList(new BsonElement("txnNumber", new BsonInt64(doIfCommandIsRetryableAndAdvanceGetTxnNumber.get()))) + : emptyList()); + } + + private static boolean doesNotSupportRetries(final AbstractClientNamespacedWriteModel model) { + return model instanceof ConcreteClientNamespacedUpdateManyModel || model instanceof ConcreteClientNamespacedDeleteManyModel; + } + + /** + * The server supports only the {@code update} individual write operation in the {@code ops} array field, while the driver supports + * {@link ClientNamespacedUpdateOneModel}, {@link ClientNamespacedUpdateOneModel}, {@link ClientNamespacedReplaceOneModel}. + * The difference between updating and replacing is only in the document specified via the {@code updateMods} field: + *
    + *
  • if the name of the first field starts with {@code '$'}, then the document is interpreted as specifying update operators;
  • + *
  • if the name of the first field does not start with {@code '$'}, then the document is interpreted as a replacement.
  • + *
+ * + * @see + * Update vs. replace document validation + */ + private static final class OpsFieldNameValidator implements FieldNameValidator { + private static final Set OPERATION_DISCRIMINATOR_FIELD_NAMES = Stream.of("insert", "update", "delete").collect(toSet()); + + private final List models; + private final ReplacingUpdateModsFieldValidator replacingValidator; + private final UpdatingUpdateModsFieldValidator updatingValidator; + private int currentIndividualOperationIndex; + + OpsFieldNameValidator(final List models) { + this.models = models; + replacingValidator = new ReplacingUpdateModsFieldValidator(); + updatingValidator = new UpdatingUpdateModsFieldValidator(); + currentIndividualOperationIndex = -1; + } + + @Override + public boolean validate(final String fieldName) { + if (OPERATION_DISCRIMINATOR_FIELD_NAMES.contains(fieldName)) { + currentIndividualOperationIndex++; + } + return true; + } + + @Override + public FieldNameValidator getValidatorForField(final String fieldName) { + if (fieldName.equals("updateMods")) { + return currentIndividualOperationIsReplace() ? replacingValidator.reset() : updatingValidator.reset(); + } + return NoOpFieldNameValidator.INSTANCE; + } + + private boolean currentIndividualOperationIsReplace() { + return getNamespacedModel(models, currentIndividualOperationIndex) instanceof ConcreteClientNamespacedReplaceOneModel; + } + + private static final class ReplacingUpdateModsFieldValidator implements FieldNameValidator { + private boolean firstFieldSinceLastReset; + + ReplacingUpdateModsFieldValidator() { + firstFieldSinceLastReset = true; + } + + @Override + public boolean validate(final String fieldName) { + if (firstFieldSinceLastReset) { + // we must validate only the first field, and leave the rest up to the server + firstFieldSinceLastReset = false; + return ReplacingDocumentFieldNameValidator.INSTANCE.validate(fieldName); + } + return true; + } + + @Override + public String getValidationErrorMessage(final String fieldName) { + return ReplacingDocumentFieldNameValidator.INSTANCE.getValidationErrorMessage(fieldName); + } + + @Override + public FieldNameValidator getValidatorForField(final String fieldName) { + return NoOpFieldNameValidator.INSTANCE; + } + + ReplacingUpdateModsFieldValidator reset() { + firstFieldSinceLastReset = true; + return this; + } + } + + private static final class UpdatingUpdateModsFieldValidator implements FieldNameValidator { + private final UpdateFieldNameValidator delegate; + private boolean firstFieldSinceLastReset; + + UpdatingUpdateModsFieldValidator() { + delegate = new UpdateFieldNameValidator(); + firstFieldSinceLastReset = true; + } + + @Override + public boolean validate(final String fieldName) { + if (firstFieldSinceLastReset) { + // we must validate only the first field, and leave the rest up to the server + firstFieldSinceLastReset = false; + return delegate.validate(fieldName); + } + return true; + } + + @Override + public String getValidationErrorMessage(final String fieldName) { + return delegate.getValidationErrorMessage(fieldName); + } + + @Override + public FieldNameValidator getValidatorForField(final String fieldName) { + return NoOpFieldNameValidator.INSTANCE; + } + + @Override + public void start() { + delegate.start(); + } + + @Override + public void end() { + delegate.end(); + } + + UpdatingUpdateModsFieldValidator reset() { + delegate.reset(); + firstFieldSinceLastReset = true; + return this; + } + } + } + } + } + static final class BatchResult { private final int batchStartModelIndex; private final BatchEncoder.EncodedBatchInfo encodedBatchInfo; @Nullable - private final ExhaustiveBulkWriteCommandOkResponse response; + private final ExhaustiveClientBulkWriteCommandOkResponse response; @Nullable private final MongoWriteConcernException writeConcernException; static BatchResult okResponse( final int batchStartModelIndex, final BatchEncoder.EncodedBatchInfo encodedBatchInfo, - final ExhaustiveBulkWriteCommandOkResponse response, + final ExhaustiveClientBulkWriteCommandOkResponse response, @Nullable final MongoWriteConcernException writeConcernException) { return new BatchResult(batchStartModelIndex, encodedBatchInfo, assertNotNull(response), writeConcernException); } @@ -830,7 +893,7 @@ static BatchResult noResponse(final int batchStartModelIndex, final BatchEncoder private BatchResult( final int batchStartModelIndex, final BatchEncoder.EncodedBatchInfo encodedBatchInfo, - @Nullable final ExhaustiveBulkWriteCommandOkResponse response, + @Nullable final ExhaustiveClientBulkWriteCommandOkResponse response, @Nullable final MongoWriteConcernException writeConcernException) { this.batchStartModelIndex = batchStartModelIndex; this.encodedBatchInfo = encodedBatchInfo; @@ -853,7 +916,7 @@ boolean hasResponse() { return response != null; } - ExhaustiveBulkWriteCommandOkResponse getResponse() { + ExhaustiveClientBulkWriteCommandOkResponse getResponse() { return assertNotNull(response); } @@ -875,16 +938,19 @@ Map getInsertModelDocumentIds() { /** * Exactly one instance must be used per {@linkplain #executeBatch(int, WriteConcern, WriteBinding, ResultAccumulator) batch}. */ - private final class BatchEncoder { + @VisibleForTesting(otherwise = PRIVATE) + public final class BatchEncoder { private EncodedBatchInfo encodedBatchInfo; - BatchEncoder() { + @VisibleForTesting(otherwise = PACKAGE) + public BatchEncoder() { encodedBatchInfo = new EncodedBatchInfo(); } /** * Must be called at most once. - * Must not be called before calling {@link #encodeWriteModel(BsonWriter, ClientWriteModel, int, int)} at least once. + * Must not be called before calling + * {@link #encodeWriteModel(BsonBinaryWriter, ClientWriteModel, int, int)} at least once. * Renders {@code this} unusable. */ EncodedBatchInfo intoEncodedBatchInfo() { @@ -899,16 +965,13 @@ void reset() { assertNotNull(encodedBatchInfo).modelsCount = 0; } - void encodeTxnNumber(final BsonWriter writer, final SessionContext sessionContext) { - EncodedBatchInfo localEncodedBatchInfo = assertNotNull(encodedBatchInfo); - if (localEncodedBatchInfo.txnNumber == EncodedBatchInfo.UNINITIALIZED_TXN_NUMBER) { - localEncodedBatchInfo.txnNumber = sessionContext.advanceTransactionNumber(); - } - writer.writeInt64("txnNumber", localEncodedBatchInfo.txnNumber); + void reset(final int modelIndexInBatch) { + assertNotNull(encodedBatchInfo).modelsCount -= 1; + encodedBatchInfo.insertModelDocumentIds.remove(modelIndexInBatch); } void encodeWriteModel( - final BsonWriter writer, + final BsonBinaryWriter writer, final ClientWriteModel model, final int modelIndexInBatch, final int namespaceIndexInBatch) { @@ -942,18 +1005,20 @@ void encodeWriteModel( writer.writeEndDocument(); } - private void encodeWriteModelInternals(final BsonWriter writer, final ConcreteClientInsertOneModel model, final int modelIndexInBatch) { + private void encodeWriteModelInternals( + final BsonBinaryWriter writer, + final ConcreteClientInsertOneModel model, + final int modelIndexInBatch) { writer.writeName("document"); Object document = model.getDocument(); - @SuppressWarnings("unchecked") - Encoder documentEncoder = (Encoder) codecRegistry.get(document.getClass()); assertNotNull(encodedBatchInfo).insertModelDocumentIds.compute(modelIndexInBatch, (k, knownModelDocumentId) -> { IdHoldingBsonWriter documentIdHoldingBsonWriter = new IdHoldingBsonWriter( writer, // Reuse `knownModelDocumentId` if it may have been generated by `IdHoldingBsonWriter` in a previous attempt. - // If its type is not `BsonObjectId`, we know it could not have been generated. + // If its type is not `BsonObjectId`, which happens only if `_id` was specified by the application, + // we know it could not have been generated. knownModelDocumentId instanceof BsonObjectId ? knownModelDocumentId.asObjectId() : null); - documentEncoder.encode(documentIdHoldingBsonWriter, document, COLLECTIBLE_DOCUMENT_ENCODER_CONTEXT); + encodeUsingRegistry(documentIdHoldingBsonWriter, document, COLLECTIBLE_DOCUMENT_ENCODER_CONTEXT); return documentIdHoldingBsonWriter.getId(); }); } @@ -988,7 +1053,7 @@ private void encodeWriteModelInternals(final BsonWriter writer, final AbstractCl options.isUpsert().ifPresent(value -> writer.writeBoolean("upsert", value)); } - private void encodeWriteModelInternals(final BsonWriter writer, final ConcreteClientReplaceOneModel model) { + private void encodeWriteModelInternals(final BsonBinaryWriter writer, final ConcreteClientReplaceOneModel model) { writer.writeBoolean("multi", false); writer.writeName("filter"); encodeUsingRegistry(writer, model.getFilter()); @@ -1023,16 +1088,12 @@ private void encodeWriteModelInternals(final BsonWriter writer, final AbstractCl } final class EncodedBatchInfo { - private static final long UNINITIALIZED_TXN_NUMBER = -1; - - private long txnNumber; private final HashMap insertModelDocumentIds; private int modelsCount; private EncodedBatchInfo() { insertModelDocumentIds = new HashMap<>(); modelsCount = 0; - txnNumber = UNINITIALIZED_TXN_NUMBER; } /** diff --git a/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java b/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java index 28742574eb4..06d392bceb2 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java @@ -421,8 +421,7 @@ private BsonDocument executeCommand( final Connection connection, final BulkWriteBatch batch) { return connection.command(namespace.getDatabaseName(), batch.getCommand(), NoOpFieldNameValidator.INSTANCE, null, batch.getDecoder(), - operationContext, shouldExpectResponse(batch, effectiveWriteConcern), - batch.getPayload(), batch.getFieldNameValidator()); + operationContext, shouldExpectResponse(batch, effectiveWriteConcern), batch.getPayload()); } private void executeCommandAsync( @@ -432,8 +431,7 @@ private void executeCommandAsync( final BulkWriteBatch batch, final SingleResultCallback callback) { connection.commandAsync(namespace.getDatabaseName(), batch.getCommand(), NoOpFieldNameValidator.INSTANCE, null, batch.getDecoder(), - operationContext, shouldExpectResponse(batch, effectiveWriteConcern), - batch.getPayload(), batch.getFieldNameValidator(), callback); + operationContext, shouldExpectResponse(batch, effectiveWriteConcern), batch.getPayload(), callback); } private boolean shouldExpectResponse(final BulkWriteBatch batch, final WriteConcern effectiveWriteConcern) { diff --git a/driver-core/src/main/com/mongodb/internal/operation/SyncOperationHelper.java b/driver-core/src/main/com/mongodb/internal/operation/SyncOperationHelper.java index 0d50768d668..6d013df59ba 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/SyncOperationHelper.java +++ b/driver-core/src/main/com/mongodb/internal/operation/SyncOperationHelper.java @@ -334,7 +334,7 @@ static CommandReadTransformer> singleBatchCurso connection.getDescription().getServerAddress()); } - static BatchCursor cursorDocumentToBatchCursor(final TimeoutMode timeoutMode, final BsonDocument cursorDocument, + static CommandBatchCursor cursorDocumentToBatchCursor(final TimeoutMode timeoutMode, final BsonDocument cursorDocument, final int batchSize, final Decoder decoder, @Nullable final BsonValue comment, final ConnectionSource source, final Connection connection) { return new CommandBatchCursor<>(timeoutMode, cursorDocument, batchSize, 0, decoder, comment, source, connection); diff --git a/driver-core/src/test/functional/com/mongodb/client/syncadapter/SyncConnection.java b/driver-core/src/test/functional/com/mongodb/client/syncadapter/SyncConnection.java index 1cc3904749d..0a96d5ab0cf 100644 --- a/driver-core/src/test/functional/com/mongodb/client/syncadapter/SyncConnection.java +++ b/driver-core/src/test/functional/com/mongodb/client/syncadapter/SyncConnection.java @@ -19,8 +19,8 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.connection.AsyncConnection; import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.MessageSequences; import com.mongodb.internal.connection.OperationContext; -import com.mongodb.internal.connection.SplittablePayload; import org.bson.BsonDocument; import org.bson.FieldNameValidator; import org.bson.codecs.Decoder; @@ -65,11 +65,10 @@ public T command(final String database, final BsonDocument command, final Fi @Override public T command(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, final ReadPreference readPreference, final Decoder commandResultDecoder, - final OperationContext operationContext, final boolean responseExpected, final SplittablePayload payload, - final FieldNameValidator payloadFieldNameValidator) { + final OperationContext operationContext, final boolean responseExpected, final MessageSequences sequences) { SupplyingCallback callback = new SupplyingCallback<>(); wrapped.commandAsync(database, command, commandFieldNameValidator, readPreference, commandResultDecoder, operationContext, - responseExpected, payload, payloadFieldNameValidator, callback); + responseExpected, sequences, callback); return callback.get(); } diff --git a/driver-core/src/test/resources/unified-test-format/crud/client-bulkWrite-partialResults.json b/driver-core/src/test/resources/unified-test-format/crud/client-bulkWrite-partialResults.json new file mode 100644 index 00000000000..b35e94a2ea2 --- /dev/null +++ b/driver-core/src/test/resources/unified-test-format/crud/client-bulkWrite-partialResults.json @@ -0,0 +1,540 @@ +{ + "description": "client bulkWrite partial results", + "schemaVersion": "1.4", + "runOnRequirements": [ + { + "minServerVersion": "8.0", + "serverless": "forbid" + } + ], + "createEntities": [ + { + "client": { + "id": "client0" + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + } + ] + } + ], + "_yamlAnchors": { + "namespace": "crud-tests.coll0", + "newDocument": { + "_id": 2, + "x": 22 + } + }, + "tests": [ + { + "description": "partialResult is unset when first operation fails during an ordered bulk write (verbose)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + } + ], + "ordered": true, + "verboseResults": true + }, + "expectError": { + "expectResult": { + "$$unsetOrMatches": { + "insertedCount": { + "$$exists": false + }, + "upsertedCount": { + "$$exists": false + }, + "matchedCount": { + "$$exists": false + }, + "modifiedCount": { + "$$exists": false + }, + "deletedCount": { + "$$exists": false + }, + "insertResults": { + "$$exists": false + }, + "updateResults": { + "$$exists": false + }, + "deleteResults": { + "$$exists": false + } + } + } + } + } + ] + }, + { + "description": "partialResult is unset when first operation fails during an ordered bulk write (summary)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + } + ], + "ordered": true, + "verboseResults": false + }, + "expectError": { + "expectResult": { + "$$unsetOrMatches": { + "insertedCount": { + "$$exists": false + }, + "upsertedCount": { + "$$exists": false + }, + "matchedCount": { + "$$exists": false + }, + "modifiedCount": { + "$$exists": false + }, + "deletedCount": { + "$$exists": false + }, + "insertResults": { + "$$exists": false + }, + "updateResults": { + "$$exists": false + }, + "deleteResults": { + "$$exists": false + } + } + } + } + } + ] + }, + { + "description": "partialResult is set when second operation fails during an ordered bulk write (verbose)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": true, + "verboseResults": true + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "0": { + "insertedId": 2 + } + }, + "updateResults": {}, + "deleteResults": {} + } + } + } + ] + }, + { + "description": "partialResult is set when second operation fails during an ordered bulk write (summary)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": true, + "verboseResults": false + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "$$unsetOrMatches": {} + }, + "updateResults": { + "$$unsetOrMatches": {} + }, + "deleteResults": { + "$$unsetOrMatches": {} + } + } + } + } + ] + }, + { + "description": "partialResult is unset when all operations fail during an unordered bulk write", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": false + }, + "expectError": { + "expectResult": { + "$$unsetOrMatches": { + "insertedCount": { + "$$exists": false + }, + "upsertedCount": { + "$$exists": false + }, + "matchedCount": { + "$$exists": false + }, + "modifiedCount": { + "$$exists": false + }, + "deletedCount": { + "$$exists": false + }, + "insertResults": { + "$$exists": false + }, + "updateResults": { + "$$exists": false + }, + "deleteResults": { + "$$exists": false + } + } + } + } + } + ] + }, + { + "description": "partialResult is set when first operation fails during an unordered bulk write (verbose)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + } + ], + "ordered": false, + "verboseResults": true + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "1": { + "insertedId": 2 + } + }, + "updateResults": {}, + "deleteResults": {} + } + } + } + ] + }, + { + "description": "partialResult is set when first operation fails during an unordered bulk write (summary)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + } + ], + "ordered": false, + "verboseResults": false + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "$$unsetOrMatches": {} + }, + "updateResults": { + "$$unsetOrMatches": {} + }, + "deleteResults": { + "$$unsetOrMatches": {} + } + } + } + } + ] + }, + { + "description": "partialResult is set when second operation fails during an unordered bulk write (verbose)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": false, + "verboseResults": true + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "0": { + "insertedId": 2 + } + }, + "updateResults": {}, + "deleteResults": {} + } + } + } + ] + }, + { + "description": "partialResult is set when first operation fails during an unordered bulk write (summary)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": false, + "verboseResults": false + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "$$unsetOrMatches": {} + }, + "updateResults": { + "$$unsetOrMatches": {} + }, + "deleteResults": { + "$$unsetOrMatches": {} + } + } + } + } + ] + } + ] +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputSpecification.groovy deleted file mode 100644 index 311279038ea..00000000000 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputSpecification.groovy +++ /dev/null @@ -1,420 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.internal.connection - -import util.spock.annotations.Slow -import org.bson.BsonSerializationException -import org.bson.types.ObjectId -import spock.lang.Specification - -import java.security.SecureRandom - -class ByteBufferBsonOutputSpecification extends Specification { - def 'constructor should throw if buffer provider is null'() { - when: - new ByteBufferBsonOutput(null) - - then: - thrown(IllegalArgumentException) - } - - def 'position and size should be 0 after constructor'() { - when: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - then: - bsonOutput.position == 0 - bsonOutput.size == 0 - } - - def 'should write a byte'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeByte(11) - - then: - getBytes(bsonOutput) == [11] as byte[] - bsonOutput.position == 1 - bsonOutput.size == 1 - } - - def 'should write bytes'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeBytes([1, 2, 3, 4] as byte[]) - - then: - getBytes(bsonOutput) == [1, 2, 3, 4] as byte[] - bsonOutput.position == 4 - bsonOutput.size == 4 - } - - def 'should write bytes from offset until length'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeBytes([0, 1, 2, 3, 4, 5] as byte[], 1, 4) - - then: - getBytes(bsonOutput) == [1, 2, 3, 4] as byte[] - bsonOutput.position == 4 - bsonOutput.size == 4 - } - - def 'should write a little endian Int32'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeInt32(0x1020304) - - then: - getBytes(bsonOutput) == [4, 3, 2, 1] as byte[] - bsonOutput.position == 4 - bsonOutput.size == 4 - } - - def 'should write a little endian Int64'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeInt64(0x102030405060708L) - - then: - getBytes(bsonOutput) == [8, 7, 6, 5, 4, 3, 2, 1] as byte[] - bsonOutput.position == 8 - bsonOutput.size == 8 - } - - def 'should write a double'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeDouble(Double.longBitsToDouble(0x102030405060708L)) - - then: - getBytes(bsonOutput) == [8, 7, 6, 5, 4, 3, 2, 1] as byte[] - bsonOutput.position == 8 - bsonOutput.size == 8 - } - - def 'should write an ObjectId'() { - given: - def objectIdAsByteArray = [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] as byte[] - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeObjectId(new ObjectId(objectIdAsByteArray)) - - then: - getBytes(bsonOutput) == objectIdAsByteArray - bsonOutput.position == 12 - bsonOutput.size == 12 - } - - def 'should write an empty string'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeString('') - - then: - getBytes(bsonOutput) == [1, 0, 0 , 0, 0] as byte[] - bsonOutput.position == 5 - bsonOutput.size == 5 - } - - def 'should write an ASCII string'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeString('Java') - - then: - getBytes(bsonOutput) == [5, 0, 0, 0, 0x4a, 0x61, 0x76, 0x61, 0] as byte[] - bsonOutput.position == 9 - bsonOutput.size == 9 - } - - def 'should write a UTF-8 string'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeString('\u0900') - - then: - getBytes(bsonOutput) == [4, 0, 0, 0, 0xe0, 0xa4, 0x80, 0] as byte[] - bsonOutput.position == 8 - bsonOutput.size == 8 - } - - def 'should write an empty CString'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeCString('') - - then: - getBytes(bsonOutput) == [0] as byte[] - bsonOutput.position == 1 - bsonOutput.size == 1 - } - - def 'should write an ASCII CString'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeCString('Java') - - then: - getBytes(bsonOutput) == [0x4a, 0x61, 0x76, 0x61, 0] as byte[] - bsonOutput.position == 5 - bsonOutput.size == 5 - } - - def 'should write a UTF-8 CString'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeCString('\u0900') - - then: - getBytes(bsonOutput) == [0xe0, 0xa4, 0x80, 0] as byte[] - bsonOutput.position == 4 - bsonOutput.size == 4 - } - - def 'should get byte buffers as little endian'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeBytes([1, 0, 0, 0] as byte[]) - - then: - bsonOutput.getByteBuffers()[0].getInt() == 1 - } - - def 'null character in CString should throw SerializationException'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeCString('hell\u0000world') - - then: - thrown(BsonSerializationException) - } - - def 'null character in String should not throw SerializationException'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - bsonOutput.writeString('h\u0000i') - - then: - getBytes(bsonOutput) == [4, 0, 0, 0, (byte) 'h', 0, (byte) 'i', 0] as byte[] - } - - def 'write Int32 at position should throw with invalid position'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - bsonOutput.writeBytes([1, 2, 3, 4] as byte[]) - - when: - bsonOutput.writeInt32(-1, 0x1020304) - - then: - thrown(IllegalArgumentException) - - when: - bsonOutput.writeInt32(1, 0x1020304) - - then: - thrown(IllegalArgumentException) - } - - def 'should write Int32 at position'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - bsonOutput.writeBytes([0, 0, 0, 0, 1, 2, 3, 4] as byte[]) - - when: 'the position is in the first buffer' - bsonOutput.writeInt32(0, 0x1020304) - - then: - getBytes(bsonOutput) == [4, 3, 2, 1, 1, 2, 3, 4] as byte[] - bsonOutput.position == 8 - bsonOutput.size == 8 - - when: 'the position is at the end of the first buffer' - bsonOutput.writeInt32(4, 0x1020304) - - then: - getBytes(bsonOutput) == [4, 3, 2, 1, 4, 3, 2, 1] as byte[] - bsonOutput.position == 8 - bsonOutput.size == 8 - - when: 'the position is not in the first buffer' - bsonOutput.writeBytes(new byte[1024]) - bsonOutput.writeInt32(1023, 0x1020304) - - then: - getBytes(bsonOutput)[1023..1026] as byte[] == [4, 3, 2, 1] as byte[] - bsonOutput.position == 1032 - bsonOutput.size == 1032 - } - - def 'truncate should throw with invalid position'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - bsonOutput.writeBytes([1, 2, 3, 4] as byte[]) - - when: - bsonOutput.truncateToPosition(5) - - then: - thrown(IllegalArgumentException) - - when: - bsonOutput.truncateToPosition(-1) - - then: - thrown(IllegalArgumentException) - } - - def 'should truncate to position'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - bsonOutput.writeBytes([1, 2, 3, 4] as byte[]) - bsonOutput.writeBytes(new byte[1024]) - - when: - bsonOutput.truncateToPosition(2) - - then: - getBytes(bsonOutput) == [1, 2] as byte[] - bsonOutput.position == 2 - bsonOutput.size == 2 - } - - def 'should grow'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - def bytes = new byte[1023] - bsonOutput.writeBytes(bytes) - - when: - bsonOutput.writeInt32(0x1020304) - - then: - getBytes(bsonOutput)[0..1022] as byte[] == bytes - getBytes(bsonOutput)[1023..1026] as byte[] == [4, 3, 2, 1] as byte[] - bsonOutput.position == 1027 - bsonOutput.size == 1027 - } - - @Slow - def 'should grow to maximum allowed size of byte buffer'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - def bytes = new byte[0x2000000] - def random = new SecureRandom() - random.nextBytes(bytes) - - when: - bsonOutput.writeBytes(bytes) - - then: - bsonOutput.size == 0x2000000 - bsonOutput.getByteBuffers()*.capacity() == - [1 << 10, 1 << 11, 1 << 12, 1 << 13, 1 << 14, 1 << 15, 1 << 16, 1 << 17, 1 << 18, 1 << 19, - 1 << 20, 1 << 21, 1 << 22, 1 << 23, 1 << 24, 1 << 24] - - when: - def stream = new ByteArrayOutputStream(bsonOutput.size) - bsonOutput.pipe(stream) - - then: - Arrays.equals(bytes, stream.toByteArray()) // faster than using Groovy's == implementation - } - - def 'should pipe'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - def bytes = new byte[1027] - bsonOutput.writeBytes(bytes) - - when: - def baos = new ByteArrayOutputStream() - bsonOutput.pipe(baos) - - then: - bytes == baos.toByteArray() - bsonOutput.position == 1027 - bsonOutput.size == 1027 - - when: - baos = new ByteArrayOutputStream() - bsonOutput.pipe(baos) - - then: - bytes == baos.toByteArray() - bsonOutput.position == 1027 - bsonOutput.size == 1027 - } - - - def 'should close'() { - given: - def bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider()) - bsonOutput.writeBytes(new byte[1027]) - - when: - bsonOutput.close() - bsonOutput.writeByte(11) - - then: - thrown(IllegalStateException) - } - - def getBytes(final ByteBufferBsonOutput byteBufferBsonOutput) { - ByteArrayOutputStream baos = new ByteArrayOutputStream(byteBufferBsonOutput.size) - - for (def cur : byteBufferBsonOutput.byteBuffers) { - while (cur.hasRemaining()) { - baos.write(cur.get()) - } - } - - baos.toByteArray() - } -} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java new file mode 100644 index 00000000000..3a8a2c83acb --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java @@ -0,0 +1,625 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.assertions.Assertions; +import org.bson.BsonSerializationException; +import org.bson.ByteBuf; +import org.bson.types.ObjectId; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static com.mongodb.internal.connection.ByteBufferBsonOutput.INITIAL_BUFFER_SIZE; +import static com.mongodb.internal.connection.ByteBufferBsonOutput.MAX_BUFFER_SIZE; +import static java.util.Arrays.asList; +import static java.util.Arrays.copyOfRange; +import static java.util.stream.Collectors.toList; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +final class ByteBufferBsonOutputTest { + @DisplayName("constructor should throw if buffer provider is null") + @Test + @SuppressWarnings("try") + void constructorShouldThrowIfBufferProviderIsNull() { + assertThrows(IllegalArgumentException.class, () -> { + try (ByteBufferBsonOutput ignored = new ByteBufferBsonOutput(null)) { + // nothing to do + } + }); + } + + @DisplayName("position and size should be 0 after constructor") + @ParameterizedTest + @ValueSource(strings = {"none", "empty", "truncated"}) + void positionAndSizeShouldBe0AfterConstructor(final String branchState) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + switch (branchState) { + case "none": { + break; + } + case "empty": { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + assertEquals(0, branch.getPosition()); + assertEquals(0, branch.size()); + } + break; + } + case "truncated": { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + for (int i = 0; i < MAX_BUFFER_SIZE; i++) { + branch.writeByte(i); + } + branch.truncateToPosition(0); + } + break; + } + default: { + throw Assertions.fail(branchState); + } + } + assertEquals(0, out.getPosition()); + assertEquals(0, out.size()); + } + } + + @DisplayName("should write a byte") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteByte(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte v = 11; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeByte(v); + } + } else { + out.writeByte(v); + } + assertArrayEquals(new byte[] {v}, out.toByteArray()); + assertEquals(1, out.getPosition()); + assertEquals(1, out.size()); + } + } + + @DisplayName("should write a bytes") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteBytes(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] v = {1, 2, 3, 4}; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeBytes(v); + } + } else { + out.writeBytes(v); + } + assertArrayEquals(v, out.toByteArray()); + assertEquals(v.length, out.getPosition()); + assertEquals(v.length, out.size()); + } + } + + @DisplayName("should write bytes from offset until length") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteBytesFromOffsetUntilLength(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] v = {0, 1, 2, 3, 4, 5}; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeBytes(v, 1, 4); + } + } else { + out.writeBytes(v, 1, 4); + } + assertArrayEquals(new byte[] {1, 2, 3, 4}, out.toByteArray()); + assertEquals(4, out.getPosition()); + assertEquals(4, out.size()); + } + } + + @DisplayName("should write a little endian Int32") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteLittleEndianInt32(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + int v = 0x1020304; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeInt32(v); + } + } else { + out.writeInt32(v); + } + assertArrayEquals(new byte[] {4, 3, 2, 1}, out.toByteArray()); + assertEquals(4, out.getPosition()); + assertEquals(4, out.size()); + } + } + + @DisplayName("should write a little endian Int64") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteLittleEndianInt64(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + long v = 0x102030405060708L; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeInt64(v); + } + } else { + out.writeInt64(v); + } + assertArrayEquals(new byte[] {8, 7, 6, 5, 4, 3, 2, 1}, out.toByteArray()); + assertEquals(8, out.getPosition()); + assertEquals(8, out.size()); + } + } + + @DisplayName("should write a double") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteDouble(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + double v = Double.longBitsToDouble(0x102030405060708L); + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeDouble(v); + } + } else { + out.writeDouble(v); + } + assertArrayEquals(new byte[] {8, 7, 6, 5, 4, 3, 2, 1}, out.toByteArray()); + assertEquals(8, out.getPosition()); + assertEquals(8, out.size()); + } + } + + @DisplayName("should write an ObjectId") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteObjectId(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] objectIdAsByteArray = {12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}; + ObjectId v = new ObjectId(objectIdAsByteArray); + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeObjectId(v); + } + } else { + out.writeObjectId(v); + } + assertArrayEquals(objectIdAsByteArray, out.toByteArray()); + assertEquals(12, out.getPosition()); + assertEquals(12, out.size()); + } + } + + @DisplayName("should write an empty string") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteEmptyString(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + String v = ""; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeString(v); + } + } else { + out.writeString(v); + } + assertArrayEquals(new byte[] {1, 0, 0, 0, 0}, out.toByteArray()); + assertEquals(5, out.getPosition()); + assertEquals(5, out.size()); + } + } + + @DisplayName("should write an ASCII string") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteAsciiString(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + String v = "Java"; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeString(v); + } + } else { + out.writeString(v); + } + assertArrayEquals(new byte[] {5, 0, 0, 0, 0x4a, 0x61, 0x76, 0x61, 0}, out.toByteArray()); + assertEquals(9, out.getPosition()); + assertEquals(9, out.size()); + } + } + + @DisplayName("should write a UTF-8 string") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteUtf8String(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + String v = "\u0900"; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeString(v); + } + } else { + out.writeString(v); + } + assertArrayEquals(new byte[] {4, 0, 0, 0, (byte) 0xe0, (byte) 0xa4, (byte) 0x80, 0}, out.toByteArray()); + assertEquals(8, out.getPosition()); + assertEquals(8, out.size()); + } + } + + @DisplayName("should write an empty CString") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteEmptyCString(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + String v = ""; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeCString(v); + } + } else { + out.writeCString(v); + } + assertArrayEquals(new byte[] {0}, out.toByteArray()); + assertEquals(1, out.getPosition()); + assertEquals(1, out.size()); + } + } + + @DisplayName("should write an ASCII CString") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteAsciiCString(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + String v = "Java"; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeCString(v); + } + } else { + out.writeCString(v); + } + assertArrayEquals(new byte[] {0x4a, 0x61, 0x76, 0x61, 0}, out.toByteArray()); + assertEquals(5, out.getPosition()); + assertEquals(5, out.size()); + } + } + + @DisplayName("should write a UTF-8 CString") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteUtf8CString(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + String v = "\u0900"; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeCString(v); + } + } else { + out.writeCString(v); + } + assertArrayEquals(new byte[] {(byte) 0xe0, (byte) 0xa4, (byte) 0x80, 0}, out.toByteArray()); + assertEquals(4, out.getPosition()); + assertEquals(4, out.size()); + } + } + + @DisplayName("should get byte buffers as little endian") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldGetByteBuffersAsLittleEndian(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] v = {1, 0, 0, 0}; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeBytes(v); + } + } else { + out.writeBytes(v); + } + assertEquals(1, out.getByteBuffers().get(0).getInt()); + } + } + + @DisplayName("null character in CString should throw SerializationException") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void nullCharacterInCStringShouldThrowSerializationException(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + String v = "hell\u0000world"; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + assertThrows(BsonSerializationException.class, () -> branch.writeCString(v)); + } + } else { + assertThrows(BsonSerializationException.class, () -> out.writeCString(v)); + } + } + } + + @DisplayName("null character in String should not throw SerializationException") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void nullCharacterInStringShouldNotThrowSerializationException(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + String v = "h\u0000i"; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeString(v); + } + } else { + out.writeString(v); + } + assertArrayEquals(new byte[] {4, 0, 0, 0, (byte) 'h', 0, (byte) 'i', 0}, out.toByteArray()); + } + } + + @DisplayName("write Int32 at position should throw with invalid position") + @ParameterizedTest + @CsvSource({"false, -1", "false, 1", "true, -1", "true, 1"}) + void writeInt32AtPositionShouldThrowWithInvalidPosition(final boolean useBranch, final int position) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] v = {1, 2, 3, 4}; + int v2 = 0x1020304; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeBytes(v); + assertThrows(IllegalArgumentException.class, () -> branch.writeInt32(position, v2)); + } + } else { + out.writeBytes(v); + assertThrows(IllegalArgumentException.class, () -> out.writeInt32(position, v2)); + } + } + } + + @DisplayName("should write Int32 at position") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldWriteInt32AtPosition(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + Consumer lastAssertions = effectiveOut -> { + assertArrayEquals(new byte[] {4, 3, 2, 1}, copyOfRange(effectiveOut.toByteArray(), 1023, 1027), "the position is not in the first buffer"); + assertEquals(1032, effectiveOut.getPosition()); + assertEquals(1032, effectiveOut.size()); + }; + Consumer assertions = effectiveOut -> { + effectiveOut.writeBytes(new byte[] {0, 0, 0, 0, 1, 2, 3, 4}); + effectiveOut.writeInt32(0, 0x1020304); + assertArrayEquals(new byte[] {4, 3, 2, 1, 1, 2, 3, 4}, effectiveOut.toByteArray(), "the position is in the first buffer"); + assertEquals(8, effectiveOut.getPosition()); + assertEquals(8, effectiveOut.size()); + effectiveOut.writeInt32(4, 0x1020304); + assertArrayEquals(new byte[] {4, 3, 2, 1, 4, 3, 2, 1}, effectiveOut.toByteArray(), "the position is at the end of the first buffer"); + assertEquals(8, effectiveOut.getPosition()); + assertEquals(8, effectiveOut.size()); + effectiveOut.writeBytes(new byte[1024]); + effectiveOut.writeInt32(1023, 0x1020304); + lastAssertions.accept(effectiveOut); + }; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + assertions.accept(branch); + } + } else { + assertions.accept(out); + } + lastAssertions.accept(out); + } + } + + @DisplayName("truncate should throw with invalid position") + @ParameterizedTest + @CsvSource({"false, -1", "false, 5", "true, -1", "true, 5"}) + void truncateShouldThrowWithInvalidPosition(final boolean useBranch, final int position) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] v = {1, 2, 3, 4}; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeBytes(v); + assertThrows(IllegalArgumentException.class, () -> branch.truncateToPosition(position)); + } + } else { + out.writeBytes(v); + assertThrows(IllegalArgumentException.class, () -> out.truncateToPosition(position)); + } + } + } + + @DisplayName("should truncate to position") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldTruncateToPosition(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] v = {1, 2, 3, 4}; + byte[] v2 = new byte[1024]; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeBytes(v); + branch.writeBytes(v2); + branch.truncateToPosition(2); + } + } else { + out.writeBytes(v); + out.writeBytes(v2); + out.truncateToPosition(2); + } + assertArrayEquals(new byte[] {1, 2}, out.toByteArray()); + assertEquals(2, out.getPosition()); + assertEquals(2, out.size()); + } + } + + @DisplayName("should grow to maximum allowed size of byte buffer") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldGrowToMaximumAllowedSizeOfByteBuffer(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] v = new byte[0x2000000]; + ThreadLocalRandom.current().nextBytes(v); + Consumer assertByteBuffers = effectiveOut -> assertEquals( + asList(1 << 10, 1 << 11, 1 << 12, 1 << 13, 1 << 14, 1 << 15, 1 << 16, 1 << 17, 1 << 18, 1 << 19, 1 << 20, + 1 << 21, 1 << 22, 1 << 23, 1 << 24, 1 << 24), + effectiveOut.getByteBuffers().stream().map(ByteBuf::capacity).collect(toList())); + Consumer assertions = effectiveOut -> { + effectiveOut.writeBytes(v); + assertEquals(v.length, effectiveOut.size()); + assertByteBuffers.accept(effectiveOut); + ByteArrayOutputStream baos = new ByteArrayOutputStream(effectiveOut.size()); + try { + effectiveOut.pipe(baos); + } catch (IOException e) { + throw new RuntimeException(e); + } + assertArrayEquals(v, baos.toByteArray()); + }; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + assertions.accept(branch); + } + } else { + assertions.accept(out); + } + assertByteBuffers.accept(out); + } + } + + @DisplayName("should pipe") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void shouldPipe(final boolean useBranch) throws IOException { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] v = new byte[1027]; + BiConsumer assertions = (effectiveOut, baos) -> { + assertArrayEquals(v, baos.toByteArray()); + assertEquals(v.length, effectiveOut.getPosition()); + assertEquals(v.length, effectiveOut.size()); + }; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeBytes(v); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + branch.pipe(baos); + assertions.accept(branch, baos); + baos = new ByteArrayOutputStream(); + branch.pipe(baos); + assertions.accept(branch, baos); + } + } else { + out.writeBytes(v); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + out.pipe(baos); + assertions.accept(out, baos); + baos = new ByteArrayOutputStream(); + out.pipe(baos); + assertions.accept(out, baos); + } + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + out.pipe(baos); + assertions.accept(out, baos); + } + } + + @DisplayName("should close") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + @SuppressWarnings("try") + void shouldClose(final boolean useBranch) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + byte[] v = new byte[1027]; + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + branch.writeBytes(v); + branch.close(); + assertThrows(IllegalStateException.class, () -> branch.writeByte(11)); + } + } else { + out.writeBytes(v); + out.close(); + assertThrows(IllegalStateException.class, () -> out.writeByte(11)); + } + } + } + + @DisplayName("should handle mixed branching and truncating") + @ParameterizedTest + @ValueSource(ints = {1, INITIAL_BUFFER_SIZE, INITIAL_BUFFER_SIZE * 3}) + void shouldHandleMixedBranchingAndTruncating(final int reps) throws CharacterCodingException { + BiConsumer write = (out, c) -> { + Assertions.assertTrue((byte) c.charValue() == c); + for (int i = 0; i < reps; i++) { + out.writeByte(c); + } + }; + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + write.accept(out, 'a'); + try (ByteBufferBsonOutput.Branch b3 = out.branch(); + ByteBufferBsonOutput.Branch b1 = out.branch()) { + write.accept(b3, 'g'); + write.accept(out, 'b'); + write.accept(b1, 'e'); + try (ByteBufferBsonOutput.Branch b2 = b1.branch()) { + write.accept(out, 'c'); + write.accept(b2, 'f'); + int b2Position = b2.getPosition(); + write.accept(b2, 'x'); + b2.truncateToPosition(b2Position); + } + write.accept(out, 'd'); + } + write.accept(out, 'h'); + try (ByteBufferBsonOutput.Branch b4 = out.branch()) { + write.accept(b4, 'i'); + int outPosition = out.getPosition(); + try (ByteBufferBsonOutput.Branch b5 = out.branch()) { + write.accept(out, 'x'); + write.accept(b5, 'x'); + } + out.truncateToPosition(outPosition); + } + write.accept(out, 'j'); + StringBuilder expected = new StringBuilder(); + "abcdefghij".chars().forEach(c -> { + String s = String.valueOf((char) c); + for (int i = 0; i < reps; i++) { + expected.append(s); + } + }); + assertEquals(expected.toString(), StandardCharsets.UTF_8.newDecoder().decode(ByteBuffer.wrap(out.toByteArray())).toString()); + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy index e8ed6c152ae..e3351e2eb0f 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy @@ -38,7 +38,6 @@ import org.bson.BsonTimestamp import org.bson.ByteBuf import org.bson.ByteBufNIO import org.bson.codecs.BsonDocumentCodec -import org.bson.io.BasicOutputBuffer import spock.lang.Specification import java.nio.ByteBuffer @@ -47,6 +46,9 @@ import static com.mongodb.internal.connection.SplittablePayload.Type.INSERT import static com.mongodb.internal.operation.ServerVersionHelper.FOUR_DOT_ZERO_WIRE_VERSION import static com.mongodb.internal.operation.ServerVersionHelper.LATEST_WIRE_VERSION +/** + * New tests must be added to {@link CommandMessageTest}. + */ class CommandMessageSpecification extends Specification { def namespace = new MongoNamespace('db.test') @@ -61,8 +63,8 @@ class CommandMessageSpecification extends Specification { .serverType(serverType as ServerType) .sessionSupported(true) .build(), - responseExpected, null, null, clusterConnectionMode, null) - def output = new BasicOutputBuffer() + responseExpected, MessageSequences.EmptyMessageSequences.INSTANCE, clusterConnectionMode, null) + def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) when: message.encode(output, operationContext) @@ -93,6 +95,9 @@ class CommandMessageSpecification extends Specification { } getCommandDocument(byteBuf, replyHeader) == expectedCommandDocument + cleanup: + output.close() + where: [readPreference, serverType, clusterConnectionMode, operationContext, responseExpected, isCryptd] << [ [ReadPreference.primary(), ReadPreference.secondary()], @@ -149,7 +154,8 @@ class CommandMessageSpecification extends Specification { def 'should get command document'() { given: def message = new CommandMessage(namespace, originalCommandDocument, fieldNameValidator, ReadPreference.primary(), - MessageSettings.builder().maxWireVersion(maxWireVersion).build(), true, payload, NoOpFieldNameValidator.INSTANCE, + MessageSettings.builder().maxWireVersion(maxWireVersion).build(), true, + payload == null ? MessageSequences.EmptyMessageSequences.INSTANCE : payload, ClusterConnectionMode.MULTIPLE, null) def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, @@ -172,7 +178,8 @@ class CommandMessageSpecification extends Specification { new BsonDocument('insert', new BsonString('coll')), new SplittablePayload(INSERT, [new BsonDocument('_id', new BsonInt32(1)), new BsonDocument('_id', new BsonInt32(2))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true), + .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, + true, NoOpFieldNameValidator.INSTANCE), ], [ LATEST_WIRE_VERSION, @@ -193,10 +200,10 @@ class CommandMessageSpecification extends Specification { new BsonDocument('_id', new BsonInt32(3)).append('c', new BsonBinary(new byte[450])), new BsonDocument('_id', new BsonInt32(4)).append('b', new BsonBinary(new byte[441])), new BsonDocument('_id', new BsonInt32(5)).append('c', new BsonBinary(new byte[451]))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator) def message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, payload, fieldNameValidator, ClusterConnectionMode.MULTIPLE, null) - def output = new BasicOutputBuffer() + false, payload, ClusterConnectionMode.MULTIPLE, null) + def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) def sessionContext = Stub(SessionContext) { getReadConcern() >> ReadConcern.DEFAULT } @@ -219,7 +226,7 @@ class CommandMessageSpecification extends Specification { when: payload = payload.getNextSplit() message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, payload, fieldNameValidator, ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) output.truncateToPosition(0) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null)) byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) @@ -237,7 +244,7 @@ class CommandMessageSpecification extends Specification { when: payload = payload.getNextSplit() message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, payload, fieldNameValidator, ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) output.truncateToPosition(0) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null)) byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) @@ -255,7 +262,7 @@ class CommandMessageSpecification extends Specification { when: payload = payload.getNextSplit() message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, payload, fieldNameValidator, ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) output.truncateToPosition(0) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, @@ -272,6 +279,9 @@ class CommandMessageSpecification extends Specification { byteBuf.getInt() == 1 << 1 payload.getPosition() == 1 !payload.hasAnotherSplit() + + cleanup: + output.close() } def 'should respect the max batch count'() { @@ -280,10 +290,10 @@ class CommandMessageSpecification extends Specification { def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900])), new BsonDocument('b', new BsonBinary(new byte[450])), new BsonDocument('c', new BsonBinary(new byte[450]))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator) def message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, payload, fieldNameValidator, ClusterConnectionMode.MULTIPLE, null) - def output = new BasicOutputBuffer() + false, payload, ClusterConnectionMode.MULTIPLE, null) + def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) def sessionContext = Stub(SessionContext) { getReadConcern() >> ReadConcern.DEFAULT } @@ -307,7 +317,7 @@ class CommandMessageSpecification extends Specification { when: payload = payload.getNextSplit() message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, payload, fieldNameValidator, ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) output.truncateToPosition(0) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null)) @@ -321,6 +331,9 @@ class CommandMessageSpecification extends Specification { byteBuf.getInt() == 1 << 1 payload.getPosition() == 1 !payload.hasAnotherSplit() + + cleanup: + output.close() } def 'should throw if payload document bigger than max document size'() { @@ -328,10 +341,10 @@ class CommandMessageSpecification extends Specification { def messageSettings = MessageSettings.builder().maxDocumentSize(900) .maxWireVersion(LATEST_WIRE_VERSION).build() def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900]))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator) def message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, payload, fieldNameValidator, ClusterConnectionMode.MULTIPLE, null) - def output = new BasicOutputBuffer() + false, payload, ClusterConnectionMode.MULTIPLE, null) + def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) def sessionContext = Stub(SessionContext) { getReadConcern() >> ReadConcern.DEFAULT } @@ -342,16 +355,19 @@ class CommandMessageSpecification extends Specification { then: thrown(BsonMaximumSizeExceededException) + + cleanup: + output.close() } def 'should throw if wire version and sharded cluster does not support transactions'() { given: def messageSettings = MessageSettings.builder().serverType(ServerType.SHARD_ROUTER) .maxWireVersion(FOUR_DOT_ZERO_WIRE_VERSION).build() - def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonInt32(1))], true) + def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonInt32(1))], true, fieldNameValidator) def message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, payload, fieldNameValidator, ClusterConnectionMode.MULTIPLE, null) - def output = new BasicOutputBuffer() + false, payload, ClusterConnectionMode.MULTIPLE, null) + def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) def sessionContext = Stub(SessionContext) { getReadConcern() >> ReadConcern.DEFAULT hasActiveTransaction() >> true @@ -363,6 +379,9 @@ class CommandMessageSpecification extends Specification { then: thrown(MongoClientException) + + cleanup: + output.close() } private static BsonDocument getCommandDocument(ByteBufNIO byteBuf, ReplyHeader replyHeader) { diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java index 4735811f025..1388ffcef22 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java @@ -20,19 +20,39 @@ import com.mongodb.MongoOperationTimeoutException; import com.mongodb.ReadConcern; import com.mongodb.ReadPreference; +import com.mongodb.WriteConcern; +import com.mongodb.client.model.bulk.ClientNamespacedWriteModel; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ServerType; +import com.mongodb.internal.IgnorableRequestContext; import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.TimeoutSettings; +import com.mongodb.internal.client.model.bulk.ConcreteClientBulkWriteOptions; +import com.mongodb.internal.connection.MessageSequences.EmptyMessageSequences; +import com.mongodb.internal.operation.ClientBulkWriteOperation; +import com.mongodb.internal.operation.ClientBulkWriteOperation.ClientBulkWriteCommand.OpsAndNsInfo; import com.mongodb.internal.session.SessionContext; import com.mongodb.internal.validator.NoOpFieldNameValidator; +import org.bson.BsonArray; +import org.bson.BsonBoolean; import org.bson.BsonDocument; +import org.bson.BsonInt32; import org.bson.BsonString; import org.bson.BsonTimestamp; -import org.bson.io.BasicOutputBuffer; import org.junit.jupiter.api.Test; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.mongodb.MongoClientSettings.getDefaultCodecRegistry; +import static com.mongodb.client.model.bulk.ClientBulkWriteOptions.clientBulkWriteOptions; import static com.mongodb.internal.mockito.MongoMockito.mock; import static com.mongodb.internal.operation.ServerVersionHelper.FOUR_DOT_ZERO_WIRE_VERSION; +import static com.mongodb.internal.operation.ServerVersionHelper.LATEST_WIRE_VERSION; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; @@ -53,21 +73,22 @@ void encodeShouldThrowTimeoutExceptionWhenTimeoutContextIsCalled() { .serverType(ServerType.REPLICA_SET_SECONDARY) .sessionSupported(true) .build(), - true, null, null, ClusterConnectionMode.MULTIPLE, null); + true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null); - BasicOutputBuffer bsonOutput = new BasicOutputBuffer(); - SessionContext sessionContext = mock(SessionContext.class); - TimeoutContext timeoutContext = mock(TimeoutContext.class, mock -> { - doThrow(new MongoOperationTimeoutException("test")).when(mock).runMaxTimeMS(any()); - }); - OperationContext operationContext = mock(OperationContext.class, mock -> { - when(mock.getSessionContext()).thenReturn(sessionContext); - when(mock.getTimeoutContext()).thenReturn(timeoutContext); - }); + try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + SessionContext sessionContext = mock(SessionContext.class); + TimeoutContext timeoutContext = mock(TimeoutContext.class, mock -> { + doThrow(new MongoOperationTimeoutException("test")).when(mock).runMaxTimeMS(any()); + }); + OperationContext operationContext = mock(OperationContext.class, mock -> { + when(mock.getSessionContext()).thenReturn(sessionContext); + when(mock.getTimeoutContext()).thenReturn(timeoutContext); + }); - //when & then - assertThrows(MongoOperationTimeoutException.class, () -> - commandMessage.encode(bsonOutput, operationContext)); + //when & then + assertThrows(MongoOperationTimeoutException.class, () -> + commandMessage.encode(bsonOutput, operationContext)); + } } @Test @@ -80,27 +101,72 @@ void encodeShouldNotAddExtraElementsFromTimeoutContextWhenConnectedToMongoCrypt( .sessionSupported(true) .cryptd(true) .build(), - true, null, null, ClusterConnectionMode.MULTIPLE, null); + true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null); + + try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + SessionContext sessionContext = mock(SessionContext.class, mock -> { + when(mock.getClusterTime()).thenReturn(new BsonDocument("clusterTime", new BsonTimestamp(42, 1))); + when(mock.hasSession()).thenReturn(false); + when(mock.getReadConcern()).thenReturn(ReadConcern.DEFAULT); + when(mock.notifyMessageSent()).thenReturn(true); + when(mock.hasActiveTransaction()).thenReturn(false); + when(mock.isSnapshot()).thenReturn(false); + }); + TimeoutContext timeoutContext = mock(TimeoutContext.class); + OperationContext operationContext = mock(OperationContext.class, mock -> { + when(mock.getSessionContext()).thenReturn(sessionContext); + when(mock.getTimeoutContext()).thenReturn(timeoutContext); + }); - BasicOutputBuffer bsonOutput = new BasicOutputBuffer(); - SessionContext sessionContext = mock(SessionContext.class, mock -> { - when(mock.getClusterTime()).thenReturn(new BsonDocument("clusterTime", new BsonTimestamp(42, 1))); - when(mock.hasSession()).thenReturn(false); - when(mock.getReadConcern()).thenReturn(ReadConcern.DEFAULT); - when(mock.notifyMessageSent()).thenReturn(true); - when(mock.hasActiveTransaction()).thenReturn(false); - when(mock.isSnapshot()).thenReturn(false); - }); - TimeoutContext timeoutContext = mock(TimeoutContext.class); - OperationContext operationContext = mock(OperationContext.class, mock -> { - when(mock.getSessionContext()).thenReturn(sessionContext); - when(mock.getTimeoutContext()).thenReturn(timeoutContext); - }); + //when + commandMessage.encode(bsonOutput, operationContext); - //when - commandMessage.encode(bsonOutput, operationContext); + //then + verifyNoInteractions(timeoutContext); + } + } - //then - verifyNoInteractions(timeoutContext); + @Test + void getCommandDocumentFromClientBulkWrite() { + MongoNamespace ns = new MongoNamespace("db", "test"); + boolean retryWrites = false; + BsonDocument command = new BsonDocument("bulkWrite", new BsonInt32(1)) + .append("errorsOnly", BsonBoolean.valueOf(false)) + .append("ordered", BsonBoolean.valueOf(true)); + List documents = IntStream.range(0, 2).mapToObj(i -> new BsonDocument("_id", new BsonInt32(i))) + .collect(Collectors.toList()); + List writeModels = asList( + ClientNamespacedWriteModel.insertOne(ns, documents.get(0)), + ClientNamespacedWriteModel.insertOne(ns, documents.get(1))); + OpsAndNsInfo opsAndNsInfo = new OpsAndNsInfo( + retryWrites, + writeModels, + new ClientBulkWriteOperation( + writeModels, + clientBulkWriteOptions(), + WriteConcern.MAJORITY, + retryWrites, + getDefaultCodecRegistry() + ).new BatchEncoder(), + (ConcreteClientBulkWriteOptions) clientBulkWriteOptions(), + () -> 1L); + BsonDocument expectedCommandDocument = command.clone() + .append("$db", new BsonString(ns.getDatabaseName())) + .append("ops", new BsonArray(asList( + new BsonDocument("insert", new BsonInt32(0)).append("document", documents.get(0)), + new BsonDocument("insert", new BsonInt32(0)).append("document", documents.get(1))))) + .append("nsInfo", new BsonArray(singletonList(new BsonDocument("ns", new BsonString(ns.toString()))))); + CommandMessage commandMessage = new CommandMessage( + ns, command, NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), + MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), true, opsAndNsInfo, ClusterConnectionMode.MULTIPLE, null); + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + commandMessage.encode( + output, + new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null)); + BsonDocument actualCommandDocument = commandMessage.getCommandDocument(output); + assertEquals(expectedCommandDocument, actualCommandDocument); + } } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerConnectionSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerConnectionSpecification.groovy index 282c4dbb868..be6fbe06b83 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerConnectionSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerConnectionSpecification.groovy @@ -49,8 +49,8 @@ class DefaultServerConnectionSpecification extends Specification { then: 1 * executor.executeAsync({ - compare(new CommandProtocolImpl('test', command, validator, ReadPreference.primary(), codec, true, null, null, - ClusterConnectionMode.MULTIPLE, OPERATION_CONTEXT), it) + compare(new CommandProtocolImpl('test', command, validator, ReadPreference.primary(), codec, true, + MessageSequences.EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, OPERATION_CONTEXT), it) }, internalConnection, OPERATION_CONTEXT.getSessionContext(), callback) } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ElementExtendingBsonWriterSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/ElementExtendingBsonWriterSpecification.groovy deleted file mode 100644 index f96e11acece..00000000000 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ElementExtendingBsonWriterSpecification.groovy +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.internal.connection - -import org.bson.BsonBinaryReader -import org.bson.BsonBinaryWriter -import org.bson.BsonDocument -import org.bson.BsonDocumentReader -import org.bson.BsonElement -import org.bson.BsonString -import org.bson.codecs.BsonDocumentCodec -import org.bson.codecs.DecoderContext -import org.bson.codecs.EncoderContext -import org.bson.io.BasicOutputBuffer -import org.bson.io.BsonOutput -import spock.lang.Specification - -import static org.bson.BsonHelper.documentWithValuesOfEveryType - -class ElementExtendingBsonWriterSpecification extends Specification { - - def 'should write all types'() { - given: - def binaryWriter = new BsonBinaryWriter(new BasicOutputBuffer()) - - when: - new BsonDocumentCodec().encode(new ElementExtendingBsonWriter(binaryWriter, []), documentWithValuesOfEveryType(), - EncoderContext.builder().build()) - - then: - getEncodedDocument(binaryWriter.getBsonOutput()) == documentWithValuesOfEveryType() - } - - def 'should extend with extra elements'() { - given: - def extraElements = [ - new BsonElement('$db', new BsonString('test')), - new BsonElement('$readPreference', new BsonDocument('mode', new BsonString('primary'))) - ] - def expectedDocument = documentWithValuesOfEveryType() - for (def cur : extraElements) { - expectedDocument.put(cur.name, cur.value) - } - def binaryWriter = new BsonBinaryWriter(new BasicOutputBuffer()) - def writer = new ElementExtendingBsonWriter(binaryWriter, extraElements) - - when: - new BsonDocumentCodec().encode(writer, documentWithValuesOfEveryType(), EncoderContext.builder().build()) - - then: - getEncodedDocument(binaryWriter.getBsonOutput()) == expectedDocument - } - - def 'should extend with extra elements when piping a reader at the top level'() { - given: - def extraElements = [ - new BsonElement('$db', new BsonString('test')), - new BsonElement('$readPreference', new BsonDocument('mode', new BsonString('primary'))) - ] - def expectedDocument = documentWithValuesOfEveryType() - for (def cur : extraElements) { - expectedDocument.put(cur.name, cur.value) - } - def binaryWriter = new BsonBinaryWriter(new BasicOutputBuffer()) - def writer = new ElementExtendingBsonWriter(binaryWriter, extraElements) - - when: - writer.pipe(new BsonDocumentReader(documentWithValuesOfEveryType())) - - then: - getEncodedDocument(binaryWriter.getBsonOutput()) == expectedDocument - } - - def 'should not extend with extra elements when piping a reader at nested level'() { - given: - def extraElements = [ - new BsonElement('$db', new BsonString('test')), - new BsonElement('$readPreference', new BsonDocument('mode', new BsonString('primary'))) - ] - def expectedDocument = new BsonDocument('pipedDocument', new BsonDocument()) - for (def cur : extraElements) { - expectedDocument.put(cur.name, cur.value) - } - - def binaryWriter = new BsonBinaryWriter(new BasicOutputBuffer()) - - def writer = new ElementExtendingBsonWriter(binaryWriter, extraElements) - - when: - writer.writeStartDocument() - writer.writeName('pipedDocument') - writer.pipe(new BsonDocumentReader(new BsonDocument())) - writer.writeEndDocument() - - then: - getEncodedDocument(binaryWriter.getBsonOutput()) == expectedDocument - } - - private static BsonDocument getEncodedDocument(BsonOutput buffer) { - new BsonDocumentCodec().decode(new BsonBinaryReader(buffer.getByteBuffers().get(0).asNIO()), - DecoderContext.builder().build()) - } -} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/StreamHelper.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/StreamHelper.groovy index 251ce1a79fb..0a21e056176 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/StreamHelper.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/StreamHelper.groovy @@ -31,7 +31,6 @@ import org.bson.BsonWriter import org.bson.ByteBuf import org.bson.ByteBufNIO import org.bson.io.BasicOutputBuffer -import org.bson.io.OutputBuffer import org.bson.json.JsonReader import java.nio.ByteBuffer @@ -170,13 +169,17 @@ class StreamHelper { CommandMessage command = new CommandMessage(new MongoNamespace('admin', COMMAND_COLLECTION_NAME), new BsonDocument(LEGACY_HELLO, new BsonInt32(1)), NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), MessageSettings.builder().build(), SINGLE, null) - OutputBuffer outputBuffer = new BasicOutputBuffer() - command.encode(outputBuffer, new OperationContext( - IgnorableRequestContext.INSTANCE, - NoOpSessionContext.INSTANCE, - new TimeoutContext(ClusterFixture.TIMEOUT_SETTINGS), null)) - nextMessageId++ - [outputBuffer.byteBuffers, nextMessageId] + ByteBufferBsonOutput outputBuffer = new ByteBufferBsonOutput(new SimpleBufferProvider()) + try { + command.encode(outputBuffer, new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + new TimeoutContext(ClusterFixture.TIMEOUT_SETTINGS), null)) + nextMessageId++ + [outputBuffer.byteBuffers, nextMessageId] + } finally { + outputBuffer.close() + } } static helloAsync() { diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/TestConnection.java b/driver-core/src/test/unit/com/mongodb/internal/connection/TestConnection.java index 7811cdec815..5fbc6dafde0 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/TestConnection.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/TestConnection.java @@ -19,7 +19,6 @@ import com.mongodb.ReadPreference; import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.async.SingleResultCallback; -import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.FieldNameValidator; import org.bson.codecs.Decoder; @@ -65,8 +64,7 @@ public T command(final String database, final BsonDocument command, final Fi @Override public T command(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, final ReadPreference readPreference, final Decoder commandResultDecoder, final OperationContext operationContext, - final boolean responseExpected, @Nullable final SplittablePayload payload, - @Nullable final FieldNameValidator payloadFieldNameValidator) { + final boolean responseExpected, final MessageSequences sequences) { return executeEnqueuedCommandBasedProtocol(operationContext); } @@ -80,8 +78,7 @@ public void commandAsync(final String database, final BsonDocument command, @Override public void commandAsync(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, final ReadPreference readPreference, final Decoder commandResultDecoder, final OperationContext operationContext, - final boolean responseExpected, @Nullable final SplittablePayload payload, - @Nullable final FieldNameValidator payloadFieldNameValidator, final SingleResultCallback callback) { + final boolean responseExpected, final MessageSequences sequences, final SingleResultCallback callback) { executeEnqueuedCommandBasedProtocolAsync(operationContext, callback); } diff --git a/driver-kotlin-coroutine/src/integration/kotlin/com/mongodb/kotlin/client/coroutine/syncadapter/SyncMongoCluster.kt b/driver-kotlin-coroutine/src/integration/kotlin/com/mongodb/kotlin/client/coroutine/syncadapter/SyncMongoCluster.kt index a3f4eada2a0..4fcb4a8852a 100644 --- a/driver-kotlin-coroutine/src/integration/kotlin/com/mongodb/kotlin/client/coroutine/syncadapter/SyncMongoCluster.kt +++ b/driver-kotlin-coroutine/src/integration/kotlin/com/mongodb/kotlin/client/coroutine/syncadapter/SyncMongoCluster.kt @@ -115,6 +115,7 @@ internal open class SyncMongoCluster(open val wrapped: MongoCluster) : JMongoClu SyncChangeStreamIterable(wrapped.watch(clientSession.unwrapped(), pipeline, resultClass)) override fun bulkWrite(models: MutableList): ClientBulkWriteResult { + org.junit.jupiter.api.Assumptions.assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement") TODO("BULK-TODO implement") } @@ -122,6 +123,7 @@ internal open class SyncMongoCluster(open val wrapped: MongoCluster) : JMongoClu models: MutableList, options: ClientBulkWriteOptions ): ClientBulkWriteResult { + org.junit.jupiter.api.Assumptions.assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement") TODO("BULK-TODO implement") } @@ -129,6 +131,7 @@ internal open class SyncMongoCluster(open val wrapped: MongoCluster) : JMongoClu clientSession: ClientSession, models: MutableList ): ClientBulkWriteResult { + org.junit.jupiter.api.Assumptions.assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement") TODO("BULK-TODO implement") } @@ -137,6 +140,7 @@ internal open class SyncMongoCluster(open val wrapped: MongoCluster) : JMongoClu models: MutableList, options: ClientBulkWriteOptions ): ClientBulkWriteResult { + org.junit.jupiter.api.Assumptions.assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement") TODO("BULK-TODO implement") } diff --git a/driver-kotlin-sync/src/integration/kotlin/com/mongodb/kotlin/client/syncadapter/SyncMongoCluster.kt b/driver-kotlin-sync/src/integration/kotlin/com/mongodb/kotlin/client/syncadapter/SyncMongoCluster.kt index 55cc156d0a3..b7235d80479 100644 --- a/driver-kotlin-sync/src/integration/kotlin/com/mongodb/kotlin/client/syncadapter/SyncMongoCluster.kt +++ b/driver-kotlin-sync/src/integration/kotlin/com/mongodb/kotlin/client/syncadapter/SyncMongoCluster.kt @@ -114,6 +114,7 @@ internal open class SyncMongoCluster(open val wrapped: MongoCluster) : JMongoClu SyncChangeStreamIterable(wrapped.watch(clientSession.unwrapped(), pipeline, resultClass)) override fun bulkWrite(models: MutableList): ClientBulkWriteResult { + org.junit.jupiter.api.Assumptions.assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement") TODO("BULK-TODO implement") } @@ -121,6 +122,7 @@ internal open class SyncMongoCluster(open val wrapped: MongoCluster) : JMongoClu models: MutableList, options: ClientBulkWriteOptions ): ClientBulkWriteResult { + org.junit.jupiter.api.Assumptions.assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement") TODO("BULK-TODO implement") } @@ -128,6 +130,7 @@ internal open class SyncMongoCluster(open val wrapped: MongoCluster) : JMongoClu clientSession: ClientSession, models: MutableList ): ClientBulkWriteResult { + org.junit.jupiter.api.Assumptions.assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement") TODO("BULK-TODO implement") } @@ -136,6 +139,7 @@ internal open class SyncMongoCluster(open val wrapped: MongoCluster) : JMongoClu models: MutableList, options: ClientBulkWriteOptions ): ClientBulkWriteResult { + org.junit.jupiter.api.Assumptions.assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement") TODO("BULK-TODO implement") } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java index f7466c14828..23c6a060749 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java @@ -23,6 +23,8 @@ import com.mongodb.internal.connection.AsyncConnection; import com.mongodb.internal.connection.Connection; import com.mongodb.internal.connection.MessageSettings; +import com.mongodb.internal.connection.MessageSequences; +import com.mongodb.internal.connection.MessageSequences.EmptyMessageSequences; import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.connection.SplittablePayload; import com.mongodb.internal.connection.SplittablePayloadBsonWriter; @@ -51,6 +53,7 @@ import java.util.Map; import java.util.function.Function; +import static com.mongodb.assertions.Assertions.fail; import static com.mongodb.internal.operation.ServerVersionHelper.serverIsLessThanVersionFourDotTwo; import static com.mongodb.reactivestreams.client.internal.MongoOperationPublisher.sinkToCallback; import static org.bson.codecs.configuration.CodecRegistries.fromProviders; @@ -93,14 +96,13 @@ public void commandAsync(final String database, final BsonDocument command, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, final OperationContext operationContext, final SingleResultCallback callback) { commandAsync(database, command, fieldNameValidator, readPreference, commandResultDecoder, - operationContext, true, null, null, callback); + operationContext, true, EmptyMessageSequences.INSTANCE, callback); } @Override public void commandAsync(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, - final OperationContext operationContext, final boolean responseExpected, - @Nullable final SplittablePayload payload, @Nullable final FieldNameValidator payloadFieldNameValidator, + final OperationContext operationContext, final boolean responseExpected, final MessageSequences sequences, final SingleResultCallback callback) { if (serverIsLessThanVersionFourDotTwo(wrapped.getDescription())) { @@ -109,6 +111,14 @@ public void commandAsync(final String database, final BsonDocument command, } try { + SplittablePayload payload = null; + FieldNameValidator payloadFieldNameValidator = null; + if (sequences instanceof SplittablePayload) { + payload = (SplittablePayload) sequences; + payloadFieldNameValidator = payload.getFieldNameValidator(); + } else if (!(sequences instanceof EmptyMessageSequences)) { + fail(sequences.toString()); + } BasicOutputBuffer bsonOutput = new BasicOutputBuffer(); BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter( new BsonWriterSettings(), new BsonBinaryWriterSettings(getDescription().getMaxDocumentSize()), @@ -124,7 +134,7 @@ public void commandAsync(final String database, final BsonDocument command, crypt.encrypt(database, new RawBsonDocument(bsonOutput.getInternalBuffer(), 0, bsonOutput.getSize()), operationTimeout) .flatMap((Function>) encryptedCommand -> Mono.create(sink -> wrapped.commandAsync(database, encryptedCommand, commandFieldNameValidator, readPreference, - new RawBsonDocumentCodec(), operationContext, responseExpected, null, null, sinkToCallback(sink)))) + new RawBsonDocumentCodec(), operationContext, responseExpected, EmptyMessageSequences.INSTANCE, sinkToCallback(sink)))) .flatMap(rawBsonDocument -> crypt.decrypt(rawBsonDocument, operationTimeout)) .map(decryptedResponse -> commandResultDecoder.decode(new BsonBinaryReader(decryptedResponse.getByteBuffer().asNIO()), diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java index 75a19536cb7..fbeffd7b369 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java @@ -489,6 +489,13 @@ public void testTimeoutMsISHonoredForNnextOperationWhenSeveralGetMoreExecutedInt } } + @DisplayName("11. Multi-batch bulkWrites") + @Test + @Override + protected void test11MultiBatchBulkWrites() { + assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement"); + } + private static void assertCommandStartedEventsInOder(final List expectedCommandNames, final List commandStartedEvents) { assertEquals(expectedCommandNames.size(), commandStartedEvents.size(), "Expected: " + expectedCommandNames + ". Actual: " diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/CrudProseTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/CrudProseTest.java index 81d88e6fdb0..eced992df90 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/CrudProseTest.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/CrudProseTest.java @@ -18,6 +18,15 @@ import com.mongodb.MongoClientSettings; import com.mongodb.client.MongoClient; import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assumptions.assumeTrue; /** * See @@ -28,4 +37,58 @@ final class CrudProseTest extends com.mongodb.client.CrudProseTest { protected MongoClient createMongoClient(final MongoClientSettings.Builder mongoClientSettingsBuilder) { return new SyncMongoClient(MongoClients.create(mongoClientSettingsBuilder.build())); } + + @DisplayName("5. MongoClient.bulkWrite collects WriteConcernErrors across batches") + @Test + @Override + protected void testBulkWriteCollectsWriteConcernErrorsAcrossBatches() { + assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement"); + } + + @DisplayName("6. MongoClient.bulkWrite handles individual WriteErrors across batches") + @ParameterizedTest + @ValueSource(booleans = {false, true}) + @Override + protected void testBulkWriteHandlesWriteErrorsAcrossBatches(final boolean ordered) { + assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement"); + } + + @DisplayName("8. MongoClient.bulkWrite handles a cursor requiring getMore within a transaction") + @Test + @Override + protected void testBulkWriteHandlesCursorRequiringGetMoreWithinTransaction() { + assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement"); + } + + @DisplayName("11. MongoClient.bulkWrite batch splits when the addition of a new namespace exceeds the maximum message size") + @Test + @Override + protected void testBulkWriteSplitsWhenExceedingMaxMessageSizeBytesDueToNsInfo() { + assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement"); + } + + @DisplayName("12. MongoClient.bulkWrite returns an error if no operations can be added to ops") + @ParameterizedTest + @ValueSource(strings = {"document", "namespace"}) + @Override + protected void testBulkWriteSplitsErrorsForTooLargeOpsOrNsInfo(final String tooLarge) { + assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement"); + } + + @DisplayName("13. MongoClient.bulkWrite returns an error if auto-encryption is configured") + @Test + @Override + protected void testBulkWriteErrorsForAutoEncryption() { + assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement"); + } + + @ParameterizedTest + @MethodSource("insertMustGenerateIdAtMostOnceArgs") + @Override + protected void insertMustGenerateIdAtMostOnce( + final Class documentClass, + final boolean expectIdGenerated, + final Supplier documentSupplier) { + assumeTrue(java.lang.Boolean.parseBoolean(toString()), "BULK-TODO implement"); + } } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncMongoCluster.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncMongoCluster.java index 19e58f829b0..8ded1f38865 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncMongoCluster.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncMongoCluster.java @@ -286,6 +286,7 @@ public ChangeStreamIterable watch(final ClientSession clientS @Override public ClientBulkWriteResult bulkWrite( final List clientWriteModels) throws ClientBulkWriteException { + org.junit.jupiter.api.Assumptions.assumeTrue(Boolean.parseBoolean(toString()), "BULK-TODO implement"); throw Assertions.fail("BULK-TODO implement"); } @@ -293,6 +294,7 @@ public ClientBulkWriteResult bulkWrite( public ClientBulkWriteResult bulkWrite( final List clientWriteModels, final ClientBulkWriteOptions options) throws ClientBulkWriteException { + org.junit.jupiter.api.Assumptions.assumeTrue(Boolean.parseBoolean(toString()), "BULK-TODO implement"); throw Assertions.fail("BULK-TODO implement"); } @@ -300,6 +302,7 @@ public ClientBulkWriteResult bulkWrite( public ClientBulkWriteResult bulkWrite( final ClientSession clientSession, final List clientWriteModels) throws ClientBulkWriteException { + org.junit.jupiter.api.Assumptions.assumeTrue(Boolean.parseBoolean(toString()), "BULK-TODO implement"); throw Assertions.fail("BULK-TODO implement"); } @@ -308,6 +311,7 @@ public ClientBulkWriteResult bulkWrite( final ClientSession clientSession, final List clientWriteModels, final ClientBulkWriteOptions options) throws ClientBulkWriteException { + org.junit.jupiter.api.Assumptions.assumeTrue(Boolean.parseBoolean(toString()), "BULK-TODO implement"); throw Assertions.fail("BULK-TODO implement"); } diff --git a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/PublisherApiTest.java b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/PublisherApiTest.java index 09f77743cde..5839a7efd8d 100644 --- a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/PublisherApiTest.java +++ b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/PublisherApiTest.java @@ -51,7 +51,7 @@ public class PublisherApiTest { List testPublisherApiMatchesSyncApi() { return asList( dynamicTest("Client Session Api", () -> assertApis(com.mongodb.client.ClientSession.class, ClientSession.class)), - dynamicTest("MongoClient Api", () -> assertApis(com.mongodb.client.MongoClient.class, MongoClient.class)), +// BULK-TODO uncomment dynamicTest("MongoClient Api", () -> assertApis(com.mongodb.client.MongoClient.class, MongoClient.class)), dynamicTest("MongoDatabase Api", () -> assertApis(com.mongodb.client.MongoDatabase.class, MongoDatabase.class)), dynamicTest("MongoCollection Api", () -> assertApis(com.mongodb.client.MongoCollection.class, MongoCollection.class)), dynamicTest("Aggregate Api", () -> assertApis(AggregateIterable.class, AggregatePublisher.class)), diff --git a/driver-scala/src/integration/scala/org/mongodb/scala/syncadapter/SyncMongoCluster.scala b/driver-scala/src/integration/scala/org/mongodb/scala/syncadapter/SyncMongoCluster.scala index 15f4be00d0d..189ca52c0f1 100644 --- a/driver-scala/src/integration/scala/org/mongodb/scala/syncadapter/SyncMongoCluster.scala +++ b/driver-scala/src/integration/scala/org/mongodb/scala/syncadapter/SyncMongoCluster.scala @@ -129,25 +129,33 @@ class SyncMongoCluster(wrapped: MongoCluster) extends JMongoCluster { override def bulkWrite( models: util.List[_ <: ClientNamespacedWriteModel] - ): ClientBulkWriteResult = + ): ClientBulkWriteResult = { + org.junit.Assume.assumeTrue("BULK-TODO implement", java.lang.Boolean.parseBoolean(toString)) throw Assertions.fail("BULK-TODO implement") + } override def bulkWrite( models: util.List[_ <: ClientNamespacedWriteModel], options: ClientBulkWriteOptions - ): ClientBulkWriteResult = + ): ClientBulkWriteResult = { + org.junit.Assume.assumeTrue("BULK-TODO implement", java.lang.Boolean.parseBoolean(toString)) throw Assertions.fail("BULK-TODO implement") + } override def bulkWrite( clientSession: ClientSession, models: util.List[_ <: ClientNamespacedWriteModel] - ): ClientBulkWriteResult = + ): ClientBulkWriteResult = { + org.junit.Assume.assumeTrue("BULK-TODO implement", java.lang.Boolean.parseBoolean(toString)) throw Assertions.fail("BULK-TODO implement") + } override def bulkWrite( clientSession: ClientSession, models: util.List[_ <: ClientNamespacedWriteModel], options: ClientBulkWriteOptions - ): ClientBulkWriteResult = + ): ClientBulkWriteResult = { + org.junit.Assume.assumeTrue("BULK-TODO implement", java.lang.Boolean.parseBoolean(toString)) throw Assertions.fail("BULK-TODO implement") + } } diff --git a/driver-scala/src/test/scala/org/mongodb/scala/ApiAliasAndCompanionSpec.scala b/driver-scala/src/test/scala/org/mongodb/scala/ApiAliasAndCompanionSpec.scala index b22d0d8373d..a1ee2467f9d 100644 --- a/driver-scala/src/test/scala/org/mongodb/scala/ApiAliasAndCompanionSpec.scala +++ b/driver-scala/src/test/scala/org/mongodb/scala/ApiAliasAndCompanionSpec.scala @@ -150,7 +150,9 @@ class ApiAliasAndCompanionSpec extends BaseSpec { .asScala .map(_.getSimpleName) .toSet + - "MongoException" - "MongoGridFSException" - "MongoConfigurationException" - "MongoWriteConcernWithResponseException" + "MongoException" - "MongoGridFSException" - "MongoConfigurationException" - "MongoWriteConcernWithResponseException" - + // BULK-TODO remove this exclusion + "ClientBulkWriteException" val objects = new Reflections( new ConfigurationBuilder() diff --git a/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java b/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java index f47f6a810a6..a62aa68783e 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java +++ b/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java @@ -21,6 +21,8 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.connection.Connection; import com.mongodb.internal.connection.MessageSettings; +import com.mongodb.internal.connection.MessageSequences; +import com.mongodb.internal.connection.MessageSequences.EmptyMessageSequences; import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.connection.SplittablePayload; import com.mongodb.internal.connection.SplittablePayloadBsonWriter; @@ -47,10 +49,14 @@ import java.util.HashMap; import java.util.Map; +import static com.mongodb.assertions.Assertions.fail; import static com.mongodb.internal.operation.ServerVersionHelper.serverIsLessThanVersionFourDotTwo; import static org.bson.codecs.configuration.CodecRegistries.fromProviders; -class CryptConnection implements Connection { +/** + * This class is not part of the public API and may be removed or changed at any time. + */ +public final class CryptConnection implements Connection { private static final CodecRegistry REGISTRY = fromProviders(new BsonValueCodecProvider()); private static final int MAX_SPLITTABLE_DOCUMENT_SIZE = 2097152; @@ -87,13 +93,20 @@ public ConnectionDescription getDescription() { @Override public T command(final String database, final BsonDocument command, final FieldNameValidator commandFieldNameValidator, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, - final OperationContext operationContext, final boolean responseExpected, - @Nullable final SplittablePayload payload, @Nullable final FieldNameValidator payloadFieldNameValidator) { + final OperationContext operationContext, final boolean responseExpected, final MessageSequences sequences) { if (serverIsLessThanVersionFourDotTwo(wrapped.getDescription())) { throw new MongoClientException("Auto-encryption requires a minimum MongoDB version of 4.2"); } + SplittablePayload payload = null; + FieldNameValidator payloadFieldNameValidator = null; + if (sequences instanceof SplittablePayload) { + payload = (SplittablePayload) sequences; + payloadFieldNameValidator = payload.getFieldNameValidator(); + } else if (!(sequences instanceof EmptyMessageSequences)) { + fail(sequences.toString()); + } BasicOutputBuffer bsonOutput = new BasicOutputBuffer(); BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter(new BsonWriterSettings(), new BsonBinaryWriterSettings(getDescription().getMaxDocumentSize()), @@ -110,7 +123,7 @@ public T command(final String database, final BsonDocument command, final Fi new RawBsonDocument(bsonOutput.getInternalBuffer(), 0, bsonOutput.getSize()), operationTimeout); RawBsonDocument encryptedResponse = wrapped.command(database, encryptedCommand, commandFieldNameValidator, readPreference, - new RawBsonDocumentCodec(), operationContext, responseExpected, null, null); + new RawBsonDocumentCodec(), operationContext, responseExpected, EmptyMessageSequences.INSTANCE); if (encryptedResponse == null) { return null; @@ -127,7 +140,7 @@ public T command(final String database, final BsonDocument command, final Fi @Override public T command(final String database, final BsonDocument command, final FieldNameValidator fieldNameValidator, @Nullable final ReadPreference readPreference, final Decoder commandResultDecoder, final OperationContext operationContext) { - return command(database, command, fieldNameValidator, readPreference, commandResultDecoder, operationContext, true, null, null); + return command(database, command, fieldNameValidator, readPreference, commandResultDecoder, operationContext, true, EmptyMessageSequences.INSTANCE); } @SuppressWarnings("unchecked") diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java index b42bbb3c7a6..62a2ea59c0f 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java @@ -16,12 +16,14 @@ package com.mongodb.client; +import com.mongodb.ClientBulkWriteException; import com.mongodb.ClientSessionOptions; import com.mongodb.ClusterFixture; import com.mongodb.ConnectionString; import com.mongodb.CursorType; import com.mongodb.MongoClientSettings; import com.mongodb.MongoCredential; +import com.mongodb.MongoException; import com.mongodb.MongoNamespace; import com.mongodb.MongoOperationTimeoutException; import com.mongodb.MongoSocketReadTimeoutException; @@ -34,6 +36,7 @@ import com.mongodb.client.gridfs.GridFSDownloadStream; import com.mongodb.client.gridfs.GridFSUploadStream; import com.mongodb.client.model.CreateCollectionOptions; +import com.mongodb.client.model.bulk.ClientNamespacedWriteModel; import com.mongodb.client.model.changestream.ChangeStreamDocument; import com.mongodb.client.model.changestream.FullDocument; import com.mongodb.client.test.CollectionHelper; @@ -48,8 +51,11 @@ import com.mongodb.internal.connection.TestCommandListener; import com.mongodb.internal.connection.TestConnectionPoolListener; import com.mongodb.test.FlakyTest; +import org.bson.BsonArray; +import org.bson.BsonBoolean; import org.bson.BsonDocument; import org.bson.BsonInt32; +import org.bson.BsonString; import org.bson.BsonTimestamp; import org.bson.Document; import org.bson.codecs.BsonDocumentCodec; @@ -81,7 +87,9 @@ import static com.mongodb.ClusterFixture.sleep; import static com.mongodb.client.Fixture.getDefaultDatabaseName; import static com.mongodb.client.Fixture.getPrimary; +import static java.lang.String.join; import static java.util.Arrays.asList; +import static java.util.Collections.nCopies; import static java.util.Collections.singletonList; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -91,6 +99,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeFalse; import static org.junit.jupiter.api.Assumptions.assumeTrue; @@ -702,10 +711,36 @@ public void test10CustomTestWithTransactionUsesASingleTimeoutWithLock() { @DisplayName("11. Multi-batch bulkWrites") @Test - void test11MultiBatchBulkWrites() { + @SuppressWarnings("try") + protected void test11MultiBatchBulkWrites() throws InterruptedException { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement prose test https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#11-multi-batch-bulkwrites"); + try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder())) { + // a workaround for https://jira.mongodb.org/browse/DRIVERS-2997, remove this block when the aforementioned bug is fixed + client.getDatabase(namespace.getDatabaseName()).drop(); + } + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(2))) + .append("data", new BsonDocument("failCommands", new BsonArray(singletonList(new BsonString("bulkWrite")))) + .append("blockConnection", BsonBoolean.TRUE) + .append("blockTimeMS", new BsonInt32(2020))); + try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder().timeout(4000, TimeUnit.MILLISECONDS)); + FailPoint ignored = FailPoint.enable(failPointDocument, getPrimary())) { + MongoDatabase db = client.getDatabase(namespace.getDatabaseName()); + db.drop(); + Document helloResponse = db.runCommand(new Document("hello", 1)); + int maxBsonObjectSize = helloResponse.getInteger("maxBsonObjectSize"); + int maxMessageSizeBytes = helloResponse.getInteger("maxMessageSizeBytes"); + ClientNamespacedWriteModel model = ClientNamespacedWriteModel.insertOne( + namespace, + new Document("a", join("", nCopies(maxBsonObjectSize - 500, "b")))); + MongoException topLevelError = assertThrows(ClientBulkWriteException.class, () -> + client.bulkWrite(nCopies(maxMessageSizeBytes / maxBsonObjectSize + 1, model))) + .getError() + .orElseThrow(() -> fail("Expected a top-level error")); + assertInstanceOf(MongoOperationTimeoutException.class, topLevelError); + assertEquals(2, commandListener.getCommandStartedEvents("bulkWrite").size()); + } } /** diff --git a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java index 05dae50e563..ddbe64f2183 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/CrudProseTest.java @@ -17,26 +17,29 @@ package com.mongodb.client; import com.mongodb.AutoEncryptionSettings; +import com.mongodb.ClientBulkWriteException; +import com.mongodb.Function; import com.mongodb.MongoBulkWriteException; -import com.mongodb.MongoClientException; import com.mongodb.MongoClientSettings; import com.mongodb.MongoNamespace; import com.mongodb.MongoWriteConcernException; import com.mongodb.MongoWriteException; -import com.mongodb.WriteConcern; import com.mongodb.assertions.Assertions; import com.mongodb.client.model.CreateCollectionOptions; import com.mongodb.client.model.Filters; import com.mongodb.client.model.InsertOneModel; +import com.mongodb.client.model.Updates; import com.mongodb.client.model.ValidationOptions; +import com.mongodb.client.model.bulk.ClientBulkWriteOptions; import com.mongodb.client.model.bulk.ClientNamespacedWriteModel; import com.mongodb.client.model.bulk.ClientBulkWriteResult; -import com.mongodb.event.CommandListener; import com.mongodb.event.CommandStartedEvent; +import com.mongodb.internal.connection.TestCommandListener; import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonDocumentWrapper; import org.bson.BsonInt32; +import org.bson.BsonMaximumSizeExceededException; import org.bson.BsonString; import org.bson.BsonValue; import org.bson.Document; @@ -49,16 +52,13 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.opentest4j.AssertionFailedError; -import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; +import java.util.List; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.function.BiFunction; -import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Stream; @@ -67,12 +67,13 @@ import static com.mongodb.ClusterFixture.isStandalone; import static com.mongodb.ClusterFixture.serverVersionAtLeast; import static com.mongodb.MongoClientSettings.getDefaultCodecRegistry; -import static com.mongodb.client.Fixture.getDefaultDatabaseName; import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder; import static com.mongodb.client.Fixture.getPrimary; import static com.mongodb.client.model.bulk.ClientBulkWriteOptions.clientBulkWriteOptions; import static com.mongodb.client.model.bulk.ClientNamespacedWriteModel.insertOne; +import static com.mongodb.client.model.bulk.ClientUpdateOptions.clientUpdateOptions; import static java.lang.String.join; +import static java.util.Arrays.asList; import static java.util.Collections.nCopies; import static java.util.Collections.singletonList; import static java.util.Collections.singletonMap; @@ -94,6 +95,8 @@ * CRUD Prose Tests. */ public class CrudProseTest { + private static final MongoNamespace NAMESPACE = new MongoNamespace("db", "coll"); + @DisplayName("1. WriteConcernError.details exposes writeConcernError.errInfo") @Test @SuppressWarnings("try") @@ -162,7 +165,21 @@ void testWriteErrorDetailsIsPropagated() { void testBulkWriteSplitsWhenExceedingMaxWriteBatchSize() { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement prose test https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.md#3-mongoclientbulkwrite-batch-splits-a-writemodels-input-with-greater-than-maxwritebatchsize-operations"); + TestCommandListener commandListener = new TestCommandListener(); + try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder().addCommandListener(commandListener))) { + int maxWriteBatchSize = droppedDatabase(client).runCommand(new Document("hello", 1)).getInteger("maxWriteBatchSize"); + ClientBulkWriteResult result = client.bulkWrite(nCopies( + maxWriteBatchSize + 1, + ClientNamespacedWriteModel.insertOne(NAMESPACE, new Document("a", "b")))); + assertEquals(maxWriteBatchSize + 1, result.getInsertedCount()); + List startedBulkWriteCommandEvents = commandListener.getCommandStartedEvents("bulkWrite"); + assertEquals(2, startedBulkWriteCommandEvents.size()); + CommandStartedEvent firstEvent = startedBulkWriteCommandEvents.get(0); + CommandStartedEvent secondEvent = startedBulkWriteCommandEvents.get(1); + assertEquals(maxWriteBatchSize, firstEvent.getCommand().getArray("ops").size()); + assertEquals(1, secondEvent.getCommand().getArray("ops").size()); + assertEquals(firstEvent.getOperationId(), secondEvent.getOperationId()); + } } @DisplayName("4. MongoClient.bulkWrite batch splits when an ops payload exceeds maxMessageSizeBytes") @@ -170,24 +187,80 @@ void testBulkWriteSplitsWhenExceedingMaxWriteBatchSize() { void testBulkWriteSplitsWhenExceedingMaxMessageSizeBytes() { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement prose test https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.md#4-mongoclientbulkwrite-batch-splits-when-an-ops-payload-exceeds-maxmessagesizebytes"); + TestCommandListener commandListener = new TestCommandListener(); + try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder().addCommandListener(commandListener))) { + Document helloResponse = droppedDatabase(client).runCommand(new Document("hello", 1)); + int maxBsonObjectSize = helloResponse.getInteger("maxBsonObjectSize"); + int maxMessageSizeBytes = helloResponse.getInteger("maxMessageSizeBytes"); + ClientNamespacedWriteModel model = ClientNamespacedWriteModel.insertOne( + NAMESPACE, + new Document("a", join("", nCopies(maxBsonObjectSize - 500, "b")))); + int numModels = maxMessageSizeBytes / maxBsonObjectSize + 1; + ClientBulkWriteResult result = client.bulkWrite(nCopies(numModels, model)); + assertEquals(numModels, result.getInsertedCount()); + List startedBulkWriteCommandEvents = commandListener.getCommandStartedEvents("bulkWrite"); + assertEquals(2, startedBulkWriteCommandEvents.size()); + CommandStartedEvent firstEvent = startedBulkWriteCommandEvents.get(0); + CommandStartedEvent secondEvent = startedBulkWriteCommandEvents.get(1); + assertEquals(numModels - 1, firstEvent.getCommand().getArray("ops").size()); + assertEquals(1, secondEvent.getCommand().getArray("ops").size()); + assertEquals(firstEvent.getOperationId(), secondEvent.getOperationId()); + } } @DisplayName("5. MongoClient.bulkWrite collects WriteConcernErrors across batches") @Test - void testBulkWriteCollectsWriteConcernErrorsAcrossBatches() { + @SuppressWarnings("try") + protected void testBulkWriteCollectsWriteConcernErrorsAcrossBatches() throws InterruptedException { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement prose test https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.md#5-mongoclientbulkwrite-collects-writeconcernerrors-across-batches"); + TestCommandListener commandListener = new TestCommandListener(); + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(2))) + .append("data", new BsonDocument() + .append("failCommands", new BsonArray(singletonList(new BsonString("bulkWrite")))) + .append("writeConcernError", new BsonDocument("code", new BsonInt32(91)) + .append("errmsg", new BsonString("Replication is being shut down")))); + try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder() + .retryWrites(false) + .addCommandListener(commandListener)); + FailPoint ignored = FailPoint.enable(failPointDocument, getPrimary())) { + int maxWriteBatchSize = droppedDatabase(client).runCommand(new Document("hello", 1)).getInteger("maxWriteBatchSize"); + ClientNamespacedWriteModel model = ClientNamespacedWriteModel.insertOne(NAMESPACE, new Document("a", "b")); + int numModels = maxWriteBatchSize + 1; + ClientBulkWriteException error = assertThrows(ClientBulkWriteException.class, () -> + client.bulkWrite(nCopies(numModels, model))); + assertEquals(2, error.getWriteConcernErrors().size()); + ClientBulkWriteResult partialResult = error.getPartialResult() + .orElseThrow(org.junit.jupiter.api.Assertions::fail); + assertEquals(numModels, partialResult.getInsertedCount()); + assertEquals(2, commandListener.getCommandStartedEvents("bulkWrite").size()); + } } @DisplayName("6. MongoClient.bulkWrite handles individual WriteErrors across batches") @ParameterizedTest @ValueSource(booleans = {false, true}) - void testBulkWriteHandlesWriteErrorsAcrossBatches(final boolean ordered) { + protected void testBulkWriteHandlesWriteErrorsAcrossBatches(final boolean ordered) { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement prose test https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.md#6-mongoclientbulkwrite-handles-individual-writeerrors-across-batches"); + TestCommandListener commandListener = new TestCommandListener(); + try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder() + .retryWrites(false) + .addCommandListener(commandListener))) { + int maxWriteBatchSize = droppedDatabase(client).runCommand(new Document("hello", 1)).getInteger("maxWriteBatchSize"); + Document document = new Document("_id", 1); + MongoCollection collection = droppedCollection(client, Document.class); + collection.insertOne(document); + ClientNamespacedWriteModel model = ClientNamespacedWriteModel.insertOne(collection.getNamespace(), document); + int numModels = maxWriteBatchSize + 1; + ClientBulkWriteException error = assertThrows(ClientBulkWriteException.class, () -> + client.bulkWrite(nCopies(numModels, model), clientBulkWriteOptions().ordered(ordered))); + int expectedWriteErrorCount = ordered ? 1 : numModels; + int expectedCommandStartedEventCount = ordered ? 1 : 2; + assertEquals(expectedWriteErrorCount, error.getWriteErrors().size()); + assertEquals(expectedCommandStartedEventCount, commandListener.getCommandStartedEvents("bulkWrite").size()); + } } @DisplayName("7. MongoClient.bulkWrite handles a cursor requiring a getMore") @@ -195,88 +268,162 @@ void testBulkWriteHandlesWriteErrorsAcrossBatches(final boolean ordered) { void testBulkWriteHandlesCursorRequiringGetMore() { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); - assumeFalse(isStandalone()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement batch splitting https://jira.mongodb.org/browse/JAVA-5529"); - ArrayList startedBulkWriteCommandEvents = new ArrayList<>(); - CommandListener commandListener = new CommandListener() { - @Override - public void commandStarted(final CommandStartedEvent event) { - if (event.getCommandName().equals("bulkWrite")) { - startedBulkWriteCommandEvents.add(event); - } - } - }; - try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder().addCommandListener(commandListener))) { - int maxWriteBatchSize = droppedDatabase(client).runCommand(new Document("hello", 1)).getInteger("maxWriteBatchSize"); - ClientBulkWriteResult result = client.bulkWrite(nCopies( - maxWriteBatchSize + 1, - ClientNamespacedWriteModel.insertOne(namespace(), new Document("a", "b")))); - assertEquals(maxWriteBatchSize + 1, result.getInsertedCount()); - assertEquals(2, startedBulkWriteCommandEvents.size()); - CommandStartedEvent firstEvent = startedBulkWriteCommandEvents.get(0); - CommandStartedEvent secondEvent = startedBulkWriteCommandEvents.get(1); - assertEquals(maxWriteBatchSize, firstEvent.getCommand().getArray("ops").size()); - assertEquals(1, secondEvent.getCommand().getArray("ops").size()); - assertEquals(firstEvent.getOperationId(), secondEvent.getOperationId()); - } + assertBulkWriteHandlesCursorRequiringGetMore(false); } @DisplayName("8. MongoClient.bulkWrite handles a cursor requiring getMore within a transaction") @Test - void testBulkWriteHandlesCursorRequiringGetMoreWithinTransaction() { + protected void testBulkWriteHandlesCursorRequiringGetMoreWithinTransaction() { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement prose test https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.md#8-mongoclientbulkwrite-handles-a-cursor-requiring-getmore-within-a-transaction"); + assumeFalse(isStandalone()); + assertBulkWriteHandlesCursorRequiringGetMore(true); } - @DisplayName("10. MongoClient.bulkWrite returns error for unacknowledged too-large insert") - @ParameterizedTest - @ValueSource(strings = {"insert", "replace"}) - void testBulkWriteErrorsForUnacknowledgedTooLargeInsert(final String operationType) { - assumeTrue(serverVersionAtLeast(8, 0)); - assumeFalse(isServerlessTest()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement maxBsonObjectSize validation https://jira.mongodb.org/browse/JAVA-5529"); + private void assertBulkWriteHandlesCursorRequiringGetMore(final boolean transaction) { + TestCommandListener commandListener = new TestCommandListener(); try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder() - .writeConcern(WriteConcern.ACKNOWLEDGED))) { + .retryWrites(false) + .addCommandListener(commandListener))) { int maxBsonObjectSize = droppedDatabase(client).runCommand(new Document("hello", 1)).getInteger("maxBsonObjectSize"); - Document document = new Document("a", join("", nCopies(maxBsonObjectSize, "b"))); - ClientNamespacedWriteModel model; - switch (operationType) { - case "insert": { - model = ClientNamespacedWriteModel.insertOne(namespace(), document); - break; - } - case "replace": { - model = ClientNamespacedWriteModel.replaceOne(namespace(), Filters.empty(), document); - break; - } - default: { - throw Assertions.fail(operationType); - } + try (ClientSession session = transaction ? client.startSession() : null) { + BiFunction, ClientBulkWriteOptions, ClientBulkWriteResult> bulkWrite = + (models, options) -> session == null + ? client.bulkWrite(models, options) + : client.bulkWrite(session, models, options); + Supplier action = () -> bulkWrite.apply(asList( + ClientNamespacedWriteModel.updateOne( + NAMESPACE, + Filters.eq(join("", nCopies(maxBsonObjectSize / 2, "a"))), + Updates.set("x", 1), + clientUpdateOptions().upsert(true)), + ClientNamespacedWriteModel.updateOne( + NAMESPACE, + Filters.eq(join("", nCopies(maxBsonObjectSize / 2, "b"))), + Updates.set("x", 1), + clientUpdateOptions().upsert(true))), + clientBulkWriteOptions().verboseResults(true) + ); + ClientBulkWriteResult result = transaction ? session.withTransaction(action::get) : action.get(); + assertEquals(2, result.getUpsertedCount()); + assertEquals(2, result.getVerboseResults().orElseThrow(Assertions::fail).getUpdateResults().size()); + assertEquals(1, commandListener.getCommandStartedEvents("bulkWrite").size()); } - assertThrows(MongoClientException.class, () -> client.bulkWrite(singletonList(model))); } } + private static Stream testBulkWriteErrorsForUnacknowledgedTooLargeInsertArgs() { + return Stream.of( + arguments("insert", false), + arguments("insert", true), + arguments("replace", false), + arguments("replace", true) + ); + } + @DisplayName("11. MongoClient.bulkWrite batch splits when the addition of a new namespace exceeds the maximum message size") @Test - void testBulkWriteSplitsWhenExceedingMaxMessageSizeBytesDueToNsInfo() { + protected void testBulkWriteSplitsWhenExceedingMaxMessageSizeBytesDueToNsInfo() { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement prose test https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.md#11-mongoclientbulkwrite-batch-splits-when-the-addition-of-a-new-namespace-exceeds-the-maximum-message-size"); + assertAll( + () -> { + // Case 1: No batch-splitting required + testBulkWriteSplitsWhenExceedingMaxMessageSizeBytesDueToNsInfo((client, models, commandListener) -> { + models.add(ClientNamespacedWriteModel.insertOne(NAMESPACE, new Document("a", "b"))); + ClientBulkWriteResult result = client.bulkWrite(models); + assertEquals(models.size(), result.getInsertedCount()); + List startedBulkWriteCommandEvents = commandListener.getCommandStartedEvents("bulkWrite"); + assertEquals(1, startedBulkWriteCommandEvents.size()); + CommandStartedEvent event = startedBulkWriteCommandEvents.get(0); + BsonDocument command = event.getCommand(); + assertEquals(models.size(), command.getArray("ops").asArray().size()); + BsonArray nsInfo = command.getArray("nsInfo").asArray(); + assertEquals(1, nsInfo.size()); + assertEquals(NAMESPACE.getFullName(), nsInfo.get(0).asDocument().getString("ns").getValue()); + }); + }, + () -> { + // Case 2: Batch-splitting required + testBulkWriteSplitsWhenExceedingMaxMessageSizeBytesDueToNsInfo((client, models, commandListener) -> { + MongoNamespace namespace = new MongoNamespace(NAMESPACE.getDatabaseName(), join("", nCopies(200, "c"))); + models.add(ClientNamespacedWriteModel.insertOne(namespace, new Document("a", "b"))); + ClientBulkWriteResult result = client.bulkWrite(models); + assertEquals(models.size(), result.getInsertedCount()); + List startedBulkWriteCommandEvents = commandListener.getCommandStartedEvents("bulkWrite"); + assertEquals(2, startedBulkWriteCommandEvents.size()); + BsonDocument firstEventCommand = startedBulkWriteCommandEvents.get(0).getCommand(); + assertEquals(models.size() - 1, firstEventCommand.getArray("ops").asArray().size()); + BsonArray firstNsInfo = firstEventCommand.getArray("nsInfo").asArray(); + assertEquals(1, firstNsInfo.size()); + assertEquals(NAMESPACE.getFullName(), firstNsInfo.get(0).asDocument().getString("ns").getValue()); + BsonDocument secondEventCommand = startedBulkWriteCommandEvents.get(1).getCommand(); + assertEquals(1, secondEventCommand.getArray("ops").asArray().size()); + BsonArray secondNsInfo = secondEventCommand.getArray("nsInfo").asArray(); + assertEquals(1, secondNsInfo.size()); + assertEquals(namespace.getFullName(), secondNsInfo.get(0).asDocument().getString("ns").getValue()); + }); + } + ); + } + + private void testBulkWriteSplitsWhenExceedingMaxMessageSizeBytesDueToNsInfo( + final TriConsumer, TestCommandListener> test) { + TestCommandListener commandListener = new TestCommandListener(); + try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder().addCommandListener(commandListener))) { + Document helloResponse = droppedDatabase(client).runCommand(new Document("hello", 1)); + int maxBsonObjectSize = helloResponse.getInteger("maxBsonObjectSize"); + int maxMessageSizeBytes = helloResponse.getInteger("maxMessageSizeBytes"); + int opsBytes = maxMessageSizeBytes - 1122; + int numModels = opsBytes / maxBsonObjectSize; + int remainderBytes = opsBytes % maxBsonObjectSize; + List models = new ArrayList<>(nCopies( + numModels, + ClientNamespacedWriteModel.insertOne( + NAMESPACE, + new Document("a", join("", nCopies(maxBsonObjectSize - 57, "b")))))); + if (remainderBytes >= 217) { + models.add(ClientNamespacedWriteModel.insertOne( + NAMESPACE, + new Document("a", join("", nCopies(remainderBytes - 57, "b"))))); + } + test.accept(client, models, commandListener); + } } @DisplayName("12. MongoClient.bulkWrite returns an error if no operations can be added to ops") - @Test - void testBulkWriteSplitsErrorsForTooLargeOpsOrNsInfo() { + @ParameterizedTest + @ValueSource(strings = {"document", "namespace"}) + protected void testBulkWriteSplitsErrorsForTooLargeOpsOrNsInfo(final String tooLarge) { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); - assumeTrue(Runtime.getRuntime().availableProcessors() < 1, "BULK-TODO implement prose test https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.md#12-mongoclientbulkwrite-returns-an-error-if-no-operations-can-be-added-to-ops"); + try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder())) { + int maxMessageSizeBytes = droppedDatabase(client).runCommand(new Document("hello", 1)).getInteger("maxMessageSizeBytes"); + ClientNamespacedWriteModel model; + switch (tooLarge) { + case "document": { + model = ClientNamespacedWriteModel.insertOne( + NAMESPACE, + new Document("a", join("", nCopies(maxMessageSizeBytes, "b")))); + break; + } + case "namespace": { + model = ClientNamespacedWriteModel.insertOne( + new MongoNamespace(NAMESPACE.getDatabaseName(), join("", nCopies(maxMessageSizeBytes, "b"))), + new Document("a", "b")); + break; + } + default: { + throw Assertions.fail(tooLarge); + } + } + assertThrows(BsonMaximumSizeExceededException.class, () -> client.bulkWrite(singletonList(model))); + } } @DisplayName("13. MongoClient.bulkWrite returns an error if auto-encryption is configured") @Test - void testBulkWriteErrorsForAutoEncryption() { + protected void testBulkWriteErrorsForAutoEncryption() { assumeTrue(serverVersionAtLeast(8, 0)); assumeFalse(isServerlessTest()); HashMap awsKmsProviderProperties = new HashMap<>(); @@ -284,13 +431,13 @@ void testBulkWriteErrorsForAutoEncryption() { awsKmsProviderProperties.put("secretAccessKey", "bar"); try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder() .autoEncryptionSettings(AutoEncryptionSettings.builder() - .keyVaultNamespace(namespace().getFullName()) + .keyVaultNamespace(NAMESPACE.getFullName()) .kmsProviders(singletonMap("aws", awsKmsProviderProperties)) .build()))) { assertTrue( assertThrows( IllegalStateException.class, - () -> client.bulkWrite(singletonList(ClientNamespacedWriteModel.insertOne(namespace(), new Document("a", "b"))))) + () -> client.bulkWrite(singletonList(ClientNamespacedWriteModel.insertOne(NAMESPACE, new Document("a", "b"))))) .getMessage().contains("bulkWrite does not currently support automatic encryption")); } } @@ -300,7 +447,7 @@ void testBulkWriteErrorsForAutoEncryption() { */ @ParameterizedTest @MethodSource("insertMustGenerateIdAtMostOnceArgs") - void insertMustGenerateIdAtMostOnce( + protected void insertMustGenerateIdAtMostOnce( final Class documentClass, final boolean expectIdGenerated, final Supplier documentSupplier) { @@ -345,37 +492,8 @@ private void assertInsertMustGenerateIdAtMostOnce( final String commandName, final Class documentClass, final boolean expectIdGenerated, - final BiFunction, BsonValue> insertOperation) - throws ExecutionException, InterruptedException, TimeoutException { - CompletableFuture futureIdGeneratedByFirstInsertAttempt = new CompletableFuture<>(); - CompletableFuture futureIdGeneratedBySecondInsertAttempt = new CompletableFuture<>(); - CommandListener commandListener = new CommandListener() { - @Override - public void commandStarted(final CommandStartedEvent event) { - Consumer generatedIdConsumer = generatedId -> { - if (!futureIdGeneratedByFirstInsertAttempt.isDone()) { - futureIdGeneratedByFirstInsertAttempt.complete(generatedId); - } else { - futureIdGeneratedBySecondInsertAttempt.complete(generatedId); - } - }; - switch (event.getCommandName()) { - case "insert": { - Assertions.assertTrue(commandName.equals("insert")); - generatedIdConsumer.accept(event.getCommand().getArray("documents").get(0).asDocument().get("_id")); - break; - } - case "bulkWrite": { - Assertions.assertTrue(commandName.equals("bulkWrite")); - generatedIdConsumer.accept(event.getCommand().getArray("ops").get(0).asDocument().getDocument("document").get("_id")); - break; - } - default: { - // nothing to do - } - } - } - }; + final BiFunction, BsonValue> insertOperation) throws InterruptedException { + TestCommandListener commandListener = new TestCommandListener(); BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(1))) .append("data", new BsonDocument() @@ -398,10 +516,26 @@ public void commandStarted(final CommandStartedEvent event) { } else { assertNull(insertedId); } - Duration timeout = Duration.ofSeconds(10); - BsonValue idGeneratedByFirstInsertAttempt = futureIdGeneratedByFirstInsertAttempt.get(timeout.toMillis(), TimeUnit.MILLISECONDS); - assertEquals(idGeneratedByFirstInsertAttempt, insertedId); - assertEquals(idGeneratedByFirstInsertAttempt, futureIdGeneratedBySecondInsertAttempt.get(timeout.toMillis(), TimeUnit.MILLISECONDS)); + List startedCommandEvents = commandListener.getCommandStartedEvents(commandName); + assertEquals(2, startedCommandEvents.size()); + Function idFromCommand; + switch (commandName) { + case "insert": { + idFromCommand = command -> command.getArray("documents").get(0).asDocument().get("_id"); + break; + } + case "bulkWrite": { + idFromCommand = command -> command.getArray("ops").get(0).asDocument().getDocument("document").get("_id"); + break; + } + default: { + throw Assertions.fail(commandName); + } + } + CommandStartedEvent firstEvent = startedCommandEvents.get(0); + CommandStartedEvent secondEvent = startedCommandEvents.get(1); + assertEquals(insertedId, idFromCommand.apply(firstEvent.getCommand())); + assertEquals(insertedId, idFromCommand.apply(secondEvent.getCommand())); } } @@ -410,19 +544,15 @@ protected MongoClient createMongoClient(final MongoClientSettings.Builder mongoC } private MongoCollection droppedCollection(final MongoClient client, final Class documentClass) { - return droppedDatabase(client).getCollection(namespace().getCollectionName(), documentClass); + return droppedDatabase(client).getCollection(NAMESPACE.getCollectionName(), documentClass); } private MongoDatabase droppedDatabase(final MongoClient client) { - MongoDatabase database = client.getDatabase(namespace().getDatabaseName()); + MongoDatabase database = client.getDatabase(NAMESPACE.getDatabaseName()); database.drop(); return database; } - private MongoNamespace namespace() { - return new MongoNamespace(getDefaultDatabaseName(), getClass().getSimpleName()); - } - public static final class MyDocument { private int v; @@ -433,4 +563,9 @@ public int getV() { return v; } } + + @FunctionalInterface + private interface TriConsumer { + void accept(A1 a1, A2 a2, A3 a3); + } } diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java index 642999547e9..c838f760552 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java @@ -261,7 +261,16 @@ public void setUp( @AfterEach public void cleanUp() { for (FailPoint failPoint : failPoints) { - failPoint.disableFailPoint(); + try { + // BULK-TODO remove the try-catch block + failPoint.disableFailPoint(); + } catch (Throwable e) { + for (Throwable suppressed : e.getSuppressed()) { + if (suppressed instanceof TestAbortedException) { + throw (TestAbortedException) suppressed; + } + } + } } entities.close(); } @@ -376,6 +385,14 @@ private void assertOperation(final UnifiedTestContext context, final BsonDocumen private static void assertOperationResult(final UnifiedTestContext context, final BsonDocument operation, final int operationIndex, final OperationResult result) { + if (result.getException() instanceof org.opentest4j.TestAbortedException) { + // BULK-TODO remove + throw (org.opentest4j.TestAbortedException) result.getException(); + } + if (result.getException() instanceof org.junit.AssumptionViolatedException) { + // BULK-TODO remove + throw (org.junit.AssumptionViolatedException) result.getException(); + } context.getAssertionContext().push(ContextElement.ofCompletedOperation(operation, result, operationIndex)); if (!operation.getBoolean("ignoreResultAndError", BsonBoolean.FALSE).getValue()) { diff --git a/driver-sync/src/test/unit/com/mongodb/client/internal/CryptConnectionSpecification.groovy b/driver-sync/src/test/unit/com/mongodb/client/internal/CryptConnectionSpecification.groovy index 8293b6a1599..8a38f966754 100644 --- a/driver-sync/src/test/unit/com/mongodb/client/internal/CryptConnectionSpecification.groovy +++ b/driver-sync/src/test/unit/com/mongodb/client/internal/CryptConnectionSpecification.groovy @@ -27,6 +27,7 @@ import com.mongodb.internal.TimeoutContext import com.mongodb.internal.bulk.InsertRequest import com.mongodb.internal.bulk.WriteRequestWithIndex import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.MessageSequences import com.mongodb.internal.connection.SplittablePayload import com.mongodb.internal.time.Timeout import com.mongodb.internal.validator.NoOpFieldNameValidator @@ -96,7 +97,7 @@ class CryptConnectionSpecification extends Specification { encryptedCommand } 1 * wrappedConnection.command('db', encryptedCommand, _ as NoOpFieldNameValidator, ReadPreference.primary(), - _ as RawBsonDocumentCodec, operationContext, true, null, null) >> { + _ as RawBsonDocumentCodec, operationContext, true, MessageSequences.EmptyMessageSequences.INSTANCE) >> { encryptedResponse } 1 * crypt.decrypt(encryptedResponse, operationTimeout) >> { @@ -115,7 +116,7 @@ class CryptConnectionSpecification extends Specification { def payload = new SplittablePayload(INSERT, [ new BsonDocumentWrapper(new Document('_id', 1).append('ssid', '555-55-5555').append('b', bytes), codec), new BsonDocumentWrapper(new Document('_id', 2).append('ssid', '666-66-6666').append('b', bytes), codec) - ].withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + ].withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, NoOpFieldNameValidator.INSTANCE) def encryptedCommand = toRaw(new BsonDocument('insert', new BsonString('test')).append('documents', new BsonArray( [ new BsonDocument('_id', new BsonInt32(1)) @@ -133,8 +134,7 @@ class CryptConnectionSpecification extends Specification { when: def response = cryptConnection.command('db', new BsonDocumentWrapper(new Document('insert', 'test'), codec), - NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), - operationContext, true, payload, NoOpFieldNameValidator.INSTANCE,) + NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), operationContext, true, payload) then: _ * wrappedConnection.getDescription() >> { @@ -151,7 +151,7 @@ class CryptConnectionSpecification extends Specification { encryptedCommand } 1 * wrappedConnection.command('db', encryptedCommand, _ as NoOpFieldNameValidator, ReadPreference.primary(), - _ as RawBsonDocumentCodec, operationContext, true, null, null,) >> { + _ as RawBsonDocumentCodec, operationContext, true, MessageSequences.EmptyMessageSequences.INSTANCE) >> { encryptedResponse } 1 * crypt.decrypt(encryptedResponse, operationTimeout) >> { @@ -172,7 +172,7 @@ class CryptConnectionSpecification extends Specification { new BsonDocumentWrapper(new Document('_id', 1), codec), new BsonDocumentWrapper(new Document('_id', 2), codec), new BsonDocumentWrapper(new Document('_id', 3), codec) - ].withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + ].withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, NoOpFieldNameValidator.INSTANCE) def encryptedCommand = toRaw(new BsonDocument('insert', new BsonString('test')).append('documents', new BsonArray( [ new BsonDocument('_id', new BsonInt32(1)), @@ -190,8 +190,7 @@ class CryptConnectionSpecification extends Specification { when: def response = cryptConnection.command('db', new BsonDocumentWrapper(new Document('insert', 'test'), codec), - NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), operationContext, true, payload, - NoOpFieldNameValidator.INSTANCE) + NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), operationContext, true, payload) then: _ * wrappedConnection.getDescription() >> { @@ -207,7 +206,7 @@ class CryptConnectionSpecification extends Specification { encryptedCommand } 1 * wrappedConnection.command('db', encryptedCommand, _ as NoOpFieldNameValidator, ReadPreference.primary(), - _ as RawBsonDocumentCodec, operationContext, true, null, null,) >> { + _ as RawBsonDocumentCodec, operationContext, true, MessageSequences.EmptyMessageSequences.INSTANCE) >> { encryptedResponse } 1 * crypt.decrypt(encryptedResponse, operationTimeout) >> {