Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

PerRoundSubgraphBuilder allows the users to specify a sub-graph that executes as per-round mode, namely all the operators would be re-created for each round. 

Code Block
languagejava
package org.apache.flink.iteration;

import org.apache.flink.annotation.Experimental;

/*public interface IterationBody {
    
    ....

    /**
     * Constructs a subgraph inside the iteration body tothat executeall asthe per-round */
@Experimental
public class PerRoundSubGraphBuilder {

operators would have a lifecycle
     /** The sub-graph inside the iteration body that should be executed as per-round.of {@link org.apache.flink.iteration.IterationConfig.OperatorLifeCycle#PER_ROUND}.
     */
    publicclass interfacePerRound PerRoundSubGraph {

        /**
 DataStreamList process(DataStreamList input);
    }

  * @param inputs The inputs of the subgraph.
         * @param perRoundSubBody The computational logic that want to be executed as per-round.
         * @return The output of the subgraph.
         */
        public static DataStreamList forEachRound(
                DataStreamList inputs, PerRoundSubGraphPerRoundSubBody subGraphperRoundSubBody) {
            return null;
        }
    }

    /** The sub-graph inside the iteration body that should be executed as per-round. */
    interface PerRoundSubBody {

        DataStreamList process(DataStreamList input);
    }

}


6) Add the DataStreamList and ReplayableDataStreamList class.

...

Code Block
languagejava
linenumberstrue
package org.apache.flink.iteration;

public class DataStreamList {
	public static DataStreamList of(DataStream<?>... streams);

    // Returns the number of data streams in this list.
    public int size() {...}

    // Returns the data stream at the given index in this list.
    public <T> DataStream<T> get(int index) {...}
}

public class ReplayableDataStreamList {

    public static ReplayableDataStreamList of(

    // Returns the number of data streams in this list.
    public int size() {...}

    // Returns the data stream at the given index in this list.
    public <T> DataStream<T> get(int index) {...}
}  

public class ReplayableDataStreamList {

    private final List<DataStream<?>> replayedDataStreams;

    private final List<DataStream<?>> nonReplayedStreams;

    private ReplayableDataStreamList(
            List<DataStream<?>> replayedDataStreams, List<DataStream<?>> nonReplayedStreams) {
        this.replayedDataStreams = replayedDataStreams;
        this.nonReplayedStreams = nonReplayedStreams;
    }

    public static ReplayedDataStreamList replay(DataStream<?>... dataStreams) {
        return new ReplayedDataStreamList(Arrays.asList(dataStreams));
    }

    public static NonReplayedDataStreamList notReplay(DataStream<?>... dataStreams) {
        return new NonReplayedDataStreamList(Arrays.asList(dataStreams));
    }

    List<DataStream<?>> getReplayedDataStreams() {
        return Collections.unmodifiableList(replayedDataStreams);
    }

    List<DataStream<?>> getNonReplayedStreams() {
        return Collections.unmodifiableList(nonReplayedStreams);
    }

    private static class ReplayedDataStreamList extends ReplayableDataStreamList {

        public ReplayedDataStreamList(List<DataStream<?>> replayedDataStreams) {
            super(replayedDataStreams, Collections.emptyList());
        }

        public ReplayableDataStreamList andNotReplay(DataStream<?>... nonReplayedStreams) {
            return new ReplayableDataStreamList(
                    getReplayedDataStreams(), Arrays.asList(nonReplayedStreams));
        }
    }

    private static class NonReplayedDataStreamList extends ReplayableDataStreamList {

        public NonReplayedDataStreamList(List<DataStream<?>> nonReplayedDataStreams) {
            Tuple2<DataStream<?>, Boolean>... dataStreamAndIfNeedReplays);
super(Collections.emptyList(), nonReplayedDataStreams);
    public static Tuple2<DataStream<?>, Boolean> replay(DataStream<?> dataStream);

   }

