You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 47 Next »

Status

Current state: <TODO>

Discussion thread: <TODO>

JIRA: <TODO>

This method is easier to use, but it also limit some possible optimizations

Released: <Flink Version>

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Motivation

Iteration is a basic building block for a ML library. It is required for training ML models for both offline and online Training. In general, two types of iterations is required:

  1. Bounded Iteration: Usually used in the offline case. In this case the algorithm usually train on a bounded dataset, it updates the parameters for multiple rounds until convergence.
  2. Unbounded Iteration: Usually used in the online case, in this case the algorithm usually train on an unbounded dataset. It accumulates a mini-batch of data and then do one update to the parameters. 

Previously Flink supported bounded iteration with DataSet API and supported the unbounded iteration with DataStream API. However, since Flink aims to deprecate the DataSet API and the iteration in the DataStream API is rather incomplete, thus we would require to re-implement a new iteration library in the Flink-ml repository to support the algorithms. 

The Goals

The Types of the Algorithms

In general a ML algorithm would update the model according to the data in iteration until the model is converged. According to the granularity of the dataset used to update the model, in general ML algorithms could be classified into two types:

  1. Epoch-based: each epoch means the algorithm goes through all the training dataset. The epoch-based algorithm must work with the bounded dataset. 
  2. Batch-based: each batch means the algorithm samples a subset from all the records and used the sampled records to update the model. The batch-based algorithm could be work with the bounded or unbounded dataset.

In a distributed settings, the dataset might be partitioned onto multiple subtasks, then for each epoch we refer to each subtask goes through all the assigned records, and for batch we refer to each subtask sample from its assigned records. When update the model with the referred data in a distributed settings, there are further two styles:

  1. Synchronous: In synchronous pattern the model must wait for the updates from all the subtasks before it could be used in the computation of the next update in each subtask. 
  2. Asynchronous: In asynchronous pattern the model could directly apply the updates from some subtasks, and uses the updated value in the following computation immediately. 

In general the synchronous pattern would have higher accuracy and the asynchronous pattern would convergent faster. 

Based on the above dimensions, the algorithms could be classified into the following types:

TypeData GranularitySynchronization Pattern Bounded / UnboundedExamples
Non-SGD-basedEpochMostly SynchronousBoundedK-Means, Apriori, Decision Tree, Random Walk

SGD-Based Synchronous Offline algorithm

Batch → Epoch*SynchronousBoundedLinear Regression, Logistic Regression, Deep Learning algorithms
SGD-Based Asynchronous Offline algorithmBatch → Epoch*AsynchronousBoundedSame to the above
SGD-Based Asynchronous Online algorithmBatchSynchronousUnboundedOnline version of the above algorithm
SGD-Based Asynchronous Online algorithmBatchAsynchronousUnboundedOnline version of the above algorithm

*Although SGD-based algorithms are also batch-based, they could be implemented with an Epoch-based method if intermediate state is allowed: the subtasks could sample a batch from all the records from the position of the last batch. 


Based on the above classification and the replacement implementation for SGD-based algorithms with bounded dataset, we mainly need to support

  1. The synchronous / asynchronous epoch-based algorithms on the bounded dataset.
  2. The synchronous / asynchronous batch-based algorithms on the unbounded dataset. 

The Goals of the Iteration Library

If we directly copy the current implementation of iteration on the DataSet And DataStream API, we would still meet with some problem, thus we would like to have some optimization to the existing iteration functionality.

The Iteration Body and Round Semantics

At the iteration level, we would need the corresponding concept corresponding to Epoch and Batch. We would call processing one epoch as a round: users would specify a subgraph as the body of the iteration to specify how to calculate the update, after the iteration body process the whole dataset for one time (namely one Epoch). Apparently the round is meaningful only for the bounded cases.

Per-Round v.s. All-Rounds Semantics

How users could specify the iteration body ? If we first consider the bounded cases, there are two options

  1. Per-round: Users specify a subgraph, and for each round, the framework would recreate the operators and do the same computation.
  2. All-rounds: Users specify a subgraph, and the operators inside the subgraph would process the epochs of all the rounds. 

The DataSet iteration choose the per-round semantics. to support this semantics, in addition to re-create operators for each round, the framework also needs:

  1. For the inputs outside the iteration, 


The benefits of this method is that writing an iteration body is no difference from constructing a DAG outside of the iteration. 

Synchronization

