...
It is important to note that the users typically should not invoke the IterationBody::process directly because the model-variables expected by the iteration body is not the same as the initial-model-variables provided by the user. Instead, model-variables are computed as the union of the feedback-model-variables (emitted by the iteration body) and the initial-model-variables (provided by the caller of the iteration body). To relieve user from creating this union operator, we have added utility class (see IterationUtilsIterations) to run an iteration-body with the user-provided inputs.
...
Code Block | ||||
---|---|---|---|---|
| ||||
package org.apache.flink.iteration /** * The callbacks defined below will be invoked only if the operator instance which implements this interface is used * within an iteration body. */ @PublicEvolving public interface IterationListener<T> { /** * This callback is invoked every time the epoch watermark of this operator increments. The initial epoch watermark * of an operator is 0. * * The epochWatermark is the maximum integer that meets this requirement: every record that arrives at the operator * going forward should have an epoch larger than the epochWatermark. See Java docs in IterationUtilsIterations for how epoch * is determined for records ingested into the iteration body and for records emitted by operators within the * iteration body. * * If all inputs are bounded, the maximum epoch of all records ingested into this operator is used as the * epochWatermark parameter for the last invocation of this callback. * * @param epochWatermark The incremented epoch watermark. * @param context A context that allows emitting side output. The context is only valid during the invocation of * this method. * @param collector The collector for returning result values. */ void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector); /** * This callback is invoked after the execution of the iteration body has terminated. * * See Java doc of methods in IterationUtilsIterations for the termination conditions. * * @param context A context that allows emitting side output. The context is only valid during the invocation of * this method. * @param collector The collector for returning result values. */ void onIterationTermination(Context context, Collector<T> collector); /** * Information available in an invocation of the callbacks defined in the IterationProgressListener. */ interface Context { /** * Emits a record to the side output identified by the {@link OutputTag}. * * @param outputTag the {@code OutputTag} that identifies the side output to emit to. * @param value The record to emit. */ <X> void output(OutputTag<X> outputTag, X value); } } |
...
Code Block | ||||
---|---|---|---|---|
| ||||
public class SynchronousBoundedLinearRegression { private static final int N_DIM = 50; private static final int N_EPOCH = 5; private static final int N_BATCH_PER_EPOCH = 10; private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{}; public static void main(String[] args) { DataStream<double[]> initParameters = loadParameters().setParallelism(1); DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1); DataStreamList resultStreams = IterationUtilsIterations.iterateBoundedStreamsUntilTermination( DataStreamList.of(initParameters), ReplayableDataStreamList.notReplay(dataset), IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build(); (variableStreams, dataStreams) -> { DataStream<double[]> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction()); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .broadcast() .connect(dataset) .coProcess(new TrainFunction()) .setParallelism(10) .process(new ReduceFunction()) .setParallelism(1) return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }); DataStream<double[]> finalModel = resultStreams.get("final_model"); finalModel.print(); } public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> implements IterationListener<double[]> { private final double[] parameters = new double[N_DIM]; public void processElement(double[] update, Context ctx, Collector<O> output) { // Suppose we have a util to add the second array to the first. ArrayUtils.addWith(parameters, update); } void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) { if (epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH) { collector.collect(parameters); } } public void onIterationEnd(int[] round, Context context) { context.output(FINAL_MODEL_OUTPUT_TAG, parameters); } } public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> { private final List<Tuple2<double[], Double>> dataset = new ArrayList<>(); private double[] firstRoundCachedParameter; private Supplier<int[]> recordRoundQuerier; public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) { this.recordRoundQuerier = querier; } public void processElement1(double[] parameter, Context context, Collector<O> output) { int[] round = recordRoundQuerier.get(); if (round[0] == 0) { firstRoundCachedParameter = parameter; return; } calculateModelUpdate(parameter, output); } public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) { dataset.add(trainSample) } void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) { if (epochWatermark == 0) { calculateModelUpdate(firstRoundCachedParameter, output); firstRoundCachedParameter = null; } } private void calculateModelUpdate(double[] parameters, Collector<O> output) { List<Tuple2<double[], Double>> samples = sample(dataset); double[] modelUpdate = new double[N_DIM]; for (Tuple2<double[], Double> record : samples) { double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1); ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff)); } output.collect(modelUpdate); } } public static class ReduceFunction { private double[] mergedValue = ArrayUtils.newArray(N_DIM); public void processElement(double[] parameter, Context context, Collector<O> output) { mergedValue = ArrayUtils.add(mergedValue, parameter); } void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) { collector.collect(mergedValue); mergedValue = ArrayUtils.newArray(N_DIM); } } } |
...
Code Block | ||
---|---|---|
| ||
DataStreamList resultStreams = IterationUtilsIterations.iterateBoundedStreamsUntilTermination( DataStreamList.of(initParameters), ReplayableDataStreamList.notReplay(dataset), IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build(); (variableStreams, dataStreams) -> { DataStream<double[]> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction()); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .broadcast() .connect(dataset) .coProcess(new TrainFunction()) .setParallelism(10) DataStream<double[]> reduced = PerRoundGraphBuilder.forEachRound(DataStreamList.of(modelUpdate), streams -> { return streams.<double[]>get(0).windowAll().reduce((x, y) -> ArrayUtils.add(x, y)); }).<double[]>get(0); return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }); |
...
Code Block | ||
---|---|---|
| ||
public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>> implements BoundedIterationProgressListener<double[]> { private final double[] parameters = new double[N_DIM]; public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) { // Suppose we have a util to add the second array to the first. ArrayUtils.addWith(parameters, update); output.collect(new Tuple2<>(update.f0, parameters)) } public void onIterationEnd(int[] round, Context context) { context.output(FINAL_MODEL_OUTPUT_TAG, parameters); } } public class AsynchronousBoundedLinearRegression { ... DataStreamList resultStreams = IterationUtilsIterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> { DataStream<double> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction()); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .partitionCustom((key, numPartitions) -> key % numPartitions, update -> update.f0) .connect(dataset) .coProcess(new TrainFunction()) .setParallelism(10) return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }); ... } |
...
Code Block | ||
---|---|---|
| ||
public class SynchronousUnboundedLinearRegression { private static final N_DIM = 50; private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{}; public static void main(String[] args) { DataStream<double[]> initParameters = loadParameters().setParallelism(1); DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1); DataStreamList resultStreams = IterationUtilsIterations.iterateUnboundedStreams(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> { DataStream<double> parameterUpdates = variableStreams.get(0); DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0); SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10)); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .broadcast() .connect(dataset) .transform( "operator", BasicTypeInfo.INT_TYPE_INFO, new TrainOperators(50)); .setParallelism(10); return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))); }) DataStream<double[]> finalModel = resultStreams.get("model_update"); finalModel.addSink(...) } public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> { private final int numOfTrainTasks; private final int numOfUpdatesReceived = 0; private final double[] parameters = new double[N_DIM]; public ParametersCacheFunction(int numOfTrainTasks) { this.numOfTrainTasks = numOfTrainTasks; } public void processElement(double[] update, Context ctx, Collector<O> output) { // Suppose we have a util to add the second array to the first. ArrayUtils.addWith(parameters, update); numOfUpdatesReceived++; if (numOfUpdatesReceived == numOfTrainTasks) { output.collect(parameters); numOfUpdatesReceived = 0; } } } public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable { private final int miniBatchSize; private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>(); private double[] firstRoundCachedParameter; public TrainOperators(int miniBatchSize) { this.miniBatchSize = miniBatchSize; } public void processElement1(double[] parameter, Context context, Collector<O> output) { calculateModelUpdate(parameter, output); miniBatchSize.clear(); } public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) { dataset.add(trainSample); } public InputSelection nextSelection() { if (miniBatch.size() < miniBatchSize) { return InputSelection.SECOND; } else { return InputSelection.FIRST; } } private void calculateModelUpdate(double[] parameters, Collector<O> output) { double[] modelUpdate = new double[N_DIM]; for (Tuple2<double[], Double> record : miniBatchSize) { double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1); ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff)); } output.collect(modelUpdate); } } } |
...