THIS IS A TEST INSTANCE. ALL YOUR CHANGES WILL BE LOST!!!!
...
Code Block | ||||
---|---|---|---|---|
| ||||
public class SynchronousBoundedLinearRegression { private static final N_DIM = 50; private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{}; public static void main(String[] args) { DataStream<double[]> initParameters = loadParameters().setParallelism(1); DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1); int batch = 5; int epochEachBatch = 10; ResultStreams resultStreams = new BoundedIteration() .withBody(new IterationBody( @IterationInput("model") DataStream<double[]> model, @IterationInput("dataset") DataStream<Tuple2<double[], Double>> dataset ) { SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction()); DataStream<double[]> modelUpdate = parameters.setParallelism(1) .broadcast() .connect(dataset) .coProcess(new TrainFunction()) .setParallelism(10) return new BoundedIterationDeclarationBuilder() .withFeedback("model", modelUpdate) .withOutput("final_model", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)) .until(new TerminationCondition(null, context -> context.getRound() >= batch * epochEachBatch)) .build(); }) .build(); DataStream<double[]> finalModel = resultStreams.get("final_model"); finalModel.print(); } public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> implements BoundedIterationProgressListener<double[]> { private final double[] parameters = new double[N_DIM]; public void processElement(double[] update, Context ctx, Collector<O> output) { // Suppose we have a util to add the second array to the first. ArrayUtils.addWith(parameters, update); } public void onRoundEnd(int[] round, Context context, Collector<T> collector) { collector.collect(parameters); } public void onIterationEnd(int[] round, Context context) { context.output(FINAL_MODEL_OUTPUT_TAG, parameters); } } public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements BoundedIterationProgressListener<double[]> { private final List<Tuple2<double[], Double>> dataset = new ArrayList<>(); private double[] firstRoundCachedParameter; private Supplier<int[]> recordRoundQuerier; public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) { this.recordRoundQuerier = querier; } public void processElement1(double[] parameter, Context context, Collector<O> output) { int[] round = recordRoundQuerier.get(); if (round[0] == 0) { firstRoundCachedParameter = parameter; return; } calculateModelUpdate(parameter, output); } public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) { dataset.add(trainSample) } public void onRoundEnd(int[] round, Context context, Collector<T> output) { if (round[0] == 0) { calculateModelUpdate(firstRoundCachedParameter, output); firstRoundCachedParameter = null; } } private void calculateModelUpdate(double[] parameters, Collector<O> output) { List<Tuple2<double[], Double>> samples = sample(dataset); double[] modelUpdate = new double[N_DIM]; for (Tuple2<double[], Double> record : samples) { double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1); ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff)); } output.collect(modelUpdate); } } } |
For the Parameters vertex, the
Implementation Plan
Logically all the iteration types would support both BATCH and STREAM execution mode. However, according to the algorithms' requirements, we would implement
...