Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.


...

Page properties

...


Discussion

...

thread

...

...

JIRA: <TODO>

This method is easier to use, but it also limit some possible optimizations

...

/9o56d7f094gdqc9mj28mwm9h4ffv02sx
Vote threadhttps://lists.apache.org/thread/899vt2momfqpn65zmx6cq74o3qn41yf1
JIRA

Jira
serverASF JIRA
serverId5aa69414-a9e9-3523-82ec-879b028fb15b
keyFLINK-24642

Releaseml-2.0


Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Table of Contents
-


[This FLIP proposal is a joint work between Yun Gao Dong Lin and Zhipeng Zhang]

...

TypeBounded / UnboundedData GranularitySynchronization Pattern Support in the existing APIsSupport in the proposed APIExamples
Non-SGD-basedBoundedEpochMostly SynchronousYesYesOffline K-Means, Apriori, Decision Tree, Random Walk

SGD-Based Synchronous Offline algorithm

BoundedBatch → Epoch*SynchronousYesYesLinear Regression, Logistic Regression, Deep Learning algorithms
SGD-Based Asynchronous Offline algorithmBoundedBatch → Epoch*AsynchronousNoYesSame to the above
SGD-Based Synchronous Online algorithmUnboundedBatchSynchronousNoYesYesOnline version of the above algorithm
SGD-Based Asynchronous Online algorithmUnboundedBatchAsynchronousYesNoYesOnline version of the above algorithm

...

It is important to note that the users typically should not invoke the IterationBody::process directly because the model-variables expected by the iteration body is not the same as the initial-model-variables provided by the user. Instead, model-variables are computed as the union of the feedback-model-variables (emitted by the iteration body) and the initial-model-variables (provided by the caller of the iteration body). To relieve user from creating this union operator, we have added utility class (see IterationUtilsIterations) to run an iteration-body with the user-provided inputs.

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.iteration

/**
 * The callbacks defined below will be invoked only if the operator instance which implements this interface is used
 * within an iteration body.
 */
@PublicEvolving
public interface IterationListener<T> {
    /**
     * This callback is invoked every time the epoch watermark of this operator increments. The initial epoch watermark
     * of an operator is 0.
     *
     * The epochWatermark is the maximum integer that meets this requirement: every record that arrives at the operator
     * going forward should have an epoch larger than the epochWatermark. See Java docs in IterationUtilsIterations for how epoch
     * is determined for records ingested into the iteration body and for records emitted by operators within the
     * iteration body.
     *
     * If all inputs are bounded, the maximum epoch of all records ingested into this operator is used as the
     * epochWatermark parameter for the last invocation of this callback.
     *
     * @param epochWatermark The incremented epoch watermark.
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector);

    /**
     * This callback is invoked after the execution of the iteration body has terminated.
     *
     * See Java doc of methods in IterationUtilsIterations for the termination conditions.
     *
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onIterationTermination(Context context, Collector<T> collector);

    /**
     * Information available in an invocation of the callbacks defined in the IterationProgressListener.
     */
    interface Context {
        /**
         * Emits a record to the side output identified by the {@link OutputTag}.
         *
         * @param outputTag the {@code OutputTag} that identifies the side output to emit to.
         * @param value The record to emit.
         */
        <X> void output(OutputTag<X> outputTag, X value);
    }
}

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.iteration;


