Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

The existing Flink ML library allows users to compose an Estimator/Transformer from a pipeline (i.e. linear sequence) of Estimator/Transformer. Users only need to construct this Pipeline once and generate the corresponding PipelineModel, without having to explicitly construct the fitted PipelineModel as a linear sequence of stages. However, in the use-case that needs order to train a DAG of Estimator/Transformer and uses the trained model for inference, users currently needs need to separately build construct the DAG separatelytwice, once for the training logic and once for the inference logic. This experience is inferior to the cases supported by the Pipeline.To improve the user experience, we propose to add several helper classes that allow users to compose Estimator/Transformer/AlgoOperator from a DAG experience of training and using a chain of Estimator/Transformer/AlgoOperator.

Public Interfaces

. In addition to requiring more work from users, this approach is more error prone because the DAG for the training logic may be inconsistent from the DAG for the inference logic.

In order to address the issues described above, we propose to add several helper classes that allow users to compose Estimator/Transformer/AlgoOperator from a DAG of Estimator/Transformer/AlgoOperator.

Public Interfaces

This FLIP proposes to add the Graph, GraphModel, GraphBuilder, GraphNode and TableId classes. The following code block shows This FLIP proposes to add the Graph, GraphTransformer and GraphBuilder classes. The following code block shows the public APIs of these classes..

Code Block
languagejava
/**
 * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be an
 * Estimator, Model, Transformer or AlgoOperator. When `Graph::fit` is called, the stages are
 * executed in a topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method
 * will be called on the input tables (from the input edges) to fit a Model. Then the Model will be
 * used to transform the input tables and produce output tables to the output edges. If a stage is
 * an AlgoOperator, its `AlgoOperator::transform` method will be called on the input tables and
 * produce output tables to the output edges. The GraphModel fitted from a Graph consists of the
 * fitted Models and AlgoOperators, corresponding to the Graph's stages.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, GraphModel> {
    public Graph(List<GraphNode> nodes, TableId[] estimatorInputIds, TableId[] modelInputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}

    @Override
    public GraphModel fit(Table... inputs) {...}

    @Override
    public void save(String path) throws IOException {...}

    @Override
    public static Graph load(String path) throws IOException {...}
}

/**
 * A GraphModel acts as a Model. A GraphModel consists of a DAG of stages, each of which could be a
 * Model, Transformer or AlgoOperators. When `GraphModel::transform` is called, the stages are
 * executed in a topologically-sorted order. When a stage is executed
Code Block
languagejava
/**
 * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be
 * an Estimator, Transformer or AlgoOperator. When `Graph::fit` is called, the stages are executed in a
 * topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method will be
 * called on the input tables (from the input edges) to fit a model. Then the model, which is a
 * Transformer, will be used to transform the input tables to produce output tables to the output
 * edges. If a stage is a Transformer or AlgoOperator, its `AlgoOperator::transform`
 * method will be called on the
 * input tables (from the input tablesedges) toand produce output tables to the
 * output edges.
 The fitted model from a Graph is a
 * GraphTransformer, which consists of fitted models and transformers, corresponding to the Graph's
 * stages.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, GraphTransformer> {*/
public final class GraphModel implements Model<GraphModel> {

    public GraphModel(List<GraphNode> nodes, TableId[] inputIds, TableId[] outputIds, TableId[] inputModelData, TableId[] outputModelData) {...}

    @Override
    public Table[] Graphtransform(Table... inputTables) {...}

    @Override
    public GraphTransformervoid fitsetModelData(Table... inputs) {...}

    /** Skipped a few methods, including the implementations of some Estimator APIs. */
}

/**
 * A GraphTransformer acts as a Transformer. A GraphTransformer consists of a DAG of Transformers or AlgoOperators. When
 * `GraphTransformer::transform` is called, the stages are executed in a topologically-sorted order. When
 * a stage is executed, its `AlgoOperator::transform` method will be called on the input tables (from
 * the input edges) to produce output tables to the output edges.
 */
public final class GraphTransformer implements Transformer<GraphTransformer> {
    /** Skipped a few methods, including the implementations of the Transformer APIs. */
}

/**
 * A GraphBuilder provides APIs to build Graph and GraphTransformer from a DAG of Estimator, Transformer and
 * AlgoOperator instances.
 */
@PublicEvolving
public final class GraphBuilder {
    private int maxOutputLength = 20;
@Override
    public Table[] getModelData() {...}

    @Override
    public void save(String path) throws IOException {...}

    public static GraphModel load(String path) throws IOException {...}
}


/**
 * A GraphBuilder provides APIs to build Graph/Model/AlgoOperator from a DAG of stages, each of
 * which could be an Estimator, Model, Transformer or AlgoOperator.
 */