Since for the bounded dataset, all the algorithms, to the best of out extend, are all able to be converted into epoch-based algorithms, thus we could only support the synchronization between epoch, namely between rounds.

How to 


Besides, the previous DataStream and DataSet iteration APIs also have some caveats to support algorithm implementation:

  1. Lack of the support for multiple inputs, arbitrary outputs and nested iteration for both iteration APIs, which is required by scenarios like Metapath (multiple-inputs), boost algorithms (nested iteration) or when we want to output both loss and model (multiple-outputs). In the new iteration we would support these functionalities.
  2. Lack of asynchronous iteration support for the DataSet iteration, which is required by algorithms like asynchronous linear regression, in the new iterations we would support both synchronous and asynchronous modes for the bounded iteration. 
  3. The current DataSet iteration by default provides a "for each round" semantics, namely users only need to specify the computation logic in each round, and the framework would executes the subgraph multiple times until convergence. To cooperate with the semantics, the DataSet iteration framework would merge the initial input and the feedback (bulk style and delta style), and replay the datasets comes from outside of the iteration. This method is easier to use, but it also limit some possible optimizations.

We also would like to improve these caveats in the new iteration library. 

Overall Design

To reduce the development and maintenance overhead, it would be preferred to have a unified implementation for different types of iterations. In fact, the different iteration types shares the same requirements in runtime implementation:

  1. All the iteration types should support multiple inputs and multiple outputs. 
  2. All the iteration types require some kind of back edges that transfer the data back to the iteration head. Since Flink does not support cycles in scheduler and network stack, the back edges should not be visible in the StreamGraph and JobGraph.
  3. All the iteration should support checkpoints mechanism in Stream execution mode.

Different types of iterations differ in their requirements for Progress tracking. Progress tracking is analogous to the watermark outside the iteration and it tracks the “progress” inside the iteration:

  1. For bounded iteration, we could track if we have processed all the records for a specific round. This is necessary for operators like aggregation inside the iteration: if it is notified all the records of the current round is processed, it could output the result of this round. We could also track if the whole iteration is end, namely all the inputs are finished and no pending records inside the iteration. 
  2. For unbounded iteration, there is no concept of global rounds, and the only progress tracking is at the end of iteration. 

The difference of the progress tracking would also affect the API. For example, for bounded iteration, we could allow users to specify the termination condition based on number of rounds, but it is meaningless for the unbounded iteration.

To make the API easy to use, we propose to have dedicated API for different types of iteration, and underlying we will translate them onto the same framework. would implements the basic functionality like iteration StreamGraph building, runtime structure and checkpoint, and it allows to implement different iterations to implement different types of progress tracking support. 

Public Interfaces

As shown in Figure 1, an iteration is composed of 

  1. The inputs from outside of the iteration. 
  2. An iteration body specify the structure inside the iteration.
    1. The subgraph inside the iteration.
    2. Some input have corresponding feedbacks to update the underlying data stream. The feedbacks are union with the corresponding inputs: the original inputs are emitted into the iteration body for only once, and the feedbacks are also emitted to the same set of operators.
    3. The outputs going out of the iteration. The outputs could be emitted from arbitrary data stream.

Figure 1. The structure of an iterations. 

Unbounded Iteration

Similar to FLIP-15, we would more tend to provide a structural iteration API to make it easier to be understand. With this method, users are required to specify an IterationBody that generates the part of JobGraph inside the iteration. The iteration body should specify the DAG inside the iteration, and also the list of feedback streams and the output streams. The feedback streams would be union with the corresponding inputs and the output streams would be provided to the caller routine. 

However, since we do not know the accurate number and type of input streams, it is not easy to define a unified interface for the iteration body without type casting. Thus we would propose to use the annotation to allows for arbitrary number of inputs:

The IterationBody API
/** The iteration body specify the sub-graph inside the iteration. */
public interface IterationBody {

	/** This annotation marks the function as it would builds the subgraph. */
	@Target(ElementType.METHOD)
	@Retention(RetentionPolicy.RUNTIME)
	public @interface IterationFunction {}

	/** 
	* This annotation marks a parameter of the iteration function as an input to the subgraph. 
	* The input is the union of the initial inputs bound to the iteration and the corresponding 
	* feedback. 
	*/
	@Target(ElementType.PARAMETER)
	@Retention(RetentionPolicy.RUNTIME)
	public @interface IterationInput {
		String value();
	}
}

