You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 10 Next »

Status

Current state: Work-in-progress. Not ready for discussion yet.

Discussion thread: To be added

JIRA: To be added

Released: To be decided

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Motivation

In order to execute iterative machine learning algorithms, Flink supports the iteration primitive, such that some outputs of an operator subgraph (denoted as iteration body) can be fed back to the inputs of this subgraph. And this subgraph is iteratively executed until some termination criteria is reached.

Flink currently provides DataSet::iterate(...) and DataStream::iterate(...) to support the iterative computation described above. However, going forward, there are issues that prevent us from using these APIs to run iterative computation on either bounded or unbounded streams:

  • The DataSet::iterate(...) only supports iteration on the bounded data streams. And as described in FLIP-131, we will deprecate the DataSet API in favor of the Table API/SQL in the future.
  • The DataStream::iterate(...) has a few design issues that prevents it from being reliably used in production jobs. Many of these issues, such as possibility of deadlock, are described in FLIP-15.

And there is another performance issue with the DataSet::iterate(...) API: in order to replay the user-provided data streams multiple times, it requires the runtime to always dump the user-provided data streams to disk. This introduces storage and disk I/O overhead even if user's algorithm may prefer to cache those values in-memory and in possibly a more compact format.

In order to address all the issues described above, and make Flink ML available for more iteration use-case in the long run, this FLIP proposes to add a couple APIs in the flink-ml repository to achieve the following goals:

  • Provide solution for all the iteration use-cases (see the use-case section below for more detail) supported by the existing APIs, without having the issues described above.
  • Provide solution for a few use-cases (e.g. bounded streams + async mode + per-round variable update) not supported by the existing APIs.
  • Decouple the iteration-related APIs from core Flink core runtime (by moving them to the flink-ml repo) so that we can keep the Flink core runtime as simple and maintainable as possible.
  • Provide iteration API that does not enforce the disk I/O overhead described above, so that users can optimize an iterative algorithm for best possible performance.

Terminology

We explain a few terminologies in the following to facilitate the understanding of this doc.

1) Iteration body

An iteration body is a subgraph of operators that implements the computation logic of e.g. an iterative machine learning algorithm, whose outputs might be be fed back as the inputs of this subgraph. Therefore there is circle in the Flink graph if the Flink program has an iteration body.

Note that not all outputs of an iteration body has to be fed back as the inputs of this subgraph.

2) Feedback stream

For a given iteration body, a stream is said to be a feedback stream if it connects an output of this iteration body back to the input of this iteration body.

3) Epoch of records

In the proposed APIs, for any given record generated by the iteration body, we define the epoch of this record to be the number of times the iteration body has been invoked in the computation history of this record. The exact definition of epoch can be found in the Java doc of the IterationUtils class below.

Note that epoch is also a term commonly used in the context of machine learning to indicate the number of passes the entire training dataset the machine learning algorithm has processed. We denote this definition as "classic definition of epoch" below.

Our definition of the term "epoch" is pretty much a natural extension of the "classic definition of epoch" to the context of asynchronous machine learning on the unbounded streams. We make the following observations regarding their comparison:

  • The classic definition of epoch is well-defined only when the machine learning algorithm processes bounded streams AND all subtasks of the algorithm update the model variables synchronously.
  • Our definition of epoch can also be applies in the cases where the machine learning algorithm processes unbounded streams or subtasks of the algorithm update the model variables asynchronously.
  • In the cases where the class definition of epoch is well-defined (defined above), if the machine learning algorithm updates the model variables once after making a pass of the training dataset, then these two definitions of the epoch are exactly the same.

Target Use-cases

The target use-cases (i.e. algorithms) can be described w.r.t. the properties described below. In the following, we first describe the definitions of properties, followed by the combinations of the property choices supported by the existing APIs and the proposed APIs, respectively.

Properties of machine learning algorithms

Different algorithms might have different required properties for the input datasets (bounded or unbounded), synchronization between parallel subtasks (sync or async), amount of data processed for every variable update (a batch/subset or the entire dataset). We describe each of these properties below.

1) Algorithms have different needs for whether the input data streams should be bounded or unbounded. We classify those algorithms into online algorithm and offline algorithms as below.

  • For online training algorithms, the training samples will be unbounded streams of data. The corresponding iteration body should ingest these unbounded streams of data, read each value in each stream once, and update machine learning model repeatedly in near real-time. The iteration will never terminate in this case. The algorithm should be executed as a streaming job.
  • For offline training algorithms, the training samples will be bounded streams of data. The corresponding iteration body should read these bounded data streams for arbitrary number of rounds and update machine learning model repeatedly until a termination criteria is met (e.g. a given number of rounds is reached or the model has converged). The algorithm should be executed as a batch job.

