Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.


Page properties


Discussion threadhttps://lists.apache.org/thread/9o56d7f094gdqc9mj28mwm9h4ffv02sx
Vote threadhttps://lists.apache.org/thread/899vt2momfqpn65zmx6cq74o3qn41yf1
JIRA

Jira
serverASF JIRA
serverId5aa69414-a9e9-3523-82ec-879b028fb15b
keyFLINK-24642

Releaseml-2.0

Status

Current state: "Under Discussion"

Discussion thread: <TODO>

JIRA: <TODO>

This method is easier to use, but it also limit some possible optimizations

...


Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Table of Contents
-


[This FLIP proposal is a joint work between Yun Gao Dong Lin and Zhipeng Zhang]

...

In general a ML algorithm would update the model according to the data in iteration until the model is converged. The target algorithms could be classified w.r.t. three properties: boundedness of input datasets, amount of data relied required for each variable update and the synchronization policy

...

  • In the sync mode, parallel subtasks, which execute the iteration body, update the model variables in a coordinated manner. There exists global epoch epochs, such that all subtasks fetch the shared model variable values at the beginning of an epoch, calculate model variable updates based on the fetched variable values, and updates the model variable values at the end of this epoch. In other words, al all subtasks read and update model variables in global lock steps.

...

TypeBounded / UnboundedData GranularitySynchronization Pattern Support in the existing APIsSupport in the proposed APIExamples
Non-SGD-basedBoundedEpochMostly SynchronousYesYesOffline K-Means, Apriori, Decision Tree, Random Walk

SGD-Based Synchronous Offline algorithm

BoundedBatch → Epoch*SynchronousYesYesLinear Regression, Logistic Regression, Deep Learning algorithms
SGD-Based Asynchronous Offline algorithmBoundedBatch → Epoch*AsynchronousNoYesSame to the above
SGD-Based Synchronous Online algorithmUnboundedBatchSynchronousNoYesYesOnline version of the above algorithm
SGD-Based Asynchronous Online algorithmUnboundedBatchAsynchronousYesNoYesOnline version of the above algorithm

...

It is important to note that the users typically should not invoke the IterationBody::process directly because the model-variables expected by the iteration body is not the same as the initial-model-variables provided by the user. Instead, model-variables are computed as the union of the feedback-model-variables (emitted by the iteration body) and the initial-model-variables (provided by the caller of the iteration body). To relieve user from creating this union operator, we have added utility class (see IterationUtilsIterations) to run an iteration-body with the user-provided inputs.

The figure below summarizes the iteration paradigm described above. The streams in the red color are inputs provided by the user to the iteration body, as well as outputs emitted by the iteration body to the user.

draw.io Diagram

...

border

Public Interfaces

We propose to make the following API changes to support the iteration paradigm described above. 

true
diagramNameTopology2
simpleViewerfalse
linksauto
tbstyletop
lboxtrue
diagramWidth631
revision2

Public Interfaces

We propose to make the following API changes to support the iteration paradigm described above. 


1)1) Add the IterationBody interface.

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration;

@PublicEvolving
public interface IterationBody {
    /**
     * This method creates the graph for the iteration body.
     *
     * See Utils::iterate, Utils::iterateBoundedStreams and Utils::iterateAndReplayBoundedStreams for how the iteration
     * body can be executed and when execution of the corresponding graph should terminate.
     *
     * Required: the number of feedback variable streams returned by this method must equal the number of variable
     * streams given to this method.
     *
     * @param variableStreams the variable streams.
     * @param dataStreams the data streams.
     * @return a IterationBodyResult.
     */
    IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams);
}

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration;

/**
 * A helper class that contains the streams returned by the iteration body.
 */
class IterationBodyResult {
    /**
     * A list of feedback variable streams. These streams will only be used during the iteration execution and will
     * not be returned to the caller of the iteration body. It is assumed that the method which executes the
     * iteration body will feed the records of the feedback variable streams back to the corresponding input variable
     * streams.
     */
    DataStreamList feedbackVariableStreams;

    /**
     * A list of output streams. These streams will be returned to the caller of the methods that execute the
     * iteration body.
     */
    DataStreamList outputStreams;

