Skip to content

Commit

Permalink
[sql-functions] add aggregation function(s) to collect the last non-n…
Browse files Browse the repository at this point in the history
…ull value
  • Loading branch information
NicoK committed Jan 4, 2022
1 parent 566c2be commit 10ee1f0
Show file tree
Hide file tree
Showing 3 changed files with 335 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.ververica.platform.sql.functions;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.ConstantArgumentCount;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.inference.Signature.Argument;
import org.apache.flink.table.types.inference.TypeInference;

/** Aggregate function for collecting the latest non-null data of type T. */
@SuppressWarnings("unused")
public class LastNonNullValueAggFunction
extends AggregateFunction<Object, LastNonNullValueAggFunction.MyAccumulator> {

public static class MyAccumulator {
public Object acc = null;
}

@Override
public MyAccumulator createAccumulator() {
return new MyAccumulator();
}

public void accumulate(MyAccumulator acc, Object value) {
if (value != null) {
acc.acc = value;
}
}

public void retract(MyAccumulator acc, Object value) {
acc.acc = null;
}

@Override
public Object getValue(MyAccumulator acc) {
return acc.acc;
}

@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
return TypeInference.newBuilder()
.inputTypeStrategy(
new InputTypeStrategy() {
@Override
public ArgumentCount getArgumentCount() {
return ConstantArgumentCount.of(1);
}

@Override
public Optional<List<DataType>> inferInputTypes(
CallContext callContext, boolean throwOnFailure) {
DataType argType = callContext.getArgumentDataTypes().get(0);
return Optional.of(Collections.singletonList(argType));
}

@Override
public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
return Collections.singletonList(Signature.of(Argument.of("value", "T")));
}
})
.accumulatorTypeStrategy(
callContext -> {
DataType argType = callContext.getArgumentDataTypes().get(0);
return Optional.of(
DataTypes.STRUCTURED(MyAccumulator.class, DataTypes.FIELD("acc", argType)));
})
.outputTypeStrategy(
callContext -> {
DataType argType = callContext.getArgumentDataTypes().get(0);
return Optional.of(argType);
})
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package com.ververica.platform.sql.functions;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.ConstantArgumentCount;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.utils.DataTypeUtils;

/** Aggregate function for collecting the latest non-null data of type T (using internal types). */
@SuppressWarnings("unused")
public class LastNonNullValueAggFunction2
extends AggregateFunction<Object, LastNonNullValueAggFunction2.MyAccumulator> {

public static class MyAccumulator {
public Object acc = null;
}

@Override
public MyAccumulator createAccumulator() {
return new MyAccumulator();
}

public void accumulate(MyAccumulator acc, Object value) {
if (value != null) {
acc.acc = value;
}
}

public void retract(MyAccumulator acc, Object value) {
acc.acc = null;
}

@Override
public Object getValue(MyAccumulator acc) {
return acc.acc;
}

@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
return TypeInference.newBuilder()
.inputTypeStrategy(
new InputTypeStrategy() {
@Override
public ArgumentCount getArgumentCount() {
return ConstantArgumentCount.of(1);
}

@Override
public Optional<List<DataType>> inferInputTypes(
CallContext callContext, boolean throwOnFailure) {
DataType argType = callContext.getArgumentDataTypes().get(0);
return Optional.of(
Collections.singletonList(DataTypeUtils.toInternalDataType(argType)));
}

@Override
public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
return Collections.singletonList(Signature.of(Signature.Argument.of("value", "T")));
}
})
.accumulatorTypeStrategy(
callContext -> {
DataType argType = callContext.getArgumentDataTypes().get(0);
DataType argTypeInternal = DataTypeUtils.toInternalDataType(argType);
return Optional.of(
DataTypes.STRUCTURED(
MyAccumulator.class, DataTypes.FIELD("acc", argTypeInternal)));
})
.outputTypeStrategy(
callContext -> {
DataType argType = callContext.getArgumentDataTypes().get(0);
DataType argTypeInternal = DataTypeUtils.toInternalDataType(argType);
return Optional.of(argTypeInternal);
})
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package com.ververica.platform.sql.functions;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;

import java.util.Arrays;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.contrib.streaming.state.EmbeddedRocksDBStateBackend;
import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.EnvironmentSettings;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.connector.ChangelogMode;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

