THIS IS A TEST INSTANCE. ALL YOUR CHANGES WILL BE LOST!!!!
...
PerRoundSubgraphBuilder allows the users to specify a sub-graph that executes as per-round mode, namely all the operators would be re-created for each round.
Code Block | ||
---|---|---|
| ||
package org.apache.flink.iteration; import org.apache.flink.annotation.Experimental; /*public interface IterationBody { .... /** * Constructs a subgraph inside the iteration body tothat executeall asthe per-round */ @Experimental public class PerRoundSubGraphBuilder { operators would have a lifecycle /** The sub-graph inside the iteration body that should be executed as per-round.of {@link org.apache.flink.iteration.IterationConfig.OperatorLifeCycle#PER_ROUND}. */ publicclass interfacePerRound PerRoundSubGraph { /** DataStreamList process(DataStreamList input); } * @param inputs The inputs of the subgraph. * @param perRoundSubBody The computational logic that want to be executed as per-round. * @return The output of the subgraph. */ public static DataStreamList forEachRound( DataStreamList inputs, PerRoundSubGraphPerRoundSubBody subGraphperRoundSubBody) { return null; } } /** The sub-graph inside the iteration body that should be executed as per-round. */ interface PerRoundSubBody { DataStreamList process(DataStreamList input); } } |
6) Add the DataStreamList and ReplayableDataStreamList class.
...
Code Block | ||||
---|---|---|---|---|
| ||||
package org.apache.flink.iteration; public class DataStreamList { public static DataStreamList of(DataStream<?>... streams); // Returns the number of data streams in this list. public int size() {...} // Returns the data stream at the given index in this list. public <T> DataStream<T> get(int index) {...} } public class ReplayableDataStreamList { public static ReplayableDataStreamList of( // Returns the number of data streams in this list. public int size() {...} // Returns the data stream at the given index in this list. public <T> DataStream<T> get(int index) {...} } public class ReplayableDataStreamList { private final List<DataStream<?>> replayedDataStreams; private final List<DataStream<?>> nonReplayedStreams; private ReplayableDataStreamList( List<DataStream<?>> replayedDataStreams, List<DataStream<?>> nonReplayedStreams) { this.replayedDataStreams = replayedDataStreams; this.nonReplayedStreams = nonReplayedStreams; } public static ReplayedDataStreamList replay(DataStream<?>... dataStreams) { return new ReplayedDataStreamList(Arrays.asList(dataStreams)); } public static NonReplayedDataStreamList notReplay(DataStream<?>... dataStreams) { return new NonReplayedDataStreamList(Arrays.asList(dataStreams)); } List<DataStream<?>> getReplayedDataStreams() { return Collections.unmodifiableList(replayedDataStreams); } List<DataStream<?>> getNonReplayedStreams() { return Collections.unmodifiableList(nonReplayedStreams); } private static class ReplayedDataStreamList extends ReplayableDataStreamList { public ReplayedDataStreamList(List<DataStream<?>> replayedDataStreams) { super(replayedDataStreams, Collections.emptyList()); } public ReplayableDataStreamList andNotReplay(DataStream<?>... nonReplayedStreams) { return new ReplayableDataStreamList( getReplayedDataStreams(), Arrays.asList(nonReplayedStreams)); } } private static class NonReplayedDataStreamList extends ReplayableDataStreamList { public NonReplayedDataStreamList(List<DataStream<?>> nonReplayedDataStreams) { Tuple2<DataStream<?>, Boolean>... dataStreamAndIfNeedReplays); super(Collections.emptyList(), nonReplayedDataStreams); public static Tuple2<DataStream<?>, Boolean> replay(DataStream<?> dataStream); } public static Tuple2<DataStream<?>, Boolean> noReplaypublic ReplayableDataStreamList andReplay(DataStream<?>... dataStreamreplayedStreams); { /** Returns the number of data streams in thisreturn list. */new ReplayableDataStreamList( public int size(); /** Returns the data stream at the given index in this listArrays. */ @SuppressWarnings("unchecked")asList(replayedStreams), getNonReplayedStreams()); public <T> DataStream<T> get(int index); } public boolean shouldReplay(int index);} } |
7) Introduce the IterationConfig
...
Code Block | ||
---|---|---|
| ||
package org.apache.flink.iteration; /** The config for an iteration. */ public class IterationConfig { private final OperatorLifeCycle operatorRoundMode; publicprivate IterationConfig(OperatorLifeCycle operatorRoundMode) { this.operatorRoundMode = operatorRoundMode; } public static IterationConfigBuilder newBuilder() { return new IterationConfigBuilder(); } public static class IterationConfigBuilder { private OperatorLifeCycle operatorRoundMode = OperatorLifeCycle.ALL_ROUND; private IterationConfigBuilder() {} public IterationConfigBuilder setOperatorRoundMode(OperatorLifeCycle operatorRoundMode) { this.operatorRoundMode = operatorRoundMode; return this; } public IterationConfig build() { return new IterationConfig(operatorRoundMode); } } public enum OperatorLifeCycle { ALL_ROUND, PER_ROUND } } |
...
Code Block | ||||
---|---|---|---|---|
| ||||
public class SynchronousBoundedLinearRegression { private static final int N_DIM = 50; private static final int N_EPOCH = 5; private static final int N_BATCH_PER_EPOCH = 10; private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{}; public static void main(String[] args) { DataStream<double[]> initParameters = loadParameters().setParallelism(1); DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1); DataStreamList resultStreams = IterationUtils.iterateBoundedStreamsUntilTermination( DataStreamList.of(initParameters), ReplayableDataStreamList.ofnotReplay(noReplay(dataset)), IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build(); (variableStreams, dataStreams) -> { DataStream<double[]> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction()); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .broadcast() .connect(dataset) .coProcess(new TrainFunction()) .setParallelism(10) .process(new ReduceFunction()) .setParallelism(1) return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }); DataStream<double[]> finalModel = resultStreams.get("final_model"); finalModel.print(); } public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> implements IterationListener<double[]> { private final double[] parameters = new double[N_DIM]; public void processElement(double[] update, Context ctx, Collector<O> output) { // Suppose we have a util to add the second array to the first. ArrayUtils.addWith(parameters, update); } void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) { if (epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH) { collector.collect(parameters); } } public void onIterationEnd(int[] round, Context context) { context.output(FINAL_MODEL_OUTPUT_TAG, parameters); } } public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> { private final List<Tuple2<double[], Double>> dataset = new ArrayList<>(); private double[] firstRoundCachedParameter; private Supplier<int[]> recordRoundQuerier; public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) { this.recordRoundQuerier = querier; } public void processElement1(double[] parameter, Context context, Collector<O> output) { int[] round = recordRoundQuerier.get(); if (round[0] == 0) { firstRoundCachedParameter = parameter; return; } calculateModelUpdate(parameter, output); } public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) { dataset.add(trainSample) } void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) { if (epochWatermark == 0) { calculateModelUpdate(firstRoundCachedParameter, output); firstRoundCachedParameter = null; } } private void calculateModelUpdate(double[] parameters, Collector<O> output) { List<Tuple2<double[], Double>> samples = sample(dataset); double[] modelUpdate = new double[N_DIM]; for (Tuple2<double[], Double> record : samples) { double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1); ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff)); } output.collect(modelUpdate); } } public static class ReduceFunction { private double[] mergedValue = ArrayUtils.newArray(N_DIM); public void processElement(double[] parameter, Context context, Collector<O> output) { mergedValue = ArrayUtils.add(mergedValue, parameter); } void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) { collector.collect(mergedValue); mergedValue = ArrayUtils.newArray(N_DIM); } } } |
...
Code Block | ||
---|---|---|
| ||
DataStreamList resultStreams = IterationUtils.iterateBoundedStreamsUntilTermination( DataStreamList.of(initParameters), ReplayableDataStreamList.ofnotReplay(noReplay(dataset)), IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build(); (variableStreams, dataStreams) -> { DataStream<double[]> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction()); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .broadcast() .connect(dataset) .coProcess(new TrainFunction()) .setParallelism(10) DataStream<double[]> reduced = PerRoundGraphBuilder.forEachRound(DataStreamList.of(modelUpdate), streams -> { return streams.<double[]>get(0).windowAll().reduce((x, y) -> ArrayUtils.add(x, y)); }).<double[]>get(0); return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }); |
...