    /**
     * An optional termination criteria stream. If this stream is not null, it will be used together with the
     * feedback variable streams to determine when the iteration should terminate.
     */
    @Nullable Optional<DataStream<DataStream<?>>> terminationCriteria;
}


3) Add the IterationListener interface.

...

This interface allows users to achieve the following goals:
- Run an algorithm in sync mode, i.e. each subtask will wait for model parameters updates from all other subtasks before reading the aggregated model parameters and starting the next epoch of epoch of execution. As long as we could find a cut in the subgraph of the iteration body that all the operators in the cut only emit records in onEpochWatermarkIncremented,  the algorithm would be synchronous. The detailed proof could be found in the appendix. 
- Emit final output after the iteration terminates.

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration

/**
 * The callbacks defined below will be invoked only if the operator instance which implements this interface is used
 * within an iteration body.
 */
@PublicEvolving
public interface IterationListener<T> {
    /**
     * This callback is invoked every time the epoch watermark of this operator increments. The initial epoch watermark
     * of an operator is 0.
     *
     * The epochWatermark is the maximum integer that meets this requirement: every record that arrives at the operator
     * going forward should have an epoch larger than the epochWatermark. See Java docs in IterationUtilsIterations for how epoch
     * is determined for records ingested into the iteration body and for records emitted by operators within the
     * iteration body.
     *
     * If all inputs are bounded, the maximum epoch of all records ingested into this operator is used as the
     * epochWatermark parameter for the last invocation of this callback.
     *
     * @param epochWatermark The incremented epoch watermark.
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector);

    /**
     * This callback is invoked after the execution of the iteration body has terminated.
     *
     * See Java doc of methods in IterationUtilsIterations for the termination conditions.
     *
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onIterationTermination(Context context, Collector<T> collector);

    /**
     * Information available in an invocation of the callbacks defined in the IterationProgressListener.
     */
    interface Context {
        /**
         * Emits a record to the side output identified by the {@link OutputTag}.
         *
         * @param outputTag the {@code OutputTag} that identifies the side output to emit to.
         * @param value The record to emit.
         */
        <X> void output(OutputTag<X> outputTag, X value);
    }
}


4) Add the IterationUtils Iterations class.

This class provides APIs to execute an iteration body with the user-provided inputs. This class provides three APIs to run an iteration body, each with different input types (e.g. bounded data streams vs. unbounded data streams) and data replay semantics (i.e. whether to replay the user-provided data streams).

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration;


/**
 * A helper class to create applyiterations. {@linkTo IterationBody}construct toan data streams.
 */
@PublicEvolving
public class IterationUtils {
    /**
     * This method can use an iteration body to process records in unbounded data streams.
     *
     * This method invokes the iteration body with the following parameters:
     * 1) The 1stiteration, Users are required to provide
 *
 * <ul>
 *   <li>initVariableStreams: the initial values of the variable data streams which would be updated
 *       in each round.
 *   <li>dataStreams: the other data streams used inside the iteration, but would not be updated.
 *   <li>iterationBody: specifies the subgraph to update the variable streams and the outputs.
 * </ul>
 *
 * <p>The iteration body will be invoked with two parameters: The first parameter is a list of input
 * variable streams, which are created as the union of the initial variable
     * streams and the
 * corresponding feedback variable streams (returned by the iteration body).
     * 2); The 2ndsecond parameter is
 * the data streams given to this method.
 *
 * <p>During the execution of iteration  *
     * The epoch values are determined as described below. See IterationListener for how the epoch values are used.
     * 1) Allbody, each of the records involved in the iteration has an
 * epoch attached, which is mark the progress of the iteration. The epoch is computed as:
 *
 * <ul>
 *   <li>All records in the initial variable streams and initial data streams has epoch =1 0.
 *   <li>For *any 2)record Allemitted recordsby inthis theoperator datainto streamsa has epoch=MAX_LONG. In this case, records innon-feedback stream, the dataepoch stream won't affect
     * any operator's epoch watermark.
of this
 *      * 3) For any record emitted by this operator into a non-feedback stream, the epoch of this emitted record = the
     * epoch of the input record that triggers this emission. If this record
 is*   emitted by
   is emitted *by onEpochWatermarkIncremented(), then the epoch of this record = epochWatermark.

 *       epochWatermark.
 *  4) For<li>For any record emitted by this operator into a feedback variable stream, the epoch of the
 emitted* record =
     * min(emitted record = the epoch of the input record that triggers this emission, MAX_LONG - 1) + 1.
 If this record is emitted by
     * onEpochWatermarkIncremented(), then the epoch of this record = epochWatermark + 1.
     *
     * The execution of the graph created by * </ul>
 *
 * <p>The framework would given the notification at the end of each epoch for operators and UDFs
 * that implements {@link IterationListener}.
 *
 * <p>The limitation of constructing the subgraph inside the iteration body could willbe refer notin terminate{@link
 by* itselfIterationBody}.
 This*
 is* because<p>An atexample least
of the iteration is like:
 *
 one of its data streams is unbounded.
     *
     * Required:
     * 1) All the init variable streams must be bounded.
     * 2) There is at least one unbounded stream in the data streams list.
     * 3) The parallelism of any stream in the initial variable streams must equal the parallelism of the stream at the* <pre>{@code
 * DataStreamList result = Iterations.iterateUnboundedStreams(
 *  DataStreamList.of(first, second),
 *  DataStreamList.of(third),
 *  (variableStreams, dataStreams) -> {
 *      ...
 *      return new IterationBodyResult(
 *          DataStreamList.of(firstFeedback, secondFeedback),
 *          DataStreamList.of(output));
 *  }
 *  result.<Integer>get(0).addSink(...);
 * }</pre>
 */
public class Iterations {