2) Algorithms (either online or offline) have different needs of how their parallel subtasks.

  • In the sync mode, parallel subtasks, which execute the iteration body, update the model variables in a coordinated manner. There exists global epoch epochs, such that all subtasks read the shared model variables at the beginning of an epoch, calculate variable updates based on the fetched variable values, and write updates of the variable values at the end of this epoch.
  • In the async mode, each parallel subtask, which execute the iteration body, could read/update the shared model variables without waiting for variable updates from other subtask. For example, a subtask could have updated model variables 10 times when another subtask has updated model variables only 3 times.

The sync mode is useful when an algorithm should be executed in a deterministic way to achieve best possible accuracy, and the straggler issue (i.e. there is subtask which is considerably slower than others) does not cause slow down the algorithm execution too much. In comparison, the async mode is useful for algorithms which want to be parallelized and executed across many subtasks as fast as possible, without worrying about performance issue caused by stragglers, at the possible cost of reduced accuracy.

3) An algorithm may have additional requirements in how much data should be consumed each time before a subtask can update variables. There are two categories of choices here:

  • Per-batch variable update: The algorithm wants to update variables every time an arbitrary subset of the user-provided data streams (either bounded or unbounded) is processed.
  • Per-round variable update: The algorithm wants to update variables every time all data of the user-provided bounded data streams is processed.

In the machine learning domain, some algorithms allow users to configure a batch size and the model will be updated every time each subtask processes a batch of data. Those algorithms fits into the first category. And such an algorithm can be either online or offline.

Other algorithms only update variables every time the entire data is consumed for one round. Those algorithms fit into the second category. And such an algorithm must be offline because, by this definition, the user-provided dataset must be bounded.

Combinations of property choices supported by the existing APIs

The existing DataSet::iterate supports algorithms that process bounded streams + sync mode + per-round variable update.

The existing DataStream::iterate has a few bugs that prevent it from being used in production yet (see FLIP-15). Other than this, it expects to support algorithms that process unbounded streams + async mode + per-batch variable update.

Combinations of categories supported by the proposed APIs

As described above, there are 3 definition of properties where each property has 2 choices. So there are a total of 8 combination of property choices. Since that the "per-round variable update" can not be used with "unbounded data streams", only 6 out of the 8 choices are valid.

The APIs proposed in this FLIP support all the 6 valid combinations of property choices.

Summary

The following table summarizes the use-cases supported by the existing APIs and proposed APIs, respectively, with respect to the properties defined above.

Input typeSynchronization  modeVariable update modeSupported by the existing APIsSupported by the proposed APIs
bounded data streamssync modeper-batch variable updateNoYes
bounded data streamssync modeper-round variable updateYesYes
bounded data streamsasync modeper-batch variable updateNoYes
bounded data streamsasync modeper-round variable updateNoYes
unbounded data streamssync modeper-batch variable updateNoYes
unbounded data streamsasync modeper-batch variable updateYesYes


Overview of the Iteration Paradigm

In the following, we explain the key components of an iterative algorithm that has motivated our choices of the proposed APIs.

An iterative algorithm has the following behavior:

  • The iterative algorithm has an iteration body that is repeatedly invoked until some termination criteria is reached (e.g. after a user-specified number of epochs has been reached)
  • In each invocation, the iteration body updates the model parameters based on the user-provided data as well as the most recent model parameters.
  • The iterative algorithm takes as inputs the user-provided data and the initial model parameters.
  • The iterative algorithm could output arbitrary user-defined information, such as the loss after each epoch, or the final model parameters. 

Therefore, the behavior of an iterative algorithm could be characterized with the following iteration paradigm (w.r.t. Flink concepts):

  • An iteration body is a Flink subgraph with the following inputs and outputs
    • Inputs: model-variables (as a list of unbounded data streams) and user-provided-data (as a list of data streams)
    • Outputs: feedback-model-variables (as a list of unbounded data streams) and user-observed-outputs (as a list of data streams)
  • A termination-condition that specifies when the iterative execution of the iteration body should terminate.
  • In order to execute an iteration body, caller needs to provide the following inputs and can get the following outputs.
    • Inputs: initial-model-variables (as a list of bounded data streams) and user-provided-data (as a list of data streams)
    • Outputs: the user-observed-output emitted by the iteration body.

