Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Iteration is a basic building block for a ML library. It is required for training ML models for both offline and online casesTraining. However, different cases might require different types of iteration:

  1. Offline Training: The training dataset is bounded, the algorithm usually iterates the dataset multiple rounds to optimize the model until convergence.
  2. Online Training: The training dataset is unbounded, the algorithm usually calculates the update to the model according to the current model and the input records, then sends the update back to the operator that holds the model.

In general, two types of iterations is required:

  1. Bounded Iteration: Used in the offline case. In this case the algorithm usually train on a bounded dataset, it updates the parameters for multiple rounds until convergence.
  2. Unbounded Iteration: Used in the online case, in this case the algorithm usually train on an unbounded dataset. It accumulates a mini-batch of data and then do one update to the parameters. 

Previously Previously Flink supported bounded iteration with DataSet API and supported the unbounded iteration with DataStream API. However, since Flink aims to deprecate the DataSet API and the iteration in the DataStream API is rather incomplete, thus we would require to re-implement a new iteration library in the Flink-ml repository to support the algorithms. 

...

  1. Lack of the support for multiple inputs, arbitrary outputs and nested iteration for both iteration APIs, which is required by algorithms like LDA or boostscenarios like Metapath (multiple-inputs), boost algorithms (nested iteration) or when we want to output both loss and model (multiple-outputs). In the new iteration we would support these functionalities.
  2. Lack of asynchronous iteration support for the DataSet iteration, which is required by algorithms like asynchronous linear regression, in the new iterations we would support both synchronous and asynchronous modes for the bounded iteration. 
  3. The current DataSet iteration by default provides a "for each round" semantics, namely users only need to specify the computation logic in each iteration, and the framework would executes the subgraph multiple times until convergence. To cooperate with the semantics, the DataSet iteration framework would merge the initial input and the feedback (bulk style and delta style), and replay the datasets comes from outside of the iteration. This method is easier to use, but it also limit some possible optimizations.

...

  1. The inputs from outside of the iteration. 
  2. An iteration body specify the structure inside the iteration.
    1. The subgraph inside the iteration.
    2. Some input have corresponding feedbacks to update the underlying data stream. The feedbacks are union with the corresponding inputs: the original inputs are emitted into the iteration body for only once, and the feedbacks are also emitted to the same set of operators.
    3. The outputs going out of the iteration. The outputs could be emitted from arbitrary data stream.

draw.io Diagram
bordertrue
diagramNameiteration
simpleViewerfalse
width
linksauto
tbstyletop
diagramDisplayName
lboxtrue
diagramWidth631
revision710

Figure 1. The structure of an iterations. 

...

Code Block
languagejava
titleUnbounded Iteration API
linenumberstrue
/** Builder for the unbounded iteration. */
public class UnboundedIteration {

    UnboundedIteration withBody(IterationBody body) {...}

	UnboundedIteration bindInput(String name, DataStream<?> input) {...}

	ResultStreams apply() {...}
}

/** The expected return type of the iteration function, which specifies the feedbacks and outputs. */
public class UnboundedIterationDeclaration {
	
	public static class Builder {

		public Builder withFeedback(String name, DataStream<?> feedback) {...}

		public Builder withOutput(String name, DataStream<?> output) {...}
	
		UnboundedIterationDeclaration build() {...}
	}
}

/** The map of the output streams of an iteration.  */
public class ResultStreams {

	public <T> DataStream<T> getStream(String name) {...}

}

/** An example unbounded iteration. */
public class UnboundedIterationExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

		DataStream<Integer> source1 = env.fromElements(1, 2, 3, 4, 5);
		DataStream<String> source2 = env.fromElements("2", "3", "4");

		ResultStreams resultStreams = new UnboundedIteration()
			.withBody(new IterationBody() {
				@IterationFunction
				UnboundedIterationDeclarative iterate(
					@IterationInput("first") DataStream<Integer> first
					@IterationInput("second") DataStream<> second,
				) {
					DataStream<Integer> feedBack1 = ...;
					DataStream<String> output1 = ...;
					return new UnboundedIterationDeclarationBuilder()
						.withFeedback("first", feedBack1)
						.withOutput("output1", output1)
						.build();
				}
			})
			.bindInput("first", source1)
			.bindInput("second", source2)
			.apply();

		DataStream<String> output = resultStreams.getStream("output1");
	}
}

