forked from ververica/lab-flink-repository-analytics
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[sql-functions] add aggregation function(s) to collect the last non-n…
…ull value
- Loading branch information
Showing
3 changed files
with
335 additions
and
0 deletions.
There are no files selected for viewing
83 changes: 83 additions & 0 deletions
83
...tions/src/main/java/com/ververica/platform/sql/functions/LastNonNullValueAggFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package com.ververica.platform.sql.functions; | ||
|
||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Optional; | ||
import org.apache.flink.table.api.DataTypes; | ||
import org.apache.flink.table.catalog.DataTypeFactory; | ||
import org.apache.flink.table.functions.AggregateFunction; | ||
import org.apache.flink.table.functions.FunctionDefinition; | ||
import org.apache.flink.table.types.DataType; | ||
import org.apache.flink.table.types.inference.ArgumentCount; | ||
import org.apache.flink.table.types.inference.CallContext; | ||
import org.apache.flink.table.types.inference.ConstantArgumentCount; | ||
import org.apache.flink.table.types.inference.InputTypeStrategy; | ||
import org.apache.flink.table.types.inference.Signature; | ||
import org.apache.flink.table.types.inference.Signature.Argument; | ||
import org.apache.flink.table.types.inference.TypeInference; | ||
|
||
/** Aggregate function for collecting the latest non-null data of type T. */ | ||
@SuppressWarnings("unused") | ||
public class LastNonNullValueAggFunction | ||
extends AggregateFunction<Object, LastNonNullValueAggFunction.MyAccumulator> { | ||
|
||
public static class MyAccumulator { | ||
public Object acc = null; | ||
} | ||
|
||
@Override | ||
public MyAccumulator createAccumulator() { | ||
return new MyAccumulator(); | ||
} | ||
|
||
public void accumulate(MyAccumulator acc, Object value) { | ||
if (value != null) { | ||
acc.acc = value; | ||
} | ||
} | ||
|
||
public void retract(MyAccumulator acc, Object value) { | ||
acc.acc = null; | ||
} | ||
|
||
@Override | ||
public Object getValue(MyAccumulator acc) { | ||
return acc.acc; | ||
} | ||
|
||
@Override | ||
public TypeInference getTypeInference(DataTypeFactory typeFactory) { | ||
return TypeInference.newBuilder() | ||
.inputTypeStrategy( | ||
new InputTypeStrategy() { | ||
@Override | ||
public ArgumentCount getArgumentCount() { | ||
return ConstantArgumentCount.of(1); | ||
} | ||
|
||
@Override | ||
public Optional<List<DataType>> inferInputTypes( | ||
CallContext callContext, boolean throwOnFailure) { | ||
DataType argType = callContext.getArgumentDataTypes().get(0); | ||
return Optional.of(Collections.singletonList(argType)); | ||
} | ||
|
||
@Override | ||
public List<Signature> getExpectedSignatures(FunctionDefinition definition) { | ||
return Collections.singletonList(Signature.of(Argument.of("value", "T"))); | ||
} | ||
}) | ||
.accumulatorTypeStrategy( | ||
callContext -> { | ||
DataType argType = callContext.getArgumentDataTypes().get(0); | ||
return Optional.of( | ||
DataTypes.STRUCTURED(MyAccumulator.class, DataTypes.FIELD("acc", argType))); | ||
}) | ||
.outputTypeStrategy( | ||
callContext -> { | ||
DataType argType = callContext.getArgumentDataTypes().get(0); | ||
return Optional.of(argType); | ||
}) | ||
.build(); | ||
} | ||
} |
87 changes: 87 additions & 0 deletions
87
...ions/src/main/java/com/ververica/platform/sql/functions/LastNonNullValueAggFunction2.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package com.ververica.platform.sql.functions; | ||
|
||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Optional; | ||
import org.apache.flink.table.api.DataTypes; | ||
import org.apache.flink.table.catalog.DataTypeFactory; | ||
import org.apache.flink.table.functions.AggregateFunction; | ||
import org.apache.flink.table.functions.FunctionDefinition; | ||
import org.apache.flink.table.types.DataType; | ||
import org.apache.flink.table.types.inference.ArgumentCount; | ||
import org.apache.flink.table.types.inference.CallContext; | ||
import org.apache.flink.table.types.inference.ConstantArgumentCount; | ||
import org.apache.flink.table.types.inference.InputTypeStrategy; | ||
import org.apache.flink.table.types.inference.Signature; | ||
import org.apache.flink.table.types.inference.TypeInference; | ||
import org.apache.flink.table.types.utils.DataTypeUtils; | ||
|
||
/** Aggregate function for collecting the latest non-null data of type T (using internal types). */ | ||
@SuppressWarnings("unused") | ||
public class LastNonNullValueAggFunction2 | ||
extends AggregateFunction<Object, LastNonNullValueAggFunction2.MyAccumulator> { | ||
|
||
public static class MyAccumulator { | ||
public Object acc = null; | ||
} | ||
|
||
@Override | ||
public MyAccumulator createAccumulator() { | ||
return new MyAccumulator(); | ||
} | ||
|
||
public void accumulate(MyAccumulator acc, Object value) { | ||
if (value != null) { | ||
acc.acc = value; | ||
} | ||
} | ||
|
||
public void retract(MyAccumulator acc, Object value) { | ||
acc.acc = null; | ||
} | ||
|
||
@Override | ||
public Object getValue(MyAccumulator acc) { | ||
return acc.acc; | ||
} | ||
|
||
@Override | ||
public TypeInference getTypeInference(DataTypeFactory typeFactory) { | ||
return TypeInference.newBuilder() | ||
.inputTypeStrategy( | ||
new InputTypeStrategy() { | ||
@Override | ||
public ArgumentCount getArgumentCount() { | ||
return ConstantArgumentCount.of(1); | ||
} | ||
|
||
@Override | ||
public Optional<List<DataType>> inferInputTypes( | ||
CallContext callContext, boolean throwOnFailure) { | ||
DataType argType = callContext.getArgumentDataTypes().get(0); | ||
return Optional.of( | ||
Collections.singletonList(DataTypeUtils.toInternalDataType(argType))); | ||
} | ||
|
||
@Override | ||
public List<Signature> getExpectedSignatures(FunctionDefinition definition) { | ||
return Collections.singletonList(Signature.of(Signature.Argument.of("value", "T"))); | ||
} | ||
}) | ||
.accumulatorTypeStrategy( | ||
callContext -> { | ||
DataType argType = callContext.getArgumentDataTypes().get(0); | ||
DataType argTypeInternal = DataTypeUtils.toInternalDataType(argType); | ||
return Optional.of( | ||
DataTypes.STRUCTURED( | ||
MyAccumulator.class, DataTypes.FIELD("acc", argTypeInternal))); | ||
}) | ||
.outputTypeStrategy( | ||
callContext -> { | ||
DataType argType = callContext.getArgumentDataTypes().get(0); | ||
DataType argTypeInternal = DataTypeUtils.toInternalDataType(argType); | ||
return Optional.of(argTypeInternal); | ||
}) | ||
.build(); | ||
} | ||
} |
165 changes: 165 additions & 0 deletions
165
...src/test/java/com/ververica/platform/sql/functions/LastNonNullValueAggFunctionITCase.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
package com.ververica.platform.sql.functions; | ||
|
||
import static org.hamcrest.MatcherAssert.assertThat; | ||
import static org.hamcrest.Matchers.contains; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import javax.annotation.Nullable; | ||
import org.apache.flink.api.common.restartstrategy.RestartStrategies; | ||
import org.apache.flink.api.common.typeinfo.TypeInformation; | ||
import org.apache.flink.api.common.typeinfo.Types; | ||
import org.apache.flink.contrib.streaming.state.EmbeddedRocksDBStateBackend; | ||
import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage; | ||
import org.apache.flink.streaming.api.datastream.DataStream; | ||
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; | ||
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; | ||
import org.apache.flink.table.api.EnvironmentSettings; | ||
import org.apache.flink.table.api.Schema; | ||
import org.apache.flink.table.api.Table; | ||
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; | ||
import org.apache.flink.table.connector.ChangelogMode; | ||
import org.apache.flink.table.functions.UserDefinedFunction; | ||
import org.apache.flink.types.Row; | ||
import org.apache.flink.types.RowKind; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.junit.runner.RunWith; | ||
import org.junit.runners.Parameterized; | ||
|
||
/** | ||
* Integration test for {@link LastNonNullValueAggFunction} and {@link | ||
* LastNonNullValueAggFunction2}. | ||
*/ | ||
@RunWith(Parameterized.class) | ||
public class LastNonNullValueAggFunctionITCase extends AbstractTableTestBase { | ||
|
||
protected StreamExecutionEnvironment env; | ||
protected StreamTableEnvironment tEnv; | ||
|
||
@Parameterized.Parameter public Class<UserDefinedFunction> implementation; | ||
|
||
@Parameterized.Parameters(name = "implementation = {0}") | ||
public static Iterable<Class<? extends UserDefinedFunction>> parameters() { | ||
return Arrays.asList(LastNonNullValueAggFunction.class, LastNonNullValueAggFunction2.class); | ||
} | ||
|
||
@Before | ||
public void setUp() { | ||
env = StreamExecutionEnvironment.getExecutionEnvironment(); | ||
// configure environment as needed | ||
env.setParallelism(4); | ||
env.getConfig().setRestartStrategy(RestartStrategies.noRestart()); | ||
env.setStateBackend(new EmbeddedRocksDBStateBackend()); | ||
env.getCheckpointConfig().setCheckpointStorage(new JobManagerCheckpointStorage()); | ||
|
||
// create table environment | ||
tEnv = | ||
StreamTableEnvironment.create( | ||
env, EnvironmentSettings.newInstance().inStreamingMode().build()); | ||
|
||
tEnv.createTemporaryFunction("LastNonNullValueAggFunction", implementation); | ||
} | ||
|
||
private void createSource( | ||
String name, | ||
@Nullable TypeInformation<Row> dataStreamTypeInfo, | ||
Schema schema, | ||
Row... inputData) { | ||
SingleOutputStreamOperator<Row> changelogStream = env.fromElements(inputData); | ||
if (dataStreamTypeInfo != null) { | ||
changelogStream = changelogStream.returns(dataStreamTypeInfo); | ||
} | ||
tEnv.createTemporaryView( | ||
name, tEnv.fromChangelogStream(changelogStream, schema, ChangelogMode.insertOnly())); | ||
} | ||
|
||
private List<Row> getResult(String query) throws Exception { | ||
Table resultTable = tEnv.sqlQuery(query); | ||
|
||
DataStream<Row> resultStream = tEnv.toChangelogStream(resultTable); | ||
return getRowsFromDataStream(resultStream); | ||
} | ||
|
||
private void createSourceForType( | ||
String tableName, TypeInformation<?> aggJavaType, String aggSqlType, Row... inputData) { | ||
|
||
Schema schema = | ||
Schema.newBuilder() | ||
.column("f0", "STRING NOT NULL") | ||
.column("f1", aggSqlType) | ||
.primaryKey("f0") | ||
.build(); | ||
createSource(tableName, Types.ROW(Types.STRING, aggJavaType), schema, inputData); | ||
} | ||
|
||
private void createSourceForType( | ||
TypeInformation<?> aggJavaType, String aggSqlType, Row... inputData) { | ||
createSourceForType("input", aggJavaType, aggSqlType, inputData); | ||
} | ||
|
||
private List<Row> getResultSimpleAgg() throws Exception { | ||
String query = "SELECT f0, LastNonNullValueAggFunction(f1) FROM input GROUP BY f0"; | ||
return getResult(query); | ||
} | ||
|
||
@Test | ||
public void testInt() throws Exception { | ||
createSourceForType( | ||
Types.INT, | ||
"INT", | ||
Row.ofKind(RowKind.INSERT, "john", null), | ||
Row.ofKind(RowKind.INSERT, "john", 1), | ||
Row.ofKind(RowKind.INSERT, "john", 2), | ||
Row.ofKind(RowKind.INSERT, "john", null)); | ||
|
||
assertThat( | ||
getResultSimpleAgg(), | ||
contains( | ||
Row.ofKind(RowKind.INSERT, "john", null), | ||
Row.ofKind(RowKind.UPDATE_BEFORE, "john", null), | ||
Row.ofKind(RowKind.UPDATE_AFTER, "john", 1), | ||
Row.ofKind(RowKind.UPDATE_BEFORE, "john", 1), | ||
Row.ofKind(RowKind.UPDATE_AFTER, "john", 2))); | ||
} | ||
|
||
@Test | ||
public void testString() throws Exception { | ||
createSourceForType( | ||
Types.STRING, | ||
"STRING", | ||
Row.ofKind(RowKind.INSERT, "john", null), | ||
Row.ofKind(RowKind.INSERT, "john", "1"), | ||
Row.ofKind(RowKind.INSERT, "john", "2"), | ||
Row.ofKind(RowKind.INSERT, "john", null)); | ||
|
||
assertThat( | ||
getResultSimpleAgg(), | ||
contains( | ||
Row.ofKind(RowKind.INSERT, "john", null), | ||
Row.ofKind(RowKind.UPDATE_BEFORE, "john", null), | ||
Row.ofKind(RowKind.UPDATE_AFTER, "john", "1"), | ||
Row.ofKind(RowKind.UPDATE_BEFORE, "john", "1"), | ||
Row.ofKind(RowKind.UPDATE_AFTER, "john", "2"))); | ||
} | ||
|
||
@Test | ||
public void testStringRetract() throws Exception { | ||
createSourceForType( | ||
Types.STRING, | ||
"STRING", | ||
Row.ofKind(RowKind.INSERT, "john", null), | ||
Row.ofKind(RowKind.INSERT, "john", "1"), | ||
Row.ofKind(RowKind.INSERT, "john", "2"), | ||
Row.ofKind(RowKind.INSERT, "john", null)); | ||
|
||
assertThat( | ||
getResultSimpleAgg(), | ||
contains( | ||
Row.ofKind(RowKind.INSERT, "john", null), | ||
Row.ofKind(RowKind.UPDATE_BEFORE, "john", null), | ||
Row.ofKind(RowKind.UPDATE_AFTER, "john", "1"), | ||
Row.ofKind(RowKind.UPDATE_BEFORE, "john", "1"), | ||
Row.ofKind(RowKind.UPDATE_AFTER, "john", "2"))); | ||
} | ||
} |