It is important to note that the model-variables observed by the iteration body is not the same as the initial-model-variables provided by the caller. Instead, model-variables are computed as the union of the feedback-model-variables (emitted by the iteration body) and the initial-model-variables (provided by the caller of the iteration body).

The figure below summarizes the iteration paradigm described above. The streams in red color indicates inputs provided by the caller to the iteration body, as well as outputs emitted by the iteration body to the caller.

Public Interfaces

We propose to make the following API changes to support the iteration paradigm described above.

1) Add the IterationBody interface.

This interface abstracts the core computation logic that is expected to be iteratively computed. Any algorithm, which wants to apply iterative computation with feedback streams, should implement this interface.

package org.apache.flink.ml.iteration;

@PublicEvolving
public interface IterationBody {
    /**
     * This method creates the graph for the iteration body.
     *
     * See Utils::iterate, Utils::iterateBoundedStreams and Utils::iterateAndReplayBoundedStreams for how the iteration
     * body can be executed and when execution of the corresponding graph should terminate.
     *
     * Required: the number of feedback variable streams returned by this method must equal the number of variable
     * streams given to this method.
     *
     * @param variableStreams the variable streams.
     * @param dataStreams the data streams.
     * @return a IterationBodyResult.
     */
    IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams);
}


2) Add the IterationListener interface.

If an UDF (a.k.a user-defined function) used inside the IterationBody implements this interface, the callbacks on this interface will be invoked when corresponding events happen.

This interface allows users to register callbacks that can be used to a) run an algorithm in sync mode; and b) emit final output after the iteration terminates.

org.apache.flink.ml.iteration

/**
 * The callbacks defined below will be invoked only if the operator instance which implements this interface is used
 * within an iteration body.
 */
@PublicEvolving
public interface IterationListener<T> {
    /**
     * This callback is invoked every time the epoch watermark increments. The initial epoch watermark is -1.
     *
     * The epochWatermark is the maximum integer that meets this requirement: every record that arrives at the operator
     * going forward should have an epoch larger than the epochWatermark. See Java docs in IterationUtils for how epoch
     * is determined for records ingested into the iteration body and for records emitted by operators within the
     * iteration body.
     *
     * If all inputs are bounded, the maximum epoch of all records ingested into this operator is used as the
     * epochWatermark by the last invocation of this callback.
     *
     * @param epochWatermark The incremented epoch watermark.
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector);

    /**
     * This callback is invoked after the execution of the iteration body has terminated.
     *
     * See Java doc of methods in IterationUtils for the termination conditions.
     *
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onIterationTermination(Context context, Collector<T> collector);

    /**
     * Information available in an invocation of the callbacks defined in the IterationProgressListener.
     */
    interface Context {
        /**
         * Emits a record to the side output identified by the {@link OutputTag}.
         *
         * @param outputTag the {@code OutputTag} that identifies the side output to emit to.
         * @param value The record to emit.
         */
        <X> void output(OutputTag<X> outputTag, X value);
    }
}


3) Add the IterationUtils class.

This class provides APIs to execute an iteration body with user-specified inputs. This class provides three APIs to run an iteration body, each with different input types (e.g. bounded data streams vs. unbounded data streams) and data replay semantics (i.e. whether to replay the user-provided data streams).

The common and the primary functionality of APIs in this class is this: Union the feedback variables streams (returned by the iteration body) with the initial variable streams (provided by the user) and use the merged streams as inputs to the iteration body.

package org.apache.flink.ml.iteration;

/**
 * A helper class to apply {@link IterationBody} to data streams.
 */
