You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 12 Next »

Current state: Under discussion.

Discussion thread: To be added

JIRA:  Unable to render Jira issues macro, execution error.

Released: Not released yet.

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

[This FLIP proposal is a joint work between Dong Lin and Zhipeng Zhang]

Motivation and Use-cases

The existing Flink ML library allows users to compose an Estimator/Transformer/AlgoOperator from a pipeline (i.e. linear sequence) of Estimator/Transformer/AlgoOperator. Users only need to construct this Pipeline once and generate the corresponding PipelineModel, without having to explicitly construct the fitted PipelineModel as a linear sequence of stages. However, in order to train a DAG of Estimator/Transformer/AlgoOperator and uses the trained model for inference, users currently need to construct the DAG twice, once for the training logic and once for the inference logic. This experience is inferior to the experience of training and using a chain of Estimator/Transformer/AlgoOperator. In addition to requiring more work from users, this approach is more error prone because the DAG for the training logic may be inconsistent from the DAG for the inference logic.

In order to address the issues described above, we propose to add several helper classes that allow users to compose Estimator/Transformer/AlgoOperator from a DAG of Estimator/Transformer/AlgoOperator.

Public Interfaces

This FLIP proposes to add the Graph, GraphModel, GraphBuilder, GraphNode and TableId classes. The following code block shows the public APIs of these classes.

1) Add the TableId class to represent the input/output of a stage.

This class is necessary in order to construct the DAG before we have the concrete Tables available. And this class overrides the equals/hashCode so that it can be used as the key of a hash map.

public class TableId {
    private final int tableId;

    @Override
    public boolean equals(Object obj) {...}

    @Override
    public int hashCode() {...}
}


2) Add the GraphNode class.

This class contains the stage as well as the input/output of this stage in the form of TableId lists. A DAG can thus be represented as a list of GraphNodes.

public class GraphNode {
    public final Stage<?> stage;
    public final TableId[] estimatorInputs;
    public final TableId[] algoInputs;
    public final TableId[] outputs;
}


3) Add the Graph class to wrap a DAG of Estimator/Model/Transformer/AlgoOperaor into an Estimator.

/**
 * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be an
 * Estimator, Model, Transformer or AlgoOperator. When `Graph::fit` is called, the stages are
 * executed in a topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method
 * will be called on the input tables (from the input edges) to fit a Model. Then the Model will be
 * used to transform the input tables and produce output tables to the output edges. If a stage is
 * an AlgoOperator, its `AlgoOperator::transform` method will be called on the input tables and
 * produce output tables to the output edges. The GraphModel fitted from a Graph consists of the
 * fitted Models and AlgoOperators, corresponding to the Graph's stages.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, GraphModel> {
    public Graph(List<GraphNode> nodes, TableId[] estimatorInputIds, TableId[] algoInputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}

    @Override
    public GraphModel fit(Table... inputs) {...}

    @Override
    public void save(String path) throws IOException {...}

    @Override
    public static Graph load(String path) throws IOException {...}
}


4) Add the GraphModel class to wrap a DAG of Estimator/Model/Transformer/AlgoOperaor into a Model.

/**
 * A GraphModel acts as a Model. A GraphModel consists of a DAG of stages, each of which could be an
 * Estimator, Model, Transformer or AlgoOperators. When `GraphModel::transform` is called, the
 * stages are executed in a topologically-sorted order. When a stage is executed, its
 * `AlgoOperator::transform` method will be called on the input tables (from the input edges) and
 * produce output tables to the output edges.
 */
public final class GraphModel implements Model<GraphModel> {

    public GraphModel(List<GraphNode> nodes, TableId[] inputIds, TableId[] outputIds, TableId[] inputModelData, TableId[] outputModelData) {...}

    @Override
    public Table[] transform(Table... inputTables) {...}

    @Override
    public void setModelData(Table... inputs) {...}

    @Override
    public Table[] getModelData() {...}

    @Override
    public void save(String path) throws IOException {...}

    public static GraphModel load(String path) throws IOException {...}
}


5) Add the GraphBuilder class to build GraphModel or Graph from a DAG of stages.

/**
 * A GraphBuilder provides APIs to build Estimator/Model/AlgoOperator from a DAG of stages, each of
 * which could be an Estimator, Model, Transformer or AlgoOperator.
 */
@PublicEvolving
public final class GraphBuilder {
    private int maxOutputLength = 20;

    public GraphBuilder() {}

    /**
     * Specifies the upper bound (could be loose) of the number of output tables that can be
     * returned by the Transformer::getModelData and AlgoOperator::transform methods, for any stage
     * involved in this Graph.
     *
     * <p>The default upper bound is 20.
     */
    public GraphBuilder setMaxOutputLength(int maxOutputLength) {...}

