Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration;

@PublicEvolving
public interface IterationBody {
    /**
     * This method creates the graph for the iteration body.
     *
     * See Utils::iterate, Utils::iterateBoundedStreams and Utils::iterateAndReplayBoundedStreams for how the iteration
     * body can be executed and when execution of the corresponding graph should terminate.
     *
     * Required: the number of feedback variable streams returned by this method must equal the number of variable
     * streams given to this method.
     *
     * @param variableStreams the variable streams.
     * @param dataStreams the data streams.
     * @return a IterationBodyResult.
     */
    IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams);
}

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration;

/**
 * A helper class that contains the streams returned by the iteration body.
 */
class IterationBodyResult {
    /**
     * A list of feedback variable streams. These streams will only be used during the iteration execution and will
     * not be returned to the caller of the iteration body. It is assumed that the method which executes the
     * iteration body will feed the records of the feedback variable streams back to the corresponding input variable
     * streams.
     */
    DataStreamList feedbackVariableStreams;

    /**
     * A list of output streams. These streams will be returned to the caller of the methods that execute the
     * iteration body.
     */
    DataStreamList outputStreams;

    /**
     * An optional termination criteria stream. If this stream is not null, it will be used together with the
     * feedback variable streams to determine when the iteration should terminate.
     */
    @Nullable DataStream<?> terminationCriteria;
}

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration

/**
 * The callbacks defined below will be invoked only if the operator instance which implements this interface is used
 * within an iteration body.
 */
@PublicEvolving
public interface IterationListener<T> {
    /**
     * This callback is invoked every time the epoch watermark of this operator increments. The initial epoch watermark
     * of an operator is 0.
     *
     * The epochWatermark is the maximum integer that meets this requirement: every record that arrives at the operator
     * going forward should have an epoch larger than the epochWatermark. See Java docs in IterationUtils for how epoch
     * is determined for records ingested into the iteration body and for records emitted by operators within the
     * iteration body.
     *
     * If all inputs are bounded, the maximum epoch of all records ingested into this operator is used as the
     * epochWatermark parameter for the last invocation of this callback.
     *
     * @param epochWatermark The incremented epoch watermark.
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector);

    /**
     * This callback is invoked after the execution of the iteration body has terminated.
     *
     * See Java doc of methods in IterationUtils for the termination conditions.
     *
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onIterationTermination(Context context, Collector<T> collector);

    /**
     * Information available in an invocation of the callbacks defined in the IterationProgressListener.
     */
    interface Context {
        /**
         * Emits a record to the side output identified by the {@link OutputTag}.
         *
         * @param outputTag the {@code OutputTag} that identifies the side output to emit to.
         * @param value The record to emit.
         */
        <X> void output(OutputTag<X> outputTag, X value);
    }
}


4) Add the IterationUtils Iterations class.

This class provides APIs to execute an iteration body with the user-provided inputs. This class provides three APIs to run an iteration body, each with different input types (e.g. bounded data streams vs. unbounded data streams) and data replay semantics (i.e. whether to replay the user-provided data streams).

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration;

import org.apache.flink.annotation.Experimental;
import org.apache.flink.iteration.operator.allround.AllRoundOperatorWrapper;

