Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

Status

Current state"Under Discussion"

...

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).


Table of Contents

Motivation and Use-cases

The existing Flink ML library allows users to compose an Estimator/Transformer from a pipeline (i.e. linear sequence) of Estimator/Transformer, and each Estimator/Transformer has one input and one output.

...

9) There is no API provided by the Estimator/Transformer interface to validate the schema consistency of a Pipeline. Users would have to instantiate Tables (with I/O logics) and run fit/transform to know whether the stages in the Pipeline are compatible with each other.

Background

Note: Readers who are familiar with the existing Estimator/Transformer/Pipeline APIs can skip this section.

...

It is important to make the following observation: if we don't provide the Pipeline class, users can still accomplish the same use-cases targeted by Pipeline by explicitly writing the training logic and inference logic separately using Estimator/Transformer APIs. But users would have to construct this chain of Estimator/Transformer twice (for training and inference respectively).

Design Principles

Multiple choices exist to address the use-cases targeted by this design doc. In the following, we explain the design principles followed by the proposed design, to hopefully make the understanding of the design choices more intuitive.

...

  • Allow users to compose an Estimator from a DAG of Estimator/Transformer, without requiring users to specify this DAG twice

Public Interfaces

This FLIP proposes quite a few changes and additions to the existing Flink ML APIs. We first describe the proposed API additions and changes, followed by the API code of interfaces and classes after making the proposed changes.

API additions and changes

Here we list the additions and the changes to the Flink ML API.

...

This change is reasonable because we will now compose Graph (not just Pipeline) using this class.

Interfaces and classes after the proposed API changes

The following code block shows the interface of Stage, Transformer, Estimator, Pipeline and PipelineModel after making the changes listed above.

...

Code Block
languagejava
/**
 * A Graph, which acts as an estimator. A Graph consists of a DAG of stages, each of which is either
 * an Estimator or a Transformer. When `Graph::fit` is called, the stages are executed in a
 * topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method will be
 * called on the input tables (from the input edges) to fit a model. Then the model, which is a
 * transformer, will be used to transform the input tables to produce output tables to the output
 * edges. If a stage is a transformer, its `Transformer::transform` method will be called on the
 * input tables to produce output tables to the output edges. The fitted model from a Graph is a
 * GraphModel consisting of fitted models and transformers, with those stages having 1-1 mapping to the 
 * stages in the original Graph.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, GraphModel> {
    public Graph(...) {...}

    @Override
    public GraphModel fit(Table... inputs) {...}

    @Override
    public TableSchema[] transformSchemas(TableSchema... schemas) {
        return schemas;
    }

    /** Skipped a few methods, including the implementations of some Estimator APIs. */
}

/**
 * A GraphModel, which acts as a Transformer. A GraphModel consists of a DAG of transformers. When
 * `GraphModel::transform` is called, the stages are executed in a topologically-sorted order. When
 * a stage is executed, its `Transformer::transform` method will be called on the input tables (from
 * the input edges) to produce output tables to the output edges.
 */
public final class GraphModel implements Transformer<GraphModel> {
    /** Skipped a few methods, including the implementations of the Transformer APIs. */
}

/** A GraphBuilder provides APIs to build Graph and GraphModel from a DAG of Estimator and Transformer instances. */
@PublicEvolving
public final class GraphBuilder {
    /**
     * Specifies the upper bound (could be loose) of the number of output tables that can be
     * returned by the Transformer::getStateStreams and Transformer::transform methods, for any
     * stage involved in this Graph.
     *
     * <p>The default upper bound is 20.
     */
    public GraphBuilder setMaxOutputLength(int maxOutputLength) {...}

    /**
     * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of
     * tables between stages, as well as the input/output tables of the Graph/GraphModel generated
     * by this builder.
     */
    public TableId createTableId() {...}

    /**
     * If the stage is an Estimator, both its fit method and the transform method of its fitted Transformer would be
     * invoked with the given inputs when the graph runs.
     *
     * If this stage is a Transformer, its transform method would be invoked with the given inputs when the graph runs.
     *
     * Returns a list of TableIds, which represents outputs of the Transformer::transform invocation.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId... inputs) {...}

    /**
     * If this stage is an Estimator, its fit method would be invoked with estimatorInputs, and the transform method
     * of its fitted Transformer would be invoked with transformerInputs, when the graph runs.
     *
     * This method throws Exception if the stage is a Transformer.
     *
     * This method is useful when the state is an Estimator AND the Estimator::fit needs to take a different list of
     * Tables from the Transformer::transform of the fitted Transformer.
     *
     * Returns a list of TableIds, which represents outputs of the Transformer::transform invocation.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId[] estimatorInputs, TableId[] transformerInputs) {...}

    /**
     * The GraphModel::setStateStreams should invoke the setStateStreams of the corresponding stage
     * with the corresponding inputs.
     */
    void setStateStreams(Stage<?> stage, TableId... inputs) {...}

    /**
     * The GraphModel::getStateStreams should invoke the getStateStreams of the corresponding stage.
     *
     * <p>Returns a list of TableIds, which represents outputs of the getStateStreams invocation.
     */
    TableId[] getStateStreams(Stage<?> stage) {...}