@PublicEvolving
public class IterationUtils {
    /**
     * This method can use an iteration body to process records in unbounded data streams.
     *
     * This method invokes the iteration body with the following parameters:
     * 1) The 1st parameter is a list of input variable streams, which are created as the union of the initial variable
     * streams and the corresponding feedback variable streams (returned by the iteration body).
     * 2) The 2nd parameter is the data streams given to this method.
     *
     * The epoch values are determined as described below. See IterationListener for how the epoch values are used.
     * 1) All records in the initial variable streams has epoch=0.
     * 2) All records in the data streams has epoch=MAX_LONG. In this case, records in the data stream won't affect
     * any operator's low watermark.
     * 3) For any record emitted by this operator into a non-feedback stream, the epoch of this emitted record = the
     * epoch of the input record that triggers this emission. If this record is emitted by onEpochLWIncremented(), then
     * the epoch of this record = incrementedEpochLW - 1.
     * 4) For any record emitted by this operator into a feedback variable stream, the epoch of the emitted record =
     * min(the epoch of the input record that triggers this emission, MAX_LONG - 1) + 1. If this record is emitted by
     * onEpochLWIncremented(), then the epoch of this record = incrementedEpochLW.
     *
     * The execution of the graph created by the iteration body will not terminate by itself. This is because at least
     * one of its data streams is unbounded.
     *
     * Required:
     * 1) All the init variable streams must be bounded.
     * 2) There is at least one unbounded stream in the data streams list.
     * 3) The parallelism of any stream in the initial variable streams must equal the parallelism of the stream at the
     * same index of the feedback variable streams returned by the IterationBody.
     *
     * @param initVariableStreams The initial variable streams. These streams will be merged with the feedback variable
     *                            streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The data streams. These streams will be used as the 2nd parameter to invoke the iteration
     *                    body.
     * @param body The computation logic which takes variable/data streams and returns variable/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    static DataStreamList iterateUnboundedStreams(DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {...}

    /**
     * This method can use an iteration body to process records in some bounded data streams iteratively until a
     * termination criteria is reached (e.g. the given number of rounds is completed or no further variable update is
     * needed). Because this method does not replay records in the data streams, the iteration body needs to cache those
     * records in order to visit those records repeatedly.
     *
     * This method invokes the iteration body with the following parameters:
     * 1) The 1st parameter is a list of input variable streams, which are created as the union of the initial variable
     * streams and the corresponding feedback variable streams (returned by the iteration body).
     * 2) The 2nd parameter is the data streams given to this method.
     *
     * The epoch values are determined as described below. See IterationListener for how the epoch values are used.
     * 1) All records in the initial variable streams has epoch=0.
     * 2) All records in the data streams has epoch=0.
     * 3) For any record emitted by this operator into a non-feedback stream, the epoch of this emitted record = the
     * epoch of the input record that triggers this emission. If this record is emitted by onEpochLWIncremented(), then
     * the epoch of this record = incrementedEpochLW - 1.
     * 4) For any record emitted by this operator into a feedback variable stream, the epoch of the emitted record = the
     * epoch of the input record that triggers this emission + 1. If this record is emitted by onEpochLWIncremented(),
     * then the epoch of this record = incrementedEpochLW.
     *
     * Suppose there is a coordinator operator which takes all feedback variable streams (emitted by the iteration body)
     * and the termination criteria stream (if not null) as inputs. The execution of the graph created by the
     * iteration body will terminate when all input streams have been fully AND any of the following conditions is met:
     * consumed:
     * 1) The termination criteria stream is not null. And the coordinator operator has not observed any new value from
     * the termination criteria stream between two consecutive onEpochLWIncremented invocations.
     * 2) The coordinator operator has not observed any new value from any feedback variable stream between two
     * consecutive onEpochLWIncremented invocations.
     *
     * Required:
     * 1) All the init variable streams and the data streams must be bounded.
     * 2) The parallelism of any stream in the initial variable streams must equal the parallelism of the stream at the
     * same index of the feedback variable streams returned by the IterationBody.
     *
     * @param initVariableStreams The initial variable streams. These streams will be merged with the feedback variable
     *                            streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The data streams. These streams will be used as the 2nd parameter to invoke the iteration
     *                    body.
     * @param body The computation logic which takes variable/data streams and returns variable/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    static DataStreamList iterateBoundedStreamsUntilTermination(DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {...}

    /**
     * This method can use an iteration body to process records in some bounded data streams iteratively until a
     * termination criteria is reached (e.g. the given number of rounds is completed or no further variable update is
     * needed). Because this method replays records in the data streams, the iteration body does not need to cache those
     * records to visit those records repeatedly.
     *
     * This method invokes the iteration body with the following parameters:
     * 1) The 1st parameter is a list of input variable streams, which are created as the union of the initial variable
     * streams and the corresponding feedback variable streams (returned by the iteration body).
     * 2) The 2nd parameter is a list of replayed data streams, which are created by replaying the initial data streams
     * round by round until the iteration terminates. The records in the Nth round will be emitted into the iteration
     * body only if the low watermark of the first operator in the iteration body >= N - 1.
     *
     * The epoch values are determined as described below. See IterationListener for how the epoch values are used.
     * 1) All records in the initial variable streams has epoch=0.
     * 2) The records from the initial data streams will be replayed round by round into the iteration body. The records
     * in the first round have epoch=0. And records in the Nth round have epoch = N - 1.
     * 3) For any record emitted by this operator into a non-feedback stream, the epoch of this emitted record = the
     * epoch of the input record that triggers this emission. If this record is emitted by onEpochLWIncremented(), then
     * the epoch of this record = incrementedEpochLW - 1.
     * 4) For any record emitted by this operator into a feedback stream, the epoch of the emitted record = the epoch
     * of the input record that triggers this emission + 1. If this record is emitted by onEpochLWIncremented(), then
     * the epoch of this record = incrementedEpochLW.
     *
     * Suppose there is a coordinator operator which takes all feedback variable streams (emitted by the iteration body)
     * and the termination criteria stream (if not null) as inputs. The execution of the graph created by the
     * iteration body will terminate when all input streams have been fully consumed AND any of the following conditions
     * is met:
     * 1) The termination criteria stream is not null. And the coordinator operator has not observed any new value from
     * the termination criteria stream between two consecutive onEpochLWIncremented invocations.
     * 2) The coordinator operator has not observed any new value from any feedback variable stream between two
     * consecutive onEpochLWIncremented invocations.
     *
     * Required:
     * 1) All the init variable streams and the data streams must be bounded.
     * 2) The parallelism of any stream in the initial variable streams must equal the parallelism of the stream at the
     * same index of the feedback variable streams returned by the IterationBody.
     *
     * @param initVariableStreams The initial variable streams. These streams will be merged with the feedback variable
     *                            streams before being used as the 1st parameter to invoke the iteration body.
     * @param initDataStreams The initial data streams. Records from these streams will be repeatedly replayed and used
     *                        as the 2nd parameter to invoke the iteration body.
     * @param body The computation logic which takes variable/data streams and returns variable/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    static DataStreamList iterateAndReplayBoundedStreamsUntilTermination(DataStreamList initVariableStreams, DataStreamList initDataStreams, IterationBody body) {...}

}


4) Add IterationBodyResult as a helper class.

IterationBodyResult is a helper class which contains the streams returned by the IterationBody::process(...)

package org.apache.flink.ml.iteration;

/**
 * A helper class that contains the streams returned by the iteration body.
 */
class IterationBodyResult {
    /**
     * A list of feedback variable streams. These streams will only be used during the iteration execution and will
     * not be returned to the caller of the iteration body. It is assumed that the method which executes the
     * iteration body will feed the records of the feedback variable streams back to the corresponding input variable
     * streams.
     */
    DataStreamList feedbackVariableStreams;