@PublicEvolving
public final class GraphBuilder {
    private int maxOutputLength = 20;

    public GraphBuilder() {}

    /**
     * Specifies the upper bound (could be loose) of the number of output tables that can be
     * returned by the Transformer::getModelData and AlgoOperator::transform methods, for any stage
     * involved in this Graph.
     *
     * <p>The default upper bound is 20.
     */
    public GraphBuilder setMaxOutputLength(int maxOutputLength) {...}

    /**
     /**
 Specifies the upper bound (could* beCreates loose)a ofTableId theassociated numberwith ofthis outputGraphBuilder. tablesIt that can be
 used to   * returned by specify the Transformer::getStateStreams and Transformer::transform methods, for any
     * stage involved in this Graph.
     *passing of
     * tables between stages, as well as the input/output tables of the Graph/GraphModel generated
     * <p>Theby default upper bound is 20this builder.
     */
    public GraphBuilderTableId setMaxOutputLengthcreateTableId(int maxOutputLength) {...}

    /**
     * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of/**
     * tables between stages, as well as the input/output tables of the Graph/GraphTransformer generated If the stage is an Estimator, both its fit method and the transform method of its fitted
     * byModel thiswould builder.
be invoked with the given */
inputs when the  public TableId createTableId() {...}

graph runs.
     /**
     * If<p>If thethis stage is ana EstimatorModel, both Transformer or AlgoOperator, its fittransform method and the transform method of its fitted would be
     * invoked with the given inputs when the graph runs.
     *
     * <p>Returns a Transformerlist wouldof beTableIds, invokedwhich withrepresents theoutputs given inputs whenof AlgoOperator::transform of the graphgiven runsstage.
     */
    public * <p>If this stage is a Transformer or AlgoOperator, its transform method would be invoked with the given
     * inputs when the graph runs.
     *TableId[] getOutputs(Stage<?> stage, TableId... inputs) {...}

    /**
     * If this stage is an Estimator, its fit method would be invoked with estimatorInputs, and the
     * <p>Returnstransform amethod listof ofits TableIds,fitted whichModel representswould outputsbe ofinvoked the Transformer::transformwith modelInputs.
     * invocation.
     */
 <p>This method throws publicException TableId[] getOutputs(Stage<?>if the stage, TableId... inputs) {...}

 is not an Estimator.
     /**
     * If this stage <p>This method is useful when the state is an Estimator, AND itsthe Estimator::fit method would be invoked with estimatorInputs, and the needs to take
     * a different list of Tables from the Model::transform of the fitted Model.
     *
 transform method of its fitted* Transformer<p>Returns woulda belist invokedof with transformerInputsTableIds, whenwhich the
represents outputs of Model::transform of *the graphfitted runsModel.
     */
    public * <p>This method throws Exception if the stage is a Transformer or AlgoOperator.TableId[] getOutputs(Stage<?> stage, TableId[] estimatorInputs, TableId[] modelInputs) {...}

     /**
     * <p>ThisThe method is useful when setModelData() of the statefitted isGraphModel anshould Estimatorinvoke AND the Estimator::fit needs to takesetModelData() of the given
     * a different list of Tables from the Transformer::transform ofstage with the fittedgiven Transformerinputs.
     */
    public * <p>Returns a list of TableIds, which represents outputs of the Transformer::transformvoid setModelData(Stage<?> stage, TableId... inputs) {...}

    /**
     * invocation.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId[] estimatorInputs, TableId[] transformerInputs) {...}

    /* The getModelData() of the fitted GraphModel should invoke the getModelData() of the given
     * stage.
     *
     * The GraphTransformer::setStateStreams should invoke the setStateStreams method <p>Returns a list of TableIds, which represents the outputs of getModelData() of the stage withgiven
     * the given inputsstage.
     */
    public voidTableId[] setStateStreamsgetModelData(Stage<?> stage, TableId... inputs) {...}

    /**
     * Returns The GraphTransformer::getStateStreams should invokean Estimator instance with the getStateStreams method of the stage.following behavior:
     *
     *
 <p>1) Estimator::fit should take *the <p>Returnsgiven ainputs listand ofreturn TableIds,a whichModel representswith the outputsfollowing
  of getStateStreams of the* stagebehavior.
     */
    TableId[] getStateStreams(Stage<?> stage) {...}

 * <p>2) Model::transform should take the given inputs and return the given outputs.
     /**
     * Returns<p>The fit anmethod Estimatorof instancethe whichreturned theEstimator followingand APIthe specification:
transform method of the fitted *Model
     * <p>1) Estimator::fit should take inputs and returns a Transformer with the following should invoke the corresponding methods of the internal stages as specified by the
     * specificationGraphBuilder.
     */
    public * <p>2) Transformer::transform should take inputs and return outputs.
