diff --git a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java index 3dd1effc3b52..d7bf6363e465 100644 --- a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java +++ b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java @@ -1407,13 +1407,16 @@ Mono downloadStreamWithResponse(BlobRange range, Down if (decoderStateObj instanceof StorageContentValidationDecoderPolicy.DecoderState) { DecoderState decoderState = (DecoderState) decoderStateObj; - // Use totalEncodedBytesProcessed to request NEW bytes from the server - // The pending buffer already contains bytes we've received, so we request - // starting from the next byte after what we've already received - long encodedOffset = decoderState.getTotalEncodedBytesProcessed(); + // Use getRetryOffset() to get the correct offset for retry + // This accounts for pending bytes that have been received but not yet consumed + long encodedOffset = decoderState.getRetryOffset(); long remainingCount = finalCount - encodedOffset; retryRange = new BlobRange(initialOffset + encodedOffset, remainingCount); + LOGGER.info( + "Structured message smart retry: resuming from offset {} (initial={}, encoded={})", + initialOffset + encodedOffset, initialOffset, encodedOffset); + // Preserve the decoder state for the retry retryContext = retryContext .addData(Constants.STRUCTURED_MESSAGE_DECODER_STATE_CONTEXT_KEY, decoderState); diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecoder.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecoder.java index 6117a7765541..be12a2d5ea85 100644 --- a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecoder.java +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/structuredmessage/StructuredMessageDecoder.java @@ -27,15 +27,19 @@ public class StructuredMessageDecoder { private int numSegments; private final long expectedContentLength; - private int messageOffset = 0; + private long messageOffset = 0; private int currentSegmentNumber = 0; - private int currentSegmentContentLength = 0; - private int currentSegmentContentOffset = 0; + private long currentSegmentContentLength = 0; + private long currentSegmentContentOffset = 0; private long messageCrc64 = 0; private long segmentCrc64 = 0; private final Map segmentCrcs = new HashMap<>(); + // Track the last complete segment boundary for smart retry + private long lastCompleteSegmentStart = 0; + private long currentSegmentStart = 0; + /** * Constructs a new StructuredMessageDecoder. * @@ -45,6 +49,50 @@ public StructuredMessageDecoder(long expectedContentLength) { this.expectedContentLength = expectedContentLength; } + /** + * Gets the byte offset where the last complete segment ended. + * This is used for smart retry to resume from a segment boundary. + * + * @return The byte offset of the last complete segment boundary. + */ + public long getLastCompleteSegmentStart() { + return lastCompleteSegmentStart; + } + + /** + * Gets the current message offset (total bytes consumed from the structured message). + * + * @return The current message offset. + */ + public long getMessageOffset() { + return messageOffset; + } + + /** + * Resets the decoder position to the last complete segment boundary. + * This is used during smart retry to ensure the decoder is in sync with + * the data being provided from the retry offset. + */ + public void resetToLastCompleteSegment() { + if (messageOffset != lastCompleteSegmentStart) { + LOGGER.atInfo() + .addKeyValue("fromOffset", messageOffset) + .addKeyValue("toOffset", lastCompleteSegmentStart) + .addKeyValue("currentSegmentNum", currentSegmentNumber) + .addKeyValue("currentSegmentContentOffset", currentSegmentContentOffset) + .addKeyValue("currentSegmentContentLength", currentSegmentContentLength) + .log("Resetting decoder to last complete segment boundary"); + messageOffset = lastCompleteSegmentStart; + // Reset current segment state - next decode will read the segment header + currentSegmentContentOffset = 0; + currentSegmentContentLength = 0; + } else { + LOGGER.atVerbose() + .addKeyValue("offset", messageOffset) + .log("Decoder already at last complete segment boundary, no reset needed"); + } + } + /** * Reads the message header from the given buffer. * @@ -79,6 +127,91 @@ private void readMessageHeader(ByteBuffer buffer) { messageOffset += V1_HEADER_LENGTH; } + /** + * Converts a ByteBuffer range to hex string for diagnostic purposes. + */ + private static String toHex(ByteBuffer buf, int len) { + int pos = buf.position(); + int peek = Math.min(len, buf.remaining()); + byte[] out = new byte[peek]; + buf.get(out, 0, peek); + buf.position(pos); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < out.length; i++) { + sb.append(String.format("%02X", out[i])); + if (i < out.length - 1) + sb.append(' '); + } + return sb.toString(); + } + + /** + * Peeks the next segment length without consuming from the buffer. + * Used by the policy to calculate encoded segment size before slicing. + * + * @param buffer The buffer to peek from. + * @param relativeIndex The position in the buffer to start reading from. + * @return The segment content length, or -1 if not enough bytes. + */ + public long peekNextSegmentLength(ByteBuffer buffer, int relativeIndex) { + // Need at least V1_SEGMENT_HEADER_LENGTH bytes to read segment number (2) + segment size (8) + if (relativeIndex + V1_SEGMENT_HEADER_LENGTH > buffer.limit()) { + return -1; + } + // Segment size is at offset 2 (after segment number which is 2 bytes) + return buffer.getLong(relativeIndex + 2); + } + + /** + * Gets the flags for the current message (needed to determine if CRC is present). + * + * @return The message flags, or null if header not yet read. + */ + public StructuredMessageFlags getFlags() { + return flags; + } + + /** + * Reads and validates segment length with diagnostic logging. + */ + private long readAndValidateSegmentLength(ByteBuffer buffer, long remaining) { + final int SEGMENT_SIZE_BYTES = 8; // segment size is 8 bytes (long) + if (buffer.remaining() < SEGMENT_SIZE_BYTES) { + LOGGER.error("Not enough bytes to read segment size. pos={}, remaining={}", buffer.position(), + buffer.remaining()); + throw new IllegalStateException("Not enough bytes to read segment size"); + } + + // Diagnostic: dump first 16 bytes at this position so we can see what's being read + LOGGER.atInfo() + .addKeyValue("decoderOffset", messageOffset) + .addKeyValue("bufferPos", buffer.position()) + .addKeyValue("bufferRemaining", buffer.remaining()) + .addKeyValue("peek16", toHex(buffer, 16)) + .addKeyValue("lastCompleteSegment", lastCompleteSegmentStart) + .log("Decoder about to read segment length"); + + long segmentLength = buffer.getLong(); + + if (segmentLength < 0 || segmentLength > remaining) { + // Peek next bytes for extra detail + String peekNext = toHex(buffer, 16); + LOGGER.error( + "Invalid segment length read: segmentLength={}, remaining={}, decoderOffset={}, " + + "lastCompleteSegment={}, bufferPos={}, peek-next-bytes={}", + segmentLength, remaining, messageOffset, lastCompleteSegmentStart, buffer.position(), peekNext); + throw new IllegalArgumentException("Invalid segment size detected: " + segmentLength + " (remaining=" + + remaining + ", decoderOffset=" + messageOffset + ")"); + } + + LOGGER.atVerbose() + .addKeyValue("segmentLength", segmentLength) + .addKeyValue("decoderOffset", messageOffset) + .log("Valid segment length read"); + + return segmentLength; + } + /** * Reads the segment header from the given buffer. * @@ -90,13 +223,13 @@ private void readSegmentHeader(ByteBuffer buffer) { throw LOGGER.logExceptionAsError(new IllegalArgumentException("Segment header is incomplete.")); } + // Mark the start of this segment (before reading the header) + currentSegmentStart = messageOffset; + int segmentNum = Short.toUnsignedInt(buffer.getShort()); - int segmentSize = (int) buffer.getLong(); - if (segmentSize < 0 || segmentSize > buffer.remaining()) { - throw LOGGER - .logExceptionAsError(new IllegalArgumentException("Invalid segment size detected: " + segmentSize)); - } + // Read segment size with validation and diagnostics + long segmentSize = readAndValidateSegmentLength(buffer, buffer.remaining()); if (segmentNum != currentSegmentNumber + 1) { throw LOGGER.logExceptionAsError(new IllegalArgumentException("Unexpected segment number.")); @@ -126,8 +259,8 @@ private void readSegmentHeader(ByteBuffer buffer) { * @throws IllegalArgumentException if there is a segment size mismatch. */ private void readSegmentContent(ByteBuffer buffer, ByteArrayOutputStream output, int size) { - int toRead = Math.min(buffer.remaining(), currentSegmentContentLength - currentSegmentContentOffset); - toRead = Math.min(toRead, size); + long remaining = currentSegmentContentLength - currentSegmentContentOffset; + int toRead = (int) Math.min(buffer.remaining(), Math.min(remaining, size)); if (toRead == 0) { return; @@ -182,10 +315,17 @@ private void readSegmentFooter(ByteBuffer buffer) { messageOffset += CRC64_LENGTH; } + // Mark that this segment is complete - update the last complete segment boundary + // This is the position where we can safely resume if a retry occurs + lastCompleteSegmentStart = messageOffset; + LOGGER.atInfo() + .addKeyValue("segmentNum", currentSegmentNumber) + .addKeyValue("offset", lastCompleteSegmentStart) + .addKeyValue("segmentLength", currentSegmentContentLength) + .log("Segment complete at byte offset"); + if (currentSegmentNumber == numSegments) { readMessageFooter(buffer); - } else { - readSegmentHeader(buffer); } } diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java index f33d9fcef890..0ecd5076f8bf 100644 --- a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java @@ -14,11 +14,14 @@ import com.azure.core.util.logging.ClientLogger; import com.azure.storage.common.DownloadContentValidationOptions; import com.azure.storage.common.implementation.Constants; +import com.azure.storage.common.implementation.structuredmessage.StructuredMessageConstants; import com.azure.storage.common.implementation.structuredmessage.StructuredMessageDecoder; +import com.azure.storage.common.implementation.structuredmessage.StructuredMessageFlags; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.Charset; import java.util.concurrent.atomic.AtomicLong; @@ -78,6 +81,8 @@ public Mono process(HttpPipelineCallContext context, HttpPipelineN /** * Decodes a stream of byte buffers using the decoder state. + * Uses relative indexing based on decoder's message offset to correctly + * slice encoded segments and handle pending buffers across chunks. * * @param encodedFlux The flux of encoded byte buffers. * @param state The decoder state. @@ -85,59 +90,142 @@ public Mono process(HttpPipelineCallContext context, HttpPipelineN */ private Flux decodeStream(Flux encodedFlux, DecoderState state) { return encodedFlux.concatMap(encodedBuffer -> { - // Combine with pending data if any - ByteBuffer dataToProcess = state.combineWithPending(encodedBuffer); + // Capture absoluteStartOfCombined BEFORE adding new bytes + long absoluteStartOfCombined = state.totalEncodedBytesProcessed.get(); - // Track encoded bytes - int encodedBytesInBuffer = encodedBuffer.remaining(); - state.totalEncodedBytesProcessed.addAndGet(encodedBytesInBuffer); + // Track the NEW bytes received from the network + int newBytesReceived = encodedBuffer.remaining(); + // Note: we add to totalEncodedBytesProcessed AFTER we're done processing this chunk + + int pendingSize = (state.pendingBuffer != null) ? state.pendingBuffer.remaining() : 0; + // Adjust absoluteStartOfCombined to account for pending bytes that came before + absoluteStartOfCombined -= pendingSize; + + LOGGER.atInfo() + .addKeyValue("newBytes", newBytesReceived) + .addKeyValue("pendingBytes", pendingSize) + .addKeyValue("absoluteStartOfCombined", absoluteStartOfCombined) + .addKeyValue("decoderOffset", state.decoder.getMessageOffset()) + .addKeyValue("lastCompleteSegment", state.decoder.getLastCompleteSegmentStart()) + .log("Received buffer in decodeStream"); + + // Combine with pending data if any - always returns buffer with position=0 and LITTLE_ENDIAN + ByteBuffer combined = state.combineWithPending(encodedBuffer); try { - // Try to decode what we have - decoder handles partial data - // Create duplicate for decoder - it will advance the duplicate's position as it reads - int availableSize = dataToProcess.remaining(); - ByteBuffer duplicateForDecode = dataToProcess.duplicate(); - int initialPosition = duplicateForDecode.position(); - - // Decode - this advances duplicateForDecode's position - ByteBuffer decodedData = state.decoder.decode(duplicateForDecode, availableSize); - - // Track decoded bytes - int decodedBytes = decodedData.remaining(); - state.totalBytesDecoded.addAndGet(decodedBytes); - - // Calculate how much of the input buffer was consumed by checking the duplicate's position - int bytesConsumed = duplicateForDecode.position() - initialPosition; - int bytesRemaining = availableSize - bytesConsumed; - - // Save only unconsumed portion to pending - if (bytesRemaining > 0) { - // Position the original buffer to skip consumed bytes, then slice to get unconsumed - dataToProcess.position(bytesConsumed); - ByteBuffer unconsumed = dataToProcess.slice(); - state.updatePendingBuffer(unconsumed); - } else { - // All data was consumed - state.pendingBuffer = null; + java.io.ByteArrayOutputStream decodedOutput = new java.io.ByteArrayOutputStream(); + + // Loop to decode complete segments from combined buffer + while (true) { + long decoderOffset = state.decoder.getMessageOffset(); + int relativeIndex = (int) (decoderOffset - absoluteStartOfCombined); + + // Defensive check + if (relativeIndex < 0) { + LOGGER.error( + "Negative relative index detected: relativeIndex={}, decoderOffset={}, absoluteStart={}", + relativeIndex, decoderOffset, absoluteStartOfCombined); + throw new IllegalStateException("Negative relative index: " + relativeIndex); + } + + // Check if we have enough for segment header + if (relativeIndex + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH > combined.limit()) { + // Save remaining as pending and break + if (relativeIndex < combined.limit()) { + combined.position(relativeIndex); + state.updatePendingBuffer(combined.slice()); + } else { + state.pendingBuffer = null; + } + break; + } + + // For the first chunk, we need to read message header first + if (decoderOffset == 0) { + // Decode up to message header length to bootstrap + ByteBuffer headerSlice = combined.duplicate(); + headerSlice.position(relativeIndex); + headerSlice.order(ByteOrder.LITTLE_ENDIAN); + ByteBuffer decoded + = state.decoder.decode(headerSlice, StructuredMessageConstants.V1_HEADER_LENGTH); + // After header is read, continue loop to process segments + continue; + } + + // Peek segment length + long segmentLength = state.decoder.peekNextSegmentLength(combined, relativeIndex); + if (segmentLength < 0) { + // Not enough bytes to read segment header + combined.position(relativeIndex); + state.updatePendingBuffer(combined.slice()); + break; + } + + // Calculate encoded segment size + int crcLength = (state.decoder.getFlags() == StructuredMessageFlags.STORAGE_CRC64) + ? StructuredMessageConstants.CRC64_LENGTH + : 0; + long encodedSegmentSize + = StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + segmentLength + crcLength; + + // Check if we have the complete segment + if (relativeIndex + encodedSegmentSize > combined.limit()) { + // Save pending and break + combined.position(relativeIndex); + state.updatePendingBuffer(combined.slice()); + break; + } + + // Slice encoded segment + ByteBuffer encodedSlice = combined.duplicate(); + encodedSlice.position(relativeIndex); + encodedSlice.limit(relativeIndex + (int) encodedSegmentSize); + encodedSlice = encodedSlice.slice(); + encodedSlice.order(ByteOrder.LITTLE_ENDIAN); + + // Decode the segment + ByteBuffer decoded = state.decoder.decode(encodedSlice); + + LOGGER.atVerbose() + .addKeyValue("relativeIndex", relativeIndex) + .addKeyValue("encodedSegmentSize", encodedSegmentSize) + .addKeyValue("decodedBytes", decoded.remaining()) + .addKeyValue("newDecoderOffset", state.decoder.getMessageOffset()) + .log("Decoded segment"); + + // Update tracked bytes + state.totalEncodedBytesProcessed.addAndGet(encodedSegmentSize); + if (decoded.remaining() > 0) { + state.totalBytesDecoded.addAndGet(decoded.remaining()); + // Accumulate decoded bytes + byte[] decodedBytes = new byte[decoded.remaining()]; + decoded.get(decodedBytes); + decodedOutput.write(decodedBytes, 0, decodedBytes.length); + } + + // Check if we've completed the message + if (state.decoder.getMessageOffset() >= state.expectedContentLength) { + state.pendingBuffer = null; + break; + } } // Return decoded data if any - if (decodedBytes > 0) { - return Flux.just(decodedData); + byte[] decodedBytes = decodedOutput.toByteArray(); + if (decodedBytes.length > 0) { + return Flux.just(ByteBuffer.wrap(decodedBytes)); } else { return Flux.empty(); } + } catch (IllegalArgumentException e) { // Handle decoder exceptions - check if it's due to incomplete data String errorMsg = e.getMessage(); if (errorMsg != null && (errorMsg.contains("not long enough") || errorMsg.contains("is incomplete"))) { // Not enough data to decode yet - preserve all data in pending buffer - state.updatePendingBuffer(dataToProcess); - - // Don't fail - just return empty and wait for more data + state.updatePendingBuffer(combined); return Flux.empty(); } else { - // Other errors should propagate LOGGER.error("Failed to decode structured message chunk: " + e.getMessage(), e); return Flux.error(e); } @@ -250,29 +338,43 @@ public DecoderState(long expectedContentLength) { /** * Combines pending buffer with new data. + * Always returns a buffer with position=0 and LITTLE_ENDIAN byte order. * * @param newBuffer The new buffer to combine. - * @return Combined buffer. + * @return Combined buffer with LITTLE_ENDIAN byte order and position=0. */ private ByteBuffer combineWithPending(ByteBuffer newBuffer) { if (pendingBuffer == null || !pendingBuffer.hasRemaining()) { - return newBuffer.duplicate(); + // Return a duplicate slice with LITTLE_ENDIAN and position=0 + ByteBuffer dup = newBuffer.duplicate().slice(); + dup.order(java.nio.ByteOrder.LITTLE_ENDIAN); + return dup; } - ByteBuffer combined = ByteBuffer.allocate(pendingBuffer.remaining() + newBuffer.remaining()); - combined.put(pendingBuffer.duplicate()); - combined.put(newBuffer.duplicate()); + // Create slices with LITTLE_ENDIAN order + ByteBuffer pendingSlice = pendingBuffer.duplicate().slice(); + pendingSlice.order(java.nio.ByteOrder.LITTLE_ENDIAN); + ByteBuffer newSlice = newBuffer.duplicate().slice(); + newSlice.order(java.nio.ByteOrder.LITTLE_ENDIAN); + + // Allocate combined buffer with LITTLE_ENDIAN order + ByteBuffer combined = ByteBuffer.allocate(pendingSlice.remaining() + newSlice.remaining()); + combined.order(java.nio.ByteOrder.LITTLE_ENDIAN); + combined.put(pendingSlice); + combined.put(newSlice); combined.flip(); return combined; } /** * Updates the pending buffer with remaining data. + * Allocates a new buffer with LITTLE_ENDIAN byte order. * * @param dataToProcess The buffer with remaining data. */ private void updatePendingBuffer(ByteBuffer dataToProcess) { pendingBuffer = ByteBuffer.allocate(dataToProcess.remaining()); + pendingBuffer.order(java.nio.ByteOrder.LITTLE_ENDIAN); pendingBuffer.put(dataToProcess); pendingBuffer.flip(); } @@ -297,15 +399,46 @@ public long getTotalEncodedBytesProcessed() { /** * Gets the offset to use for retry requests. - * This is the total encoded bytes processed minus any bytes in the pending buffer, - * since pending bytes have already been counted but haven't been successfully processed yet. + * This uses the decoder's last complete segment boundary to ensure retries + * resume from a valid segment boundary, not mid-segment. + * + * Also clears the pending buffer and resets decoder state to align with + * the segment boundary. * - * @return The offset for retry requests. + * @return The offset for retry requests (last complete segment boundary). */ public long getRetryOffset() { - long processed = totalEncodedBytesProcessed.get(); - int pending = (pendingBuffer != null) ? pendingBuffer.remaining() : 0; - return processed - pending; + // Use the decoder's last complete segment start as the retry offset + // This ensures we resume from a segment boundary, not mid-segment + long retryOffset = decoder.getLastCompleteSegmentStart(); + long decoderOffsetBefore = decoder.getMessageOffset(); + int pendingSize = (pendingBuffer != null) ? pendingBuffer.remaining() : 0; + + LOGGER.atInfo() + .addKeyValue("retryOffset", retryOffset) + .addKeyValue("decoderOffsetBefore", decoderOffsetBefore) + .addKeyValue("pendingBytes", pendingSize) + .addKeyValue("totalProcessed", totalEncodedBytesProcessed.get()) + .log("Computing retry offset"); + + // Reset decoder to the last complete segment boundary + // This ensures messageOffset and segment state match the retry offset + decoder.resetToLastCompleteSegment(); + + // Clear pending buffer since we're restarting from the segment boundary + // Any bytes in pending are from after this boundary and will be re-fetched + if (pendingBuffer != null && pendingBuffer.hasRemaining()) { + LOGGER.atInfo() + .addKeyValue("pendingBytes", pendingBuffer.remaining()) + .addKeyValue("retryOffset", retryOffset) + .log("Clearing pending bytes for retry from segment boundary"); + pendingBuffer = null; + } + + LOGGER.atInfo() + .addKeyValue("retryOffset", retryOffset) + .log("Retry offset calculated (last complete segment boundary)"); + return retryOffset; } /**