/**
* Integration test for {@link LastNonNullValueAggFunction} and {@link
* LastNonNullValueAggFunction2}.
*/
@RunWith(Parameterized.class)
public class LastNonNullValueAggFunctionITCase extends AbstractTableTestBase {

protected StreamExecutionEnvironment env;
protected StreamTableEnvironment tEnv;

@Parameterized.Parameter public Class<UserDefinedFunction> implementation;

@Parameterized.Parameters(name = "implementation = {0}")
public static Iterable<Class<? extends UserDefinedFunction>> parameters() {
return Arrays.asList(LastNonNullValueAggFunction.class, LastNonNullValueAggFunction2.class);
}

@Before
public void setUp() {
env = StreamExecutionEnvironment.getExecutionEnvironment();
// configure environment as needed
env.setParallelism(4);
env.getConfig().setRestartStrategy(RestartStrategies.noRestart());
env.setStateBackend(new EmbeddedRocksDBStateBackend());
env.getCheckpointConfig().setCheckpointStorage(new JobManagerCheckpointStorage());

// create table environment
tEnv =
StreamTableEnvironment.create(
env, EnvironmentSettings.newInstance().inStreamingMode().build());

tEnv.createTemporaryFunction("LastNonNullValueAggFunction", implementation);
}

private void createSource(
String name,
@Nullable TypeInformation<Row> dataStreamTypeInfo,
Schema schema,
Row... inputData) {
SingleOutputStreamOperator<Row> changelogStream = env.fromElements(inputData);
if (dataStreamTypeInfo != null) {
changelogStream = changelogStream.returns(dataStreamTypeInfo);
}
tEnv.createTemporaryView(
name, tEnv.fromChangelogStream(changelogStream, schema, ChangelogMode.insertOnly()));
}

private List<Row> getResult(String query) throws Exception {
Table resultTable = tEnv.sqlQuery(query);

DataStream<Row> resultStream = tEnv.toChangelogStream(resultTable);
return getRowsFromDataStream(resultStream);
}

private void createSourceForType(
String tableName, TypeInformation<?> aggJavaType, String aggSqlType, Row... inputData) {

Schema schema =
Schema.newBuilder()
.column("f0", "STRING NOT NULL")
.column("f1", aggSqlType)
.primaryKey("f0")
.build();
createSource(tableName, Types.ROW(Types.STRING, aggJavaType), schema, inputData);
}

private void createSourceForType(
TypeInformation<?> aggJavaType, String aggSqlType, Row... inputData) {
createSourceForType("input", aggJavaType, aggSqlType, inputData);
}

private List<Row> getResultSimpleAgg() throws Exception {
String query = "SELECT f0, LastNonNullValueAggFunction(f1) FROM input GROUP BY f0";
return getResult(query);
}

@Test
public void testInt() throws Exception {
createSourceForType(
Types.INT,
"INT",
Row.ofKind(RowKind.INSERT, "john", null),
Row.ofKind(RowKind.INSERT, "john", 1),
Row.ofKind(RowKind.INSERT, "john", 2),
Row.ofKind(RowKind.INSERT, "john", null));

assertThat(
getResultSimpleAgg(),
contains(
Row.ofKind(RowKind.INSERT, "john", null),
Row.ofKind(RowKind.UPDATE_BEFORE, "john", null),
Row.ofKind(RowKind.UPDATE_AFTER, "john", 1),
Row.ofKind(RowKind.UPDATE_BEFORE, "john", 1),
Row.ofKind(RowKind.UPDATE_AFTER, "john", 2)));
}

@Test
public void testString() throws Exception {
createSourceForType(
Types.STRING,
"STRING",
Row.ofKind(RowKind.INSERT, "john", null),
Row.ofKind(RowKind.INSERT, "john", "1"),
Row.ofKind(RowKind.INSERT, "john", "2"),
Row.ofKind(RowKind.INSERT, "john", null));

assertThat(
getResultSimpleAgg(),
contains(
Row.ofKind(RowKind.INSERT, "john", null),
Row.ofKind(RowKind.UPDATE_BEFORE, "john", null),
Row.ofKind(RowKind.UPDATE_AFTER, "john", "1"),
Row.ofKind(RowKind.UPDATE_BEFORE, "john", "1"),
Row.ofKind(RowKind.UPDATE_AFTER, "john", "2")));
}

@Test
public void testStringRetract() throws Exception {
createSourceForType(
Types.STRING,
"STRING",
Row.ofKind(RowKind.INSERT, "john", null),
Row.ofKind(RowKind.INSERT, "john", "1"),
Row.ofKind(RowKind.INSERT, "john", "2"),
Row.ofKind(RowKind.INSERT, "john", null));

assertThat(
getResultSimpleAgg(),
contains(
Row.ofKind(RowKind.INSERT, "john", null),
Row.ofKind(RowKind.UPDATE_BEFORE, "john", null),
Row.ofKind(RowKind.UPDATE_AFTER, "john", "1"),
Row.ofKind(RowKind.UPDATE_BEFORE, "john", "1"),
Row.ofKind(RowKind.UPDATE_AFTER, "john", "2")));
}
}

0 comments on commit 10ee1f0

Please sign in to comment.