/**
 * A helper class to create iterations. To construct an iteration, Users are required to provide
 *
 * <ul>
 *   <li>initVariableStreams: the initial values of the variable data streams which would be updated
 *       in each round.
 *   <li>dataStreams: the other data streams used inside the iteration, but would not be updated.
 *   <li>iterationBody: specifies the subgraph to update the variable streams and the outputs.
 * </ul>
 *
 * <p>The iteration body will be invoked with two parameters: The first parameter is a list of input
 * variable streams, which are created as the union of the initial variable streams and the
 * corresponding feedback variable streams (returned by the iteration body); The second parameter is
 * the data streams given to this method.
 *
 * <p>During the execution of iteration body, each of the records involved in the iteration has an
 * epoch attached, which is mark the progress of the iteration. The epoch is computed as:
 *
 * <ul>
 *   <li>All records in the initial variable streams and initial data streams has epoch = 0.
 *   <li>For any record emitted by this operator into a non-feedback stream, the epoch of this
 *       emitted record = the epoch of the input record that triggers this emission. If this record
 *       is emitted by onEpochWatermarkIncremented(), then the epoch of this record =
 *       epochWatermark.
 *   <li>For any record emitted by this operator into a feedback variable stream, the epoch of the
 *       emitted record = the epoch of the input record that triggers this emission + 1.
 * </ul>
 *
 * <p>The framework would given the notification at the end of each epoch for operators and UDFs
 * that implements {@link IterationListener}.
 *
 * <p>The limitation of constructing the subgraph inside the iteration body could be refer in {@link
 * IterationBody}.
 *
 * <p>An example of the iteration is like:
 *
 * <pre>{@code
 * DataStreamList result = Iterations.iterateUnboundedStreams(
 *  DataStreamList.of(first, second),
 *  DataStreamList.of(third),
 *  (variableStreams, dataStreams) -> {
 *      ...
 *      return new IterationBodyResult(
 *          DataStreamList.of(firstFeedback, secondFeedback),
 *          DataStreamList.of(output));
 *  }
 *  result.<Integer>get(0).addSink(...);
 * }</pre>
 */
public class Iterations {

    /**
     * This method uses an iteration body to process records in possibly unbounded data streams. The
     * iteration would not terminate if at least one of its inputs is unbounded. Otherwise it will
     * terminated after all the inputs are terminated and no more records are iterating.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refered in the {@code body}.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateUnboundedStreams(
            DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {
        return null;
    }

    /**
     * This method uses an iteration body to process records in some bounded data streams
     * iteratively until a termination criteria is reached (e.g. the given number of rounds is
     * completed or no further variable update is needed). Because this method does not replay
     * records in the data streams, the iteration body needs to cache those records in order to
     * visit those records repeatedly.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refered in the {@code body}.
     * @param config The config for the iteration, like whether to re-create the operator on each
     *     round.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateBoundedStreamsUntilTermination(
            DataStreamList initVariableStreams,
            ReplayableDataStreamList dataStreams,
            IterationConfig config,
            IterationBody body) {
        return null;
    }
}  

5) Introduce the PerRoundSubGraphBuilder forEachRound utility method.

PerRoundSubgraphBuilder forEachRound allows the users to specify a sub-graph that executes as per-round mode, namely all the operators would be re-created for each round. 

Code Block
languagejava
package org.apache.flink.iteration;

import org.apache.flink.annotation.Experimental;

/** Constructs a subgraph inside the iteration body to execute as per-round */
@Experimental
public class PerRoundSubGraphBuilder {

    /** The sub-graph inside the iteration body that shouldpublic interface IterationBody {
    
    ....


    /**
     * @param inputs The inputs of the subgraph.
     * @param perRoundSubBody The computational logic that want to be executed as per-round.
 */
    public* interface@return PerRoundSubGraphThe {

output of the subgraph.
     DataStreamList process(DataStreamList input);
*/
    static DataStreamList }forEachRound(

     public static DataStreamList forEachRound(DataStreamList inputs, PerRoundSubGraphPerRoundSubBody subGraphperRoundSubBody) {
        return null;
    }
}

    /** The sub-graph inside the iteration body that should be executed as per-round. */
    interface PerRoundSubBody {

        DataStreamList process(DataStreamList input);
    }
}


6) 6) Add the DataStreamList and ReplayableDataStreamList class.

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.iteration;

public class DataStreamList {
	public static DataStreamList of(DataStream<?>... streams);

    // Returns the number of data streams in this list.
    public int size() {...}

    // Returns the data stream at the given index in this list.
    public <T> DataStream<T> get(int index) {...}
}  

public class ReplayableDataStreamList {

