Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

To make the API easy to use, we propose to have dedicated API for different types of iteration, and underlying we will translate them onto the same framework. would implements the basic functionality like iteration StreamGraph building, runtime structure and checkpoint, and it allows to implement different iterations to implement different types of progress tracking support. 

...

Public Interfaces

As shown in Figure 1, an iteration is composed of 

...

Figure 1. The structure of an iterations. 

Iteration API

Unbounded Iteration

Similar to FLIP-15, we would more tend to provide a structural iteration API to make it easier to be understand. With this method, users are required to specify an IterationBody that generates the part of JobGraph inside the iteration. The iteration body should specify the DAG inside the iteration, and also the list of feedback streams and the output streams. The feedback streams would be union with the corresponding inputs and the output streams would be provided to the caller routine. 

...

Code Block
languagejava
titleBounded Iteration API
linenumberstrue
/** Builder for the bounded iteration. */
public class BoundedIteration {

    BoundedIteration withBody(IterationBody body) {...}

	BoundedIteration bindInput(String name, DataStream<?> input) {...}

	ResultStreams apply() {...}
}

/** The expected return type of the iteration function, which specifies the feedbacks, outputs and termination conditions. */
public class BoundedIterationDeclaration {
	
	public static class Builder {

		public Builder withFeedback(String name, DataStream<?> feedback) {...}

		public Builder withOutput(String name, DataStream<?> output) {...}

		<U> Builder until(TerminationCondition terminationCondition) { ... }
	
		BoundedIterationDeclaration build() {...}
	}
}

/** The termination condition judges if iteration should stop based on the round or the records of a data stream in one round. */
public class TerminationCondition {
	
	@Nullable DataStream<?> refStream;

	Function<Context, Boolean> isConverged;

	TerminationCondition(DataStream<?> refStream

	interface Context {

     	int[] getRound();

     	<T> List<T> getStreamRecords();
  	}	
}

/** The progress tracking interface for UDF / Operator */
public interface BoundedIterationProgressListener<T> {

	default void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {}

	void onRoundEnd(int[] round, Context context, Collector<T> collector);
	
	default void onIterationEnd(int[] rounds, Context context) {}

    public interface Context {
		
		<X> void output(OutputTag<X> outputTag, X value);
		
		Long timestamp();

		TimerService timerService();
	}
}

/** The builder that creates the subgraph that executed with the per-round semantics **/
public interface EachRound {
	
	Map<String, DataStream<?>> executeInEachRound();

}

/** The utility methods to support per-round semantics. */
public class BoundedIterationPerRoundUtils {

	static ResultStreams forEachRound(EachRound eachRoundBuilder);
	
	static <T> ProcessFunction bulkCache(DataStream<T> inputStream);

	static <K, T> ProcessFunction deltaCache(KeyedStream<K, T> inputStream);
}

// Example
public class BoundedIterationExample {

	public static void MyReducer implements FlatMapFunction<Integer, Integer>, BoundedIterationProgressListener<Integer> {
			
		private int value = 0;

		void flatMap(Integer v1, Collector<Integer> v2) {value += v1}

		void onRoundEnd(int[] round, Context context, Collector<T> collector) {
			collector.collector(value);
			value = 0;
		}
	}

	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

		DataStream<Integer> source1 = env.fromElements(1, 2, 3, 4, 5);
		DataStream<String> source2 = env.fromElements("2", "3", "4");

		ResultStreams resultStreams = new BoundedIteration()
			.withBody(new IterationBody() {

				@IterationFunction
				BoundedIterationDeclaration iterate(
					@IterationInput("first") DataStream<Integer> first,
					@IterationInput("second") DataStream<String> second
				) {
					DataStream<Integer> feedBack1 = ...;

					ResultStreams results = BoundedIterationUtils.forEachRound(() -> {
						return Collections.singletonMap("result", feedback1.map(xx));
					});
					DataStream<String> feedBack2 = results.getStream("result");

					DataStream<String> output1 = feedback1.flatMap(new MyReducer());

					return new BoundedIterationDeclarationBuilder()
						.withFeedback("first", feedBack1)
						.withOutput("output1", output1)
						.until(new TerminationCondition(feedback2, context -> context.getStreamRecords().size() == 0))
						.build();
				}
			})
			.bindInput("first", source1)
			.bindInput("second", source2)
			.apply();

		DataStream<String> output = resultStreams.getStream("output1");
	}
}