To avoid more data is read from the inputs while too much data accumulate inside the iteration, the iteration would first process the feedback data if both side of data is available. 

For termination detection, the iteration would continue until

  1. All the inputs are terminated.
  2. And there is no records inside the iteration subgraph. 

Then the iteration terminates.

Bounded Iteration

As mentioned in the motivation, the existing dataset iteration API uses the "per-round" semantics: it views the iteration as a repeat execution of the same DAG, thus underlying it would automatically merge the inputs and feedbacks and replay the inputs without feedbacks, and the operators inside the iteration live only for one-round. This might cause bad performance for some algorithms who could cache these data in a more efficient way. 

To avoid this issue, similar to the unbounded iteration, by default we use the "per-iteration" semantics: 

  1. Operators inside the iteration would live till the whole iteration is finished.
  2. We do not automatically merge the inputs and feedbacks. Instead, we union the original inputs and the feedbacks so that users could decide how the merge them.
  3. We do not replay the inputs without feedbacks. Users could decide to how to cache them more efficiently. 

Besides, to cooperate with the "per-round" semantics, previously the iteration is by default synchronous: before the current round fully finished, the feedback data is cached and would not be emitted. Thus it could not support some algorithms like asynchronous regression. To cope with this issue, we view synchronous iteration as a special case of asynchronous iteration with additional synchronization. Thus by default the iteration is asynchronous. 

Based on the above assumption, the API to add iteration to a job is nearly the same compared to the unbounded iteration. The only difference is that bounded iteration supports more sophisticated termination conditions: a function is evaluated when each round ends based on the round or the records of a specified data stream. If it returns true, the iteration would deserts all the following feedback records, ends all the ongoing rounds and finish. 

Since now the operators would live across multiple rounds and multiple rounds might be concurrent, the operators inside the iteration needs to know the rounds of the current record and when one round is fully finished, namely the progress tracking. For example, an operator computes the sum of the records in each rounds would like to add the record to the corresponding partial sum, and when one round is finished, it would emit the sum for this round. To support the progress track, UDFs / operators inside the iteration could implementation `BoundedIterationProgressListener` to acquire the additional information about the progress. 

Based on the progress tracking interface, if users want to implement a synchronous method, some operators inside the subgraph needs to be synchronous: they only emits the records in `onRoundEnd`, namely after all the data of the current round is received. If for the subgraph of iteration body, every path from input to the feedbacks has at least such an operator, then the iteration would be synchronous. 

To avoid more data is read from the inputs while too much data accumulate inside the iteration, the iteration would first process the feedback data if both side of data is available. 

For termination detection, the iteration would continue until

  1. All the inputs are terminated.
  2. And there is no records inside the iteration subgraph. 

Then the iteration terminates.

Bounded Iteration

As mentioned in the motivation, the existing dataset iteration API uses the "per-round" semantics: it views the iteration as a repeat execution of the same DAG, thus underlying it would automatically merge the inputs and feedbacks and replay the inputs without feedbacks, and the operators inside the iteration live only for one-round. This might cause bad performance for some algorithms who could cache these data in a more efficient way. 

To avoid this issue, similar to the unbounded iteration, by default we use the "per-iteration" semantics: 

  1. Operators inside the iteration would live till the whole iteration is finished.
  2. We do not automatically merge the inputs and feedbacks. Instead, we union the original inputs and the feedbacks so that users could decide how the merge them.
  3. We do not replay the inputs without feedbacks. Users could decide to how to cache them more efficiently. 

Besides, to cooperate For users still want to use the iteration with the "per-round" semantics, a utility `forEachRound()` is provided. With the utility users could add a subgraph inside the iteration body that

  1. The operators inside the subgraph would live only for one round.
  2. If an input stream without feedback is referenced, the input stream would be replayed for each round.

For input stream with feedbacks, we also provide two utility processFunction that automatically merge the original inputs and feedbacks. Both the existing bulk and delta method is supported. Then users would be able to implement a per-round iteration with input.process(bulkCache()).forEachRound(() → {...}).