    /**
     * A list of output streams. These streams will be returned to the caller of the methods that execute the
     * iteration body.
     */
    DataStreamList outputStreams;

    /**
     * An optional termination criteria stream. If this stream is not null, it will be used together with the
     * feedback variable streams to determine when the iteration should terminate.
     */
    Optional<DataStream<?>> terminationCriteria;
}


5) Add the DataStreamList as a helper class.

DataStreamList is a helper class that serves as a container for a list of data streams with possibly different elements types.

package org.apache.flink.ml.iteration;

public class DataStreamList {
    // Returns the number of data streams in this list.
    public int size() {...}

    // Returns the data stream at the given index in this list.
    public <T> DataStream<T> get(int index) {...}
}


6) Deprecate the existing DataStream::iterate() and the DataStream::iterate(long maxWaitTimeMillis) methods.

We plan to remove both methods after the APIs added in this doc is ready for production use. This change is needed to decouple the iteration-related APIs from core Flink core runtime  so that we can keep the Flink core runtime as simple and maintainable as possible.


Proposed Changes

1) Termination of the iteration execution.

See the Java doc of those APIs in the IterationUtils for how each API determine the iteration termination.

2) Execution mode.

If all inputs are bounded streams, then the iteration body can be executed in either the stream mode or the batch mode.

If some inputs are unbounded streams, then the iteration body must be executed in the stream mode.

3) Type of edges inside the iteration body.

All edges inside the iteration body are required to have the PIPELINE type.

If the user-defined iteration body contains an edge that does not have the PIPELINE type, methods that create the subgraph from the iteration body, such as iterateBoundedStreamsUntilTermination, will throw exception upon invocation.

