...
Page properties |
---|
...
|
...
|
...
...
JIRA: <TODO>
This method is easier to use, but it also limit some possible optimizations
...
|
Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).
Table of Contents |
---|
[This FLIP proposal is a joint work between Yun Gao, Dong Lin and Zhipeng Zhang]
...
It is important to note that the users typically should not invoke the IterationBody::process directly because the model-variables expected by the iteration body is not the same as the initial-model-variables provided by the user. Instead, model-variables are computed as the union of the feedback-model-variables (emitted by the iteration body) and the initial-model-variables (provided by the caller of the iteration body). To relieve user from creating this union operator, we have added utility class (see IterationUtilsIterations) to run an iteration-body with the user-provided inputs.
...
Code Block | ||||
---|---|---|---|---|
| ||||
package org.apache.flink.iteration /** * The callbacks defined below will be invoked only if the operator instance which implements this interface is used * within an iteration body. */ @PublicEvolving public interface IterationListener<T> { /** * This callback is invoked every time the epoch watermark of this operator increments. The initial epoch watermark * of an operator is 0. * * The epochWatermark is the maximum integer that meets this requirement: every record that arrives at the operator * going forward should have an epoch larger than the epochWatermark. See Java docs in IterationUtilsIterations for how epoch * is determined for records ingested into the iteration body and for records emitted by operators within the * iteration body. * * If all inputs are bounded, the maximum epoch of all records ingested into this operator is used as the * epochWatermark parameter for the last invocation of this callback. * * @param epochWatermark The incremented epoch watermark. * @param context A context that allows emitting side output. The context is only valid during the invocation of * this method. * @param collector The collector for returning result values. */ void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector); /** * This callback is invoked after the execution of the iteration body has terminated. * * See Java doc of methods in IterationUtilsIterations for the termination conditions. * * @param context A context that allows emitting side output. The context is only valid during the invocation of * this method. * @param collector The collector for returning result values. */ void onIterationTermination(Context context, Collector<T> collector); /** * Information available in an invocation of the callbacks defined in the IterationProgressListener. */ interface Context { /** * Emits a record to the side output identified by the {@link OutputTag}. * * @param outputTag the {@code OutputTag} that identifies the side output to emit to. * @param value The record to emit. */ <X> void output(OutputTag<X> outputTag, X value); } } |
...
Code Block | ||||
---|---|---|---|---|
| ||||
package org.apache.flink.iteration; /** * A helper class to create iterations. To construct an iteration, Users are required to provide * * <ul> * <li>initVariableStreams: the initial values of the variable data streams which would be updated * in each round. * <li>dataStreams: the other data streams used inside the iteration, but would not be updated. * <li>iterationBody: specifies the subgraph to update the variable streams and the outputs. * </ul> * * <p>The iteration body will be invoked with two parameters: The first parameter is a list of input * variable streams, which are created as the union of the initial variable streams and the * corresponding feedback variable streams (returned by the iteration body); The second parameter is * the data streams given to this method. * * <p>During the execution of iteration body, each of the records involved in the iteration has an * epoch attached, which is mark the progress of the iteration. The epoch is computed as: * * <ul> * <li>All records in the initial variable streams and initial data streams has epoch = 0. * <li>For any record emitted by this operator into a non-feedback stream, the epoch of this * emitted record = the epoch of the input record that triggers this emission. If this record * is emitted by onEpochWatermarkIncremented(), then the epoch of this record = * epochWatermark. * <li>For any record emitted by this operator into a feedback variable stream, the epoch of the * emitted record = the epoch of the input record that triggers this emission + 1. * </ul> * * <p>The framework would given the notification at the end of each epoch for operators and UDFs * that implements {@link IterationListener}. * * <p>The limitation of constructing the subgraph inside the iteration body could be refer in {@link * IterationBody}. * * <p>An example of the iteration is like: * * <pre>{@code * DataStreamList result = Iterations.iterateUnboundedStreams( * DataStreamList.of(first, second), * DataStreamList.of(third), * (variableStreams, dataStreams) -> { * ... * return new IterationBodyResult( * DataStreamList.of(firstFeedback, secondFeedback), * DataStreamList.of(output)); * } * result.<Integer>get(0).addSink(...); * }</pre> */ public class Iterations { /** * This method uses an iteration body to process records in possibly unbounded data streams. The * iteration would not terminate if at least one of its inputs is unbounded. Otherwise it will * terminated after all the inputs are terminated and no more records are iterating. * * @param initVariableStreams The initial variable streams, which is merged with the feedback * variable streams before being used as the 1st parameter to invoke the iteration body. * @param dataStreams The non-variable streams also refered in the {@code body}. * @param body The computation logic which takes variable/data streams and returns * feedback/output streams. * @return The list of output streams returned by the iteration boy. */ public static DataStreamList iterateUnboundedStreams( DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) { return null; } /** * This method uses an iteration body to process records in some bounded data streams * iteratively until a termination criteria is reached (e.g. the given number of rounds is * completed or no further variable update is needed). Because this method does not replay * records in the data streams, the iteration body needs to cache those records in order to * visit those records repeatedly. * * @param initVariableStreams The initial variable streams, which is merged with the feedback * variable streams before being used as the 1st parameter to invoke the iteration body. * @param dataStreams The non-variable streams also refered in the {@code body}. * @param config The config for the iteration, like whether to re-create the operator on each * round. * @param body The computation logic which takes variable/data streams and returns * feedback/output streams. * @return The list of output streams returned by the iteration boy. */ public static DataStreamList iterateBoundedStreamsUntilTermination( DataStreamList initVariableStreams, ReplayableDataStreamList dataStreams, IterationConfig config, IterationBody body) { return null; } } |
5) Introduce the PerRoundSubGraphBuilder forEachRound utility method.
PerRoundSubgraphBuilder forEachRound allows the users to specify a sub-graph that executes as per-round mode, namely all the operators would be re-created for each round.
Code Block | ||
---|---|---|
| ||
public interface IterationBody { .... /** * Constructs a subgraph inside the iteration body that all the operators would have a lifecycle * of {@link org.apache.flink.iteration.IterationConfig.OperatorLifeCycle#PER_ROUND}. */ class PerRound { /** * @param inputs The inputs of the subgraph. * @param perRoundSubBody The computational logic that want to be executed as per-round. * @return The output of the subgraph. */ public static DataStreamList forEachRound( DataStreamList inputs, PerRoundSubBody perRoundSubBody) { return null; } } /** The sub-graph inside the iteration body that should be executed as per-round. */ interface PerRoundSubBody { DataStreamList process(DataStreamList input); } } |
6) Add the DataStreamList and ReplayableDataStreamList class.
...
Code Block | ||||
---|---|---|---|---|
| ||||
public class SynchronousBoundedLinearRegression { private static final int N_DIM = 50; private static final int N_EPOCH = 5; private static final int N_BATCH_PER_EPOCH = 10; private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{}; public static void main(String[] args) { DataStream<double[]> initParameters = loadParameters().setParallelism(1); DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1); DataStreamList resultStreams = IterationUtilsIterations.iterateBoundedStreamsUntilTermination( DataStreamList.of(initParameters), ReplayableDataStreamList.notReplay(dataset), IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build(); (variableStreams, dataStreams) -> { DataStream<double[]> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction()); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .broadcast() .connect(dataset) .coProcess(new TrainFunction()) .setParallelism(10) .process(new ReduceFunction()) .setParallelism(1) return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }); DataStream<double[]> finalModel = resultStreams.get("final_model"); finalModel.print(); } public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> implements IterationListener<double[]> { private final double[] parameters = new double[N_DIM]; public void processElement(double[] update, Context ctx, Collector<O> output) { // Suppose we have a util to add the second array to the first. ArrayUtils.addWith(parameters, update); } void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) { if (epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH) { collector.collect(parameters); } } public void onIterationEnd(int[] round, Context context) { context.output(FINAL_MODEL_OUTPUT_TAG, parameters); } } public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> { private final List<Tuple2<double[], Double>> dataset = new ArrayList<>(); private double[] firstRoundCachedParameter; private Supplier<int[]> recordRoundQuerier; public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) { this.recordRoundQuerier = querier; } public void processElement1(double[] parameter, Context context, Collector<O> output) { int[] round = recordRoundQuerier.get(); if (round[0] == 0) { firstRoundCachedParameter = parameter; return; } calculateModelUpdate(parameter, output); } public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) { dataset.add(trainSample) } void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) { if (epochWatermark == 0) { calculateModelUpdate(firstRoundCachedParameter, output); firstRoundCachedParameter = null; } } private void calculateModelUpdate(double[] parameters, Collector<O> output) { List<Tuple2<double[], Double>> samples = sample(dataset); double[] modelUpdate = new double[N_DIM]; for (Tuple2<double[], Double> record : samples) { double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1); ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff)); } output.collect(modelUpdate); } } public static class ReduceFunction { private double[] mergedValue = ArrayUtils.newArray(N_DIM); public void processElement(double[] parameter, Context context, Collector<O> output) { mergedValue = ArrayUtils.add(mergedValue, parameter); } void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) { collector.collect(mergedValue); mergedValue = ArrayUtils.newArray(N_DIM); } } } |
...
Code Block | ||
---|---|---|
| ||
DataStreamList resultStreams = IterationUtilsIterations.iterateBoundedStreamsUntilTermination( DataStreamList.of(initParameters), ReplayableDataStreamList.notReplay(dataset), IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build(); (variableStreams, dataStreams) -> { DataStream<double[]> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction()); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .broadcast() .connect(dataset) .coProcess(new TrainFunction()) .setParallelism(10) DataStream<double[]> reduced = PerRoundGraphBuilder.forEachRound(DataStreamList.of(modelUpdate), streams -> { return streams.<double[]>get(0).windowAll().reduce((x, y) -> ArrayUtils.add(x, y)); }).<double[]>get(0); return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }); |
...
Code Block | ||
---|---|---|
| ||
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>> implements BoundedIterationProgressListener<double[]> { private final double[] parameters = new double[N_DIM]; public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) { // Suppose we have a util to add the second array to the first. ArrayUtils.addWith(parameters, update); output.collect(new Tuple2<>(update.f0, parameters)) } public void onIterationEnd(int[] round, Context context) { context.output(FINAL_MODEL_OUTPUT_TAG, parameters); } } public class AsynchronousBoundedLinearRegression { ... DataStreamList resultStreams = IterationUtilsIterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> { DataStream<double> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction()); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .partitionCustom((key, numPartitions) -> key % numPartitions, update -> update.f0) .connect(dataset) .coProcess(new TrainFunction()) .setParallelism(10) return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }); ... } |
...
Code Block | ||
---|---|---|
| ||
public class SynchronousUnboundedLinearRegression { private static final N_DIM = 50; private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{}; public static void main(String[] args) { DataStream<double[]> initParameters = loadParameters().setParallelism(1); DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1); DataStreamList resultStreams = IterationUtilsIterations.iterateUnboundedStreams(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> { DataStream<double> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10)); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .broadcast() .connect(dataset) .transform( "operator", BasicTypeInfo.INT_TYPE_INFO, new TrainOperators(50)); .setParallelism(10); return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }) DataStream<double[]> finalModel = resultStreams.get("model_update"); finalModel.addSink(...) } public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> { private final int numOfTrainTasks; private final int numOfUpdatesReceived = 0; private final double[] parameters = new double[N_DIM]; public ParametersCacheFunction(int numOfTrainTasks) { this.numOfTrainTasks = numOfTrainTasks; } public void processElement(double[] update, Context ctx, Collector<O> output) { // Suppose we have a util to add the second array to the first. ArrayUtils.addWith(parameters, update); numOfUpdatesReceived++; if (numOfUpdatesReceived == numOfTrainTasks) { output.collect(parameters); numOfUpdatesReceived = 0; } } } public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable { private final int miniBatchSize; private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>(); private double[] firstRoundCachedParameter; public TrainOperators(int miniBatchSize) { this.miniBatchSize = miniBatchSize; } public void processElement1(double[] parameter, Context context, Collector<O> output) { calculateModelUpdate(parameter, output); miniBatchSize.clear(); } public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) { dataset.add(trainSample); } public InputSelection nextSelection() { if (miniBatch.size() < miniBatchSize) { return InputSelection.SECOND; } else { return InputSelection.FIRST; } } private void calculateModelUpdate(double[] parameters, Collector<O> output) { double[] modelUpdate = new double[N_DIM]; for (Tuple2<double[], Double> record : miniBatchSize) { double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1); ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff)); } output.collect(modelUpdate); } } } |
...