/**
 * A helper class to create iterations. To construct an iteration, Users are required to provide
 *
 * <ul>
 *   <li>initVariableStreams: the initial values of the variable data streams which would be updated
 *       in each round.
 *   <li>dataStreams: the other data streams used inside the iteration, but would not be updated.
 *   <li>iterationBody: specifies the subgraph to update the variable streams and the outputs.
 * </ul>
 *
 * <p>The iteration body will be invoked with two parameters: The first parameter is a list of input
 * variable streams, which are created as the union of the initial variable streams and the
 * corresponding feedback variable streams (returned by the iteration body); The second parameter is
 * the data streams given to this method.
 *
 * <p>During the execution of iteration body, each of the records involved in the iteration has an
 * epoch attached, which is mark the progress of the iteration. The epoch is computed as:
 *
 * <ul>
 *   <li>All records in the initial variable streams and initial data streams has epoch = 0.
 *   <li>For any record emitted by this operator into a non-feedback stream, the epoch of this
 *       emitted record = the epoch of the input record that triggers this emission. If this record
 *       is emitted by onEpochWatermarkIncremented(), then the epoch of this record =
 *       epochWatermark.
 *   <li>For any record emitted by this operator into a feedback variable stream, the epoch of the
 *       emitted record = the epoch of the input record that triggers this emission + 1.
 * </ul>
 *
 * <p>The framework would given the notification at the end of each epoch for operators and UDFs
 * that implements {@link IterationListener}.
 *
 * <p>The limitation of constructing the subgraph inside the iteration body could be refer in {@link
 * IterationBody}.
 *
 * <p>An example of the iteration is like:
 *
 * <pre>{@code
 * DataStreamList result = Iterations.iterateUnboundedStreams(
 *  DataStreamList.of(first, second),
 *  DataStreamList.of(third),
 *  (variableStreams, dataStreams) -> {
 *      ...
 *      return new IterationBodyResult(
 *          DataStreamList.of(firstFeedback, secondFeedback),
 *          DataStreamList.of(output));
 *  }
 *  result.<Integer>get(0).addSink(...);
 * }</pre>
 */
@Experimental
public class Iterations {

    /**
     * This method uses an iteration body to process records in possibly unbounded data streams. The
     * iteration would not terminate if at least one of its inputs is unbounded. Otherwise it will
     * terminated after all the inputs are terminated and no more records are iterating.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refered in the {@code body}.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateUnboundedStreams(
            DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {
        return null;
    }

    /**
     * This method uses an iteration body to process records in some bounded data streams
     * iteratively until a termination criteria is reached (e.g. the given number of rounds is
     * completed or no further variable update is needed). Because this method does not replay
     * records in the data streams, the iteration body needs to cache those records in order to
     * visit those records repeatedly.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refered in the {@code body}.
     * @param config The config for the iteration, like whether to re-create the operator on each
     *     round.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateBoundedStreamsUntilTermination(
            DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {}

    /**
     * This method can use an iteration body to process records in some bounded data streamsReplayableDataStreamList dataStreams,
     * iteratively until a termination criteria is reached (e.g. the given number of rounds isIterationConfig config,
     * completed or no further variable update isIterationBody neededbody). Because{
 this method replays records in
   return null;
 * the data streams, the iteration body does not need to cache those records to visit those
     * records repeatedly. This corresponds to the bulk iteration in the dataset. 
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refered in the {@code body}.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */}
}  

5) Introduce the PerRoundSubGraphBuilder.

PerRoundSubgraphBuilder allows the users to specify a sub-graph that executes as per-round mode, namely all the operators would be re-created for each round. 

Code Block
languagejava
package org.apache.flink.iteration;

import org.apache.flink.annotation.Experimental;

/** Constructs a subgraph inside the iteration body to execute as per-round */
@Experimental
public class PerRoundSubGraphBuilder {

    /** The sub-graph inside the iteration body that should be executed as per-round. */
    public interface PerRoundSubGraph {

        DataStreamList process(DataStreamList input);
    }

    public static DataStreamList forEachRound(DataStreamList inputs, PerRoundSubGraph subGraph) {
        return null;
    }
}


6) Add the DataStreamList and ReplayableDataStreamList class.

DataStreamList is a helper class that contains a list of data streams with possibly different elements types and ReplayableDataStreamList is a helper class that contains a list of data streams and whether they need replay for each round.

Code Block
languagejava
linenumberstrue
package org.apache.flink.iteration;

public class DataStreamList {
	public static DataStreamList of(DataStream<?>... streams);

    // Returns the number of data streams in this list.
    public int size() {...}

    // Returns the data stream at the given index in this list.
    public <T> DataStream<T> get(int index) {...}
}

public class ReplayableDataStreamList {

    public static ReplayableDataStreamList of(
            Tuple2<DataStream<?>, Boolean>... dataStreamAndIfNeedReplays);