4) Implementation of the feedback edge.

The Flink core runtime supports only DAG of operators. Thus it does not provide native support for feedback edges since feedback edges introduce circle in the operator graph.

Same as the implementation of the DataSet::iterate() API, the proposed APIs are implemented with the following trick:

  • Automatically insert HEAD and TAIL operators as the first and last operators in the iteration body.
  • Co-locate HEAD and TAIL operators on the same task manager.
  • Have HEAD and TAIL operators transmit the records of the feedback edges using an in-memory queue.

5) Lifetime of the operators inside the iteration body.

The operator inside the iteration body are only created once and destroyed after the iteration terminates. In contrast, the existing DataSet::iterate(..) would re-create the iteration body (together with all states inside it) once for every round of execution, which in general could introduce more runtime overhead then the approach adopted in this FLIP.

6) Failover

To be described.

7) Support for synchronous iteration.

To be descried

Example Usages

Offline Training with Bounded Iteration

We would like to first show the usage of the bounded iteration with the linear regression case: the model is Y = XA, and we would like to acquire the best estimation of A with the SGD algorithm. To simplify we assume the parameters could be held in the memory of one task.

The job graph of the algorithm could be shown in the Figure 3: in each round, we use the latest parameters to calculate the update to the parameters: ΔA = ∑(Y - XA)X. To achieve this, the Parameters vertex would broadcast the latest parameters to the Train vertex. Each subtask of the Train vertex holds a part of dataset. Follow the sprite of SGD, it would sample a small batch of training records, and calculate the update with the above equation. Then the Train vertex emit ΔA to the Parameters node to update the parameters.


We will start with the synchronous training. The synchronous training requires the updates from all the Train vertex subtask is merged before the next round of training. It could be done by only emit the next round of parameters on the end of round. The code is shown as follows:

public class SynchronousBoundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        int batch = 5;
        int epochEachBatch = 10;

        ResultStreams resultStreams = new BoundedIteration()
            .withBody(new IterationBody(
                @IterationInput("model") DataStream<double[]> model,
                @IterationInput("dataset") DataStream<Tuple2<double[], Double>> dataset
            ) {
                SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction());
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .coProcess(new TrainFunction())
                    .setParallelism(10)

                return new BoundedIterationDeclarationBuilder()
                    .withFeedback("model", modelUpdate)
                    .withOutput("final_model", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))
                    .until(new TerminationCondition(null, context -> context.getRound() >= batch * epochEachBatch))
                    .build();
            })
            .build();
        
        DataStream<double[]> finalModel = resultStreams.get("final_model");
        finalModel.print();
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
        implements BoundedIterationProgressListener<double[]> {  
        
        private final double[] parameters = new double[N_DIM];

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
        }

        public void onRoundEnd(int[] round, Context context, Collector<T> collector) {
            collector.collect(parameters);
        }

        public void onIterationEnd(int[] round, Context context) {
            context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
        }
    }

    public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements BoundedIterationProgressListener<double[]> {

        private final List<Tuple2<double[], Double>> dataset = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        private Supplier<int[]> recordRoundQuerier;

        public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
            this.recordRoundQuerier = querier;
        } 

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            int[] round = recordRoundQuerier.get();
            if (round[0] == 0) {
                firstRoundCachedParameter = parameter;
                return;
            }

            calculateModelUpdate(parameter, output);
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample)
        }

        public void onRoundEnd(int[] round, Context context, Collector<T> output) {
            if (round[0] == 0) {
                calculateModelUpdate(firstRoundCachedParameter, output);
                firstRoundCachedParameter = null;                
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            List<Tuple2<double[], Double>> samples = sample(dataset);

            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : samples) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

If instead we want to do asynchronous training, we would need to do the following change:

  1. The Parameters vertex would not wait till round end to ensure received all the updates from the iteration. Instead, it would immediately output the current parameters values once it received the model update from one train subtask.
  2. To label the source of the update, we would like to change the input type to be Tuple2<Integer, double[]>. The Parameters would only output the new parameters values to the Train task that send the update.

We omit the change to the graph building code since the change is trivial (change the output type and the partitioner to be customized one). The change to the Parameters vertex is the follows:

public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>>
    implements BoundedIterationProgressListener<double[]> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        // Suppose we have a util to add the second array to the first.
        ArrayUtils.addWith(parameters, update);
        output.collect(new Tuple2<>(update.f0, parameters))
    }

    public void onIterationEnd(int[] round, Context context) {
        context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
    }
}