    publicprivate staticfinal ReplayableDataStreamList of(List<DataStream<?>> replayedDataStreams;

    private final List<DataStream<?>> nonReplayedStreams;

     Tuple2<DataStream<?>, Boolean>... dataStreamAndIfNeedReplays);

private ReplayableDataStreamList(
        public static Tuple2<DataStream<?>, Boolean> replay(DataStream<?> dataStream);

  List<DataStream<?>> replayedDataStreams, List<DataStream<?>> nonReplayedStreams) {
     public static Tuple2<DataStream<?>, Boolean> noReplay(DataStream<?> dataStream);
 this.replayedDataStreams = replayedDataStreams;
    /** Returns the number ofthis.nonReplayedStreams data= streamsnonReplayedStreams;
 in this list. */}

    public static intReplayedDataStreamList size();

replay(DataStream<?>... dataStreams) {
     /** Returns the datareturn stream at the given index in this list. */
    @SuppressWarnings("unchecked")
    public <T> DataStream<T> get(int index);

    public boolean shouldReplay(int index);
}

7) Introduce the IterationConfig

IterationConfig allows users to specify the config for each iteration. For now users could only specify the default operator lifecycle inside the iteration, but it ensures the forward compatibility if we have more options in the future. 

Code Block
languagejava
package org.apache.flink.iteration;

/** The config for an iteration. */
public class IterationConfig {

    private final OperatorLifeCycle operatorRoundMode;

    public IterationConfig(OperatorLifeCycle operatorRoundModenew ReplayedDataStreamList(Arrays.asList(dataStreams));
    }

    public static NonReplayedDataStreamList notReplay(DataStream<?>... dataStreams) {
        return new NonReplayedDataStreamList(Arrays.asList(dataStreams));
    }

    List<DataStream<?>> getReplayedDataStreams() {
        return Collections.unmodifiableList(replayedDataStreams);
    }

    List<DataStream<?>> getNonReplayedStreams() {
        return Collections.unmodifiableList(nonReplayedStreams);
    }

    private static class ReplayedDataStreamList extends ReplayableDataStreamList {

        public ReplayedDataStreamList(List<DataStream<?>> replayedDataStreams) {
        this.operatorRoundMode = operatorRoundMode;
    }

super(replayedDataStreams, Collections.emptyList());
    public static IterationConfigBuilder newBuilder() {}

        returnpublic newReplayableDataStreamList IterationConfigBuilder();
andNotReplay(DataStream<?>... nonReplayedStreams) {
        }

    return new ReplayableDataStreamList(
     public static class IterationConfigBuilder {

        private OperatorLifeCycle operatorRoundMode = OperatorLifeCycle.ALL_ROUND;

getReplayedDataStreams(), Arrays.asList(nonReplayedStreams));
        }
  private IterationConfigBuilder() {}

    private static class NonReplayedDataStreamList public IterationConfigBuilder setOperatorRoundMode(OperatorLifeCycle operatorRoundMode)extends ReplayableDataStreamList {

        public NonReplayedDataStreamList(List<DataStream<?>>   this.operatorRoundMode = operatorRoundMode;nonReplayedDataStreams) {
            return thissuper(Collections.emptyList(), nonReplayedDataStreams);
        }

        public IterationConfigReplayableDataStreamList build(andReplay(DataStream<?>... replayedStreams) {
            return new IterationConfigReplayableDataStreamList(operatorRoundMode);
        }
    }

    public enum OperatorLifeCycle {
        ALL_ROUND,
 Arrays.asList(replayedStreams), getNonReplayedStreams());
        PER_ROUND}
    }
}

87) Deprecate the existing DataStream::iterate() and the DataStream::iterate(long maxWaitTimeMillis) methods.

We plan to remove both methods after the APIs added in this doc is ready for production use. This change is needed to decouple the iteration-related APIs from core Flink core runtime  so that we can keep the Flink core runtime as simple and maintainable as possible.

Example Usage

This sections shows how general used ML algorithms could be implemented with the iteration API. 

Offline Algorithms

We would like to first show the usage of the bounded iteration with the linear regression case: the model is Y = XA, and we would like to acquire the best estimation of A with the SGD algorithm. To simplify we assume the parameters could be held in the memory of one task.

The job graph of the algorithm could be shown in the Figure 3: in each round, we use the latest parameters to calculate the update to the parameters: ΔA = ∑(Y - XA)X. To achieve this, the Parameters vertex would broadcast the latest parameters to the Train vertex. Each subtask of the Train vertex holds a part of dataset. Follow the sprite of SGD, it would sample a small batch of training records, and calculate the update with the above equation. Then the Reducer would merge ΔA from all the train tasks and emit the reduced value to the Parameters node to update the parameters. The Reducer is required since we required the feedback streams have the same parallelism with the corresponding initialized streams. 

