diff --git a/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/JoinHints.java b/euphoria-core/src/main/java/cz/seznam/euphoria/core/client/operator/JoinHints.java similarity index 83% rename from euphoria-spark/src/main/java/cz/seznam/euphoria/spark/JoinHints.java rename to euphoria-core/src/main/java/cz/seznam/euphoria/core/client/operator/JoinHints.java index 37a1b07b..02b30ea1 100644 --- a/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/JoinHints.java +++ b/euphoria-core/src/main/java/cz/seznam/euphoria/core/client/operator/JoinHints.java @@ -13,10 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package cz.seznam.euphoria.spark; +package cz.seznam.euphoria.core.client.operator; import cz.seznam.euphoria.core.annotation.audience.Audience; -import cz.seznam.euphoria.core.client.operator.JoinHint; @Audience(Audience.Type.CLIENT) public class JoinHints { @@ -28,8 +27,7 @@ public static BroadcastHashJoin broadcastHashJoin() { } /** - * Broadcasts optional join side to all executors. See {@link BroadcastHashJoinTranslator} - * for more details. + * Broadcasts optional join side to all executors. */ public static class BroadcastHashJoin implements JoinHint { diff --git a/euphoria-core/src/main/java/cz/seznam/euphoria/core/executor/util/MultiValueContext.java b/euphoria-core/src/main/java/cz/seznam/euphoria/core/executor/util/MultiValueContext.java new file mode 100644 index 00000000..f482aed5 --- /dev/null +++ b/euphoria-core/src/main/java/cz/seznam/euphoria/core/executor/util/MultiValueContext.java @@ -0,0 +1,121 @@ +/* + * Copyright 2016-2018 Seznam.cz, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.seznam.euphoria.core.executor.util; + +import cz.seznam.euphoria.core.annotation.audience.Audience; +import cz.seznam.euphoria.core.client.accumulators.Counter; +import cz.seznam.euphoria.core.client.accumulators.Histogram; +import cz.seznam.euphoria.core.client.accumulators.Timer; +import cz.seznam.euphoria.core.client.dataset.windowing.Window; +import cz.seznam.euphoria.core.client.io.Collector; +import cz.seznam.euphoria.core.client.io.Context; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; + +@Audience(Audience.Type.EXECUTOR) +public class MultiValueContext implements Context, Collector { + + private final List elements = new ArrayList<>(1); + @Nullable + final Context wrap; + + public MultiValueContext() { + this(null); + } + + public MultiValueContext(Context wrap) { + this.wrap = wrap; + } + + /** + * Replace the stored value with given one. + * + * @param elem the element to store + */ + @Override + public void collect(T elem) { + elements.add(elem); + } + + @Override + public Context asContext() { + return this; + } + + /** + * Retrieve window associated with the stored element. + */ + @Override + public Window getWindow() throws UnsupportedOperationException { + if (wrap == null) { + throw new UnsupportedOperationException( + "The window is unknown in this context"); + } + return wrap.getWindow(); + } + + @Override + public Counter getCounter(String name) { + if (wrap == null) { + throw new UnsupportedOperationException( + "Accumulators not supported in this context"); + } + return wrap.getCounter(name); + } + + @Override + public Histogram getHistogram(String name) { + if (wrap == null) { + throw new UnsupportedOperationException( + "Accumulators not supported in this context"); + } + return wrap.getHistogram(name); + + } + + @Override + public Timer getTimer(String name) { + if (wrap == null) { + throw new UnsupportedOperationException( + "Accumulators not supported in this context"); + } + return wrap.getTimer(name); + + } + + /** + * Retrieve and reset the stored elements. + * + * @return the stored value + */ + public List getAndResetValue() { + List copiedElements = new ArrayList<>(elements); + elements.clear(); + return copiedElements; + } + + /** + * Retrieve value of this context. + * + * @return value + */ + public List get() { + return elements; + } +} + diff --git a/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BatchFlowTranslator.java b/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BatchFlowTranslator.java index bc59d6f3..b137d2d4 100644 --- a/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BatchFlowTranslator.java +++ b/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BatchFlowTranslator.java @@ -15,23 +15,18 @@ */ package cz.seznam.euphoria.flink.batch; -import cz.seznam.euphoria.flink.accumulators.FlinkAccumulatorFactory; -import cz.seznam.euphoria.shadow.com.google.common.base.Preconditions; import cz.seznam.euphoria.core.client.flow.Flow; import cz.seznam.euphoria.core.client.functional.UnaryPredicate; -import cz.seznam.euphoria.core.executor.graph.DAG; -import cz.seznam.euphoria.core.executor.graph.Node; import cz.seznam.euphoria.core.client.io.DataSink; -import cz.seznam.euphoria.core.client.operator.FlatMap; -import cz.seznam.euphoria.core.client.operator.Operator; -import cz.seznam.euphoria.core.client.operator.ReduceByKey; -import cz.seznam.euphoria.core.client.operator.ReduceStateByKey; -import cz.seznam.euphoria.core.client.operator.Union; +import cz.seznam.euphoria.core.client.operator.*; import cz.seznam.euphoria.core.executor.FlowUnfolder; +import cz.seznam.euphoria.core.executor.graph.DAG; +import cz.seznam.euphoria.core.executor.graph.Node; import cz.seznam.euphoria.core.util.Settings; import cz.seznam.euphoria.flink.FlinkOperator; import cz.seznam.euphoria.flink.FlowOptimizer; import cz.seznam.euphoria.flink.FlowTranslator; +import cz.seznam.euphoria.flink.accumulators.FlinkAccumulatorFactory; import cz.seznam.euphoria.flink.batch.io.DataSinkWrapper; import org.apache.flink.api.common.io.LocatableInputSplitAssigner; import org.apache.flink.api.java.DataSet; @@ -67,22 +62,23 @@ private Translation( this.accept = accept; } - static > void set( - Map idx, + static > void add( + Map> idx, Class type, BatchOperatorTranslator translator) { - set(idx, type, translator, null); + add(idx, type, translator, null); } - static > void set( - Map idx, + static > void add( + Map> idx, Class type, BatchOperatorTranslator translator, UnaryPredicate accept) { - idx.put(type, new Translation<>(translator, accept)); + idx.putIfAbsent(type, new ArrayList<>()); + idx.get(type).add(new Translation<>(translator, accept)); } } - private final Map translations = new IdentityHashMap<>(); + private final Map> translations = new IdentityHashMap<>(); private final Settings settings; private final ExecutionEnvironment env; @@ -103,21 +99,28 @@ public BatchFlowTranslator(Settings settings, this.accumulatorFactory = Objects.requireNonNull(accumulatorFactory); // basic operators - Translation.set(translations, FlowUnfolder.InputOperator.class, new InputTranslator(splitAssignerFactory)); - Translation.set(translations, FlatMap.class, new FlatMapTranslator()); - Translation.set(translations, ReduceStateByKey.class, new ReduceStateByKeyTranslator()); - Translation.set(translations, Union.class, new UnionTranslator()); + Translation.add(translations, FlowUnfolder.InputOperator.class, new InputTranslator + (splitAssignerFactory)); + Translation.add(translations, FlatMap.class, new FlatMapTranslator()); + Translation.add(translations, ReduceStateByKey.class, new ReduceStateByKeyTranslator()); + Translation.add(translations, Union.class, new UnionTranslator()); // derived operators - Translation.set(translations, ReduceByKey.class, new ReduceByKeyTranslator(), + Translation.add(translations, ReduceByKey.class, new ReduceByKeyTranslator(), ReduceByKeyTranslator::wantTranslate); + + // ~ batch broadcast join for a very small left side + Translation.add(translations, Join.class, new BroadcastHashJoinTranslator(), + BroadcastHashJoinTranslator::wantTranslate); } @SuppressWarnings("unchecked") @Override protected Collection getAcceptors() { return translations.entrySet().stream() - .map(e -> new TranslateAcceptor(e.getKey(), e.getValue().accept)) + .flatMap((entry) -> entry.getValue() + .stream() + .map(translator -> new TranslateAcceptor(entry.getKey(), translator.accept))) .collect(Collectors.toList()); } @@ -140,19 +143,27 @@ public List> translateInto(Flow flow) { // translate each operator to proper Flink transformation dag.traverse().map(Node::get).forEach(op -> { Operator originalOp = op.getOriginalOperator(); - Translation> tx = translations.get(originalOp.getClass()); - if (tx == null) { + List txs = this.translations.get(originalOp.getClass()); + if (txs.isEmpty()) { throw new UnsupportedOperationException( "Operator " + op.getClass().getSimpleName() + " not supported"); } // ~ verify the flowToDag translation - Preconditions.checkState( - tx.accept == null || Boolean.TRUE.equals(tx.accept.apply(originalOp))); - - DataSet out = tx.translator.translate(op, executorContext); - - // save output of current operator to context - executorContext.setOutput(op, out); + Translation> firstMatch = null; + for (Translation tx : txs) { + if (tx.accept == null || Boolean.TRUE.equals(tx.accept.apply(originalOp))) { + firstMatch = tx; + break; + } + } + final DataSet out; + if (firstMatch != null) { + out = firstMatch.translator.translate(op, executorContext); + // save output of current operator to context + executorContext.setOutput(op, out); + } else { + throw new IllegalStateException("No matching translation."); + } }); // process all sinks in the DAG (leaf nodes) diff --git a/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BroadcastHashJoinTranslator.java b/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BroadcastHashJoinTranslator.java new file mode 100644 index 00000000..88324a27 --- /dev/null +++ b/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BroadcastHashJoinTranslator.java @@ -0,0 +1,239 @@ +/* + * Copyright 2016-2018 Seznam.cz, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.seznam.euphoria.flink.batch; + +import cz.seznam.euphoria.core.client.dataset.windowing.MergingWindowing; +import cz.seznam.euphoria.core.client.dataset.windowing.Window; +import cz.seznam.euphoria.core.client.dataset.windowing.Windowing; +import cz.seznam.euphoria.core.client.functional.BinaryFunctor; +import cz.seznam.euphoria.core.client.functional.UnaryFunction; +import cz.seznam.euphoria.core.client.operator.Join; +import cz.seznam.euphoria.core.client.operator.JoinHints; +import cz.seznam.euphoria.core.client.util.Pair; +import cz.seznam.euphoria.core.executor.util.MultiValueContext; +import cz.seznam.euphoria.flink.FlinkOperator; +import org.apache.flink.api.common.functions.FlatJoinFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.operators.base.JoinOperatorBase; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.util.Collector; + +import java.util.List; +import java.util.Objects; + +public class BroadcastHashJoinTranslator implements BatchOperatorTranslator { + + static boolean wantTranslate(Join o) { + return o.getHints().contains(JoinHints.broadcastHashJoin()) + && (o.getType() == Join.Type.LEFT || o.getType() == Join.Type.RIGHT) + && !(o.getWindowing() instanceof MergingWindowing); + } + + @Override + @SuppressWarnings("unchecked") + public DataSet translate(FlinkOperator operator, BatchExecutorContext context) { + final List inputs = (List) context.getInputStreams(operator); + + if (inputs.size() != 2) { + throw new IllegalStateException( + "Join should have two data sets on input, got " + inputs.size()); + } + final DataSet left = inputs.get(0); + final DataSet right = inputs.get(1); + final Join originalOperator = operator.getOriginalOperator(); + + final UnaryFunction leftKeyExtractor = originalOperator.getLeftKeyExtractor(); + final UnaryFunction rightKeyExtractor = originalOperator.getRightKeyExtractor(); + final Windowing windowing = + originalOperator.getWindowing() == null + ? AttachedWindowing.INSTANCE + : originalOperator.getWindowing(); + DataSet> leftExtracted = + left.flatMap(new KeyExtractor(leftKeyExtractor, windowing)) + .returns(new TypeHint>() { + }) + .name(operator.getName() + "::extract-left-key") + .setParallelism(operator.getParallelism()); + + DataSet> rightExtracted = + right.flatMap(new KeyExtractor(rightKeyExtractor, windowing)) + .returns(new TypeHint>() { + }) + .name(operator.getName() + "::extract-right-key") + .setParallelism(operator.getParallelism()); + + DataSet joined; + switch (originalOperator.getType()) { + case LEFT: + joined = leftExtracted + .leftOuterJoin(rightExtracted, JoinOperatorBase.JoinHint.BROADCAST_HASH_SECOND) + .where(new JoinKeySelector()) + .equalTo(new JoinKeySelector()) + .with(new BroadcastFlatJoinFunction(originalOperator.getJoiner())) + .returns(new TypeHint>() { + }) + .name(operator.getName() + "::left-join"); + + break; + case RIGHT: + joined = leftExtracted + .rightOuterJoin(rightExtracted, JoinOperatorBase.JoinHint.BROADCAST_HASH_FIRST) + .where(new JoinKeySelector()) + .equalTo(new JoinKeySelector()) + .with(new BroadcastFlatJoinFunction(originalOperator.getJoiner())) + .returns(new TypeHint>() { + }) + .name(operator.getName() + "::right-join"); + break; + default: + throw new IllegalStateException("Invalid type: " + originalOperator.getType() + "."); + } + return joined; + } + + private static class KeyExtractor + implements FlatMapFunction> { + + private final UnaryFunction keyExtractor; + private final Windowing windowing; + + KeyExtractor(UnaryFunction keyExtractor, Windowing windowing) { + this.keyExtractor = keyExtractor; + this.windowing = windowing; + } + + @Override + @SuppressWarnings("unchecked") + public void flatMap(BatchElement wel, Collector> coll) throws Exception { + Iterable assigned = windowing.assignWindowsToElement(wel); + for (Window wid : assigned) { + Object el = wel.getElement(); + long stamp = wid.maxTimestamp() - 1; + coll.collect(new BatchElement( + wid, stamp, Pair.of(keyExtractor.apply(el), el))); + } + } + } + + static class BroadcastFlatJoinFunction + implements FlatJoinFunction, BatchElement, + BatchElement> { + final BinaryFunctor joiner; + transient MultiValueContext multiValueContext; + + BroadcastFlatJoinFunction(BinaryFunctor joiner) { + this.joiner = joiner; + } + + @Override + @SuppressWarnings("unchecked") + public void join(BatchElement first, BatchElement second, + Collector> coll) throws Exception { + + if (multiValueContext == null) { + multiValueContext = new MultiValueContext<>(); + } + final Window window = first == null ? second.getWindow() : first.getWindow(); + + final long maxTimestamp = Math.max( + first == null ? window.maxTimestamp() - 1 : first.getTimestamp(), + second == null ? window.maxTimestamp() - 1 : second.getTimestamp()); + + Object firstEl = first == null ? null : first.getElement().getSecond(); + Object secondEl = second == null ? null : second.getElement().getSecond(); + + joiner.apply(firstEl, secondEl, multiValueContext); + + final Object key = first == null + ? second.getElement().getFirst() + : first.getElement().getFirst(); + List values = multiValueContext.getAndResetValue(); + values.forEach(val -> coll.collect(new BatchElement<>( + window, + maxTimestamp, + Pair.of(key, val)))); + } + } + + static class JoinKeySelector + implements KeySelector, KeyedWindow> { + + @Override + public KeyedWindow getKey(BatchElement value) { + return new KeyedWindow(value.getWindow(), value.getElement().getFirst()); + } + + } + + public static final class KeyedWindow implements Comparable { + private final W window; + private final K key; + + KeyedWindow(W window, K key) { + this.window = Objects.requireNonNull(window); + this.key = Objects.requireNonNull(key); + } + + public W window() { + return window; + } + + public K key() { + return key; + } + + @Override + public boolean equals(Object o) { + if (o instanceof KeyedWindow) { + final KeyedWindow other = (KeyedWindow) o; + return Objects.equals(window, other.window) && Objects.equals(key, other.key); + } + return false; + } + + @Override + public int hashCode() { + int result = window.hashCode(); + result = 31 * result + (key != null ? key.hashCode() : 0); + return result; + } + + @Override + public String toString() { + return "KeyedWindow{" + + "window=" + window + + ", key=" + key + + '}'; + } + + @Override + public int compareTo(KeyedWindow other) { + final int compareWindowResult = this.window.compareTo(other.window); + if (compareWindowResult == 0) { + if (Objects.equals(key, other.key)) { + return 0; + } else { + return 1; + } + } + return compareWindowResult; + } + } +} + + diff --git a/euphoria-operator-testkit/src/main/java/cz/seznam/euphoria/operator/test/AllOperatorsSuite.java b/euphoria-operator-testkit/src/main/java/cz/seznam/euphoria/operator/test/AllOperatorsSuite.java index d7bfc1c9..99965a4d 100644 --- a/euphoria-operator-testkit/src/main/java/cz/seznam/euphoria/operator/test/AllOperatorsSuite.java +++ b/euphoria-operator-testkit/src/main/java/cz/seznam/euphoria/operator/test/AllOperatorsSuite.java @@ -25,6 +25,7 @@ */ @RunWith(ExecutorProviderRunner.class) @Suite.SuiteClasses({ + BroadcastHashJoinTest.class, CountByKeyTest.class, DistinctTest.class, FilterTest.class, diff --git a/euphoria-spark/src/test/java/cz/seznam/euphoria/spark/BroadcastHashJoinTest.java b/euphoria-operator-testkit/src/main/java/cz/seznam/euphoria/operator/test/BroadcastHashJoinTest.java similarity index 70% rename from euphoria-spark/src/test/java/cz/seznam/euphoria/spark/BroadcastHashJoinTest.java rename to euphoria-operator-testkit/src/main/java/cz/seznam/euphoria/operator/test/BroadcastHashJoinTest.java index bdd392e2..96c82490 100644 --- a/euphoria-spark/src/test/java/cz/seznam/euphoria/spark/BroadcastHashJoinTest.java +++ b/euphoria-operator-testkit/src/main/java/cz/seznam/euphoria/operator/test/BroadcastHashJoinTest.java @@ -13,28 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package cz.seznam.euphoria.spark; +package cz.seznam.euphoria.operator.test; import cz.seznam.euphoria.core.client.dataset.Dataset; import cz.seznam.euphoria.core.client.io.Collector; +import cz.seznam.euphoria.core.client.operator.JoinHints; import cz.seznam.euphoria.core.client.operator.LeftJoin; import cz.seznam.euphoria.core.client.operator.RightJoin; import cz.seznam.euphoria.core.client.util.Pair; -import cz.seznam.euphoria.operator.test.JoinTest; import cz.seznam.euphoria.operator.test.junit.AbstractOperatorTest; -import cz.seznam.euphoria.operator.test.junit.ExecutorProviderRunner; import cz.seznam.euphoria.operator.test.junit.Processing; import cz.seznam.euphoria.shadow.com.google.common.collect.Sets; -import cz.seznam.euphoria.spark.testkit.SparkExecutorProvider; import org.junit.Test; -import org.junit.runner.RunWith; import java.util.Arrays; import java.util.List; import java.util.Optional; -@RunWith(ExecutorProviderRunner.class) -public class BroadcastHashJoinTest extends AbstractOperatorTest implements SparkExecutorProvider { +public class BroadcastHashJoinTest extends AbstractOperatorTest { @Processing(Processing.Type.BOUNDED) @Test @@ -116,4 +112,40 @@ public List> getUnorderedOutput() { } }); } + + @Processing(Processing.Type.BOUNDED) + @Test + public void keyHashCollisionBroadcastHashJoin() { + final String sameHashCodeKey1 = "FB"; + final String sameHashCodeKey2 = "Ea"; + execute(new JoinTest.JoinTestCase>() { + + @Override + protected Dataset> getOutput( + Dataset left, Dataset right) { + return LeftJoin.of(left, right) + .by(e -> e, e -> e % 2 == 0 ? sameHashCodeKey2 : sameHashCodeKey1) + .using((String l, Optional r, Collector c) -> + c.collect(l + "+" + r.orElse(null))) + .withHints(Sets.newHashSet(JoinHints.broadcastHashJoin())) + .output(); + } + + @Override + protected List getLeftInput() { + return Arrays.asList(sameHashCodeKey1, sameHashCodeKey2, "keyWithoutRightSide"); + } + + @Override + protected List getRightInput() { + return Arrays.asList(1, 2); + } + + @Override + public List> getUnorderedOutput() { + return Arrays.asList(Pair.of(sameHashCodeKey1, "FB+1"), Pair.of(sameHashCodeKey2, "Ea+2"), + Pair.of("keyWithoutRightSide", "keyWithoutRightSide+null")); + } + }); + } } diff --git a/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/BroadcastHashJoinTranslator.java b/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/BroadcastHashJoinTranslator.java index 8471605f..0121c230 100644 --- a/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/BroadcastHashJoinTranslator.java +++ b/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/BroadcastHashJoinTranslator.java @@ -23,6 +23,7 @@ import cz.seznam.euphoria.core.client.dataset.windowing.Windowing; import cz.seznam.euphoria.core.client.functional.UnaryFunction; import cz.seznam.euphoria.core.client.operator.Join; +import cz.seznam.euphoria.core.client.operator.JoinHints; import cz.seznam.euphoria.core.client.util.Either; import cz.seznam.euphoria.core.client.util.Pair; import org.apache.spark.api.java.JavaPairRDD;