/** An example usage for the iteration body with two inputs. */
new IterationBody() {
	@IterationFunction
	UnboundedIterationDeclarative iterate(
		@IterationInput("first") DataStream<Integer> first
		@IterationInput("second") DataStream<> second,
	) {
		DataStream<Integer> feedBack1 = ...;
		DataStream<String> output1 = ...;
		return new UnboundedIterationDeclarationBuilder()
			.withFeedback("first", feedBack1)
			.withOutput("output1", output1)
			.build();
		}
	}

The interface for the unbounded iteration is straightforward:

Unbounded Iteration API
/** Builder for the unbounded iteration. */
public class UnboundedIteration {

	/** Set the body of the iteration. */
    UnboundedIteration withBody(IterationBody body) {...}
	
	/** Bind the initial input with the specific name. */
	UnboundedIteration bindInput(String name, DataStream<?> input) {...}

	/** Generates and adds the subgraph corresponding to the iteration.  */
	ResultStreams build() {...}
}

/** The expected return type of the iteration function, which specifies the feedbacks and outputs. */
public class UnboundedIterationDeclaration {
	
	public static class Builder {

		/** 
		* Specify the feedback corresponding to the specific name. The feedback would
		* be union with the initial input with the same name to provide to the iteration
		* body. 
		*/
		public Builder withFeedback(String name, DataStream<?> feedback) {...}

		/** Specify one output with the specific name. */
		public Builder withOutput(String name, DataStream<?> output) {...}
		
		/** Generate the Declaration. */
		UnboundedIterationDeclaration build() {...}
	}
}

/** The map of the output streams of an iteration.  */
public class ResultStreams {
	
	/** Gets the DataStream with the specific name. */
	public <T> DataStream<T> getStream(String name) {...}

}

To avoid more data is read from the inputs while too much data accumulate inside the iteration, the iteration would first process the feedback data if both side of data is available. 

For termination detection, the iteration would continue until

  1. All the inputs are terminated.
  2. And there is no records inside the iteration subgraph. 

Then the iteration terminates.

Bounded Iteration

As mentioned in the motivation, the existing dataset iteration API uses the "per-round" semantics: it views the iteration as a repeat execution of the same DAG, thus underlying it would automatically merge the inputs and feedbacks and replay the inputs without feedbacks, and the operators inside the iteration live only for one-round. This might cause bad performance for some algorithms who could cache these data in a more efficient way. 

To avoid this issue, similar to the unbounded iteration, by default we use the "per-iteration" semantics: 

  1. Operators inside the iteration would live till the whole iteration is finished.
  2. We do not automatically merge the inputs and feedbacks. Instead, we union the original inputs and the feedbacks so that users could decide how the merge them.
  3. We do not replay the inputs without feedbacks. Users could decide to how to cache them more efficiently. 

Besides, to cooperate with the "per-round" semantics, previously the iteration is by default synchronous: before the current round fully finished, the feedback data is cached and would not be emitted. Thus it could not support some algorithms like asynchronous regression. To cope with this issue, we view synchronous iteration as a special case of asynchronous iteration with additional synchronization. Thus by default the iteration is asynchronous. 

Based on the above assumption, the API to add iteration to a job is nearly the same compared to the unbounded iteration. The only difference is that bounded iteration supports more sophisticated termination conditions: a function is evaluated when each round ends based on the round or the records of a specified data stream. If it returns true, the iteration would deserts all the following feedback records, ends all the ongoing rounds and finish. 

Since now the operators would live across multiple rounds and multiple rounds might be concurrent, the operators inside the iteration needs to know the rounds of the current record and when one round is fully finished, namely the progress tracking. For example, an operator computes the sum of the records in each rounds would like to add the record to the corresponding partial sum, and when one round is finished, it would emit the sum for this round. To support the progress track, UDFs / operators inside the iteration could implementation `BoundedIterationProgressListener` to acquire the additional information about the progress. 

Based on the progress tracking interface, if users want to implement a synchronous method, some operators inside the subgraph needs to be synchronous: they only emits the records in `onRoundEnd`, namely after all the data of the current round is received. If for the subgraph of iteration body, every path from input to the feedbacks has at least such an operator, then the iteration would be synchronous. 

For users still want to use the iteration with the "per-round" semantics, a utility `forEachRound()` is provided. With the utility users could add a subgraph inside the iteration body that

  1. The operators inside the subgraph would live only for one round.
  2. If an input stream without feedback is referenced, the input stream would be replayed for each round.

