You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 5 Next »

Current state"Under Discussion"

Discussion thread: To be added

JIRA:  Unable to render Jira issues macro, execution error.

Released: 1.15

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Motivation and Use-cases

The existing Flink ML library allows users to compose an Estimator/Transformer from a pipeline (i.e. linear sequence) of Estimator/Transformer. Users only need to construct this Pipeline once and generate the corresponding PipelineModel, without having to explicitly construct the fitted PipelineModel as a linear sequence of stages.

However, in the use-case that needs a DAG of Estimator/Transformer, users currently needs to separately build the DAG separately, once for the training logic and once for the inference logic. This experience is inferior to the cases supported by the Pipeline.

To improve the user experience, we propose to add several helper classes that allow users to compose Estimator/Transformer/AlgoOperator from a DAG of Estimator/Transformer/AlgoOperator.

Public Interfaces

This FLIP proposes to add the Graph, GraphTransformer and GraphBuilder classes. The following code block shows the public APIs of these classes.

/**
 * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be
 * an Estimator, Transformer or AlgoOperator. When `Graph::fit` is called, the stages are executed in a
 * topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method will be
 * called on the input tables (from the input edges) to fit a model. Then the model, which is a
 * Transformer, will be used to transform the input tables to produce output tables to the output
 * edges. If a stage is a Transformer or AlgoOperator, its `AlgoOperator::transform` method will be called on the
 * input tables to produce output tables to the output edges. The fitted model from a Graph is a
 * GraphTransformer, which consists of fitted models and transformers, corresponding to the Graph's
 * stages.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, GraphTransformer> {
    public Graph(...) {...}

    @Override
    public GraphTransformer fit(Table... inputs) {...}

    /** Skipped a few methods, including the implementations of some Estimator APIs. */
}

/**
 * A GraphTransformer acts as a Transformer. A GraphTransformer consists of a DAG of Transformers or AlgoOperators. When
 * `GraphTransformer::transform` is called, the stages are executed in a topologically-sorted order. When
 * a stage is executed, its `AlgoOperator::transform` method will be called on the input tables (from
 * the input edges) to produce output tables to the output edges.
 */
public final class GraphTransformer implements Transformer<GraphTransformer> {
    /** Skipped a few methods, including the implementations of the Transformer APIs. */
}

/**
 * A GraphBuilder provides APIs to build Graph and GraphTransformer from a DAG of Estimator, Transformer and
 * AlgoOperator instances.
 */
@PublicEvolving
public final class GraphBuilder {
    private int maxOutputLength = 20;

    public GraphBuilder() {}

    /**
     * Specifies the upper bound (could be loose) of the number of output tables that can be
     * returned by the Transformer::getStateStreams and Transformer::transform methods, for any
     * stage involved in this Graph.
     *
     * <p>The default upper bound is 20.
     */
    public GraphBuilder setMaxOutputLength(int maxOutputLength) {...}

    /**
     * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of
     * tables between stages, as well as the input/output tables of the Graph/GraphTransformer generated
     * by this builder.
     */
    public TableId createTableId() {...}

    /**
     * If the stage is an Estimator, both its fit method and the transform method of its fitted
     * Transformer would be invoked with the given inputs when the graph runs.
     *
     * <p>If this stage is a Transformer or AlgoOperator, its transform method would be invoked with the given
     * inputs when the graph runs.
     *
     * <p>Returns a list of TableIds, which represents outputs of the Transformer::transform
     * invocation.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId... inputs) {...}

    /**
     * If this stage is an Estimator, its fit method would be invoked with estimatorInputs, and the
     * transform method of its fitted Transformer would be invoked with transformerInputs, when the
     * graph runs.
     *
     * <p>This method throws Exception if the stage is a Transformer or AlgoOperator.
     *
     * <p>This method is useful when the state is an Estimator AND the Estimator::fit needs to take
     * a different list of Tables from the Transformer::transform of the fitted Transformer.
     *
     * <p>Returns a list of TableIds, which represents outputs of the Transformer::transform
     * invocation.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId[] estimatorInputs, TableId[] transformerInputs) {...}

    /**
     * The GraphTransformer::setStateStreams should invoke the setStateStreams method of the stage with
     * the given inputs.
     */
    void setStateStreams(Stage<?> stage, TableId... inputs) {...}