    /**
     * sameThis indexmethod ofuses thean feedbackiteration variablebody streamsto returnedprocess byrecords thein IterationBody.
possibly unbounded data streams.  *The
     * @paramiteration would initVariableStreamsnot Theterminate initialif variableat streams.least Theseone streamsof willits beinputs mergedis withunbounded. theOtherwise feedbackit variablewill
     * terminated after all the inputs are terminated and no more records are iterating.
     *
     * @param initVariableStreams The initial variable streams, beforewhich beingis usedmerged aswith the 1st parameter to invoke the iteration body.feedback
     * @param dataStreams The data streams. Thesevariable streams willbefore bebeing used as the 2nd1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refered in           the {@code body}.
     * @param body The computation logic which takes variable/data streams and returns variable
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateUnboundedStreams(
            DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {...
        return null;
    }

    /**
     * This method can useuses an iteration body to process records in some bounded data streams
 iteratively until a
  * iteratively until *a termination criteria is reached (e.g. the given number of rounds is
     * completed or no further variable update is
     * needed). Because this method does not replay
     * records in the data streams, the iteration body needs to cache those
 records in order to
 * records in order to* visit those records repeatedly.
     *
     * This method invokes the iteration body@param initVariableStreams The initial variable streams, which is merged with the following parameters:feedback
     * 1) The 1st parameter is a list of input variable streams, whichbefore arebeing createdused as the 1st parameter unionto ofinvoke the initialiteration variablebody.
     * streams@param anddataStreams the corresponding feedback The non-variable streams (returned byalso refered in the iteration{@code body)}.
     * @param 2)config The 2ndconfig parameter isfor the dataiteration, streamslike givenwhether to this method. re-create the operator on each
     *
     *round.
 The epoch values are determined* as@param describedbody below.The Seecomputation IterationListenerlogic forwhich howtakes thevariable/data epochstreams values are used.and returns
     * 1) All records in the initial variable streams has epoch=1 feedback/output streams.
     * 2) All records in the data streams has epoch=1 @return The list of output streams returned by the iteration boy.
     */
 3) For any recordpublic emittedstatic byDataStreamList thisiterateBoundedStreamsUntilTermination(
 operator into a non-feedback stream, the epoch of this emitted record =DataStreamList theinitVariableStreams,
     * epoch of the input record that triggersReplayableDataStreamList thisdataStreams,
 emission. If this record is emitted by
     *IterationConfig onEpochWatermarkIncremented()config,
 then the epoch of this record = epochWatermark.
    IterationBody * 4body) For{
 any record emitted by this operator into areturn feedbacknull;
 variable stream, the epoch of the emitted record = the
     * epoch of the input record that triggers this emission + 1. If this record is emitted by}
}  