...

Introduce the IterationConfig

IterationConfig allows users to specify the config for each iteration. For now users could only specify the default operator lifecycle inside the iteration, but it ensures the forward compatibility if we have more options in the future. 

Code Block
languagejava
package org.apache.flink.iteration;

/** The config for an iteration. */
public class IterationConfig {

    private final OperatorLifeCycle operatorRoundMode;

    private IterationConfig(OperatorLifeCycle operatorRoundMode) {
        this.operatorRoundMode = operatorRoundMode;
    }

    public static IterationConfigBuilder newBuilder() {

        return new IterationConfigBuilder();
    }

    public static class IterationConfigBuilder {

        private OperatorLifeCycle operatorRoundMode = OperatorLifeCycle.ALL_ROUND;

        private IterationConfigBuilder() {}

        public IterationConfigBuilder setOperatorRoundMode(OperatorLifeCycle operatorRoundMode) {
            this.operatorRoundMode = operatorRoundMode;
            return this;
        }

        public IterationConfig build() {
            return new IterationConfig(operatorRoundMode);
        }
    }

    public enum OperatorLifeCycle {
        ALL_ROUND,

        PER_ROUND
    }
}


8) Deprecate the existing DataStream::iterate() and the DataStream::iterate(long maxWaitTimeMillis) methods.

We plan to remove both methods after the APIs added in this doc is ready for production use. This change is needed to decouple the iteration-related APIs from core Flink core runtime  so that we can keep the Flink core runtime as simple and maintainable as possible.

Example Usage

This sections shows how general used ML algorithms could be implemented with the iteration API. 

Offline Algorithms

We would like to first show the usage of the bounded iteration with the linear regression case: the model is Y = XA, and we would like to acquire the best estimation of A with the SGD algorithm. To simplify we assume the parameters could be held in the memory of one task.

The job graph of the algorithm could be shown in the Figure 3: in each round, we use the latest parameters to calculate the update to the parameters: ΔA = ∑(Y - XA)X. To achieve this, the Parameters vertex would broadcast the latest parameters to the Train vertex. Each subtask of the Train vertex holds a part of dataset. Follow the sprite of SGD, it would sample a small batch of training records, and calculate the update with the above equation. Then the Reducer would merge ΔA from all the train tasks and emit the reduced value to the Parameters node to update the parameters. The Reducer is required since we required the feedback streams have the same parallelism with the corresponding initialized streams. 


draw.io Diagram
bordertrue
diagramNamesync_lr
simpleViewerfalse
width
linksauto
tbstyletop
lboxtrue
diagramWidth651
revision5

Figure 2. The JobGraph for the offline training of the linear regression case.


We will start with the synchronous training. The synchronous training requires the updates from all the Train vertex subtask is merged before the next round of training. It could be done by only emit the next round of parameters on the end of round. The code is shown as follows:

Code Block
languagejava
linenumberstrue
public class SynchronousBoundedLinearRegression {
    private static final int N_DIM = 50;
    private static final int N_EPOCH = 5;
    private static final int N_BATCH_PER_EPOCH = 10;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = 
            Iterations.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), 
				ReplayableDataStreamList.notReplay(dataset), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {
                DataStream<double[]> parameterUpdates = variableStreams.get(0);
                	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
                	DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                   		.broadcast()
	                    .connect(dataset)
    	                .coProcess(new TrainFunction())
        	            .setParallelism(10)
						.process(new ReduceFunction())
						.setParallelism(1)
	
                	return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
            	});
        

...

Figure 2. The JobGraph for the offline training of the linear regression case.

We will start with the synchronous training. The synchronous training requires the updates from all the Train vertex subtask is merged before the next round of training. It could be done by only emit the next round of parameters on the end of round. The code is shown as follows:

Code Block
languagejava
linenumberstrue
public class SynchronousBoundedLinearRegression {
    private static final int N_DIM = 50;
    private static final int N_EPOCH = 5;
    private static final int N_BATCH_PER_EPOCH = 10;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParametersfinalModel = loadParameters()resultStreams.setParallelism(1get("final_model");
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);
finalModel.print();
        DataStreamList resultStreams = }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
  IterationUtils.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), 
				ReplayableDataStreamList.of(noReplay(dataset)), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {     implements IterationListener<double[]> {  
        
        DataStream<doubleprivate final double[]> parameterUpdatesparameters = variableStreams.get(0); new double[N_DIM];

        public        	DataStream<Tuple2<doublevoid processElement(double[] update, Double>>Context datasetctx, = dataStreams.get(0);
Collector<O> output) {
            // Suppose we have 	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
      a util to add the second array to the first.
          	DataStream<double[]> modelUpdate = parameters.setParallelism(1) ArrayUtils.addWith(parameters, update);
        }

           		.broadcast()
	      void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
             if .connect(dataset)
epochWatermark < N_EPOCH  	* N_BATCH_PER_EPOCH) {
                collector.coProcess(new TrainFunction())collect(parameters);
        	    }
        .setParallelism(10)
						.process(new ReduceFunction())
						.setParallelism(1)
	
      }

        public void onIterationEnd(int[] round, Context context) {
          	return new IterationBodyResult(DataStreamListcontext.of(modelUpdate), DataStreamList.of(parameters.getSideOut(output(FINAL_MODEL_OUTPUT_TAG), parameters));
        }
    	});

    public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> {

        private final DataStream<doubleList<Tuple2<double[]>, Double>> finalModeldataset = new resultStreams.get("final_model"ArrayList<>();
        finalModel.print();
    }private double[] firstRoundCachedParameter;

    public static class ParametersCacheFunction extends ProcessFunction<double[], doubleprivate Supplier<int[]> recordRoundQuerier;

        public implementsvoid IterationListener<doublesetCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
  
        
  this.recordRoundQuerier = querier;
    private final double[] parameters = new double[N_DIM];} 

        public void processElementprocessElement1(double[] updateparameter, Context ctxcontext, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
 int[] round = recordRoundQuerier.get();
       }

     if   void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector(round[0] == 0) {
            if (epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH) {firstRoundCachedParameter = parameter;
                collector.collect(parameters)return;
            }

            calculateModelUpdate(parameter, output);
        }

        public void onIterationEndprocessElement2(intTuple2<double[], Double> roundtrainSample, Context context, Collector<O> output) {
            contextdataset.output(FINAL_MODEL_OUTPUT_TAG, parameters);add(trainSample)
        }

    }

    publicvoid static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> {

onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
         private final List<Tuple2<double[], Double>>if dataset(epochWatermark == new ArrayList<>();0) {
        private double[] firstRoundCachedParameter;

      calculateModelUpdate(firstRoundCachedParameter, output);
 private Supplier<int[]> recordRoundQuerier;

        public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
 firstRoundCachedParameter = null;
         this.recordRoundQuerier  = querier;}
        } 

        publicprivate void processElement1calculateModelUpdate(double[] parameterparameters, Context context, Collector<O> output) {
            intList<Tuple2<double[], Double>> samples = sample(dataset);

            double[] roundmodelUpdate = recordRoundQuerier.get()new double[N_DIM];
            iffor (roundTuple2<double[0] == 0, Double> record : samples) {
                double firstRoundCachedParameterdiff = parameter (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                returnArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            calculateModelUpdate(parameter, outputoutput.collect(modelUpdate);
        }
    }
	
	public static class ReduceFunction {
		private double[] mergedValue = ArrayUtils.newArray(N_DIM);

	 	public void processElement2processElement(Tuple2<doubledouble[], Double> trainSampleparameter, Context context, Collector<O> output) {
            mergedValue = datasetArrayUtils.add(trainSamplemergedValue, parameter);
        }

        	 	void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark == 0) {collector.collect(mergedValue);
			mergedValue = ArrayUtils.newArray(N_DIM);
                calculateModelUpdate(firstRoundCachedParameter, output);
    }
	}
}