    /**
     * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of
     * tables between stages, as well as the input/output tables of the Graph/GraphModel generated
     * by this builder.
     */
    public TableId createTableId() {...}

    /**
     * If the stage is an Estimator, both its fit method and the transform method of its fitted
     * Model would be invoked with the given inputs when the graph runs.
     *
     * <p>If this stage is a Model, Transformer or AlgoOperator, its transform method would be
     * invoked with the given inputs when the graph runs.
     *
     * <p>Returns a list of TableIds, which represents outputs of AlgoOperator::transform of the given stage.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId... inputs) {...}

    /**
     * If this stage is an Estimator, its fit method would be invoked with estimatorInputs, and the
     * transform method of its fitted Model would be invoked with algoInputs.
     *
     * <p>This method throws Exception if the stage is not an Estimator.
     *
     * <p>This method is useful when the state is an Estimator AND the Estimator::fit needs to take
     * a different list of Tables from the Model::transform of the fitted Model.
     *
     * <p>Returns a list of TableIds, which represents outputs of Model::transform of the fitted Model.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId[] estimatorInputs, TableId[] algoInputs) {...}

    /**
     * The setModelData() of the fitted GraphModel should invoke the setModelData() of the given
     * stage with the given inputs.
     */
    public void setModelData(Stage<?> stage, TableId... inputs) {...}

    /**
     * The getModelData() of the fitted GraphModel should invoke the getModelData() of the given
     * stage.
     *
     * <p>Returns a list of TableIds, which represents the outputs of getModelData() of the given
     * stage.
     */
    public TableId[] getModelData(Stage<?> stage) {...}

    /**
     * Returns an Estimator instance with the following behavior:
     *
     * <p>1) Estimator::fit should take the given inputs and return a Model with the following
     * behavior.
     *
     * <p>2) Model::transform should take the given inputs and return the given outputs.
     *
     * <p>The fit method of the returned Estimator and the transform method of the fitted Model
     * should invoke the corresponding methods of the internal stages as specified by the
     * GraphBuilder.
     */
    public Estimator<?, ?> buildEstimator(TableId[] inputs, TableId[] outputs) {...}

    /**
     * Returns an Estimator instance with the following behavior:
     *
     * <p>1) Estimator::fit should take the given inputs and returns a Model with the following
     * behavior.
     *
     * <p>2) Model::transform should take the given inputs and return the given outputs.
     *
     * <p>3) Model::setModelData should take the given inputModelData.
     *
     * <p>4) Model::getModelData should return the given outputModelData.
     *
     * <p>The fit method of the returned Estimator and the transform/setModelData/getModelData
     * methods of the fitted Model should invoke the corresponding methods of the internal stages as
     * specified by the GraphBuilder.
     */
    public Estimator<?, ?> buildEstimator(TableId[] inputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}

    /**
     * Returns an Estimator instance with the following behavior:
     *
     * <p>1) Estimator::fit should take the given estimatorInputs and returns a Model with the
     * following behavior.
     *
     * <p>2) Model::transform should take the given transformerInputs and return the given outputs.
     *
     * <p>3) Model::setModelData should take the given inputModelData.
     *
     * <p>4) Model::getModelData should return the given outputModelData.
     *
     * <p>The fit method of the returned Estimator and the transform/setModelData/getModelData
     * methods of the fitted Model should invoke the corresponding methods of the internal stages as
     * specified by the GraphBuilder.
     */
    public Estimator<?, ?> buildEstimator(TableId[] estimatorInputs, TableId[] algoInputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}

    /**
     * Returns an AlgoOperator instance with the following behavior:
     *
     * <p>1) AlgoOperator::transform should take the given inputs and returns the given outputs.
     *
     * <p>The transform method of the returned AlgoOperator should invoke the corresponding methods
     * of the internal stages as specified by the GraphBuilder.
     */
    public AlgoOperator<?> buildAlgoOperator(TableId[] inputs, TableId[] outputs) {...}

    /**
     * Returns a Model instance with the following behavior:
     *
     * <p>1) Model::transform should take the given inputs and returns the given outputs.
     *
     * <p>The transform method of the returned Model should invoke the corresponding methods of the
     * internal stages as specified by the GraphBuilder.
     */
    public Model<?> buildModel(TableId[] inputs, TableId[] outputs) {...}

    /**
     * Returns a Model instance with the following behavior:
     *
     * <p>1) Model::transform should take the given inputs and returns the given outputs.
     *
     * <p>2) Model::setModelData should take the given inputModelData.
     *
     * <p>3) Model::getModelData should return the given outputModelData.
     *
     * <p>The transform/setModelData/getModelData methods of the returned Model should invoke the
     * corresponding methods of the internal stages as specified by the GraphBuilder.
     */
    public Model<?> buildModel(TableId[] inputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}
}