5) Introduce the forEachRound utility method.

forEachRound allows the users to specify a sub-graph that executes as per-round mode, namely all the operators would be re-created for each round. 

Code Block
languagejava
public interface IterationBody {
    
    ....


    /**
     * onEpochWatermarkIncremented(), then the epoch@param inputs The inputs of this record = epochWatermark + 1the subgraph.
     *
 @param perRoundSubBody The computational *logic Supposethat therewant isto abe coordinatorexecuted operator which takes all feedback variable streams (emitted by the iteration body)as per-round.
     * @return The output of the subgraph.
     */
 and the termination criteriastatic streamDataStreamList forEachRound(if
 not null) as inputs. The execution of theDataStreamList graphinputs, createdPerRoundSubBody byperRoundSubBody) the{
     * iteration body willreturn terminatenull;
 when all input streams}


 have been fully consumed/** ANDThe anysub-graph ofinside the followingiteration conditions
body that should be executed * is met:as per-round. */
    interface PerRoundSubBody {

        DataStreamList process(DataStreamList input);
    }
}


6) Add the DataStreamList and ReplayableDataStreamList class.

DataStreamList is a helper class that contains a list of data streams with possibly different elements types and ReplayableDataStreamList is a helper class that contains a list of data streams and whether they need replay for each round.

Code Block
languagejava
linenumberstrue
package org.apache.flink.iteration;

public class DataStreamList {
	public static DataStreamList of(DataStream<?>... streams);

    // Returns the number of data streams in this list* 1) The termination criteria stream is not null. And the coordinator operator has not observed any new value from
     * the termination criteria stream between two consecutive onEpochWatermarkIncremented invocations.
    public *int 2size() The coordinator operator has not observed any new value from any feedback variable stream between two
     * consecutive onEpochWatermarkIncremented invocations.
     *{...}

    // Returns the data stream at the given index in this list.
    public <T> DataStream<T> get(int index) {...}
}  

public class ReplayableDataStreamList {

    private * Required:final List<DataStream<?>> replayedDataStreams;

    private * 1) All the init variable streams and the data streams must be bounded.