Estimator<?, ?> buildEstimator(TableId[] inputs, TableId[] outputs) {...}

     /**
     * Returns <p>Thean fit/transformEstimator shouldinstance invokewith the APIs of the internal stages in the order specified by the DAG of stages.
     */
    Estimator buildEstimator(TableId[] inputs, TableId[] outputs) {...}

    /* following behavior:
     *
     * <p>1) Estimator::fit should take the given inputs and returns a Model with the following
     * behavior.
     *
     * Returns an Estimator instance which <p>2) Model::transform should take the given inputs and return the following API specification:given outputs.
     *
     * <p>1<p>3) EstimatorModel::fit should take inputs and returns a Transformer with the following specificationsetModelData should take the given inputModelData.
     *
     * <p>4) Model::getModelData should return the given outputModelData.
     *
     * <p>2) Transformer::transform should take inputs <p>The fit method of the returned Estimator and return outputs.the transform/setModelData/getModelData
     *
 methods of the fitted *Model <p>3) Transformer::setStateStreams should take inputStates.
     *should invoke the corresponding methods of the internal stages as
     * <p>4) Transformer::getStateStreams should return outputStatesspecified by the GraphBuilder.
     */
    public Estimator<?, ?> buildEstimator(TableId[] inputs, TableId[] * <p>The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internaloutputs, TableId[] inputModelData, TableId[] outputModelData) {...}

    /**
     * stages in Returns an Estimator instance with the orderfollowing specifiedbehavior:
 by the DAG of stages.*
     */
    Estimator buildEstimator(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /** <p>1) Estimator::fit should take the given estimatorInputs and returns a Model with the
     * following behavior.
     *
     * Returns an Estimator instance which <p>2) Model::transform should take the given transformerInputs and return the following API specification:given outputs.
     *
     * <p>1<p>3) EstimatorModel::fitsetModelData should take estimatorInputs and returns a Transformer with the followinggiven specificationinputModelData.
     *
     * <p>2<p>4) TransformerModel::transformgetModelData should takereturn transformerInputsthe andgiven return outputsoutputModelData.
     *
     * <p>3) Transformer::setStateStreams should take inputStates.
     *<p>The fit method of the returned Estimator and the transform/setModelData/getModelData
     * <p>4) Transformer::getStateStreams methods of the fitted Model should return outputStates.
     *invoke the corresponding methods of the internal stages as
     * <p>The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal
     * stages in the order specified by the DAG of stages.
     specified by the GraphBuilder.
     */
    public Estimator<?, ?> buildEstimator(TableId[] estimatorInputs, TableId[] modelInputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}

    /**
     * <p>ThisReturns methodan isAlgoOperator usefulinstance whenwith the following Estimator::fit needs to take a different list of Tables frombehavior:
     *
     * the<p>1) TransformerAlgoOperator::transform should oftake the fittedgiven Transformer.
inputs and returns   */
    Estimator buildEstimator(
    the given outputs.
     *
   TableId[] estimatorInputs,
 * <p>The transform method of the returned AlgoOperator should invoke the TableId[] transformerInputs,corresponding methods
     * of the internal stages as specified by TableId[] outputs,the GraphBuilder.
     */
    public AlgoOperator<?>  buildAlgoOperator(TableId[] inputStates,
            inputs, TableId[] outputStatesoutputs) {...}

    /**
     * Returns a TransformerModel instance whichwith the following API specificationbehavior:
     *
     * <p>1) TransformerModel::transform should take the given inputs and returns outputs.
     *
     * <p>2) Transformer::setStateStreams should take inputStates.
     *
     * <p>3) Transformer::getStateStreams should return outputStatesthe given outputs.
     *
     * <p>The fit/transform/setStateStreams/getStateStreamstransform method of the returned Model should invoke the corresponding APIsmethods of the
     * internal stages as specified by the GraphBuilder.
     */
 stages in the orderpublic specified by the DAG of stages.
Model<?> buildModel(TableId[] inputs, TableId[] outputs) {...}

     /**
     * <p>ThisReturns methoda throwsModel Exceptioninstance ifwith anythe stagefollowing ofbehavior:
 this DAG is an Estimator.*
     */
    Transformer buildTransformer(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /* <p>1) Model::transform should take the given inputs and returns the given outputs.
     *
     * Returns an AlgoOperator instance which the following API specification:<p>2) Model::setModelData should take the given inputModelData.
     *
     * <p>1<p>3) AlgoOperatorModel::transformgetModelData should takereturn inputsthe andgiven returns outputsoutputModelData.
     *
     * <p>The fit/transform should invoke the APIstransform/setModelData/getModelData methods of the returned Model should invoke the
     * corresponding methods of the internal stages in theas order specified by the DAG of stagesGraphBuilder.
     */
    public Model<?> buildModel(TableId[] inputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}
}

public class GraphNode *{
 <p>This method throws Exceptionpublic iffinal anyStage<?> stage;
 of  this DAGpublic isfinal an Estimator.TableId[] estimatorInputs;
     */public final TableId[] modelInputs;
    AlgoOperatorpublic final buildAlgoOperator(TableId[] inputs,outputs;
}

public class TableId[] outputs) {...}

    private //final The TableId is necessary to pass the inputs/outputs of various API calls across theint tableId;

    @Override
    public boolean equals(Object obj) {...}

    // Graph/GraphTransformer stages.@Override
    staticpublic classint TableIdhashCode() {...}
}


Example Usage

In this section we provide examples code snippets to demonstrate how we can use the APIs proposed in this FLIP to address the use-cases in the motivation section.

...

  • TransformerA whose transform(...) takes 1 input table and has 1 output table.
  • TransformerB ModelB whose transform(...) takes 2 input tables and has 1 output table.
  • EstimatorB whose fit(...) takes 2 input tables and returns an instance of TransformerBof ModelB.

And we want to compose an Estimator (e.g. Graph) from the following DAG of Transformer/Estimator.

...

  • The method takes 2 input tables. The 1st input table is given to a TransformerA instance. And the 2nd input table is given to another TransformerA instance.
  • An EstimatorB instance fits the output tables of these two TransformerA instances and generates a new TransformerB ModelB instance.
  • Returns a GraphTransformer GraphModel instance which contains 2 TransformerA instance and 1 TransformerB ModelB instance, connected using the same DAG as shown above.

...

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new TransformerA();
Stage<?> stage2 = new TransformerA();
Stage<?> stage3 = new EstimatorB();
// Creates inputs and inputStates
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, input1)[0];
TableId output2 = builder.getOutputs(stage2, input2)[0];
TableId output3 = builder.getOutputs(stage3, output1, output2)[0];

