Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration;

/**
 * A helper class to apply {@link IterationBody} to data streams.
 */
@PublicEvolving
public class IterationUtils {
    /**
     * This method can use an iteration body to process records in unbounded data streams.
     *
     * This method invokes the iteration body with the following parameters:
     * 1) The 1st parameter is a list of input variable streams, which are created as the union of the initial variable
     * streams and the corresponding feedback variable streams (returned by the iteration body).
     * 2) The 2nd parameter is the data streams given to this method.
     *
     * The epoch values are determined as described below. See IterationListener for how the epoch values are used.
     * 1) All records in the initial variable streams has epoch=1.
     * 2) All records in the data streams has epoch=MAX_LONG. In this case, records in the data stream won't affect
     * any operator's epoch watermark.
     * 3) For any record emitted by this operator into a non-feedback stream, the epoch of this emitted record = the
     * epoch of the input record that triggers this emission. If this record is emitted by
     * onEpochWatermarkIncremented(), then the epoch of this record = epochWatermark.
     * 4) For any record emitted by this operator into a feedback variable stream, the epoch of the emitted record =
     * min(the epoch of the input record that triggers this emission, MAX_LONG - 1) + 1. If this record is emitted by
     * onEpochWatermarkIncremented(), then the epoch of this record = epochWatermark + 1.
     *
     * The execution of the graph created by the iteration body will not terminate by itself. This is because at least
     * one of its data streams is unbounded.
     *
     * Required:
     * 1) All the init variable streams must be bounded.
     * 2) There is at least one unbounded stream in the data streams list.
     * 3) The parallelism of any stream in the initial variable streams must equal the parallelism of the stream at the
     * same index of the feedback variable streams returned by the IterationBody.
     *
     * @param initVariableStreams The initial variable streams. These streams will be merged with the feedback variable
     *                            streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The data streams. These streams will be used as the 2nd parameter to invoke the iteration
     *                    body.
     * @param body The computation logic which takes variable/data streams and returns variable/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    static DataStreamList iterateUnboundedStreams(DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {...}

    /**
     * This method can use an iteration body to process records in some bounded data streams iteratively until a
     * termination criteria is reached (e.g. the given number of rounds is completed or no further variable update is
     * needed). Because this method does not replay records in the data streams, the iteration body needs to cache those
     * records in order to visit those records repeatedly.
     *
     * This method invokes the iteration body with the following parameters:
     * 1) The 1st parameter is a list of input variable streams, which are created as the union of the initial variable
     * streams and the corresponding feedback variable streams (returned by the iteration body).
     * 2) The 2nd parameter is the data streams given to this method.
     *
     * The epoch values are determined as described below. See IterationListener for how the epoch values are used.
     * 1) All records in the initial variable streams has epoch=1.
     * 2) All records in the data streams has epoch=1.
     * 3) For any record emitted by this operator into a non-feedback stream, the epoch of this emitted record = the
     * epoch of the input record that triggers this emission. If this record is emitted by
     * onEpochWatermarkIncremented(), then the epoch of this record = epochWatermark.
     * 4) For any record emitted by this operator into a feedback variable stream, the epoch of the emitted record = the
     * epoch of the input record that triggers this emission + 1. If this record is emitted by
     * onEpochWatermarkIncremented(), then the epoch of this record = epochWatermark + 1.
     *
     * Suppose there is a coordinator operator which takes all feedback variable streams (emitted by the iteration body)
     * and the termination criteria stream (if not null) as inputs. The execution of the graph created by the
     * iteration body will terminate when all input streams have been fully consumed AND any of the following conditions is met:
     * is consumedmet:
     * 1) The termination criteria stream is not null. And the coordinator operator has not observed any new value from
     * the termination criteria stream between two consecutive onEpochWatermarkIncremented invocations.
     * 2) The coordinator operator has not observed any new value from any feedback variable stream between two
     * consecutive onEpochWatermarkIncremented invocations.
     *
     * Required:
     * 1) All the init variable streams and the data streams must be bounded.
     * 2) The parallelism of any stream in the initial variable streams must equal the parallelism of the stream at the
     * same index of the feedback variable streams returned by the IterationBody.
     *
     * @param initVariableStreams The initial variable streams. These streams will be merged with the feedback variable
     *                            streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The data streams. These streams will be used as the 2nd parameter to invoke the iteration
     *                    body.
     * @param body The computation logic which takes variable/data streams and returns variable/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    static DataStreamList iterateBoundedStreamsUntilTermination(DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {...}

    /**
     * This method can use an iteration body to process records in some bounded data streams iteratively until a
     * termination criteria is reached (e.g. the given number of rounds is completed or no further variable update is
     * needed). Because this method replays records in the data streams, the iteration body does not need to cache those
     * records to visit those records repeatedly.
     *
     * This method invokes the iteration body with the following parameters:
     * 1) The 1st parameter is a list of input variable streams, which are created as the union of the initial variable
     * streams and the corresponding feedback variable streams (returned by the iteration body).
     * 2) The 2nd parameter is a list of replayed data streams, which are created by replaying the initial data streams
     * round by round until the iteration terminates. The records in the Nth round will be emitted into the iteration
     * body only if the low watermark of the first operator in the iteration body >= N - 1.
     *
     * The epoch values are determined as described below. See IterationListener for how the epoch values are used.
     * 1) All records in the initial variable streams has epoch=1.
     * 2) The records from the initial data streams will be replayed round by round into the iteration body. The records
     * in the first round have epoch=1. And records in the Nth round have epoch = N.
     * 3) For any record emitted by this operator into a non-feedback stream, the epoch of this emitted record = the
     * epoch of the input record that triggers this emission. If this record is emitted by
     * onEpochWatermarkIncremented(), then the epoch of this record = epochWatermark.
     * 4) For any record emitted by this operator into a feedback stream, the epoch of the emitted record = the epoch
     * of the input record that triggers this emission + 1. If this record is emitted by onEpochWatermarkIncremented(),
     * then the epoch of this record = epochWatermark + 1.
     *
     * Suppose there is a coordinator operator which takes all feedback variable streams (emitted by the iteration body)
     * and the termination criteria stream (if not null) as inputs. The execution of the graph created by the
     * iteration body will terminate when all input streams have been fully consumed AND any of the following conditions
     * is met:
     * 1) The termination criteria stream is not null. And the coordinator operator has not observed any new value from
     * the termination criteria stream between two consecutive onEpochWatermarkIncremented invocations.
     * 2) The coordinator operator has not observed any new value from any feedback variable stream between two
     * consecutive onEpochWatermarkIncremented invocations.
     *
     * Required:
     * 1) All the init variable streams and the data streams must be bounded.
     * 2) The parallelism of any stream in the initial variable streams must equal the parallelism of the stream at the
     * same index of the feedback variable streams returned by the IterationBody.
     *
     * @param initVariableStreams The initial variable streams. These streams will be merged with the feedback variable
     *                            streams before being used as the 1st parameter to invoke the iteration body.
     * @param initDataStreams The initial data streams. Records from these streams will be repeatedly replayed and used
     *                        as the 2nd parameter to invoke the iteration body.
     * @param body The computation logic which takes variable/data streams and returns variable/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    static DataStreamList iterateAndReplayBoundedStreamsUntilTermination(DataStreamList initVariableStreams, DataStreamList initDataStreams, IterationBody body) {...}
}

...

We plan to remove both methods after the APIs added in this doc is ready for production use. This change is needed to decouple the iteration-related APIs from core Flink core runtime  so that we can keep the Flink core runtime as simple and maintainable as possible.

Proposed Changes

1) Termination of the iteration execution.

See the Java doc of those APIs in the IterationUtils for how each API determine the iteration termination.

2) Execution mode.

If all inputs are bounded streams, then the iteration body can be executed in either the stream mode or the batch mode.

If some inputs are unbounded streams, then the iteration body must be executed in the stream mode.

3) Type of edges inside the iteration body.

All edges inside the iteration body are required to have the PIPELINE type.

If the user-defined iteration body contains an edge that does not have the PIPELINE type, methods that create the subgraph from the iteration body, such as iterateBoundedStreamsUntilTermination, will throw exception upon invocation.

4) Implementation of the feedback edge.

The Flink core runtime supports only DAG of operators. Thus it does not provide native support for feedback edges since feedback edges introduce circle in the operator graph.

Same as the implementation of the DataSet::iterate() API, the proposed APIs are implemented with the following trick:

In this section, we discuss a few design choices related to the implementation and usage of the proposed APIs.

1) How the termination of the iteration execution is determined.

We will add a coordinator operator which takes all feedback variable streams (emitted by the iteration body) and the termination criteria stream (if not null) as inputs. The execution of the graph created by the iteration body will terminate when all input streams have been fully consumed AND any of the following conditions is met:

  • The termination criteria stream is not null. And the coordinator operator has not observed any new value from the termination criteria stream between two consecutive onEpochWatermarkIncremented invocations.
  • The coordinator operator has not observed any new value from any feedback variable stream between two consecutive onEpochWatermarkIncremented invocations.

2) The execution mode that is required to execute the iteration body.

  • If all inputs streams are bounded, then the iteration body can be executed in either the stream mode or the batch mode.
  • If any input stream is unbounded, then the iteration body must be executed in the stream mode.


3) The edge type that should be used inside the iteration body.

All edges inside the iteration body are required to have the PIPELINE type.

If the user-defined iteration body contains an edge that does not have the PIPELINE type, methods that create the subgraph from the iteration body, such as iterateBoundedStreamsUntilTermination, will throw exception upon invocation.


4) How the feedback edge is supported.

The Flink core runtime supports only DAG of operators. Thus it does not provide native support for feedback edges since feedback edges introduce circle in the operator graph.

Same as the implementation of the DataSet::iterate() API, the proposed APIs are implemented with the following approach:

  • Automatically insert the HEAD and the TAIL operators as the first and the last operators in the iteration body.
  • Co-locate the HEAD and the
  • Automatically insert HEAD and TAIL operators as the first and last operators in the iteration body.
  • Co-locate HEAD and TAIL operators on the same task manager.
  • Have the HEAD and the TAIL operators transmit the records of the feedback edges using an in-memory queue.

...

5) Lifetime of the operators inside the iteration body.

The operator With the approach proposed in this FLIP, the operators inside the iteration body are only created once and destroyed after the iteration terminates.

In contrastcomparison, the existing DataSet::iterate(..) would destroy and re-create the iteration body (together with all states inside it) once for every each round of executioniteration, which in general could introduce more runtime overhead then the approach adopted in this FLIP.


6) FailoverHow an iteration can resume from the most recent completed epoch after failover.

For jobs any job that are is executed in the batch mode, this FLIP does not support failover from the middle of an iteration, i.e.. the job can not start from a recent epoch after failover. In other words, if an iterative job fails, it will start from the every first epoch of this iteration. Note that the existing DataSet::iterate(...) has the same pattern behavior after job failover.

For jobs any job that are is executed in the stream mode, this FLIP supports failover from the middle of an iteration, i.e. if an iterative job fails, it will be re-started from the latest epoch that has been completed before the job fails.

This failover basically re-use the existing checkpoint mechanism with the following extra work: the runtime will recognizes the records buffered on the feedback edge and include these records in the checkpoint.

the job can start from a recent epoch after failover. This is achieved by re-using the existing checkpoint mechanism (only available in the stream mode) and additionally checkpointing the values buffered on the feedback edges.


7) How to implement an iterative algorithm in the sync mode7) Support for synchronous iteration.

Definition of sync-mode

An iterative algorithm is run in sync-mode if there exists global epoch, such that at the time a given operator computes its output for the Nth epoch, this operator has received exactly the following records from its input edges:

...

  • When any operator within the IterationBody receives values from its input edges, this operator does not immediately emit records to its output.
  • Operators inside the IterationBody only compute and emit records to their outputs in the onEpochWatermarkIncremented(...) callback. The emitted records should be computed based on the values received from the input edges up to the invocation of this callback.

Proof

In the following, we will prove that See the Appendix section for a proof of why the solution described above could enforce achieve the sync-mode execution . Note that the calculation of the record's epoch and the semantics of onEpochWatermarkIncremented(...) are described in the Java doc of the corresponding APIs.

Lemma-1: For any operator OpB defined in the IterationBody, at the time its Nth invocation of onEpochWatermarkIncremented(...) starts, it is guaranteed that:

  • If an input edge is a non-feedback edge from OpA, then OpA's Nth invocation of onEpochWatermarkIncremented(...) has been completed.
  • If an input edge is a feedback edge from OpA, then OpA's (N-1)th invocation of onEpochWatermarkIncremented(...) has been completed.

Let's prove the lemma-1 by contradiction:

  • At the time the OpB's Nth invocation starts, its epoch watermark has incremented to N, which means OpB will no longer receive any record with epoch <= N.
  • Suppose there is a non-feedback edge from OpA AND OpA's Nth invocation has not been completed. Then when OpA's Nth invocation completes, OpA can generate a record with epoch=N and send it to OpB via this non-feedback edge, which contradicts the guarantee described above.
  • Suppose there is a feedback edge from OpA AND OpA's (N-1)th invocation has not been completed. Then when OpA's (N-1)th invocation completes, OpA can generate a record with epoch=N and send it to OpB via this feedback edge, which contradicts the guarantee described above.

Lemma-2: For any operator OpB defined in the IterationBody, at the time its Nth invocation of onEpochWatermarkIncremented(...) starts, it is guaranteed that:

  • If an edge is a non-feedback input edge from OpA and this edge is part of a feedback loop, then OpA's (N+1)th invocation of onEpochWatermarkIncremented(...) has not started.
  • If an edge is a feedback input edge from OpA and this edge is part of a feedback loop, then OpA's Nth invocation of onEpochWatermarkIncremented(...) has not started.

Let's prove this lemma by contradiction:

  • Suppose there is a non-feedback edge from OpA, this edge is part of a feedback loop, and OpA's (N+1)th invocation has started. Since this non-feedback edge is part of a feedback loop, there is a backward path from OpA to OpB with exactly 1 feedback edge on this path. By applying the lemma-1 recursively for operators on this path, we can tell that OpB's Nth invocation has been completed. This contradicts the assumption that OpB's Nth invocation just started.
  • Suppose there is a feedback edge from OpA, this edge is part of a feedback loop, and OpA's Nth invocation has started. Since this feedback edge is part of feedback loop, there is a backward path from OpA to OpB with no feedback edge on this path. By applying lemma-1 recursively for operators on this path, we can tell that OpB's Nth invocation has been completed. This contradicts the assumption that OpB's Nth invocation just started.

Let's now prove that the sync-mode is achieved:

  • For any operator in the IterationBody, we define its output for the Nth epoch as the output emitted by the Nth invocation of onEpochWatermarkIncremented(). This definition is well-defined because operators only emit records in onEpochWatermarkIncremented().
  • At the time an operator OpB computes its output for the Nth epoch, this operator must have received exactly the following records from its input edges:
    • Suppose an edge is a non-feedback input edge from OpA and this edge is part of a feedback loop. It follows that OpA has emitted records for its Nth epoch (by lemma-1) and has not started to emit records for its (N+1)th epoch (by lemma-2).
    • Suppose an edge is a feedback input edge from OpA and this edge is part of a feedback loop. It follows that OpA has emitted records for its (N-1)th epoch (by lemma-1) and has not started to emit records for its Nth epoch (by lemma-2).

8) Run iterative algorithm without dumping all user-provided data streams to disk.

As mentioned in the motivation section, the existing DataSet::iterate() always dump the user-provided data streams to disk so that it can replay the data streams regardless of the size of those data streams. Since this is the only available API to do iteration on bounded data streams, there is no way for algorithm developer to get rid of this performance overhead.

In comparison, this FLIP provides the iterateBoundedStreamsUntilTermination(...) for users to run an iteration body without having this performance overhead. Developers have the freedom to optimize the performance based on its algorithm and data size, e.g. cache data in memory in a more compact format.

Example Usages

Offline Training with Bounded Iteration

We would like to first show the usage of the bounded iteration with the linear regression case: the model is Y = XA, and we would like to acquire the best estimation of A with the SGD algorithm. To simplify we assume the parameters could be held in the memory of one task.

The job graph of the algorithm could be shown in the Figure 3: in each round, we use the latest parameters to calculate the update to the parameters: ΔA = ∑(Y - XA)X. To achieve this, the Parameters vertex would broadcast the latest parameters to the Train vertex. Each subtask of the Train vertex holds a part of dataset. Follow the sprite of SGD, it would sample a small batch of training records, and calculate the update with the above equation. Then the Train vertex emit ΔA to the Parameters node to update the parameters.

We will start with the synchronous training. The synchronous training requires the updates from all the Train vertex subtask is merged before the next round of training. It could be done by only emit the next round of parameters on the end of round. The code is shown as follows:

as defined above.


8) How to run an iterative algorithm without dumping all user-provided data streams to disk.

As mentioned in the motivation section, the existing DataSet::iterate() always dump the user-provided data streams to disk so that it can replay the data streams regardless of the size of those data streams. Since this is the only available API to do iteration on bounded data streams, there is no way for algorithm developer to get rid of this performance overhead.

In comparison, the iterateBoundedStreamsUntilTermination(...) method proposed in this FLIP allows users to run an iteration body without incurring this disk performance overhead. Developers have the freedom to optimize the performance based on its algorithm and data size, e.g. cache data in memory in a more compact format.

Example Usages

Offline Training with Bounded Iteration

We would like to first show the usage of the bounded iteration with the linear regression case: the model is Y = XA, and we would like to acquire the best estimation of A with the SGD algorithm. To simplify we assume the parameters could be held in the memory of one task.

The job graph of the algorithm could be shown in the Figure 3: in each round, we use the latest parameters to calculate the update to the parameters: ΔA = ∑(Y - XA)X. To achieve this, the Parameters vertex would broadcast the latest parameters to the Train vertex. Each subtask of the Train vertex holds a part of dataset. Follow the sprite of SGD, it would sample a small batch of training records, and calculate the update with the above equation. Then the Train vertex emit ΔA to the Parameters node to update the parameters.

We will start with the synchronous training. The synchronous training requires the updates from all the Train vertex subtask is merged before the next round of training. It could be done by only emit the next round of parameters on the end of round. The code is shown as follows:

Code Block
languagejava
linenumberstrue
public class SynchronousBoundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        int batch = 5;
        int epochEachBatch = 10;

        ResultStreams resultStreams = new BoundedIteration()
            .withBody(new IterationBody(
                @IterationInput("model") DataStream<double[]> model,
                @IterationInput("dataset") DataStream<Tuple2<double[], Double>> dataset
            ) {
                SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction());
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .coProcess(new TrainFunction())
                    .setParallelism(10)

                return new BoundedIterationDeclarationBuilder()
                    .withFeedback("model", modelUpdate)
                    .withOutput("final_model", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))
                    .until(new TerminationCondition(null, context -> context.getRound() >= batch * epochEachBatch))
                    .build();
            })
            .build();
        
        DataStream<double[]> finalModel = resultStreams.get("final_model");
        finalModel.print();
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
        implements BoundedIterationProgressListener<double[]> {  
        
        private final double[] parameters = new double[N_DIM];

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
        }

        public void onRoundEnd(int[] round, Context context, Collector<T> collector) {
            collector.collect(parameters);
        }

        public void onIterationEnd(int[] round, Context context) {
            context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
        }
    }

    public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements BoundedIterationProgressListener<double[]> {

        private final List<Tuple2<double[], Double>> dataset = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        private Supplier<int[]> recordRoundQuerier
Code Block
languagejava
linenumberstrue
public class SynchronousBoundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        int batch = 5;
        int epochEachBatch = 10;

        ResultStreamspublic resultStreams = new BoundedIteration()void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
            .withBody(new IterationBody(this.recordRoundQuerier = querier;
        } 

        public void @IterationInput("model") DataStream<doubleprocessElement1(double[]> modelparameter,
 Context context, Collector<O> output) {
           @IterationInput("dataset") DataStream<Tuple2<double int[], Double>>round dataset
            ) {= recordRoundQuerier.get();
                SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction());if (round[0] == 0) {
                DataStream<double[]> modelUpdatefirstRoundCachedParameter = parameters.setParallelism(1)parameter;
                return;
    .broadcast()
        }

            .connect(dataset)calculateModelUpdate(parameter, output);
        }

        public void   .coProcess(new TrainFunction())
      processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
              dataset.setParallelismadd(10trainSample)

        }

        returnpublic newvoid BoundedIterationDeclarationBuilder()
      onRoundEnd(int[] round, Context context, Collector<T> output) {
            if  .withFeedback("model", modelUpdate)(round[0] == 0) {
                    .withOutput("final_model", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))calculateModelUpdate(firstRoundCachedParameter, output);
                firstRoundCachedParameter = null;  .until(new TerminationCondition(null, context -> context.getRound() >= batch * epochEachBatch))
      
              .build();
    }
        })

        private void   .build();
    calculateModelUpdate(double[] parameters, Collector<O> output) {
    
        DataStream<doubleList<Tuple2<double[]>, Double>> finalModelsamples = resultStreams.get("final_model"sample(dataset);

        finalModel.print();
    }

    publicdouble[] staticmodelUpdate class ParametersCacheFunction extends ProcessFunction<double[],= new double[N_DIM]>;
        implements BoundedIterationProgressListener<double    for (Tuple2<double[]>, {Double> record 
: samples) {
      
        private final double[] parametersdiff = new double[N_DIM];

   (ArrayUtils.muladd(record.f0, parameters) - record.f1);
     public void processElement(double[] update, Context ctx, Collector<O> output) {
   ArrayUtils.addWith(modelUpdate,         // Suppose we have a util to add the second array to the first.ArrayUtils.multiply(record.f0, diff));
            }

            ArrayUtilsoutput.addWith(parameters, updatecollect(modelUpdate);
        }

        public void onRoundEnd(int[] round, Context context, Collector<T> collector) {
            collector.collect(parameters);
        }

        public void onIterationEnd(int[] round, Context context) {
            context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
        }
    }

    public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements BoundedIterationProgressListener<double[]> {

        private final List<Tuple2<double[], Double>> dataset = new ArrayList<>();}
}

If instead we want to do asynchronous training, we would need to do the following change:

  1. The Parameters vertex would not wait till round end to ensure received all the updates from the iteration. Instead, it would immediately output the current parameters values once it received the model update from one train subtask.
  2. To label the source of the update, we would like to change the input type to be Tuple2<Integer, double[]>. The Parameters would only output the new parameters values to the Train task that send the update.

We omit the change to the graph building code since the change is trivial (change the output type and the partitioner to be customized one). The change to the Parameters vertex is the follows:

Code Block
languagejava
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>>
    implements BoundedIterationProgressListener<double[]> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        // Suppose we have a util to add the second array to the first.
        private double[] firstRoundCachedParameter;
ArrayUtils.addWith(parameters, update);
        private Supplier<int[]> recordRoundQuerier;
output.collect(new Tuple2<>(update.f0, parameters))
    }

    public void setCurrentRecordRoundsQuerieronIterationEnd(Supplier<intint[]> querier) {
            this.recordRoundQuerier = querier;
        } 

        public void processElement1(double[] parameter round, Context context, Collector<O> output) {
            int[] round = recordRoundQuerier.get(context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
    }
}

Online Training with Unbounded Iteration

Suppose now we would change the algorithm to unbounded iteration, compared to the offline, the differences is that

  1. The dataset is unbounded. The Train operator could not cache all the data in the first round.
  2. The training algorithm might be changed to others like FTRL. But we keep using SGD in this example since it does not affect showing the usage of the iteration.

We also start with the synchronous case. for online training, the Train vertex usually do one update after accumulating one mini-batch. This is to ensure the distribution of the samples is similar to the global statistics. In this example we omit the complex data re-sample process and just fetch the next several records as one mini-batch. 

The JobGraph for online training is still shown in Figure 1, with the training dataset become unbounded. Similar to the bounded cases, for the synchronous training, the process would be expected like

  1. The Parameters broadcast the initialized values on received the input values.
  2. All the Train task read the next mini-batch of records, Calculating an update and emit to the Parameters vertex. Then it would wait till received update parameters from the Parameters Vertex before it head to process the next mini-batch.
  3. The Parameter vertex would wait received the updates from all the Train tasks before it broadcast the updated parameters. 

Since in the unbounded case there is not the concept of round, and we do update per-mini-batch, thus we could instead use the InputSelectable functionality to implement the algorithm:

Code Block
languagejava
public class SynchronousUnboundedLinearRegression        if (round[0] == 0) {
                firstRoundCachedParameter = parameter;
                return;
            }

            calculateModelUpdate(parameter, output);
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample)
        }

        public void onRoundEnd(int[] round, Context context, Collector<T> output) {
            if (round[0] == 0) {
    private static final N_DIM         calculateModelUpdate(firstRoundCachedParameter, output)= 50;
    private static final          firstRoundCachedParameterOutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = null;                
            }
        }new OutputTag<double[]>{};

    public    privatestatic void calculateModelUpdatemain(doubleString[] parameters, Collector<O> outputargs) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        List<Tuple2<doubleDataStream<Tuple2<double[], Double>> samplesdataset = sample(datasetloadDataSet().setParallelism(1);

        ResultStreams resultStreams =  double[] modelUpdate = new double[N_DIM];
new UnboundedIteration()
            .withBody(new IterationBody(
   for (Tuple2<double[], Double> record : samples) {
       @IterationInput("model") DataStream<double[]> model,
       double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
   @IterationInput("dataset") DataStream<Tuple2<double[], Double>> dataset
          ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff)); {
            }

    SingleOutputStreamOperator<double[]> parameters =      output.collect(modelUpdatemodel.process(new ParametersCacheFunction(10));
        }
    }
}

If instead we want to do asynchronous training, we would need to do the following change:

  1. The Parameters vertex would not wait till round end to ensure received all the updates from the iteration. Instead, it would immediately output the current parameters values once it received the model update from one train subtask.
  2. To label the source of the update, we would like to change the input type to be Tuple2<Integer, double[]>. The Parameters would only output the new parameters values to the Train task that send the update.

We omit the change to the graph building code since the change is trivial (change the output type and the partitioner to be customized one). The change to the Parameters vertex is the follows:

Code Block
languagejava
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>>
    implements BoundedIterationProgressListener<double[]> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
          DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .transform(
                                "operator",
                                BasicTypeInfo.INT_TYPE_INFO,
                                // Suppose we have a util to add the second array to the first.
new TrainOperators(50));
                    ArrayUtils.addWith(parameters, updatesetParallelism(10);

                return output.collect(new Tuple2<>(update.f0, parameters)UnBoundedIterationDeclarationBuilder()
    }

    public void onIterationEnd(int[] round, Context context) {
        context.output(FINAL_MODEL_OUTPUT_TAGwithFeedback("model", parameters);
    }
}

Online Training with Unbounded Iteration

Suppose now we would change the algorithm to unbounded iteration, compared to the offline, the differences is that

  1. The dataset is unbounded. The Train operator could not cache all the data in the first round.
  2. The training algorithm might be changed to others like FTRL. But we keep using SGD in this example since it does not affect showing the usage of the iteration.

We also start with the synchronous case. for online training, the Train vertex usually do one update after accumulating one mini-batch. This is to ensure the distribution of the samples is similar to the global statistics. In this example we omit the complex data re-sample process and just fetch the next several records as one mini-batch. 

The JobGraph for online training is still shown in Figure 1, with the training dataset become unbounded. Similar to the bounded cases, for the synchronous training, the process would be expected like

  1. The Parameters broadcast the initialized values on received the input values.
  2. All the Train task read the next mini-batch of records, Calculating an update and emit to the Parameters vertex. Then it would wait till received update parameters from the Parameters Vertex before it head to process the next mini-batch.
  3. The Parameter vertex would wait received the updates from all the Train tasks before it broadcast the updated parameters. 

Since in the unbounded case there is not the concept of round, and we do update per-mini-batch, thus we could instead use the InputSelectable functionality to implement the algorithm:

Code Block
languagejava
public class SynchronousUnboundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
modelUpdate)
                    .withOutput("model_update", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))
                    .build();
            })
            .build();
        
        DataStream<double[]> finalModel = resultStreams.get("model_update");
        finalModel.addSink(...)
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> {  
        
        private final int numOfTrainTasks;

        private final int numOfUpdatesReceived = 0;
        private final double[] parameters = new double[N_DIM];

        public ParametersCacheFunction(int numOfTrainTasks) {
           DataStream<double[]> initParameters this.numOfTrainTasks = loadParameters().setParallelism(1)numOfTrainTasks;
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);}

        ResultStreamspublic resultStreams = new UnboundedIteration()
      void processElement(double[] update, Context ctx, Collector<O> output) {
      .withBody(new IterationBody(
     // Suppose we have a util to add the second  @IterationInput("model") DataStream<double[]> model,array to the first.
                @IterationInput("dataset") DataStream<Tuple2<double[], Double>> datasetArrayUtils.addWith(parameters, update);
            ) {numOfUpdatesReceived++;

            if    SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10));(numOfUpdatesReceived == numOfTrainTasks) {
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)output.collect(parameters);
                numOfUpdatesReceived    .broadcast()= 0;
            }
        .connect(dataset)}
    }

    public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable .transform({

        private final int miniBatchSize;

        private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>();
        "operator",
 private double[] firstRoundCachedParameter;

        public TrainOperators(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
        BasicTypeInfo.INT_TYPE_INFO,}

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
                new TrainOperators(50)calculateModelUpdate(parameter, output);
			miniBatchSize.clear();
        }

        public void processElement2(Tuple2<double[],  .setParallelism(10);

     Double> trainSample, Context context, Collector<O> output) {
           return new UnBoundedIterationDeclarationBuilder() dataset.add(trainSample);
        }

        public InputSelection   .withFeedback("model", modelUpdate)nextSelection() {
                    .withOutput("model_update", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))
   if (miniBatch.size() < miniBatchSize) {
                return InputSelection.build()SECOND;
            })
 else {
          .build();
      return InputSelection.FIRST;
 
        DataStream<double[]> finalModel = resultStreams.get("model_update");
     }
    finalModel.addSink(...)
    }

    public  static class ParametersCacheFunctionprivate extendsvoid ProcessFunction<doublecalculateModelUpdate(double[] parameters, double[]> Collector<O> output) {
  
          double[] modelUpdate = new double[N_DIM];
        private final int numOfTrainTasks;

 for (Tuple2<double[], Double> record : miniBatchSize) {
 private final int numOfUpdatesReceived = 0;
        private final double[] parametersdiff = new double[N_DIM];

   (ArrayUtils.muladd(record.f0, parameters) - record.f1);
     public ParametersCacheFunction(int numOfTrainTasks) {
        ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
  this.numOfTrainTasks = numOfTrainTasks;
        }

        public void processElement(double[] update, Context ctx, Collector<O> output) {output.collect(modelUpdate);
        }
    // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);}
}

Also similar to the bounded case, for the asynchronous training the Parameters vertex would not wait for received updates from all the Train tasks. Instead, it would directly response to the task sending update:

Code Block
languagejava
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>> {  
    
    private final double[] parameters =    numOfUpdatesReceived++new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx,  if (numOfUpdatesReceived == numOfTrainTasksCollector<Tuple2<Integer, double[]>> output) {
        ArrayUtils.addWith(parameters, update);
         output.collect(parameters);
       
        if numOfUpdatesReceived =(update.f0 < 0;) {
            }
// Received the initialized parameter values, broadcast to }
all the downstream  }
tasks
    public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable {

        private final int miniBatchSize;

 for (int i = 0; i < 10; ++i) {
              private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>();
 output.collect(new Tuple2<>(i, parameters))        private
 double[] firstRoundCachedParameter;

        public TrainOperators(int miniBatchSize) {
 }
        } else {
   this.miniBatchSize  = miniBatchSize;
      output.collect(new  }

Tuple2<>(update.f0, parameters))
        }
 public void processElement1(double[] parameter, Context context, Collector<O> output) {
            calculateModelUpdate(parameter, output);
			miniBatchSize.clear();
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample);
        }

        public InputSelection nextSelection() {
            if (miniBatch.size() < miniBatchSize) {
                return InputSelection.SECOND;
            } else {
                return InputSelection.FIRST;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : miniBatchSize) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

Also similar to the bounded case, for the asynchronous training the Parameters vertex would not wait for received updates from all the Train tasks. Instead, it would directly response to the task sending update:

Code Block
languagejava
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        ArrayUtils.addWith(parameters, update);
                
        if (update.f0 < 0) {
            // Received the initialized parameter values, broadcast to all the downstream tasks
            for (int i = 0; i < 10; ++i) {
                output.collect(new Tuple2<>(i, parameters))        
            }
        } else {
            output.collect(new Tuple2<>(update.f0, parameters))
        }
    }
}

Compatibility, Deprecation, and Migration Plan

The following APIs will be deprecated and removed in the future Flink release:

  • The entire DataSet class. See FLIP-131 for its motivation and the migration plan. The deprecation of DataSet::iterate(...) proposed by this FLIP is covered by FLIP-131.
  • The DataStream::iterate(...) and DataStream::iterate(long).

The proposed removal of DataStream::iterate(...) and DataStream::iterate(long) is a backward incompatible change. However, we believe that there is not wide-spread usage of these two APIs due to the issues described in FLIP-15.

...

 }
}

Compatibility, Deprecation, and Migration Plan

The following APIs will be deprecated and removed in the future Flink release:

  • The entire DataSet class. See FLIP-131 for its motivation and the migration plan. The deprecation of DataSet::iterate(...) proposed by this FLIP is covered by FLIP-131.
  • The DataStream::iterate(...) and DataStream::iterate(long).

The proposed removal of DataStream::iterate(...) and DataStream::iterate(long) is a backward incompatible change. However, we believe that there is not wide-spread usage of these two APIs due to the issues described in FLIP-15.

Users will need to re-write their application code in order to migrate from the existing iterative APIs to the proposed APIs. We expect that the APIs proposed in this FLIP can support all use-cases supported by the existing iterative APIs.


Appendix

1) In the following, we prove that the proposed solution can be used to implement an iterative algorithm in the sync mode.

Refer to the "Proposed Changes" section for the definition of sync mode and the description of the solution. In the following, we prove that the solution does work as expected.

Proof

In the following, we will prove that the solution described above could enforce the sync-mode execution. Note that the calculation of the record's epoch and the semantics of onEpochWatermarkIncremented(...) are described in the Java doc of the corresponding APIs.

Lemma-1: For any operator OpB defined in the IterationBody, at the time its Nth invocation of onEpochWatermarkIncremented(...) starts, it is guaranteed that:

  • If an input edge is a non-feedback edge from OpA, then OpA's Nth invocation of onEpochWatermarkIncremented(...) has been completed.
  • If an input edge is a feedback edge from OpA, then OpA's (N-1)th invocation of onEpochWatermarkIncremented(...) has been completed.

Let's prove the lemma-1 by contradiction:

  • At the time the OpB's Nth invocation starts, its epoch watermark has incremented to N, which means OpB will no longer receive any record with epoch <= N.
  • Suppose there is a non-feedback edge from OpA AND OpA's Nth invocation has not been completed. Then when OpA's Nth invocation completes, OpA can generate a record with epoch=N and send it to OpB via this non-feedback edge, which contradicts the guarantee described above.
  • Suppose there is a feedback edge from OpA AND OpA's (N-1)th invocation has not been completed. Then when OpA's (N-1)th invocation completes, OpA can generate a record with epoch=N and send it to OpB via this feedback edge, which contradicts the guarantee described above.

Lemma-2: For any operator OpB defined in the IterationBody, at the time its Nth invocation of onEpochWatermarkIncremented(...) starts, it is guaranteed that:

  • If an edge is a non-feedback input edge from OpA and this edge is part of a feedback loop, then OpA's (N+1)th invocation of onEpochWatermarkIncremented(...) has not started.
  • If an edge is a feedback input edge from OpA and this edge is part of a feedback loop, then OpA's Nth invocation of onEpochWatermarkIncremented(...) has not started.

Let's prove this lemma by contradiction:

  • Suppose there is a non-feedback edge from OpA, this edge is part of a feedback loop, and OpA's (N+1)th invocation has started. Since this non-feedback edge is part of a feedback loop, there is a backward path from OpA to OpB with exactly 1 feedback edge on this path. By applying the lemma-1 recursively for operators on this path, we can tell that OpB's Nth invocation has been completed. This contradicts the assumption that OpB's Nth invocation just started.
  • Suppose there is a feedback edge from OpA, this edge is part of a feedback loop, and OpA's Nth invocation has started. Since this feedback edge is part of feedback loop, there is a backward path from OpA to OpB with no feedback edge on this path. By applying lemma-1 recursively for operators on this path, we can tell that OpB's Nth invocation has been completed. This contradicts the assumption that OpB's Nth invocation just started.

Let's now prove that the sync-mode is achieved:

  • For any operator in the IterationBody, we define its output for the Nth epoch as the output emitted by the Nth invocation of onEpochWatermarkIncremented(). This definition is well-defined because operators only emit records in onEpochWatermarkIncremented().
  • At the time an operator OpB computes its output for the Nth epoch, this operator must have received exactly the following records from its input edges:
    • Suppose an edge is a non-feedback input edge from OpA and this edge is part of a feedback loop. It follows that OpA has emitted records for its Nth epoch (by lemma-1) and has not started to emit records for its (N+1)th epoch (by lemma-2).
    • Suppose an edge is a feedback input edge from OpA and this edge is part of a feedback loop. It follows that OpA has emitted records for its (N-1)th epoch (by lemma-1) and has not started to emit records for its Nth epoch (by lemma-2).