final List<DataStream<?>> nonReplayedStreams;

    private ReplayableDataStreamList(
           * 2) The parallelism of any stream in the initial variable streams must equal the parallelism of the stream at the
     * same index of the feedback variable streams returned by the IterationBody.
     *
     * @param initVariableStreams The initial variable streams. These streams will be merged with the feedback variableList<DataStream<?>> replayedDataStreams, List<DataStream<?>> nonReplayedStreams) {
        this.replayedDataStreams = replayedDataStreams;
        this.nonReplayedStreams = nonReplayedStreams;
    }

    public static ReplayedDataStreamList replay(DataStream<?>... dataStreams) {
        return new ReplayedDataStreamList(Arrays.asList(dataStreams));
    }

    public static NonReplayedDataStreamList notReplay(DataStream<?>... dataStreams) {
     *   return new NonReplayedDataStreamList(Arrays.asList(dataStreams));
    }

    List<DataStream<?>> getReplayedDataStreams() {
        return Collections.unmodifiableList(replayedDataStreams);
    streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The data streams. These streams will be used as the 2nd parameter to invoke the iteration}

    List<DataStream<?>> getNonReplayedStreams() {
        return Collections.unmodifiableList(nonReplayedStreams);
    }

    private static class ReplayedDataStreamList extends ReplayableDataStreamList {

        public ReplayedDataStreamList(List<DataStream<?>> replayedDataStreams) {
     *       super(replayedDataStreams, Collections.emptyList());
        }

    body.
    public * @param body The computation logic which takes variable/data streams and returns variable/output streams.
     * @return The list of output streams returned by the iteration boy.
ReplayableDataStreamList andNotReplay(DataStream<?>... nonReplayedStreams) {
            return new ReplayableDataStreamList(
                 */
    static DataStreamList iterateBoundedStreamsUntilTermination(DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {...}
getReplayedDataStreams(), Arrays.asList(nonReplayedStreams));
        }
    /**}

    private *static Thisclass methodNonReplayedDataStreamList canextends useReplayableDataStreamList an{

 iteration body to process records in some bounded data streams iteratively until a
public NonReplayedDataStreamList(List<DataStream<?>> nonReplayedDataStreams) {
       * termination criteria is reached super(e.g. the given number of rounds is completed or no further variable update is
Collections.emptyList(), nonReplayedDataStreams);
        }

        public * needed). Because this method replays records in the data streams, the iteration body does not need to cache thoseReplayableDataStreamList andReplay(DataStream<?>... replayedStreams) {
            return new ReplayableDataStreamList(
     * records to visit those records repeatedly.
     *
     * This method invokes the iteration body with the following parameters:
     * 1) The 1st parameter is a list of input variable streams, which are created as the union of the initial Arrays.asList(replayedStreams), getNonReplayedStreams());
        }
    }
}

7) Introduce the IterationConfig

IterationConfig allows users to specify the config for each iteration. For now users could only specify the default operator lifecycle inside the iteration, but it ensures the forward compatibility if we have more options in the future. 

Code Block
languagejava
package org.apache.flink.iteration;

/** The config for an iteration. */
public class IterationConfig {

    private final OperatorLifeCycle operatorRoundMode;

    private IterationConfig(OperatorLifeCycle operatorRoundMode) { variable
     * streams and thethis.operatorRoundMode corresponding= feedbackoperatorRoundMode;
 variable streams (returned by}

 the iteration body).
 public static IterationConfigBuilder  * 2) The 2nd parameter is a list of replayed data streams, which are created by replaying the initial data streamsnewBuilder() {

        return new IterationConfigBuilder();
    }

    public static class IterationConfigBuilder {

     * round by roundprivate untilOperatorLifeCycle theoperatorRoundMode iteration= terminates. The records in the Nth round will be emitted into the iterationOperatorLifeCycle.ALL_ROUND;

        private IterationConfigBuilder() {}

     * body only ifpublic the low watermark of the first operator in the iteration body >= N - 1.
     *
IterationConfigBuilder setOperatorRoundMode(OperatorLifeCycle operatorRoundMode) {
            this.operatorRoundMode = operatorRoundMode;
        * The epoch values arereturn determinedthis;
 as described below. See IterationListener for how the}

 epoch values are used.
    public *IterationConfig 1build() All{
 records in the initial variable streams has epoch=1.
    return *new 2IterationConfig(operatorRoundMode);
 The records from the initial data streams will}
 be replayed round by}

 round into the iterationpublic body.enum TheOperatorLifeCycle records{
     * in the first round have epoch=1. And records in the Nth round have epoch = N.
     * 3) For any record emitted by this operator into a non-feedback stream, the epoch of this emitted record = the
     * epoch of the input record that triggers this emission. If this record is emitted by
     * onEpochWatermarkIncremented(), then the epoch of this record = epochWatermark.
     * 4) For any record emitted by this operator into a feedback stream, the epoch of the emitted record = the epoch
     * of the input record that triggers this emission + 1. If this record is emitted by onEpochWatermarkIncremented(),
     * then the epoch of this record = epochWatermark + 1.
     *
     * Suppose there is a coordinator operator which takes all feedback variable streams (emitted by the iteration body)
     * and the termination criteria stream (if not null) as inputs. The execution of the graph created by the
     * iteration body will terminate when all input streams have been fully consumed AND any of the following conditions
     * is met:
     * 1) The termination criteria stream is not null. And the coordinator operator has not observed any new value from
     * the termination criteria stream between two consecutive onEpochWatermarkIncremented invocations.
     * 2) The coordinator operator has not observed any new value from any feedback variable stream between two
     * consecutive onEpochWatermarkIncremented invocations.
     *
     * Required:
     * 1) All the init variable streams and the data streams must be bounded.
     * 2) The parallelism of any stream in the initial variable streams must equal the parallelism of the stream at the
     * same index of the feedback variable streams returned by the IterationBody.
     *
     * @param initVariableStreams The initial variable streams. These streams will be merged with the feedback variable
     *                            streams before being used as the 1st parameter to invoke the iteration body.
     * @param initDataStreams The initial data streams. Records from these streams will be repeatedly replayed and used
     *                        as the 2nd parameter to invoke the iteration body.
     * @param body The computation logic which takes variable/data streams and returns variable/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    static DataStreamList iterateAndReplayBoundedStreamsUntilTermination(DataStreamList initVariableStreams, DataStreamList initDataStreams, IterationBody body) {...}
}

