THIS IS A TEST INSTANCE. ALL YOUR CHANGES WILL BE LOST!!!!
...
Code Block | ||||
---|---|---|---|---|
| ||||
public class SynchronousBoundedLinearRegression {
private static final N_DIM = 50;
private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};
public static void main(String[] args) {
DataStream<double[]> initParameters = loadParameters().setParallelism(1);
DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);
int batch = 5;
int epochEachBatch = 10;
ResultStreams resultStreams = new BoundedIteration()
.withBody(new IterationBody(
@IterationInput("model") DataStream<double[]> model,
@IterationInput("dataset") DataStream<Tuple2<double[], Double>> dataset
) {
SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction());
DataStream<double[]> modelUpdate = parameters.setParallelism(1)
.broadcast()
.connect(dataset)
.coProcess(new TrainFunction())
.setParallelism(10)
return new BoundedIterationDeclarationBuilder()
.withFeedback("model", modelUpdate)
.withOutput("final_model", parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG))
.until(new TerminationCondition(null, context -> context.getRound() >= batch * epochEachBatch))
.build();
})
.build();
DataStream<double[]> finalModel = resultStreams.get("final_model");
finalModel.print();
}
public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
implements BoundedIterationProgressListener<double[]> {
private final double[] parameters = new double[N_DIM];
public void processElement(double[] update, Context ctx, Collector<O> output) {
// Suppose we have a util to add the second array to the first.
ArrayUtils.addWith(parameters, update);
}
public void onRoundEnd(int[] round, Context context, Collector<T> collector) {
collector.collect(parameters);
}
public void onIterationEnd(int[] round, Context context) {
context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
}
}
public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements BoundedIterationProgressListener<double[]> {
private final List<Tuple2<double[], Double>> dataset = new ArrayList<>();
private double[] firstRoundCachedParameter;
private Supplier<int[]> recordRoundQuerier;
public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
this.recordRoundQuerier = querier;
}
public void processElement1(double[] parameter, Context context, Collector<O> output) {
int[] round = recordRoundQuerier.get();
if (round[0] == 0) {
firstRoundCachedParameter = parameter;
return;
}
calculateModelUpdate(parameter, output);
}
public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
dataset.add(trainSample)
}
public void onRoundEnd(int[] round, Context context, Collector<T> output) {
if (round[0] == 0) {
calculateModelUpdate(firstRoundCachedParameter, output);
firstRoundCachedParameter = null;
}
}
private void calculateModelUpdate(double[] parameters, Collector<O> output) {
List<Tuple2<double[], Double>> samples = sample(dataset);
double[] modelUpdate = new double[N_DIM];
for (Tuple2<double[], Double> record : samples) {
double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
}
output.collect(modelUpdate);
}
}
} |
Implementation Plan
Logically all the iteration types would support both BATCH and STREAM execution mode. However, according to the algorithms' requirements, we would implement
...