For input stream with feedbacks, we also provide two utility processFunction that automatically merge the original inputs and feedbacks. Both the existing bulk and delta method is supported. Then users would be able to implement a per-round iteration with input.process(bulkCache()).forEachRound(() → {...}).

The API for the bounded iteration is as follows:

Bounded Iteration API
/** Builder for the bounded iteration. */
public class BoundedIteration {
	
	/** Set the body of the iteration. */
    BoundedIteration withBody(IterationBody body) {...}

	/** Bind the initial input with the specific name. */
	BoundedIteration bindInput(String name, DataStream<?> input) {...}
	
	/** Generates and adds the subgraph corresponding to the iteration.  */
	ResultStreams build() {...}
}

/** The expected return type of the iteration function, which specifies the feedbacks, outputs and termination conditions. */
public class BoundedIterationDeclaration {
	
	public static class Builder {
		
		/** 
		* Specify the feedback corresponding to the specific name. The feedback would
		* be union with the initial input with the same name to provide to the iteration
		* body. 
		*/
		public Builder withFeedback(String name, DataStream<?> feedback) {...}

		/** Specify one output with the specific name. */
		public Builder withOutput(String name, DataStream<?> output) {...}
		
		/** Specify the termination condition of the iteration. */
		<U> Builder until(TerminationCondition terminationCondition) { ... }
		
		/** Generate the Declaration. */
		BoundedIterationDeclaration build() {...}
	}
}

/** 
* The termination condition judges if iteration should stop based on the round number
* or the records of a given data stream in one round. 
* 
* Since there might be asynchronous iteration that multiple rounds are executed in parallel,
* the condition is evaluated at the end of each round. If it is evaluated to true for one
* round, then the iteration would desert the feedback records for the following rounds, but
* the already emitted records would not be withdrawal. 
*/
public class TerminationCondition {

	/**
	* The records of the given DataStream will be collected in each
	* round to be used in judging whether the loop should terminate.
	*/
	@Nullable DataStream<?> refStream;

	/**
	* A user-defined function that is evaluated at the end of each round.
	*/
	Function<Context, Boolean> isConverged;

	interface Context {
		
		/** The round number. */
     	int[] getRound();
		
		/** The list of records of the referred stream. */
     	<T> List<T> getStreamRecords();
  	}	
}

/** The progress tracking interface for UDF / Operator */
public interface BoundedIterationProgressListener<T> {
	
	/** Sets a tool to be used to query the round number of the current record. */
	default void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {}
	
	/** Notified at the end of each round. */
	void onRoundEnd(int[] round, Context context, Collector<T> collector);
	
	/** Notified at the end of the whole iteration. */
	default void onIterationEnd(int[] rounds, Context context) {}

    public interface Context {
		
		<X> void output(OutputTag<X> outputTag, X value);
		
		Long timestamp();

		TimerService timerService();
	}
}

/** The utility methods to support per-round semantics. */
public class BoundedIterationPerRoundUtils {
	/** The builder that creates the subgraph that executed with the per-round semantics **/	
	public interface EachRound {
	
		Map<String, DataStream<?>> executeInEachRound();

	}
	
	/** Create a subgraph inside the iteration that execute with per-round semantics */
	static ResultStreams forEachRound(EachRound eachRoundBuilder);

	/** A cache to a Datastream. For each round it replaces the Datastream with the records received in this round and outputs all these records. */	
	static <T> ProcessFunction bulkCache(DataStream<T> inputStream);

	/** A cache to a KeyedStream. For each round it updates the KeyedStream with the records received in this round and outputs the updated records. */	
	static <K, T> ProcessFunction deltaCache(KeyedStream<K, T> inputStream);
}

Key Implementation

The runtime physical structure of the iteration is shown in Figure 1, which is similar to the current implementation. The head & tail is added by the framework. They would be colocated so that we could implement the feedback edge with the local queue.  The head could coordinator with an operator coordinator bind to a virtual operator ID for synchronization, including progress tracking and termination condition calculating. 


Figure 1. The physical runtime structure for the iteration. 

To support the progress tracking, we would introduce new events inside the iteration body, like how watermark is implemented. However, since the normal operators could not identify these event, we would wrap the operators inside the iteration to parse these events.

To wrap the operators for the part of DAG inside the iteration, when building the stream graph we would introduce a mock execution environment and build the iteration DAG inside this environment first, then when apply() method is called, we would translate the DAG into the real execution environment with the suitable wrapper. Besides, all the edges inside the iteration should be PIPELINE, we would also set the edge property when translating.