// Specifies the ordered lists of inputs, outputs, input states and output states that will
// be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};

// Generates the Graph instance.
Graph graph
Estimator<?, ?> estimator = builder.buildbuildEstimator(inputs, outputs, new TableId[]{}, new TableId[]{});
// The fit method takes 2 tables which are mapped to input1 and input2.
GraphTransformerModel<?> transformermodel = graphestimator.fit(...);
// The transform method takes 2 tables which are mapped to input1 and input2.
Table[] results = transformermodel.transform(...);

Compose an Estimator from a chain of Estimator/Transformer whose input schemas are different from its fitted Transformer 

...

  • The method takes 2 input tables. Both tables are given to EstimatorA::fit.
  • EstimatorA fits the input tables and generates a TransformerA instance. The TransformerA instance takes 1 table input, which is different from the 2 tables given to the EstimatorA.
  • Returns a GraphTransformer GraphModel instance which contains a TransformerA instance and a TransformerB instance, which are connected as a chain.

The fitted GraphTransformer is represented by the following DAG:

Image Removed

Notes:

...

  • which are connected as a chain.

The fitted GraphModel is represented by the following DAG:

Image Added

Here is the code snippet that addresses this use-case by using the proposed APIs:

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new EstimatorA();
Stage<?> stage2 = new TransformerB();
// Creates inputs
TableId estimatorInput1 = builder.createTableId();
TableId estimatorInput2 = builder.createTableId();
TableId transformerInput1 = builder.createTableId();

// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, new TableId[] {estimatorInput1, estimatorInput2}, new TableId[] {transformerInput1})[0];
TableId output2 = builder.getOutputs(stage2, output1)[0];

// Specifies the ordered lists of estimator inputs, transformer inputs, outputs, input states and output states
// that will be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] estimatorInputs = new TableId[] {estimatorInput1, estimatorInput2};
TableId[] transformerInputs = new TableId[] {transformerInput1};
TableId[] outputs = new TableId[] {output2};
TableId[] inputStatesinputModelData = new TableId[] {};
TableId[] outputStatesoutputModelData = new TableId[] {};

// Generates the Graph instance.
Graph graphEstimator<?, ?> estimator = builder.buildbuildEstimator(estimatorInputs, transformerInputs, outputs, inputStatesinputModelData, outputStatesoutputModelData);
// The fit method takes 2 tables which are mapped to estimatorInput1 and estimatorInput2.
GraphTransformerModel<?> transformermodel = graphestimator.fit(...);
// The transform method takes 1 table which is mapped to transformerInput1.
Table[] results = transformermodel.transform(...);

Compatibility, Deprecation, and Migration Plan

...