    public static Tuple2<DataStream<?>, Boolean> replay(DataStream<?> dataStream);

    public static Tuple2<DataStream<?>, Boolean> noReplay(DataStream<?> dataStream);

    /** Returns the number of data streams in this list. */
    public int size();

    /** Returns the data stream at the given index in this list. */
    @SuppressWarnings("unchecked")
    public <T> DataStream<T> get(int index);

    public boolean shouldReplay(int index);
}

7) Introduce the IterationConfig

IterationConfig allows users to specify the config for each iteration. For now users could only specify the default operator lifecycle inside the iteration, but it ensures the forward compatibility if we have more options in the future. 

Code Block
languagejava
package org.apache.flink.iteration;

/** The config for an iteration. */
public class IterationConfig {

    private final OperatorLifeCycle operatorRoundMode;

    public IterationConfig(OperatorLifeCycle operatorRoundMode) {
        this.operatorRoundMode = operatorRoundMode;
    }

    public static IterationConfigBuilder newBuilder() {

        return new IterationConfigBuilder();
    }

    public static class DataStreamListIterationConfigBuilder iterateAndReplayBoundedStreamsUntilTermination(
{

        private OperatorLifeCycle operatorRoundMode = OperatorLifeCycle.ALL_ROUND;

    DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBodyprivate bodyIterationConfigBuilder() {}
}

5) Add the DataStreamList class.

DataStreamList is a helper class that contains a list of data streams with possibly different elements types.

Code Block
languagejava
linenumberstrue
package org.apache.flink.ml.iteration;

public class DataStreamList {
	public static DataStreamList of(DataStream<?>... streams);

    // Returns the number of data streams in this list.
}

        public IterationConfigBuilder setOperatorRoundMode(OperatorLifeCycle operatorRoundMode) {
            this.operatorRoundMode = operatorRoundMode;
            return this;
        }

        public intIterationConfig sizebuild() {...
            return new IterationConfig(operatorRoundMode);
        }
    }

    //public Returnsenum theOperatorLifeCycle data{
 stream at the given index in this list.
 ALL_ROUND,

     public <T> DataStream<T> get(int index) {...   PER_ROUND
    }
}


68) Deprecate the existing DataStream::iterate() and the DataStream::iterate(long maxWaitTimeMillis) methods.

...

Code Block
languagejava
linenumberstrue
public class SynchronousBoundedLinearRegression {
    private static final int N_DIM = 50;
    private static final int N_EPOCH = 5;
    private static final int N_BATCH_PER_EPOCH = 10;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = 
            IterationUtils.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), DataStreamList
				ReplayableDataStreamList.of(noReplay(dataset)), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {
                DataStream<double[]> parameterUpdates = variableStreams.get(0);
                	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
                	DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    		.broadcast()
	                    .connect(dataset)
    	                .coProcess(new TrainFunction())
        	            .setParallelism(10)
	
                	return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
            	});
        
        DataStream<double[]> finalModel = resultStreams.get("final_model");
        finalModel.print();
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
        implements IterationListener<double[]> {  
        
        private final double[] parameters = new double[N_DIM];

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
        }

        void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH) {
                collector.collect(parameters);
            }
        }

        public void onIterationEnd(int[] round, Context context) {
            context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
        }
    }

    public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> {

        private final List<Tuple2<double[], Double>> dataset = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        private Supplier<int[]> recordRoundQuerier;

        public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
            this.recordRoundQuerier = querier;
        } 

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            int[] round = recordRoundQuerier.get();
            if (round[0] == 0) {
                firstRoundCachedParameter = parameter;
                return;
            }

            calculateModelUpdate(parameter, output);
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample)
        }

        void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark == 0) {
                calculateModelUpdate(firstRoundCachedParameter, output);
                firstRoundCachedParameter = null;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            List<Tuple2<double[], Double>> samples = sample(dataset);

            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : samples) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

...

If instead we want to do asynchronous training, we would need to do the following change:

...