The operator wrapper needs to simulates the context that an operator executes. Specially, for operators with single-round lifecycle in bounded iteration, we would need to isolate the states used for each round and cleanup the corresponding state after the round end.

Examples

This sections shows how general used ML algorithms could be implemented with the iteration API. 

Offline Training with Bounded Iteration

We would like to first show the usage of the bounded iteration with the linear regression case: the model is Y = XA, and we would like to acquire the best estimation of A with the SGD algorithm. To simplify we assume the parameters could be held in the memory of one task.

The job graph of the algorithm could be shown in the Figure 3: in each round, we use the latest parameters to calculate the update to the parameters: ΔA = ∑(Y - XA)X. To achieve this, the Parameters vertex would broadcast the latest parameters to the Train vertex. Each subtask of the Train vertex holds a part of dataset. Follow the sprite of SGD, it would sample a small batch of training records, and calculate the update with the above equation. Then the Train vertex emit ΔA to the Parameters node to update the parameters.


Figure 3. The JobGraph for the offline training of the linear regression case.


We will start with the synchronous training. The synchronous training requires the updates from all the Train vertex subtask is merged before the next round of training. It could be done by only emit the next round of parameters on the end of round. The code is shown as follows:

public class SynchronousBoundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        int batch = 5;
        int epochEachBatch = 10;

        ResultStreams resultStreams = new BoundedIteration()
            .withBody(new IterationBody(
                @IterationInput("model") DataStream<double[]> model,
                @IterationInput("dataset") DataStream<Tuple2<double[], Double>> dataset
            ) {
                SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction());
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .coProcess(new TrainFunction())
                    .setParallelism(10)

                return new BoundedIterationDeclarationBuilder()
                    .withFeedback("model", modelUpdate)
                    .withOutput("final_model", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))
                    .until(new TerminationCondition(null, context -> context.getRound() >= batch * epochEachBatch))
                    .build();
            })
            .build();
        
        DataStream<double[]> finalModel = resultStreams.get("final_model");
        finalModel.print();
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
        implements BoundedIterationProgressListener<double[]> {  
        
        private final double[] parameters = new double[N_DIM];

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
        }

        public void onRoundEnd(int[] round, Context context, Collector<T> collector) {
            collector.collect(parameters);
        }

        public void onIterationEnd(int[] round, Context context) {
            context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
        }
    }

    public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements BoundedIterationProgressListener<double[]> {

        private final List<Tuple2<double[], Double>> dataset = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        private Supplier<int[]> recordRoundQuerier;

        public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
            this.recordRoundQuerier = querier;
        } 

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            int[] round = recordRoundQuerier.get();
            if (round[0] == 0) {
                firstRoundCachedParameter = parameter;
                return;
            }

            calculateModelUpdate(parameter, output);
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample)
        }

        public void onRoundEnd(int[] round, Context context, Collector<T> output) {
            if (round[0] == 0) {
                calculateModelUpdate(firstRoundCachedParameter, output);
                firstRoundCachedParameter = null;                
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            List<Tuple2<double[], Double>> samples = sample(dataset);

            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : samples) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

If instead we want to do asynchronous training, we would need to do the following change:

  1. The Parameters vertex would not wait till round end to ensure received all the updates from the iteration. Instead, it would immediately output the current parameters values once it received the model update from one train subtask.
  2. To label the source of the update, we would like to change the input type to be Tuple2<Integer, double[]>. The Parameters would only output the new parameters values to the Train task that send the update.

We omit the change to the graph building code since the change is trivial (change the output type and the partitioner to be customized one). The change to the Parameters vertex is the follows:

public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>>
    implements BoundedIterationProgressListener<double[]> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        // Suppose we have a util to add the second array to the first.
        ArrayUtils.addWith(parameters, update);
        output.collect(new Tuple2<>(update.f0, parameters))
    }

    public void onIterationEnd(int[] round, Context context) {
        context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
    }
}

Online Training with Unbounded Iteration

Suppose now we would change the algorithm to unbounded iteration, compared to the offline, the differences is that

  1. The dataset is unbounded. The Train operator could not cache all the data in the first round.
  2. The training algorithm might be changed to others like FTRL. But we keep using SGD in this example since it does not affect showing the usage of the iteration.

We also start with the synchronous case. for online training, the Train vertex usually do one update after accumulating one mini-batch. This is to ensure the distribution of the samples is similar to the global statistics. In this example we omit the complex data re-sample process and just fetch the next several records as one mini-batch. 