    /**
     * The GraphTransformer::getStateStreams should invoke the getStateStreams method of the stage.
     *
     * <p>Returns a list of TableIds, which represents the outputs of getStateStreams of the stage.
     */
    TableId[] getStateStreams(Stage<?> stage) {...}

    /**
     * Returns an Estimator instance which the following API specification:
     *
     * <p>1) Estimator::fit should take inputs and returns a Transformer with the following
     * specification.
     *
     * <p>2) Transformer::transform should take inputs and return outputs.
     *
     * <p>The fit/transform should invoke the APIs of the internal stages in the order specified by the DAG of stages.
     */
    Estimator buildEstimator(TableId[] inputs, TableId[] outputs) {...}

    /**
     * Returns an Estimator instance which the following API specification:
     *
     * <p>1) Estimator::fit should take inputs and returns a Transformer with the following specification.
     *
     * <p>2) Transformer::transform should take inputs and return outputs.
     *
     * <p>3) Transformer::setStateStreams should take inputStates.
     *
     * <p>4) Transformer::getStateStreams should return outputStates.
     *
     * <p>The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal
     * stages in the order specified by the DAG of stages.
     */
    Estimator buildEstimator(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /**
     * Returns an Estimator instance which the following API specification:
     *
     * <p>1) Estimator::fit should take estimatorInputs and returns a Transformer with the following specification.
     *
     * <p>2) Transformer::transform should take transformerInputs and return outputs.
     *
     * <p>3) Transformer::setStateStreams should take inputStates.
     *
     * <p>4) Transformer::getStateStreams should return outputStates.
     *
     * <p>The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal
     * stages in the order specified by the DAG of stages.
     *
     * <p>This method is useful when the Estimator::fit needs to take a different list of Tables from
     * the Transformer::transform of the fitted Transformer.
     */
    Estimator buildEstimator(
            TableId[] estimatorInputs,
            TableId[] transformerInputs,
            TableId[] outputs,
            TableId[] inputStates,
            TableId[] outputStates) {...}

    /**
     * Returns a Transformer instance which the following API specification:
     *
     * <p>1) Transformer::transform should take inputs and returns outputs.
     *
     * <p>2) Transformer::setStateStreams should take inputStates.
     *
     * <p>3) Transformer::getStateStreams should return outputStates.
     *
     * <p>The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal
     * stages in the order specified by the DAG of stages.
     *
     * <p>This method throws Exception if any stage of this DAG is an Estimator.
     */
    Transformer buildTransformer(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /**
     * Returns an AlgoOperator instance which the following API specification:
     *
     * <p>1) AlgoOperator::transform should take inputs and returns outputs.
     *
     * <p>The fit/transform should invoke the APIs of the internal stages in the order specified by the DAG of stages.
     *
     * <p>This method throws Exception if any stage of this DAG is an Estimator.
     */
    AlgoOperator buildAlgoOperator(TableId[] inputs, TableId[] outputs) {...}

    // The TableId is necessary to pass the inputs/outputs of various API calls across the
    // Graph/GraphTransformer stages.
    static class TableId {}
}


Example Usage

In this section we provide examples code snippets to demonstrate how we can use the APIs proposed in this FLIP to address the use-cases in the motivation section.

Composing an Estimator from a DAG of Estimator/Transformer

Suppose we have the following Transformer and Estimator classes:

  • TransformerA whose transform(...) takes 1 input table and has 1 output table.
  • TransformerB whose transform(...) takes 2 input tables and has 1 output table.
  • EstimatorB whose fit(...) takes 2 input tables and returns an instance of TransformerB.

And we want to compose an Estimator (e.g. Graph) from the following DAG of Transformer/Estimator.


The resulting Graph::fit is expected to have the following behavior:

  • The method takes 2 input tables. The 1st input table is given to a TransformerA instance. And the 2nd input table is given to another TransformerA instance.
  • An EstimatorB instance fits the output tables of these two TransformerA instances and generates a new TransformerB instance.
  • Returns a GraphTransformer instance which contains 2 TransformerA instance and 1 TransformerB instance, connected using the same DAG as shown above.


Here is the code snippet that addresses this use-case by using the proposed APIs:

GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new TransformerA();
Stage<?> stage2 = new TransformerA();
Stage<?> stage3 = new EstimatorB();
// Creates inputs and inputStates
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, input1)[0];
TableId output2 = builder.getOutputs(stage2, input2)[0];
TableId output3 = builder.getOutputs(stage3, output1, output2)[0];

// Specifies the ordered lists of inputs, outputs, input states and output states that will
// be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};

// Generates the Graph instance.
Graph graph = builder.build(inputs, outputs, new TableId[]{}, new TableId[]{});
// The fit method takes 2 tables which are mapped to input1 and input2.
GraphTransformer transformer = graph.fit(...);
// The transform method takes 2 tables which are mapped to input1 and input2.
Table[] results = transformer.transform(...);

Compose an Estimator from a chain of Estimator/Transformer whose input schemas are different from its fitted Transformer 

Suppose we have the following Estimator and Transformer classes where an Estimator's input schemas could be different from the input schema of its fitted Transformer:

