Versions Compared


  • This line was added.
  • This line was removed.
  • Formatting was changed.


It is important to note that the users typically should not invoke the IterationBody::process directly because the model-variables expected by the iteration body is not the same as the initial-model-variables provided by the user. Instead, model-variables are computed as the union of the feedback-model-variables (emitted by the iteration body) and the initial-model-variables (provided by the caller of the iteration body). To relieve user from creating this union operator, we have added utility class (see IterationUtilsIterations) to run an iteration-body with the user-provided inputs.


Code Block
package org.apache.flink.iteration

 * The callbacks defined below will be invoked only if the operator instance which implements this interface is used
 * within an iteration body.
public interface IterationListener<T> {
     * This callback is invoked every time the epoch watermark of this operator increments. The initial epoch watermark
     * of an operator is 0.
     * The epochWatermark is the maximum integer that meets this requirement: every record that arrives at the operator
     * going forward should have an epoch larger than the epochWatermark. See Java docs in IterationUtilsIterations for how epoch
     * is determined for records ingested into the iteration body and for records emitted by operators within the
     * iteration body.
     * If all inputs are bounded, the maximum epoch of all records ingested into this operator is used as the
     * epochWatermark parameter for the last invocation of this callback.
     * @param epochWatermark The incremented epoch watermark.
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
    void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector);

     * This callback is invoked after the execution of the iteration body has terminated.
     * See Java doc of methods in IterationUtilsIterations for the termination conditions.
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
    void onIterationTermination(Context context, Collector<T> collector);

     * Information available in an invocation of the callbacks defined in the IterationProgressListener.
    interface Context {
         * Emits a record to the side output identified by the {@link OutputTag}.
         * @param outputTag the {@code OutputTag} that identifies the side output to emit to.
         * @param value The record to emit.
        <X> void output(OutputTag<X> outputTag, X value);


Code Block
public class SynchronousBoundedLinearRegression {
    private static final int N_DIM = 50;
    private static final int N_EPOCH = 5;
    private static final int N_BATCH_PER_EPOCH = 10;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = 
				(variableStreams, dataStreams) -> {
                DataStream<double[]> parameterUpdates = variableStreams.get(0);
                	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
                	DataStream<double[]> modelUpdate = parameters.setParallelism(1)
    	                .coProcess(new TrainFunction())
						.process(new ReduceFunction())
                	return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
        DataStream<double[]> finalModel = resultStreams.get("final_model");

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
        implements IterationListener<double[]> {  
        private final double[] parameters = new double[N_DIM];

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);

        void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH) {

        public void onIterationEnd(int[] round, Context context) {
            context.output(FINAL_MODEL_OUTPUT_TAG, parameters);

    public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> {

        private final List<Tuple2<double[], Double>> dataset = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        private Supplier<int[]> recordRoundQuerier;

        public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
            this.recordRoundQuerier = querier;

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            int[] round = recordRoundQuerier.get();
            if (round[0] == 0) {
                firstRoundCachedParameter = parameter;

            calculateModelUpdate(parameter, output);

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {

        void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark == 0) {
                calculateModelUpdate(firstRoundCachedParameter, output);
                firstRoundCachedParameter = null;

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            List<Tuple2<double[], Double>> samples = sample(dataset);

            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : samples) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));

	public static class ReduceFunction {
		private double[] mergedValue = ArrayUtils.newArray(N_DIM);

	 	public void processElement(double[] parameter, Context context, Collector<O> output) {
            mergedValue = ArrayUtils.add(mergedValue, parameter);

	 	void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
			mergedValue = ArrayUtils.newArray(N_DIM);


Code Block
 			DataStreamList resultStreams = 
				(variableStreams, dataStreams) -> {
                DataStream<double[]> parameterUpdates = variableStreams.get(0);
                	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
                	DataStream<double[]> modelUpdate = parameters.setParallelism(1)
    	                .coProcess(new TrainFunction())
			     	DataStream<double[]> reduced = PerRoundGraphBuilder.forEachRound(DataStreamList.of(modelUpdate), streams -> {
						return streams.<double[]>get(0).windowAll().reduce((x, y) -> ArrayUtils.add(x, y));
                	return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));


Code Block
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>>
    implements BoundedIterationProgressListener<double[]> {  
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        // Suppose we have a util to add the second array to the first.
        ArrayUtils.addWith(parameters, update);
        output.collect(new Tuple2<>(update.f0, parameters))

    public void onIterationEnd(int[] round, Context context) {
        context.output(FINAL_MODEL_OUTPUT_TAG, parameters);

public class AsynchronousBoundedLinearRegression {

    DataStreamList resultStreams = 
        IterationUtilsIterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> {
            DataStream<double> parameterUpdates = variableStreams.get(0);
            DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

            SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
            DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                .partitionCustom((key, numPartitions) -> key % numPartitions, update -> update.f0)
                .coProcess(new TrainFunction())

            return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));



Code Block
public class SynchronousUnboundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = IterationUtilsIterations.iterateUnboundedStreams(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> {
            DataStream<double> parameterUpdates = variableStreams.get(0);
            DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

            SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10));
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                                new TrainOperators(50));
            return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
        DataStream<double[]> finalModel = resultStreams.get("model_update");

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> {  
        private final int numOfTrainTasks;

        private final int numOfUpdatesReceived = 0;
        private final double[] parameters = new double[N_DIM];

        public ParametersCacheFunction(int numOfTrainTasks) {
            this.numOfTrainTasks = numOfTrainTasks;

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);

            if (numOfUpdatesReceived == numOfTrainTasks) {
                numOfUpdatesReceived = 0;

    public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable {

        private final int miniBatchSize;

        private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        public TrainOperators(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            calculateModelUpdate(parameter, output);

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {

        public InputSelection nextSelection() {
            if (miniBatch.size() < miniBatchSize) {
                return InputSelection.SECOND;
            } else {
                return InputSelection.FIRST;

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : miniBatchSize) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));