Key Implementation

The runtime physical structure of the iteration is shown in Figure 1, which is similar to the current implementation. The head & tail is added by the framework. They would be colocated so that we could implement the feedback edge with the local queue.  The head could coordinator with an operator coordinator bind to a virtual operator ID for synchronization, including progress tracking and termination condition calculating. 

...

The operator wrapper needs to simulates the context that an operator executes. Specially, for operators with single-round lifecycle in bounded iteration, we would need to isolate the states used for each round and cleanup the corresponding state after the round end.

Public Interfaces

Examples

This sections shows how general used ML algorithms could be implemented with the iteration API. 

...

  1. The Parameters vertex would not wait till round end to ensure received all the updates from the iteration. Instead, it would immediately output the current parameters values once it received the model update from one train subtask.
  2. To label the source of the update, we would like to change the input type to be Tuple2<Integer, double[]>. The Parameters vertex would output the only output the new parameters values to the Train task that send the update.

We omit the change to the graph building code since the change is trivial (change the output type and the partitioner to be customized one). The change to the graph Parameters vertex is the follows:

Code Block
languagejava
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>>
    implements BoundedIterationProgressListener<double[]> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        // Suppose we have a util to add the second array to the first.
        ArrayUtils.addWith(parameters, update);
        output.collect(new Tuple2<>(update.f0, parameters))
    }

    public void onIterationEnd(int[] round, Context context) {
        context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
    }
}

...

  1. The Parameters broadcast the initialized values on received the input values.
  2. All the Train task read the next mini-batch of records, Calculating an update and emit to the Parameters vertex. Then it would wait till received update parameters from the Parameters Vertex before it head to process the next mini-batch.
  3. The Parameter vertex would wait received the updates from all the Train tasks before it broadcast the updated parameters. 

Since the in the unbounded case there is not the concept of round, and we do update per-mini-batch, thus we could instead use the InputSelectable functionality to implement the algorithm:

Code Block
languagejava
public class SynchronousUnboundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        ResultStreams resultStreams = new UnboundedIteration()
            .withBody(new IterationBody(
                @IterationInput("model") DataStream<double[]> model,
                @IterationInput("dataset") DataStream<Tuple2<double[], Double>> dataset
            ) {
                SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10));
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .transform(
                                "operator",
                                BasicTypeInfo.INT_TYPE_INFO,
                                new TrainOperators(50));
                    .setParallelism(10);

                return new UnBoundedIterationDeclarationBuilder()
                    .withFeedback("model", modelUpdate)
                    .withOutput("model_update", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))
                    .build();
            })
            .build();
        
        DataStream<double[]> finalModel = resultStreams.get("model_update");
        finalModel.addSink(...)
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> {  
        
        private final int numOfTrainTasks;

        private final int numOfUpdatesReceived = 0;
        private final double[] parameters = new double[N_DIM];

        public ParametersCacheFunction(int numOfTrainTasks) {
            this.numOfTrainTasks = numOfTrainTasks;
        }

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
            numOfUpdatesReceived++;

            if (numOfUpdatesReceived == numOfTrainTasks) {
                output.collect(parameters);
                numOfUpdatesReceived = 0;
            }
        }
    }

    public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable {

        private final int miniBatchSize;

        private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        public TrainOperators(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
        }

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            calculateModelUpdate(parameter, output);
 			miniBatchSize.clear();
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample);
        }

        public InputSelection nextSelection() {
            if (miniBatch.size() < miniBatchSize) {
                return InputSelection.SECOND;
            } else {
                return InputSelection.FIRST;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            List<Tuple2<double[], Double>> samples = sample(dataset);

            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : samplesminiBatchSize) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

Also similar to the bounded case, for the asynchronous training the Parameters vertex would not wait for received updates from all the Train tasks. Instead, it would directly response to the task sending update:

Code Block
languagejava
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        ArrayUtils.addWith(parameters, update);
                
        if (update.f0 < 0) {
            // Received the initialized parameter values, broadcast to all the downstream tasks
            for (int i = 0; i < 10; ++i) {
                output.collect(new Tuple2<>(i, parameters))        
            }
        } else {
            output.collect(new Tuple2<>(update.f0, parameters))
        }
    }
}

Implementation Plan

Logically all the iteration types would support both BATCH and STREAM execution mode. However, according to the algorithms' requirements, we would implement 

...