The API for the bounded iteration is as follows:

previously the iteration is by default synchronous: before the current round fully finished, the feedback data is cached and would not be emitted. Thus it could not support some algorithms like asynchronous regression. To cope with this issue, we view synchronous iteration as a special case of asynchronous iteration with additional synchronization. Thus by default the iteration is asynchronous. 

Based on the above assumption, the API to add iteration to a job is nearly the same compared to the unbounded iteration. The only difference is that bounded iteration supports more sophisticated termination conditions: a function is evaluated when each round ends based on the round or the records of a specified data stream. If it returns true, the iteration would deserts all the following feedback records, ends all the ongoing rounds and finish. 

Since now the operators would live across multiple rounds and multiple rounds might be concurrent, the operators inside the iteration needs to know the rounds of the current record and when one round is fully finished, namely the progress tracking. For example, an operator computes the sum of the records in each rounds would like to add the record to the corresponding partial sum, and when one round is finished, it would emit the sum for this round. To support the progress track, UDFs / operators inside the iteration could implementation `BoundedIterationProgressListener` to acquire the additional information about the progress. 

Based on the progress tracking interface, if users want to implement a synchronous method, some operators inside the subgraph needs to be synchronous: they only emits the records in `onRoundEnd`, namely after all the data of the current round is received. If for the subgraph of iteration body, every path from input to the feedbacks has at least such an operator, then the iteration would be synchronous. 

For users still want to use the iteration with the "per-round" semantics, a utility `forEachRound()` is provided. With the utility users could add a subgraph inside the iteration body that

  1. The operators inside the subgraph would live only for one round.
  2. If an input stream without feedback is referenced, the input stream would be replayed for each round.

For input stream with feedbacks, we also provide two utility processFunction that automatically merge the original inputs and feedbacks. Both the existing bulk and delta method is supported. Then users would be able to implement a per-round iteration with input.process(bulkCache()).forEachRound(() → {...}).

The API for the bounded iteration is as follows:

Code Block
languagejava
titleBounded Iteration API
linenumberstrue
/** Builder for the bounded iteration. */
public class BoundedIteration {

    BoundedIteration withBody(IterationBody body) {...}

	BoundedIteration bindInput(String name, DataStream<?> input) {...}