The JobGraph for online training is still shown in Figure 1, with the training dataset become unbounded. Similar to the bounded cases, for the synchronous training, the process would be expected like

  1. The Parameters broadcast the initialized values on received the input values.
  2. All the Train task read the next mini-batch of records, Calculating an update and emit to the Parameters vertex. Then it would wait till received update parameters from the Parameters Vertex before it head to process the next mini-batch.
  3. The Parameter vertex would wait received the updates from all the Train tasks before it broadcast the updated parameters. 

Since in the unbounded case there is not the concept of round, and we do update per-mini-batch, thus we could instead use the InputSelectable functionality to implement the algorithm:

public class SynchronousUnboundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        ResultStreams resultStreams = new UnboundedIteration()
            .withBody(new IterationBody(
                @IterationInput("model") DataStream<double[]> model,
                @IterationInput("dataset") DataStream<Tuple2<double[], Double>> dataset
            ) {
                SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10));
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .transform(
                                "operator",
                                BasicTypeInfo.INT_TYPE_INFO,
                                new TrainOperators(50));
                    .setParallelism(10);

                return new UnBoundedIterationDeclarationBuilder()
                    .withFeedback("model", modelUpdate)
                    .withOutput("model_update", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))
                    .build();
            })
            .build();
        
        DataStream<double[]> finalModel = resultStreams.get("model_update");
        finalModel.addSink(...)
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> {  
        
        private final int numOfTrainTasks;

        private final int numOfUpdatesReceived = 0;
        private final double[] parameters = new double[N_DIM];

        public ParametersCacheFunction(int numOfTrainTasks) {
            this.numOfTrainTasks = numOfTrainTasks;
        }

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
            numOfUpdatesReceived++;

            if (numOfUpdatesReceived == numOfTrainTasks) {
                output.collect(parameters);
                numOfUpdatesReceived = 0;
            }
        }
    }

    public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable {

        private final int miniBatchSize;

        private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        public TrainOperators(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
        }

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            calculateModelUpdate(parameter, output);
			miniBatchSize.clear();
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample);
        }

        public InputSelection nextSelection() {
            if (miniBatch.size() < miniBatchSize) {
                return InputSelection.SECOND;
            } else {
                return InputSelection.FIRST;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : miniBatchSize) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

Also similar to the bounded case, for the asynchronous training the Parameters vertex would not wait for received updates from all the Train tasks. Instead, it would directly response to the task sending update:

public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        ArrayUtils.addWith(parameters, update);
                
        if (update.f0 < 0) {
            // Received the initialized parameter values, broadcast to all the downstream tasks
            for (int i = 0; i < 10; ++i) {
                output.collect(new Tuple2<>(i, parameters))        
            }
        } else {
            output.collect(new Tuple2<>(update.f0, parameters))
        }
    }
}

Implementation Plan

Logically all the iteration types would support both BATCH and STREAM execution mode. However, according to the algorithms' requirements, we would implement 

  1. Unbounded iteration + STREAM mode.
  2. Bounded iteration + BATCH mode.

Currently we do not see requirements on Bounded iteration + STREAM mode, if there are additional requirement in the future we would implement this mode, and it could also be supported with the current framework. 

Compatibility, Deprecation, and Migration Plan

The API is added as a library inside flink-ml repository, thus it does not have compatibility problem. However, it has some difference with the existing iteration API and the algorithms would need some re-implementation.

For the long run, the new iteration implementation might provide an alternative for the iteration functionality, and we may consider deprecating and removing the existing API to reduce the complexity of core flink code. 

Rejected Alternatives

Naiad has proposed a unified model for watermark mechanism (namely progress tracking outside of the iteration) and the progress tracking inside the iteration. It extends the event time and watermark to be a vector (long timestamp, int[] rounds) and implements a vectorized alignment algorithm. Although Naiad provides an elegant model, the direct implementation on Flink would requires a large amount of modification to the flink runtime, which would cause a lot of complexity and maintenance overhead.  Thus we would choose to implement a simplified version on top of FLINK, as a part of the flink-ml library.

For the iteration DAG build graph, it would be more simpler if we could directly refer to the data stream variables outside of the closure of iteration body. However, since we need to make the iteration DAG creation first happen in the mock execution environment, we could not use these variables directly, otherwise we would directly modify the real environment and won't have chance to add wrappers to the operators. 

  • No labels