21
21
import static com .google .cloud .dataflow .sdk .runners .worker .SourceTranslationUtils .cloudProgressToReaderProgress ;
22
22
import static com .google .cloud .dataflow .sdk .runners .worker .SourceTranslationUtils .splitRequestToApproximateSplitRequest ;
23
23
import static com .google .cloud .dataflow .sdk .util .common .Counter .AggregationKind .SUM ;
24
+ import static com .google .common .base .Preconditions .checkState ;
24
25
25
26
import com .google .api .services .dataflow .model .ApproximateReportedProgress ;
26
27
import com .google .api .services .dataflow .model .ApproximateSplitRequest ;
27
28
import com .google .cloud .dataflow .sdk .coders .Coder ;
29
+ import com .google .cloud .dataflow .sdk .coders .Coder .Context ;
28
30
import com .google .cloud .dataflow .sdk .coders .IterableCoder ;
29
31
import com .google .cloud .dataflow .sdk .coders .KvCoder ;
30
32
import com .google .cloud .dataflow .sdk .options .PipelineOptions ;
54
56
import org .slf4j .Logger ;
55
57
import org .slf4j .LoggerFactory ;
56
58
59
+ import java .io .ByteArrayInputStream ;
57
60
import java .io .IOException ;
58
61
import java .util .Iterator ;
59
62
import java .util .concurrent .atomic .AtomicLong ;
@@ -80,7 +83,8 @@ public class GroupingShuffleReader<K, V> extends NativeReader<WindowedValue<KV<K
80
83
// Counts how many bytes were from by a given operation from a given shuffle session.
81
84
@ Nullable Counter <Long > perOperationPerDatasetBytesCounter ;
82
85
Coder <K > keyCoder ;
83
- Coder <V > valueCoder ;
86
+ Coder <?> valueCoder ;
87
+ @ Nullable Coder <?> secondaryKeyCoder ;
84
88
85
89
public GroupingShuffleReader (
86
90
PipelineOptions options ,
@@ -90,15 +94,16 @@ public GroupingShuffleReader(
90
94
Coder <WindowedValue <KV <K , Iterable <V >>>> coder ,
91
95
BatchModeExecutionContext executionContext ,
92
96
CounterSet .AddCounterMutator addCounterMutator ,
93
- String operationName )
97
+ String operationName ,
98
+ boolean valuesAreSorted )
94
99
throws Exception {
95
100
this .shuffleReaderConfig = shuffleReaderConfig ;
96
101
this .startShufflePosition = startShufflePosition ;
97
102
this .stopShufflePosition = stopShufflePosition ;
98
103
this .executionContext = executionContext ;
99
104
this .addCounterMutator = addCounterMutator ;
100
105
this .operationName = operationName ;
101
- initCoder (coder );
106
+ initCoder (coder , valuesAreSorted );
102
107
// We cannot initialize perOperationPerDatasetBytesCounter here, as it
103
108
// depends on shuffleReaderConfig, which isn't populated yet.
104
109
}
@@ -131,7 +136,8 @@ public GroupingShuffleReaderIterator<K, V> iterator() throws IOException {
131
136
new ChunkingShuffleBatchReader (asr )));
132
137
}
133
138
134
- private void initCoder (Coder <WindowedValue <KV <K , Iterable <V >>>> coder ) throws Exception {
139
+ private void initCoder (Coder <WindowedValue <KV <K , Iterable <V >>>> coder ,
140
+ boolean valuesAreSorted ) throws Exception {
135
141
if (!(coder instanceof WindowedValueCoder )) {
136
142
throw new Exception ("unexpected kind of coder for WindowedValue: " + coder );
137
143
}
@@ -151,7 +157,17 @@ private void initCoder(Coder<WindowedValue<KV<K, Iterable<V>>>> coder) throws Ex
151
157
+ "a key-grouping shuffle" );
152
158
}
153
159
IterableCoder <V > iterCoder = (IterableCoder <V >) kvValueCoder ;
154
- this .valueCoder = iterCoder .getElemCoder ();
160
+ if (valuesAreSorted ) {
161
+ checkState (iterCoder .getElemCoder () instanceof KvCoder ,
162
+ "unexpected kind of coder for elements read from a "
163
+ + "key-grouping value sorting shuffle: %s" , iterCoder .getElemCoder ());
164
+ @ SuppressWarnings ("rawtypes" )
165
+ KvCoder <?, ?> valueKvCoder = (KvCoder ) iterCoder .getElemCoder ();
166
+ this .secondaryKeyCoder = valueKvCoder .getKeyCoder ();
167
+ this .valueCoder = valueKvCoder .getValueCoder ();
168
+ } else {
169
+ this .valueCoder = iterCoder .getElemCoder ();
170
+ }
155
171
}
156
172
157
173
final GroupingShuffleReaderIterator <K , V > iterator (ShuffleEntryReader reader ) {
@@ -390,7 +406,20 @@ public V next() {
390
406
// notify the bytes that have been read so far.
391
407
notifyValueReturned (currentGroupSize .getAndSet (0L ));
392
408
try {
393
- return CoderUtils .decodeFromByteArray (parentReader .valueCoder , entry .getValue ());
409
+ if (parentReader .secondaryKeyCoder != null ) {
410
+ ByteArrayInputStream bais = new ByteArrayInputStream (entry .getSecondaryKey ());
411
+ @ SuppressWarnings ("unchecked" )
412
+ V value = (V ) KV .of (
413
+ // We ignore decoding the timestamp.
414
+ parentReader .secondaryKeyCoder .decode (bais , Context .NESTED ),
415
+ CoderUtils .decodeFromByteArray (parentReader .valueCoder , entry .getValue ()));
416
+ return value ;
417
+ } else {
418
+ @ SuppressWarnings ("unchecked" )
419
+ V value = (V ) CoderUtils .decodeFromByteArray (parentReader .valueCoder ,
420
+ entry .getValue ());
421
+ return value ;
422
+ }
394
423
} catch (IOException exn ) {
395
424
throw new RuntimeException (exn );
396
425
}
0 commit comments