Skip to content
This repository was archived by the owner on Nov 11, 2022. It is now read-only.

Commit 054f1e6

Browse files
peihedavorbonaci
authored andcommitted
Optimize mergeAccumulators by reusing an existing accumulator
Also, added units tests for Sum, Min, Max, Mean. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=112723885
1 parent 6e5d743 commit 054f1e6

File tree

5 files changed

+121
-19
lines changed

5 files changed

+121
-19
lines changed

sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -534,17 +534,23 @@ public Holder<V> addInput(Holder<V> accumulator, V input) {
534534

535535
@Override
536536
public Holder<V> mergeAccumulators(Iterable<Holder<V>> accumulators) {
537-
Holder<V> running = new Holder<>();
538-
for (Holder<V> accumulator : accumulators) {
539-
if (accumulator.present) {
540-
if (running.present) {
541-
running.set(apply(running.value, accumulator.value));
542-
} else {
543-
running.set(accumulator.value);
537+
Iterator<Holder<V>> iter = accumulators.iterator();
538+
if (!iter.hasNext()) {
539+
return createAccumulator();
540+
} else {
541+
Holder<V> running = iter.next();
542+
while (iter.hasNext()) {
543+
Holder<V> accum = iter.next();
544+
if (accum.present) {
545+
if (running.present) {
546+
running.set(apply(running.value, accum.value));
547+
} else {
548+
running.set(accum.value);
549+
}
544550
}
545551
}
552+
return running;
546553
}
547-
return running;
548554
}
549555

550556
@Override
@@ -632,7 +638,9 @@ public void verifyDeterministic() throws NonDeterministicException {
632638

633639
/**
634640
* An abstract subclass of {@link CombineFn} for implementing combiners that are more
635-
* easily and efficiently expressed as binary operations on <code>int</code>s.
641+
* easily and efficiently expressed as binary operations on <code>int</code>s
642+
*
643+
* <p> It uses {@code int[0]} as the mutable accumulator.
636644
*/
637645
public abstract static class BinaryCombineIntegerFn extends CombineFn<Integer, int[], Integer> {
638646

@@ -664,11 +672,11 @@ public int[] mergeAccumulators(Iterable<int[]> accumulators) {
664672
if (!iter.hasNext()) {
665673
return createAccumulator();
666674
} else {
667-
int running = iter.next()[0];
675+
int[] running = iter.next();
668676
while (iter.hasNext()) {
669-
running = apply(running, iter.next()[0]);
677+
running[0] = apply(running[0], iter.next()[0]);
670678
}
671-
return wrap(running);
679+
return running;
672680
}
673681
}
674682

@@ -713,6 +721,8 @@ public Counter<Integer> getCounter(String name) {
713721
/**
714722
* An abstract subclass of {@link CombineFn} for implementing combiners that are more
715723
* easily and efficiently expressed as binary operations on <code>long</code>s.
724+
*
725+
* <p> It uses {@code long[0]} as the mutable accumulator.
716726
*/
717727
public abstract static class BinaryCombineLongFn extends CombineFn<Long, long[], Long> {
718728
/**
@@ -743,11 +753,11 @@ public long[] mergeAccumulators(Iterable<long[]> accumulators) {
743753
if (!iter.hasNext()) {
744754
return createAccumulator();
745755
} else {
746-
long running = iter.next()[0];
756+
long[] running = iter.next();
747757
while (iter.hasNext()) {
748-
running = apply(running, iter.next()[0]);
758+
running[0] = apply(running[0], iter.next()[0]);
749759
}
750-
return wrap(running);
760+
return running;
751761
}
752762
}
753763

@@ -791,6 +801,8 @@ public Counter<Long> getCounter(String name) {
791801
/**
792802
* An abstract subclass of {@link CombineFn} for implementing combiners that are more
793803
* easily and efficiently expressed as binary operations on <code>double</code>s.
804+
*
805+
* <p> It uses {@code double[0]} as the mutable accumulator.
794806
*/
795807
public abstract static class BinaryCombineDoubleFn extends CombineFn<Double, double[], Double> {
796808

@@ -822,11 +834,11 @@ public double[] mergeAccumulators(Iterable<double[]> accumulators) {
822834
if (!iter.hasNext()) {
823835
return createAccumulator();
824836
} else {
825-
double running = iter.next()[0];
837+
double[] running = iter.next();
826838
while (iter.hasNext()) {
827-
running = apply(running, iter.next()[0]);
839+
running[0] = apply(running[0], iter.next()[0]);
828840
}
829-
return wrap(running);
841+
return running;
830842
}
831843
}
832844

sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MaxTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616

1717
package com.google.cloud.dataflow.sdk.transforms;
1818

19+
import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn;
1920
import static org.junit.Assert.assertEquals;
2021

22+
import com.google.common.collect.Lists;
23+
2124
import org.junit.Test;
2225
import org.junit.runner.RunWith;
2326
import org.junit.runners.JUnit4;
@@ -36,4 +39,28 @@ public void testMeanGetNames() {
3639
assertEquals("Max.PerKey", Max.doublesPerKey().getName());
3740
assertEquals("Max.PerKey", Max.longsPerKey().getName());
3841
}
42+
43+
@Test
44+
public void testMaxIntegerFn() {
45+
checkCombineFn(
46+
new Max.MaxIntegerFn(),
47+
Lists.newArrayList(1, 2, 3, 4),
48+
4);
49+
}
50+
51+
@Test
52+
public void testMaxLongFn() {
53+
checkCombineFn(
54+
new Max.MaxLongFn(),
55+
Lists.newArrayList(1L, 2L, 3L, 4L),
56+
4L);
57+
}
58+
59+
@Test
60+
public void testMaxDoubleFn() {
61+
checkCombineFn(
62+
new Max.MaxDoubleFn(),
63+
Lists.newArrayList(1.0, 2.0, 3.0, 4.0),
64+
4.0);
65+
}
3966
}

sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MeanTest.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
package com.google.cloud.dataflow.sdk.transforms;
1818

19+
import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn;
1920
import static org.junit.Assert.assertEquals;
2021

2122
import com.google.cloud.dataflow.sdk.coders.Coder;
2223
import com.google.cloud.dataflow.sdk.testing.CoderProperties;
2324
import com.google.cloud.dataflow.sdk.transforms.Mean.CountSum;
2425
import com.google.cloud.dataflow.sdk.transforms.Mean.CountSumCoder;
26+
import com.google.common.collect.Lists;
2527

2628
import org.junit.Test;
2729
import org.junit.runner.RunWith;
@@ -59,4 +61,12 @@ public void testCountSumCoderEncodeDecode() throws Exception {
5961
public void testCountSumCoderSerializable() throws Exception {
6062
CoderProperties.coderSerializable(TEST_CODER);
6163
}
64+
65+
@Test
66+
public void testMeanFn() throws Exception {
67+
checkCombineFn(
68+
new Mean.MeanFn<Integer>(),
69+
Lists.newArrayList(1, 2, 3, 4),
70+
2.5);
71+
}
6272
}

sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MinTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616

1717
package com.google.cloud.dataflow.sdk.transforms;
1818

19+
import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn;
1920
import static org.junit.Assert.assertEquals;
2021

22+
import com.google.common.collect.Lists;
23+
2124
import org.junit.Test;
2225
import org.junit.runner.RunWith;
2326
import org.junit.runners.JUnit4;
@@ -36,4 +39,28 @@ public void testMeanGetNames() {
3639
assertEquals("Min.PerKey", Min.doublesPerKey().getName());
3740
assertEquals("Min.PerKey", Min.longsPerKey().getName());
3841
}
42+
43+
@Test
44+
public void testMinIntegerFn() {
45+
checkCombineFn(
46+
new Min.MinIntegerFn(),
47+
Lists.newArrayList(1, 2, 3, 4),
48+
1);
49+
}
50+
51+
@Test
52+
public void testMinLongFn() {
53+
checkCombineFn(
54+
new Min.MinLongFn(),
55+
Lists.newArrayList(1L, 2L, 3L, 4L),
56+
1L);
57+
}
58+
59+
@Test
60+
public void testMinDoubleFn() {
61+
checkCombineFn(
62+
new Min.MinDoubleFn(),
63+
Lists.newArrayList(1.0, 2.0, 3.0, 4.0),
64+
1.0);
65+
}
3966
}

sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SumTest.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
* License for the specific language governing permissions and limitations under
1414
* the License.
1515
*/
16-
1716
package com.google.cloud.dataflow.sdk.transforms;
1817

18+
import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn;
1919
import static org.junit.Assert.assertEquals;
2020

21+
import com.google.common.collect.Lists;
22+
2123
import org.junit.Test;
2224
import org.junit.runner.RunWith;
2325
import org.junit.runners.JUnit4;
@@ -37,4 +39,28 @@ public void testSumGetNames() {
3739
assertEquals("Sum.PerKey", Sum.doublesPerKey().getName());
3840
assertEquals("Sum.PerKey", Sum.longsPerKey().getName());
3941
}
42+
43+
@Test
44+
public void testSumIntegerFn() {
45+
checkCombineFn(
46+
new Sum.SumIntegerFn(),
47+
Lists.newArrayList(1, 2, 3, 4),
48+
10);
49+
}
50+
51+
@Test
52+
public void testSumLongFn() {
53+
checkCombineFn(
54+
new Sum.SumLongFn(),
55+
Lists.newArrayList(1L, 2L, 3L, 4L),
56+
10L);
57+
}
58+
59+
@Test
60+
public void testSumDoubleFn() {
61+
checkCombineFn(
62+
new Sum.SumDoubleFn(),
63+
Lists.newArrayList(1.0, 2.0, 3.0, 4.0),
64+
10.0);
65+
}
4066
}

0 commit comments

Comments
 (0)