Example Usage

In this section we provide examples code snippets to demonstrate how we can use the APIs proposed in this FLIP to address the use-cases in the motivation section.

Composing an Estimator from a DAG of Estimator/Transformer

Suppose we have the following Transformer and Estimator classes:

  • TransformerA whose transform(...) takes 1 input table and has 1 output table.
  • ModelB whose transform(...) takes 2 input tables and has 1 output table.
  • EstimatorB whose fit(...) takes 2 input tables and returns an instance of ModelB.

And we want to compose an Estimator (e.g. Graph) from the following DAG of Transformer/Estimator.


The resulting Graph::fit is expected to have the following behavior:

  • The method takes 2 input tables. The 1st input table is given to a TransformerA instance. And the 2nd input table is given to another TransformerA instance.
  • An EstimatorB instance fits the output tables of these two TransformerA instances and generates a new ModelB instance.
  • Returns a GraphModel instance which contains 2 TransformerA instance and 1 ModelB instance, connected using the same DAG as shown above.


Here is the code snippet that addresses this use-case by using the proposed APIs:

GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new TransformerA();
Stage<?> stage2 = new TransformerA();
Stage<?> stage3 = new EstimatorB();
// Creates inputs and inputStates
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, input1)[0];
TableId output2 = builder.getOutputs(stage2, input2)[0];
TableId output3 = builder.getOutputs(stage3, output1, output2)[0];

// Specifies the ordered lists of inputs, outputs, input states and output states that will
// be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};

// Generates the Graph instance.
Estimator<?, ?> estimator = builder.buildEstimator(inputs, outputs);
// The fit method takes 2 tables which are mapped to input1 and input2.
Model<?> model = estimator.fit(...);
// The transform method takes 2 tables which are mapped to input1 and input2.
Table[] results = model.transform(...);

Compose an Estimator from a chain of Estimator/Transformer whose input schemas are different from its fitted Transformer 

Suppose we have the following Estimator and Transformer classes where an Estimator's input schemas could be different from the input schema of its fitted Transformer:

  • TransformerA whose transform(...) takes 1 input table and has 1 output table.
  • EstimatorA whose fit(...) takes 2 input tables and returns an instance of TransformerA.
  • TransformerB whose transform(...) takes 1 input table and has 1 output table.

And we want to compose an Estimator (e.g. Graph) from the following DAG of Transformer/Estimator.


The resulting Graph::fit is expected to have the following behavior:

  • The method takes 2 input tables. Both tables are given to EstimatorA::fit.
  • EstimatorA fits the input tables and generates a TransformerA instance. The TransformerA instance takes 1 table input, which is different from the 2 tables given to the EstimatorA.
  • Returns a GraphModel instance which contains a TransformerA instance and a TransformerB instance, which are connected as a chain.

The fitted GraphModel is represented by the following DAG:

Here is the code snippet that addresses this use-case by using the proposed APIs:

GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new EstimatorA();
Stage<?> stage2 = new TransformerB();
// Creates inputs
TableId estimatorInput1 = builder.createTableId();
TableId estimatorInput2 = builder.createTableId();
TableId transformerInput1 = builder.createTableId();

// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, new TableId[] {estimatorInput1, estimatorInput2}, new TableId[] {transformerInput1})[0];
TableId output2 = builder.getOutputs(stage2, output1)[0];

// Specifies the ordered lists of estimator inputs, transformer inputs, outputs, input states and output states
// that will be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] estimatorInputs = new TableId[] {estimatorInput1, estimatorInput2};
TableId[] transformerInputs = new TableId[] {transformerInput1};
TableId[] outputs = new TableId[] {output2};
TableId[] inputModelData = new TableId[] {};
TableId[] outputModelData = new TableId[] {};

// Generates the Graph instance.
Estimator<?, ?> estimator = builder.buildEstimator(estimatorInputs, transformerInputs, outputs, inputModelData, outputModelData);
// The fit method takes 2 tables which are mapped to estimatorInput1 and estimatorInput2.
Model<?> model = estimator.fit(...);
// The transform method takes 1 table which is mapped to transformerInput1.
Table[] results = model.transform(...);

Compatibility, Deprecation, and Migration Plan

This FLIP does not remove or modify any existing APIs. There is no backward incompatible change, deprecation or migration plan.

Test Plan

We will provide unit tests to validate the proposed changes.

Rejected Alternatives

There is no rejected alternatives to be listed here yet.




  • No labels