	ResultStreams apply() {...
Code Block
languagejava
titleBounded Iteration API
linenumberstrue
/** Builder for the bounded iteration. */
public class BoundedIteration {

    BoundedIteration withBody(IterationBody body) {...}

	BoundedIteration bindInput(String name, DataStream<?> input) {...}

	ResultStreams apply() {...}
}

/** The expected return type of the iteration function, which specifies the feedbacks, outputs and termination conditions. */
public class BoundedIterationDeclaration {
	
	public static class Builder {

		public Builder withFeedback(String name, DataStream<?> feedback) {...}

		public Builder withOutput(String name, DataStream<?> output) {...}

		<U> Builder until(TerminationCondition terminationCondition) { ... }
	
		BoundedIterationDeclaration build() {...}
	}
}

/** The termination condition judges if iteration should stop based on the round or the records of a data stream in one round. */
public class TerminationCondition {
	
	@Nullable DataStream<?> refStream;

	Function<Context, Boolean> isConverged;

	TerminationCondition(DataStream<?> refStream

	interface Context {

     	int[] getRound();

     	<T> List<T> getStreamRecords();
  	}	
}

/** The progress tracking interface for UDF / Operator */
public interface BoundedIterationProgressListener<T> {

	default void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {}

	void onRoundEnd(int[] round, Context context, Collector<T> collector);
	
	default void onIterationEnd(int[] rounds, Context context) {}

    public interface Context {
		
		<X> void output(OutputTag<X> outputTag, X value);
		
		Long timestamp();

		TimerService timerService();
	}
}

/** The expected builderreturn thattype createsof the subgraph that executed with the per-round semantics * iteration function, which specifies the feedbacks, outputs and termination conditions. */
public interfaceclass EachRoundBoundedIterationDeclaration {
	
	Map<String, DataStream<?>> executeInEachRound();

}

/** The utility methods to support per-round semantics. */
public class BoundedIterationPerRoundUtils {

	static ResultStreams forEachRound(EachRound eachRoundBuilder);
	
	static <T> ProcessFunction bulkCache(DataStream<T> inputStream);

	static <K, T> ProcessFunction deltaCache(KeyedStream<K, T> inputStream);
}

// Example
public class BoundedIterationExample {

	public static void MyReducer implements FlatMapFunction<Integer, Integer>, BoundedIterationProgressListener<Integer> {
			
		private int value = 0;

		void flatMap(Integer v1, Collector<Integer> v2) {value += v1}

		void onRoundEnd(int[] round, Context context, Collector<T> collector) {
			collector.collector(value);
			value = 0;
		}
	}

	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

		DataStream<Integer> source1 = env.fromElements(1, 2, 3, 4, 5);
		DataStream<String> source2 = env.fromElements("2", "3", "4");

		ResultStreams resultStreams = new BoundedIteration()
			.withBody(new IterationBody() {

				@IterationFunction
				BoundedIterationDeclaration iterate(
					@IterationInput("first") DataStream<Integer> first,
					@IterationInput("second") DataStream<String> second
				) {
					DataStream<Integer> feedBack1 = ...;

					ResultStreams results = BoundedIterationUtils.forEachRound(() -> {
						return Collections.singletonMap("result", feedback1.map(xx));
					});
					DataStream<String> feedBack2 = results.getStream("result");

					DataStream<String> output1 = feedback1.flatMap(new MyReducer());

					return new BoundedIterationDeclarationBuilder()
						.withFeedback("first", feedBack1)
						.withOutput("output1", output1)
						.until(new TerminationCondition(feedback2, context -> context.getStreamRecords().size() == 0))
						.build();
				}
			})
			.bindInput("first", source1)
			.bindInput("second", source2)
			.apply();

		DataStream<String> output = resultStreams.getStream("output1");
	}public static class Builder {

		public Builder withFeedback(String name, DataStream<?> feedback) {...}

		public Builder withOutput(String name, DataStream<?> output) {...}

		<U> Builder until(TerminationCondition terminationCondition) { ... }
	
		BoundedIterationDeclaration build() {...}
	}
}

/** The termination condition judges if iteration should stop based on the round or the records of a data stream in one round. */
public class TerminationCondition {
	
	@Nullable DataStream<?> refStream;

	Function<Context, Boolean> isConverged;

	TerminationCondition(DataStream<?> refStream

	interface Context {

     	int[] getRound();

     	<T> List<T> getStreamRecords();
  	}	
}

/** The progress tracking interface for UDF / Operator */
public interface BoundedIterationProgressListener<T> {

	default void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {}

	void onRoundEnd(int[] round, Context context, Collector<T> collector);
	
	default void onIterationEnd(int[] rounds, Context context) {}

    public interface Context {
		
		<X> void output(OutputTag<X> outputTag, X value);
		
		Long timestamp();

		TimerService timerService();
	}
}

/** The utility methods to support per-round semantics. */
public class BoundedIterationPerRoundUtils {
	/** The builder that creates the subgraph that executed with the per-round semantics **/	
	public interface EachRound {
	
		Map<String, DataStream<?>> executeInEachRound();

	}
	
	/** Create a subgraph inside the iteration that execute with per-round semantics */
	static ResultStreams forEachRound(EachRound eachRoundBuilder);

	/** A cache to a Datastream. For each round it replaces the Datastream with the records received in this round and outputs all these records. */	
	static <T> ProcessFunction bulkCache(DataStream<T> inputStream);

	/** A cache to a KeyedStream. For each round it updates the KeyedStream with the records received in this round and outputs the updated records. */	
	static <K, T> ProcessFunction deltaCache(KeyedStream<K, T> inputStream);
}

Key Implementation

The runtime physical structure of the iteration is shown in Figure 1, which is similar to the current implementation. The head & tail is added by the framework. They would be colocated so that we could implement the feedback edge with the local queue.  The head could coordinator with an operator coordinator bind to a virtual operator ID for synchronization, including progress tracking and termination condition calculating. 

...