Online Training with Unbounded Iteration

Suppose now we would change the algorithm to unbounded iteration, compared to the offline, the differences is that

  1. The dataset is unbounded. The Train operator could not cache all the data in the first round.
  2. The training algorithm might be changed to others like FTRL. But we keep using SGD in this example since it does not affect showing the usage of the iteration.

We also start with the synchronous case. for online training, the Train vertex usually do one update after accumulating one mini-batch. This is to ensure the distribution of the samples is similar to the global statistics. In this example we omit the complex data re-sample process and just fetch the next several records as one mini-batch. 

The JobGraph for online training is still shown in Figure 1, with the training dataset become unbounded. Similar to the bounded cases, for the synchronous training, the process would be expected like

  1. The Parameters broadcast the initialized values on received the input values.
  2. All the Train task read the next mini-batch of records, Calculating an update and emit to the Parameters vertex. Then it would wait till received update parameters from the Parameters Vertex before it head to process the next mini-batch.
  3. The Parameter vertex would wait received the updates from all the Train tasks before it broadcast the updated parameters. 

Since in the unbounded case there is not the concept of round, and we do update per-mini-batch, thus we could instead use the InputSelectable functionality to implement the algorithm:

public class SynchronousUnboundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        ResultStreams resultStreams = new UnboundedIteration()
            .withBody(new IterationBody(
                @IterationInput("model") DataStream<double[]> model,
                @IterationInput("dataset") DataStream<Tuple2<double[], Double>> dataset
            ) {
                SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10));
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .transform(
                                "operator",
                                BasicTypeInfo.INT_TYPE_INFO,
                                new TrainOperators(50));
                    .setParallelism(10);

                return new UnBoundedIterationDeclarationBuilder()
                    .withFeedback("model", modelUpdate)
                    .withOutput("model_update", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))
                    .build();
            })
            .build();
        
        DataStream<double[]> finalModel = resultStreams.get("model_update");
        finalModel.addSink(...)
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> {  
        
        private final int numOfTrainTasks;

        private final int numOfUpdatesReceived = 0;
        private final double[] parameters = new double[N_DIM];

        public ParametersCacheFunction(int numOfTrainTasks) {
            this.numOfTrainTasks = numOfTrainTasks;
        }

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
            numOfUpdatesReceived++;

            if (numOfUpdatesReceived == numOfTrainTasks) {
                output.collect(parameters);
                numOfUpdatesReceived = 0;
            }
        }
    }

    public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable {

        private final int miniBatchSize;

        private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        public TrainOperators(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
        }

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            calculateModelUpdate(parameter, output);
			miniBatchSize.clear();
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample);
        }

        public InputSelection nextSelection() {
            if (miniBatch.size() < miniBatchSize) {
                return InputSelection.SECOND;
            } else {
                return InputSelection.FIRST;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : miniBatchSize) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

Also similar to the bounded case, for the asynchronous training the Parameters vertex would not wait for received updates from all the Train tasks. Instead, it would directly response to the task sending update:

public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        ArrayUtils.addWith(parameters, update);
                
        if (update.f0 < 0) {
            // Received the initialized parameter values, broadcast to all the downstream tasks
            for (int i = 0; i < 10; ++i) {
                output.collect(new Tuple2<>(i, parameters))        
            }
        } else {
            output.collect(new Tuple2<>(update.f0, parameters))
        }
    }
}

Compatibility, Deprecation, and Migration Plan

TBD

Rejected Alternatives

Naiad has proposed a unified model for watermark mechanism (namely progress tracking outside of the iteration) and the progress tracking inside the iteration. It extends the event time and watermark to be a vector (long timestamp, int[] rounds) and implements a vectorized alignment algorithm. Although Naiad provides an elegant model, the direct implementation on Flink would requires a large amount of modification to the flink runtime, which would cause a lot of complexity and maintenance overhead.  Thus we would choose to implement a simplified version on top of FLINK, as a part of the flink-ml library.

For the iteration DAG build graph, it would be more simpler if we could directly refer to the data stream variables outside of the closure of iteration body. However, since we need to make the iteration DAG creation first happen in the mock execution environment, we could not use these variables directly, otherwise we would directly modify the real environment and won't have chance to add wrappers to the operators. 






  • No labels