Here we use a specialized reduce function that handles the boundary of the round. However, outside of the iteration we might already be able to do the reduce with windowAllI().reduce(), to reuse the operators outside of the iteration, we could use the forEachRound utility method. 

Code Block
languagejava
 			DataStreamList resultStreams = 
            firstRoundCachedParameter = null;
   Iterations.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), 
				ReplayableDataStreamList.notReplay(dataset), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {
         }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
DataStream<double[]> parameterUpdates = variableStreams.get(0);
                List<Tuple2<double	DataStream<Tuple2<double[], Double>> samplesdataset = sample(dataset dataStreams.get(0);

                double	SingleOutputStreamOperator<double[]> modelUpdateparameters = parameterUpdates.process(new double[N_DIM]ParametersCacheFunction());
               for (Tuple2<double	DataStream<double[],> Double>modelUpdate record : samples) {= parameters.setParallelism(1)
                double diff = (ArrayUtils		.muladd(record.f0, parameters) - record.f1);
broadcast()
	                   ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
connect(dataset)
    	            }

    .coProcess(new TrainFunction())
        output.collect(modelUpdate);
	        }
    }
.setParallelism(10)
					
			public static  class ReduceFunction  {
		private double	DataStream<double[]> mergedValuereduced = ArrayUtilsforEachRound(DataStreamList.newArrayof(N_DIM);

	 	public void processElement(double[] parameter, Context context, Collector<O> output) {
            mergedValue = ArrayUtils.add(mergedValue, parameter);
        }

	 	void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
modelUpdate), streams -> {
						return streams.<double[]>get(0).windowAll().reduce((x, y) -> ArrayUtils.add(x, y));
					}).<double[]>get(0);
	
                	return new collectorIterationBodyResult(DataStreamList.collectof(mergedValue);
			mergedValue = ArrayUtils.newArray(N_DIM);
modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
            }
	}
});

This would be very helpful for complex scenarios that we could reuse the ability of the the current datastream and table operators. Here we use a specialized reduce function that 



If instead we want to do asynchronous training, we would need to do the following change:

...

Code Block
languagejava
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>>
    implements BoundedIterationProgressListener<double[]> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        // Suppose we have a util to add the second array to the first.
        ArrayUtils.addWith(parameters, update);
        output.collect(new Tuple2<>(update.f0, parameters))
    }

    public void onIterationEnd(int[] round, Context context) {
        context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
    }
}

public class AsynchronousBoundedLinearRegression {
    
    ...

    DataStreamList resultStreams = 
        IterationUtilsIterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> {
            DataStream<double> parameterUpdates = variableStreams.get(0);
            DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

            SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
            DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                .partitionCustom((key, numPartitions) -> key % numPartitions, update -> update.f0)
                .connect(dataset)
                .coProcess(new TrainFunction())
                .setParallelism(10)

            return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
        });

    ...
}

...

Code Block
languagejava
public class SynchronousUnboundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = IterationUtilsIterations.iterateUnboundedStreams(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> {
            DataStream<double> parameterUpdates = variableStreams.get(0);
            DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

            SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10));
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .transform(
                                "operator",
                                BasicTypeInfo.INT_TYPE_INFO,
                                new TrainOperators(50));
                    .setParallelism(10);
            return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
        })
        
        DataStream<double[]> finalModel = resultStreams.get("model_update");
        finalModel.addSink(...)
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> {  
        
        private final int numOfTrainTasks;

        private final int numOfUpdatesReceived = 0;
        private final double[] parameters = new double[N_DIM];

        public ParametersCacheFunction(int numOfTrainTasks) {
            this.numOfTrainTasks = numOfTrainTasks;
        }

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
            numOfUpdatesReceived++;

            if (numOfUpdatesReceived == numOfTrainTasks) {
                output.collect(parameters);
                numOfUpdatesReceived = 0;
            }
        }
    }

    public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable {

        private final int miniBatchSize;

        private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        public TrainOperators(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
        }

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            calculateModelUpdate(parameter, output);
			miniBatchSize.clear();
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample);
        }

        public InputSelection nextSelection() {
            if (miniBatch.size() < miniBatchSize) {
                return InputSelection.SECOND;
            } else {
                return InputSelection.FIRST;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : miniBatchSize) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

...