    /**
     * Returns a Graph instance which the following API specification:
     *
     * 1) Graph::fit should take inputs and returns a GraphModel with the following specification.
     *
     * 2) GraphModel::transform should take inputs and return outputs.
     *
     * 3) GraphModel::setStateStreams should take inputStates.
     *
     * 4) GraphModel::getStateStreams should return outputStates.
     *
     * The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal stages in the order specified by the DAG of stages.
     */
    Graph build(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /**
     * Returns a Graph instance which the following API specification:
     *
     * 1) Graph::fit should take estimatorInputs and returns a GraphModel with the following specification.
     *
     * 2) GraphModel::transform should take transformerInputs and return outputs.
     *
     * 3) GraphModel::setStateStreams should take inputStates.
     *
     * 4) GraphModel::getStateStreams should return outputStates.
     *
     * The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal stages in the order specified by the DAG of stages.
     *
     * This method is useful when the Graph::fit needs to take a different list of Tables from the GraphModel::transform of the fitted GraphModel.
     */
    Graph build(TableId[] estimatorInputs, TableId[] transformerInputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /**
     * Returns a GraphModel instance which the following API specification:
     *
     * 1) GraphModel::transform should take inputs and returns outputs.
     *
     * 2) GraphModel::setStateStreams should take inputStates.
     *
     * 3) GraphModel::getStateStreams should return outputStates.
     *
     * The transform/setStateStreams/getStateStreams should invoke the APIs of the internal  stages in the order specified by the DAG of stages.
     *
     * This method throws Exception if any stage of this graph is an Estimator.
     */
    GraphModel buildModel(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    // The TableId is necessary to pass the inputs/outputs of various API calls across the
    // Graph/GraphModel stags.
    static class TableId {}

}


Example Usage

In this section we provide examples code snippets to demonstrate how we can use the APIs proposed in this FLIP to address the use-cases in the motivation section.

Composing an Estimator from a DAG of Estimator/Transformer

Suppose we have the following Transformer and Estimator classes:

...

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new TransformerA();
Stage<?> stage2 = new TransformerA();
Stage<?> stage3 = new EstimatorB();
// Creates inputs and inputStates
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, input1)[0];
TableId output2 = builder.getOutputs(stage2, input2)[0];
TableId output3 = builder.getOutputs(stage3, output1, output2)[0];

// Specifies the ordered lists of inputs, outputs, input states and output states that will
// be used as the inputs/outputs of the corresponding Graph and GraphModel APIs.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};

// Generates the Graph instance.
Graph graph = builder.build(inputs, outputs, new TableId[]{}, new TableId[]{});
// The fit method takes 2 tables which are mapped to input1 and input2.
GraphModel model = graph.fit(...);
// The transform method takes 2 tables which are mapped to input1 and input2.
Table[] results = model.transform(...);

Online learning by running Transformer and Estimator concurrently on different machines

Here is an online learning scenario:

...

Code Block
languagejava
void runInferenceOnWebServer(...) {
  // Creates the state stream from Kafka topicA which is written by the above code snippet. 
  Table state_stream = ...;
  // Creates the input stream that needs inference.
  Table input_stream = ...;

  Transformer transformer = new Transformer(...);
  transformer.load(remote_path);
  transformer.setStateStreams(new Table[]{state_stream});
  Table output_stream = transformer.transform(input_stream);

  // Do something with the output_stream.

  // Executes the operators generated by the Transformer::transform(...), which reads from state_stream to update its parameters. 
  // It also does inference on input_stream and produces results to the output_stream.
  env.execute()
}

Compose an Estimator from a chain of Estimator/Transformer whose input schemas are different from its fitted Transformer 

Suppose we have the following Estimator and Transformer classes where an Estimator's input schemas could be different from the input schema of its fitted Transformer:

...

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new EstimatorA();
Stage<?> stage2 = new TransformerB();
// Creates inputs
TableId estimatorInput1 = builder.createTableId();
TableId estimatorInput2 = builder.createTableId();
TableId transformerInput1 = builder.createTableId();

// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, new TableId[] {estimatorInput1, estimatorInput2}, new TableId[] {transformerInput1})[0];
TableId output2 = builder.getOutputs(stage2, output1)[0];

// Specifies the ordered lists of estimator inputs, transformer inputs, outputs, input states and output states
// that will be used as the inputs/outputs of the corresponding Graph and GraphModel APIs.
TableId[] estimatorInputs = new TableId[] {estimatorInput1, estimatorInput2};
TableId[] transformerInputs = new TableId[] {transformerInput1};
TableId[] outputs = new TableId[] {output2};
TableId[] inputStates = new TableId[] {};
TableId[] outputStates = new TableId[] {};

// Generates the Graph instance.
Graph graph = builder.build(estimatorInputs, transformerInputs, outputs, inputStates, outputStates);
// The fit method takes 2 tables which are mapped to estimatorInput1 and estimatorInput2.
GraphModel model = graph.fit(...);
// The transform method takes 1 table which is mapped to transformerInput1.
Table[] results = model.transform(...);

Compatibility, Deprecation, and Migration Plan

The changes proposed in this FLIP is backward incompatible with the existing APIs. We propose to change the APIs directly without deprecation period. And we will manually migrate the existing open source projects which use the existing Flink ML API to use the proposed APIs.

...

To our best knowledge, the only open source project that uses the Flink ML API is https://github.com/alibaba/Alink. We will work together with Alink developers to migrate the existing code to use the proposed API. Furthermore, we will migrate Alink's Estimator/Transformer implementation to the Flink ML library codebase as much as possible.

Test Plan

We will provide unit tests to validate the proposed changes.

Rejected Alternatives

There is no rejected alternatives to be listed here yet.

...