5) Add the DataStreamList class.

DataStreamList is a helper class that contains a list of data streams with possibly different elements types.

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration;

public class DataStreamList {
	public static DataStreamList of(DataStream<?>... streams);

    // Returns the number of data streams in this list.
    public int size() {...}

    // Returns the data stream at the given index in this list.
    public <T> DataStream<T> get(int index) {...}
}

6) Deprecate the existing DataStream::iterate() and the DataStream::iterate(long maxWaitTimeMillis) methods.

We plan to remove both methods after the APIs added in this doc is ready for production use. This change is needed to decouple the iteration-related APIs from core Flink core runtime  so that we can keep the Flink core runtime as simple and maintainable as possible.

Example Usage

This sections shows how general used ML algorithms could be implemented with the iteration API. 

Offline Algorithms

We would like to first show the usage of the bounded iteration with the linear regression case: the model is Y = XA, and we would like to acquire the best estimation of A with the SGD algorithm. To simplify we assume the parameters could be held in the memory of one task.

The job graph of the algorithm could be shown in the Figure 3: in each round, we use the latest parameters to calculate the update to the parameters: ΔA = ∑(Y - XA)X. To achieve this, the Parameters vertex would broadcast the latest parameters to the Train vertex. Each subtask of the Train vertex holds a part of dataset. Follow the sprite of SGD, it would sample a small batch of training records, and calculate the update with the above equation. Then the Train vertex emit ΔA to the Parameters node to update the parameters.

...

 ALL_ROUND,

        PER_ROUND
    }
}


8) Deprecate the existing DataStream::iterate() and the DataStream::iterate(long maxWaitTimeMillis) methods.

We plan to remove both methods after the APIs added in this doc is ready for production use. This change is needed to decouple the iteration-related APIs from core Flink core runtime  so that we can keep the Flink core runtime as simple and maintainable as possible.

Example Usage

This sections shows how general used ML algorithms could be implemented with the iteration API. 

Offline Algorithms

We would like to first show the usage of the bounded iteration with the linear regression case: the model is Y = XA, and we would like to acquire the best estimation of A with the SGD algorithm. To simplify we assume the parameters could be held in the memory of one task.

The job graph of the algorithm could be shown in the Figure 3: in each round, we use the latest parameters to calculate the update to the parameters: ΔA = ∑(Y - XA)X. To achieve this, the Parameters vertex would broadcast the latest parameters to the Train vertex. Each subtask of the Train vertex holds a part of dataset. Follow the sprite of SGD, it would sample a small batch of training records, and calculate the update with the above equation. Then the Reducer would merge ΔA from all the train tasks and emit the reduced value to the Parameters node to update the parameters. The Reducer is required since we required the feedback streams have the same parallelism with the corresponding initialized streams. 


draw.io Diagram
bordertrue
diagramNamesync_lr
simpleViewerfalse
width
linksauto
tbstyletop
lboxtrue
diagramWidth651
revision5

Figure 2. The JobGraph for the offline training of the linear regression case.


We will start with the synchronous training. The synchronous training requires the updates from all the Train vertex subtask is merged before the next round of training. It could be done by only emit the next round of parameters on the end of round. The code is shown as follows:

Code Block
languagejava
linenumberstrue
public class SynchronousBoundedLinearRegression {
    private static final int N_DIM = 50;
    private static final int N_EPOCH = 5;
    private static final int N_BATCH_PER_EPOCH = 10;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = 
            Iterations.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), 
				ReplayableDataStreamList.notReplay(dataset), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {
                DataStream<double[]> parameterUpdates = variableStreams.get(0);
                	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
                	DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                   		.broadcast()
	                    .connect(dataset)
    	                .coProcess(new TrainFunction())
        	            .setParallelism(10)
						.process(new ReduceFunction())
						.setParallelism(1)
	
                	return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
            	});
        
        DataStream<double[]> finalModel = resultStreams.get("final_model");
        finalModel.print();
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
        implements IterationListener<double[]> {  
        
        private final double[] parameters = new double[N_DIM];

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update

Figure 2. The JobGraph for the offline training of the linear regression case.

We will start with the synchronous training. The synchronous training requires the updates from all the Train vertex subtask is merged before the next round of training. It could be done by only emit the next round of parameters on the end of round. The code is shown as follows:

Code Block
languagejava
linenumberstrue
public class SynchronousBoundedLinearRegression {
    private static final int N_DIM = 50;
    private static final int N_EPOCH = 5;
    private static final int N_BATCH_PER_EPOCH = 10;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

}

        void onEpochWatermarkIncremented(int epochWatermark, Context DataStreamListcontext, resultStreamsCollector<T> =collector) {
            if IterationUtils.iterateBoundedStreamsUntilTermination(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) ->(epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH) {
                DataStream<double[]> parameterUpdates = variableStreams.get(0collector.collect(parameters);
            }
    DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); }

        public        SingleOutputStreamOperator<doublevoid onIterationEnd(int[]> parametersround, = parameterUpdates.process(new ParametersCacheFunction());Context context) {
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
        }
    }

    public static class  .broadcast()
     TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> {

        private final List<Tuple2<double[], Double>> dataset = new .connectArrayList<>(dataset);
        private double[] firstRoundCachedParameter;

        private  .coProcess(new TrainFunction())Supplier<int[]> recordRoundQuerier;

        public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
        .setParallelism(10)

    this.recordRoundQuerier = querier;
        } 

 return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
     public void processElement1(double[] parameter, Context context, Collector<O> }output); {
        
        DataStream<doubleint[]> finalModelround = resultStreamsrecordRoundQuerier.get("final_model");
            finalModel.print();if (round[0] == 0) {
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
  firstRoundCachedParameter = parameter;
    implements IterationListener<double[]> {  
        return;
        private final double[] parameters = new double[N_DIM]; }

        public void processElement(double[] update, Context ctx, Collector<O> calculateModelUpdate(parameter, output) {;
        }

    // Suppose we have apublic util to add the second array to the first.void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            ArrayUtilsdataset.addWith(parameters, update);add(trainSample)
        }

        void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH== 0) {
                collector.collect(parameterscalculateModelUpdate(firstRoundCachedParameter, output);
                firstRoundCachedParameter = null;
            }
        }

        publicprivate void onIterationEndcalculateModelUpdate(intdouble[] roundparameters, ContextCollector<O> contextoutput) {
            context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
        }List<Tuple2<double[], Double>> samples = sample(dataset);

    }

    public static class TrainFunction extends CoProcessFunction<doubledouble[], Tuple2<double[], Double>, modelUpdate = new double[N_DIM]>;
 implements IterationListener<double[]> {

        private finalfor List<Tuple2<double(Tuple2<double[], Double>>Double> datasetrecord =: new ArrayList<>();
samples) {
         private double[] firstRoundCachedParameter;

     double diff  private Supplier<int[]> recordRoundQuerier;
= (ArrayUtils.muladd(record.f0, parameters) - record.f1);
        public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
            this.recordRoundQuerier = querier;
ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            } 

          public void processElement1(double[] parameter, Context context, Collector<O> output) {
  output.collect(modelUpdate);
        }
    }
	
	public static class ReduceFunction   int{
		private double[] roundmergedValue = recordRoundQuerierArrayUtils.getnewArray(N_DIM);

	 	public void processElement(double[] parameter, Context context,      if (round[0] == 0Collector<O> output) {
            mergedValue    firstRoundCachedParameter == ArrayUtils.add(mergedValue, parameter);
                return;}

	 	void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
    }

        collector.collect(mergedValue);
			mergedValue    calculateModelUpdate(parameter, output= ArrayUtils.newArray(N_DIM);
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {	}
}

Here we use a specialized reduce function that handles the boundary of the round. However, outside of the iteration we might already be able to do the reduce with windowAllI().reduce(), to reuse the operators outside of the iteration, we could use the forEachRound utility method. 

Code Block
languagejava
 			DataStreamList resultStreams = 
            dataset.add(trainSample)
        }

   Iterations.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), 
				ReplayableDataStreamList.notReplay(dataset), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {
     void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark == 0) {DataStream<double[]> parameterUpdates = variableStreams.get(0);
                calculateModelUpdate(firstRoundCachedParameter, output);	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                firstRoundCachedParameter	SingleOutputStreamOperator<double[]> parameters = null parameterUpdates.process(new ParametersCacheFunction());
            }
    	DataStream<double[]> modelUpdate   }