  • TransformerA whose transform(...) takes 1 input table and has 1 output table.
  • EstimatorA whose fit(...) takes 2 input tables and returns an instance of TransformerA.
  • TransformerB whose transform(...) takes 1 input table and has 1 output table.

And we want to compose an Estimator (e.g. Graph) from the following DAG of Transformer/Estimator.


The resulting Graph::fit is expected to have the following behavior:

  • The method takes 2 input tables. Both tables are given to EstimatorA::fit.
  • EstimatorA fits the input tables and generates a TransformerA instance. The TransformerA instance takes 1 table input, which is different from the 2 tables given to the EstimatorA.
  • Returns a GraphTransformer instance which contains a TransformerA instance and a TransformerB instance, which are connected as a chain.

The fitted GraphTransformer is represented by the following DAG:

Notes:

  • The fitted GraphTransformer takes only 1 table as input whereas the Graph takes 2 tables as inputs.
  • The proposed APIs also support composing an Estimator from a DAG of Estimator/Transformer whose input schemas are different from its fitted Transformer. 

Here is the code snippet that addresses this use-case by using the proposed APIs:

GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new EstimatorA();
Stage<?> stage2 = new TransformerB();
// Creates inputs
TableId estimatorInput1 = builder.createTableId();
TableId estimatorInput2 = builder.createTableId();
TableId transformerInput1 = builder.createTableId();

// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, new TableId[] {estimatorInput1, estimatorInput2}, new TableId[] {transformerInput1})[0];
TableId output2 = builder.getOutputs(stage2, output1)[0];

// Specifies the ordered lists of estimator inputs, transformer inputs, outputs, input states and output states
// that will be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] estimatorInputs = new TableId[] {estimatorInput1, estimatorInput2};
TableId[] transformerInputs = new TableId[] {transformerInput1};
TableId[] outputs = new TableId[] {output2};
TableId[] inputStates = new TableId[] {};
TableId[] outputStates = new TableId[] {};

// Generates the Graph instance.
Graph graph = builder.build(estimatorInputs, transformerInputs, outputs, inputStates, outputStates);
// The fit method takes 2 tables which are mapped to estimatorInput1 and estimatorInput2.
GraphTransformer transformer = graph.fit(...);
// The transform method takes 1 table which is mapped to transformerInput1.
Table[] results = transformer.transform(...);

Compatibility, Deprecation, and Migration Plan

This FLIP does not remove or modify any existing APIs. There is no backward incompatible change, deprecation or migration plan.

Test Plan

We will provide unit tests to validate the proposed changes.

Rejected Alternatives

There is no rejected alternatives to be listed here yet.




  • No labels