      public static Tuple2<DataStream<?>, Boolean> noReplaypublic ReplayableDataStreamList andReplay(DataStream<?>... dataStreamreplayedStreams);
 {
    /** Returns the number of data streams in thisreturn list. */new ReplayableDataStreamList(
    public int size();

    /** Returns the data stream at the given index in this listArrays. */
    @SuppressWarnings("unchecked")asList(replayedStreams), getNonReplayedStreams());
    public <T> DataStream<T> get(int index);
 }
    public boolean shouldReplay(int index);}
}

7) Introduce the IterationConfig

...

Code Block
languagejava
package org.apache.flink.iteration;

/** The config for an iteration. */
public class IterationConfig {

    private final OperatorLifeCycle operatorRoundMode;

    publicprivate IterationConfig(OperatorLifeCycle operatorRoundMode) {
        this.operatorRoundMode = operatorRoundMode;
    }

    public static IterationConfigBuilder newBuilder() {

        return new IterationConfigBuilder();
    }

    public static class IterationConfigBuilder {

        private OperatorLifeCycle operatorRoundMode = OperatorLifeCycle.ALL_ROUND;

        private IterationConfigBuilder() {}

        public IterationConfigBuilder setOperatorRoundMode(OperatorLifeCycle operatorRoundMode) {
            this.operatorRoundMode = operatorRoundMode;
            return this;
        }

        public IterationConfig build() {
            return new IterationConfig(operatorRoundMode);
        }
    }

    public enum OperatorLifeCycle {
        ALL_ROUND,

        PER_ROUND
    }
}

...

Code Block
languagejava
linenumberstrue
public class SynchronousBoundedLinearRegression {
    private static final int N_DIM = 50;
    private static final int N_EPOCH = 5;
    private static final int N_BATCH_PER_EPOCH = 10;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = 
            IterationUtils.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), 
				ReplayableDataStreamList.ofnotReplay(noReplay(dataset)), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {
                DataStream<double[]> parameterUpdates = variableStreams.get(0);
                	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
                	DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                   		.broadcast()
	                    .connect(dataset)
    	                .coProcess(new TrainFunction())
        	            .setParallelism(10)
						.process(new ReduceFunction())
						.setParallelism(1)
	
                	return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
            	});
        
        DataStream<double[]> finalModel = resultStreams.get("final_model");
        finalModel.print();
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
        implements IterationListener<double[]> {  
        
        private final double[] parameters = new double[N_DIM];

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
        }

        void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH) {
                collector.collect(parameters);
            }
        }

        public void onIterationEnd(int[] round, Context context) {
            context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
        }
    }

    public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> {

        private final List<Tuple2<double[], Double>> dataset = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        private Supplier<int[]> recordRoundQuerier;

        public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
            this.recordRoundQuerier = querier;
        } 

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            int[] round = recordRoundQuerier.get();
            if (round[0] == 0) {
                firstRoundCachedParameter = parameter;
                return;
            }

            calculateModelUpdate(parameter, output);
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample)
        }

        void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark == 0) {
                calculateModelUpdate(firstRoundCachedParameter, output);
                firstRoundCachedParameter = null;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            List<Tuple2<double[], Double>> samples = sample(dataset);

            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : samples) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
	
	public static class ReduceFunction {
		private double[] mergedValue = ArrayUtils.newArray(N_DIM);

	 	public void processElement(double[] parameter, Context context, Collector<O> output) {
            mergedValue = ArrayUtils.add(mergedValue, parameter);
        }

	 	void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            collector.collect(mergedValue);
			mergedValue = ArrayUtils.newArray(N_DIM);
        }
	}
}

...

Code Block
languagejava
 			DataStreamList resultStreams = 
            IterationUtils.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), 
				ReplayableDataStreamList.ofnotReplay(noReplay(dataset)), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {
                DataStream<double[]> parameterUpdates = variableStreams.get(0);
                	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
                	DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                   		.broadcast()
	                    .connect(dataset)
    	                .coProcess(new TrainFunction())
        	            .setParallelism(10)
					
			     	DataStream<double[]> reduced = PerRoundGraphBuilder.forEachRound(DataStreamList.of(modelUpdate), streams -> {
						return streams.<double[]>get(0).windowAll().reduce((x, y) -> ArrayUtils.add(x, y));
					}).<double[]>get(0);
	
                	return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
            	});

...