= parameters.setParallelism(1)
        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
       		.broadcast()
	     List<Tuple2<double[], Double>> samples = sample(dataset);

            double[] modelUpdate = new double[N_DIM];
.connect(dataset)
    	            for (Tuple2<double[], Double> record : samples) {
.coProcess(new TrainFunction())
        	         double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));setParallelism(10)
					
			     	DataStream<double[]> reduced = forEachRound(DataStreamList.of(modelUpdate), streams -> {
						return streams.<double[]>get(0).windowAll().reduce((x, y) -> ArrayUtils.add(x, y));
					}).<double[]>get(0);
	
            }

    	return new       output.collect(modelUpdateIterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
        }
    	}
});

This would be very helpful for complex scenarios that we could reuse the ability of the the current datastream and table operatorsAnother method to implement the synchronous train on the bounded dataset is to use iterateAndReplayBoundedStreamsUntilTermination



If instead we want to do asynchronous training, we would need to do the following change:

  1. The Parameters vertex would not wait till round end to ensure received all the updates from the iteration. Instead, it would immediately output the current parameters values once it received the model update from one train subtask.
  2. To label the source of the update, we would like to change the input type to be Tuple2<Integer, double[]>. The Parameters would only output the new parameters values to the Train task that send the update.
  3. The Reducer would not merge the received updates any more, instead, it would directly redirect the received updates. 

We omit the change to the graph building code since the change is trivial (change the output type and the partitioner to be customized one). The change to the Parameters vertex is the follows:

Code Block
languagejava
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>>
    implements BoundedIterationProgressListener<double[]> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        // Suppose we have a util to add the second array to the first.
        ArrayUtils.addWith(parameters, update);
        output.collect(new Tuple2<>(update.f0, parameters))
    }

    public void onIterationEnd(int[] round, Context context) {
        context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
    }
}

public class AsynchronousBoundedLinearRegression {
    
    ...

    DataStreamList resultStreams = 
        IterationUtilsIterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> {
            DataStream<double> parameterUpdates = variableStreams.get(0);
            DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

            SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
            DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                .partitionCustom((key, numPartitions) -> key % numPartitions, update -> update.f0)
                .connect(dataset)
                .coProcess(new TrainFunction())
                .setParallelism(10)

            return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
        });

    ...
}

...

Code Block
languagejava
public class SynchronousUnboundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = IterationUtilsIterations.iterateUnboundedStreams(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> {
            DataStream<double> parameterUpdates = variableStreams.get(0);
            DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

            SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10));
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .transform(
                                "operator",
                                BasicTypeInfo.INT_TYPE_INFO,
                                new TrainOperators(50));
                    .setParallelism(10);
            return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
        })
        
        DataStream<double[]> finalModel = resultStreams.get("model_update");
        finalModel.addSink(...)
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> {  
        
        private final int numOfTrainTasks;

        private final int numOfUpdatesReceived = 0;
        private final double[] parameters = new double[N_DIM];

        public ParametersCacheFunction(int numOfTrainTasks) {
            this.numOfTrainTasks = numOfTrainTasks;
        }

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
            numOfUpdatesReceived++;

            if (numOfUpdatesReceived == numOfTrainTasks) {
                output.collect(parameters);
                numOfUpdatesReceived = 0;
            }
        }
    }

    public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable {

        private final int miniBatchSize;

        private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        public TrainOperators(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
        }

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            calculateModelUpdate(parameter, output);
			miniBatchSize.clear();
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample);
        }

        public InputSelection nextSelection() {
            if (miniBatch.size() < miniBatchSize) {
                return InputSelection.SECOND;
            } else {
                return InputSelection.FIRST;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : miniBatchSize) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

...