Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.


Page properties


Current state: Not ready for discussion.

Discussion thread: To be added

...

Jira
serverASF JIRA
serverId5aa69414-a9e9-3523-82ec-879b028fb15b
keyFLINK-23959

...

Releaseml-2.0.0


Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Table of Contents

[This FLIP proposal is a joint work between Dong Lin and Zhipeng Zhang]

Motivation and Use-cases

The existing Flink ML library allows users to compose an Estimator/Transformer/AlgoOperator from a pipeline (i.e. linear sequence) of Estimator/Transformer/AlgoOperator. Users only need to construct this Pipeline once and generate the corresponding PipelineModel, without having to explicitly construct the fitted PipelineModel as a linear sequence of stages. However, in order to train a DAG of Estimator/Transformer/AlgoOperator and uses the trained model for inference, users currently need to construct the DAG twice, once for the training logic and once for the inference logic. This experience is inferior to the experience of training and using a chain of Estimator/Transformer/AlgoOperator. In addition to requiring more work from users, this approach is more error prone because the DAG for the training logic may be inconsistent from the DAG for the inference logic.

...

Code Block
languagejava
public class GraphNode {
    public final Stage<?> stage;
    public final TableId[] estimatorInputs;
    public final TableId[] modelInputsalgoOpInputs;
    public final TableId[] outputs;
}

...

Code Block
languagejava
/**
 * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be an
 * Estimator, Model, Transformer or AlgoOperator. When `Graph::fit` is called, the stages are
 * executed in a topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method
 * will be called on the input tables (from the input edges) to fit a Model. Then the Model will be
 * used to transform the input tables and produce output tables to the output edges. If a stage is
 * an AlgoOperator, its `AlgoOperator::transform` method will be called on the input tables and
 * produce output tables to the output edges. The GraphModel fitted from a Graph consists of the
 * fitted Models and AlgoOperators, corresponding to the Graph's stages.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, GraphModel> {
    public Graph(List<GraphNode> nodes, TableId[] estimatorInputIds, TableId[] modelInputsalgoOpInputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}

    @Override
    public GraphModel fit(Table... inputs) {...}

    @Override
    public void save(String path) throws IOException {...}

    @Override
    public static Graph load(StringStreamTableEnvironment tEnv, String path) throws IOException {...}
}

...

Code Block
languagejava
/**
 * A GraphModel acts as a Model. A GraphModel consists of a DAG of stages, each of which could be an
 * Estimator, Model, Transformer or AlgoOperators. When `GraphModel::transform` is called, the
 * stages are executed in a topologically-sorted order. When a stage is executed, its
 * `AlgoOperator::transform` method will be called on the input tables (from the input edges) and
 * produce output tables to the output edges.
 */
public final class GraphModel implements Model<GraphModel> {

    public GraphModel(List<GraphNode> nodes, TableId[] inputIds, TableId[] outputIds, TableId[] inputModelData, TableId[] outputModelData) {...}

    @Override
    public Table[] transform(Table... inputTables) {...}

    @Override
    public void setModelData(Table... inputs) {...}

    @Override
    public Table[] getModelData() {...}

    @Override
    public void save(String path) throws IOException {...}

    public static GraphModel load(StreamTableEnvironment tEnv, String path) throws IOException {...}
}

...

Code Block
languagejava
/**
 * A GraphBuilder provides APIs to build Estimator/Model/AlgoOperator from a DAG of stages, each of
 * which could be an Estimator, Model, Transformer or AlgoOperator.
 */
@PublicEvolving
public final class GraphBuilder {
    private int maxOutputLength = 20;

    public GraphBuilder() {}

    /**
     * Specifies the loose upper bound (could be loose) of the number of output tables that can be
 returned by the
  * returned by the* TransformerModel::getModelData() and AlgoOperator::transform() methods, for any stage
 involved in this
  * involved in this* Graph.
     *
     * <p>The default upper bound is 20.
     */
    public GraphBuilder setMaxOutputLengthsetMaxOutputTableNum(int maxOutputLength) {...}

    /**
     * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of
     * tables between stages, as well as the input/output tables of the Graph/GraphModel generated
     * by this builder.
     */
    public * TableId@return createTableId()A {TableId...}

     */**
    public *TableId If the stage is an Estimator, both its fit method and the transform method of its fittedcreateTableId() {...}

    /**
     * Adds an AlgoOperator in the graph.
     *
 Model  would be invoked* with<p>When the graph givenruns inputsas whenEstimator, the graph runs.
     *transform() of the given AlgoOperator would be
     * <p>Ifinvoked thiswith stage is a Model, Transformer or AlgoOperator, its transform method would bethe given inputs. Then when the GraphModel fitted by this graph runs, the
     * invokedtransform() withof the given inputs when AlgoOperator would be invoked with the graphgiven runsinputs.
     *
     * <p>Returns<p>When athe listgraph ofruns TableIds,as whichAlgoOperator representsor outputsModel, ofthe AlgoOperator::transform() of the given stage.AlgoOperator
     */
 would be invoked publicwith TableId[] getOutputs(Stage<?> stage, TableId... inputs) {...}

the given inputs.
     /**
     * If<p>NOTE: thisthe stagenumber isof anthe Estimator,returned itsTableIds fitdoes methodnot wouldrepresent bethe invokedactual withnumber estimatorInputs, and theof Tables
     * transformoutputted method of its fitted Model would be invoked with modelInputs.by transform(). This number could be configured using {@link
     *
     * <p>This method throws Exception if the stage is not an Estimator. #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of
     *
 Tables outputted by transform().
 * <p>This  method *
 is useful when the state* is@param analgoOp EstimatorAn AND the Estimator::fit needs to takeAlgoOperator instance.
     * a@param inputs differentA list of TablesTableIds fromwhich the Model::transformrepresents inputs to transform() of the fitted Model.given
     *     AlgoOperator.
     * <p>Returns@return aA list of TableIds, which represents the outputs of Model::transform() of the fitted Model given
     *     AlgoOperator.
     */
    public TableId[] getOutputsaddAlgoOperator(Stage<AlgoOperator<?> stagealgoOp, TableId[] estimatorInputs, TableId[] modelInputs... inputs) {...}

    /**
     * Adds Thean setModelData()Estimator ofin the fittedgraph.
 GraphModel should invoke the setModelData() of *
     * <p>When the given
graph runs as Estimator, the * stage withfit() of the given inputs.
Estimator would be invoked  */with
    public void setModelData(Stage<?> stage, TableId... inputs) {...}

    /**
     * The getModelData* the given inputs. Then when the GraphModel fitted by this graph runs, the transform() of the
 fitted GraphModel should invoke the getModelData() of* Model fitted by the given
 Estimator would be invoked with *the given stageinputs.
     *
     * <p>When <p>Returnsthe agraph listruns ofas TableIds,AlgoOperator whichor representsModel, the outputs of getModelDatafit() of the given Estimator would be
     * stage.
invoked with the given inputs, */
then the transform() of publicthe TableId[] getModelData(Stage<?> stage) {...}

    /**Model fitted by the given
     * ReturnsEstimator anwould Estimatorbe instanceinvoked with the followinggiven behavior:inputs.
     *
     * <p>1) Estimator::fit should take<p>NOTE: the number of the givenreturned TableIds inputsdoes andnot returnrepresent athe Modelactual withnumber theof followingTables
     * behavior.
     *outputted by transform(). This number could be configured using {@link
     * <p>2#setMaxOutputTableNum(int) Model::transform}. Users should takemake thesure giventhat inputsthis andnumber return>= the given outputs.actual number of
     * Tables outputted by transform().
     *
 <p>The   fit method* of@param theestimator returnedAn Estimator andinstance.
 the transform method of the* fitted@param Model
inputs A list of TableIds *which shouldrepresents invokeinputs the corresponding methodsto fit() of the internalgiven stagesEstimator as
 specified by the
     * GraphBuilder.
    well */
as inputs to transform() publicof Estimator<?, ?> buildEstimator(TableId[] inputs, TableId[] outputs) {...}

the Model fitted by the given Estimator.
     /**
 @return A list of TableIds *which represents Returnsthe anoutputs Estimatorof instancetransform() withof the following behavior:Model fitted by
     *
     *the <p>1)given Estimator::fit.
 should take the given inputs*/
 and returns a Modelpublic with the following
     * behavior.
TableId[] addEstimator(Estimator<?, ?> estimator, TableId... inputs) {...}

     /**
     * <p>2) Model::transform should takeAdds an Estimator in the givengraph.
 inputs and return the given outputs.*
     *
 <p>When the graph runs *as <p>3) Model::setModelData should takeEstimator, the fit() of the given inputModelData.
Estimator would be invoked  *with
     * <p>4) Model::getModelData should return the given outputModelData.
     *estimatorInputs. Then when the GraphModel fitted by this graph runs, the transform() of the
     * <p>TheModel fitfitted methodby of the returnedgiven Estimator would andbe invoked the transform/setModelData/getModelDatawith modelInputs.
     *
 methods of   * <p>When the fitted Model should invokegraph runs as AlgoOperator or Model, the corresponding methodsfit() of the given internalEstimator stageswould asbe
     * specified invoked with estimatorInputs, then the transform() of the Model fitted by the given GraphBuilder.Estimator
     */
 would be invoked public Estimator<?, ?> buildEstimator(TableId[] inputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}

    /**
     * Returns an Estimator instance with the following behavior:with modelInputs.
     *
     * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables
     *
 outputted by transform(). This *number <p>1) Estimator::fit should take the given estimatorInputs and returns a Model with the
     * following behavior.
     *could be configured using {@link
     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of
     * <p>2) Model::transform should take the given transformerInputs and return the given outputsTables outputted by transform().
     *
       * <p>3) Model::setModelData should take the given inputModelData@param estimator An Estimator instance.
     *
 @param estimatorInputs A  * <p>4) Model::getModelData should returnlist of TableIds which represents inputs to fit() of the given outputModelData.
     *     Estimator.
     * @param <p>ThemodelInputs fitA methodlist of theTableIds returnedwhich Estimatorrepresents andinputs theto transform/setModelData/getModelData
     * methods of the fitted Model should invoke the corresponding methods of the internal stages as() of the Model
     *     fitted by the given Estimator.
     * specified by the GraphBuilder@return A list of TableIds which represents the outputs of transform() of the Model fitted by
     *     the given Estimator.
     */
    public Estimator<?, ?> buildEstimator(TableId[] estimatorInputs, TableId[] modelInputs, TableId[] outputs addEstimator(
            Estimator<?, ?> estimator, TableId[] inputModelDataestimatorInputs, TableId[] outputModelDatamodelInputs) {...}

    /**
     * Returns an AlgoOperator instance with the following behavior:
     *When the graph runs as Estimator, it first generates a GraphModel that contains the Model
     * <p>1) AlgoOperator::transform should take fitted by the given inputsEstimator. andThen returnswhen thethis given outputs.
     *GraphModel runs, the setModelData() of the
     * <p>Thefitted transformModel methodwould ofbe theinvoked returnedwith AlgoOperatorthe shouldgiven invokeinputs thebefore corresponding methodsits transform() is invoked.
     * of
 the internal stages as specified* by<p>When the GraphBuilder.
graph runs as AlgoOperator or */
Model,    public AlgoOperator<?> buildAlgoOperator(TableId[] inputs, TableId[] outputs) {...}

    /**the setModelData() of the Model fitted by
     * Returns a Model instancethe given Estimator would be invoked with the following behavior:given inputs before its transform() is invoked.
     *
     * @param estimator An Estimator instance.
     * @param inputs A list of TableIds which represents inputs to setModelData() of the Model
     *     fitted by the given Estimator.
     */
    public void setModelDataOnEstimator(Estimator<?, ?> estimator, TableId... inputs) {...}

    /**
     * When the graph runs as Estimator, the setModelData() of the given Model would be invoked with
     * the given inputs before its transform() is invoked. Then when the GraphModel fitted by this
     * graph runs, the setModelData() of the given Model would be invoked with the given inputs.
     *
     * <p>When the graph runs as AlgoOperator or Model, the setModelData() of the given Model would
     * be invoked with the given inputs before its transform() is invoked.
     *
     * @param model A Model instance.
     * @param inputs A list of TableIds which represents inputs to setModelData() of the given
     *     Model.
     */
    public void setModelDataOnModel(Model<?> model, TableId... inputs) {...}

    /**
     * When the graph runs as Estimator, it first generates a GraphModel that contains the Model
     * fitted by the given Estimator. Then when this GraphModel runs, the getModelData() of the
     * fitted Model would be invoked.
     *
     * <p>When the graph runs as AlgoOperator or Model, the getModelData() of the Model fitted by
     * the given Estimator would be invoked.
     *
     * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables
     * outputted by getModelData(). This number could be configured using {@link
     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of
     * Tables outputted by getModelData().
     *
     * @param estimator An Estimator instance.
     * @return A list of TableIds which represents the outputs of getModelData() of the Model fitted
     *     by the given Estimator.
     */
    public TableId[] getModelDataFromEstimator(Estimator<?, ?> estimator) {...}

    /**
     * When the graph runs as Estimator, the getModelData() of the given Model would be invoked.
     * Then when the GraphModel fitted by this graph runs, the getModelData() of the given Model
     * would be invoked.
     *
     * <p>When the graph runs as AlgoOperator or Model, the getModelData() of the given Model would
     * be invoked.
     *
     * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables
     * outputted by getModelData(). This number could be configured using {@link
     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of
     * Tables outputted by getModelData().
     *
     * @param model A Model instance.
     * @return A list of TableIds which represents the outputs of getModelData() of the given Model.
     */
    public TableId[] getModelDataFromModel(Model<?> model) {...}

    /**
     * Wraps nodes of the graph into an Estimator.
     *
     * <p>When the returned Estimator runs, and when the Model fitted by the returned Estimator
     * runs, the sequence of operations recorded by the {@code addAlgoOperator(...)}, {@code
     * addEstimator(...)}, {@code setModelData(...)} and {@code getModelData(...)} would be executed
     * as specified in the Java doc of the corresponding methods.
     *
     * @param inputs A list of TableIds which represents inputs to fit() of the returned Estimator
     *     as well as inputs to transform() of the Model fitted by the returned Estimator.
     * @param outputs A list of TableIds which represents outputs of transform() of the Model fitted
     *     by the returned Estimator.
     * @return An Estimator which wraps the nodes of this graph.
     */
    public Estimator<?, ?> buildEstimator(TableId[] inputs, TableId[] outputs) {...}

    /**
     * Wraps nodes of the graph into an Estimator.
     *
     * <p>When the returned Estimator runs, and when the Model fitted by the returned Estimator
     * runs, the sequence of operations recorded by the {@code addAlgoOperator(...)}, {@code
     * addEstimator(...)}, {@code setModelData(...)} and {@code getModelData(...)} would be executed
     * as specified in the Java doc of the corresponding methods.
     *
     * @param inputs A list of TableIds which represents inputs to fit() of the returned Estimator
     *     as well as inputs to transform() of the Model fitted by the returned Estimator.
     * @param outputs A list of TableIds which represents outputs of transform() of the Model fitted
     *     by the returned Estimator.
     * @param inputModelData A list of TableIds which represents inputs to setModelData() of the
     *     Model fitted by the returned Estimator.
     * @param outputModelData A list of TableIds which represents outputs of getModelData() of the
     *     Model fitted by the returned Estimator.
     * @return An Estimator which wraps the nodes of this graph.
     */
    public Estimator<?, ?> buildEstimator(
            TableId[] inputs,
            TableId[] outputs,
            TableId[] inputModelData,
            TableId[] outputModelData) {...}

    /**
     * Wraps nodes of the graph into an Estimator.
     *
     * <p>When the returned Estimator runs, and when the Model fitted by the returned Estimator
     * runs, the sequence of operations recorded by the {@code addAlgoOperator(...)}, {@code
     * addEstimator(...)}, {@code setModelData(...)} and {@code getModelData(...)} would be executed
     * as specified in the Java doc of the corresponding methods.
     *
     * @param estimatorInputs A list of TableIds which represents inputs to fit() of the returned
     *     Estimator.
     * @param modelInputs A list of TableIds which represents inputs to transform() of the Model
     *     fitted by the returned Estimator.
     * @param outputs A list of TableIds which represents outputs of transform() of the Model fitted
     *     by the returned Estimator.
     * @param inputModelData A list of TableIds which represents inputs to setModelData() of the
     *     Model fitted by the returned Estimator.
     * @param outputModelData A list of TableIds which represents outputs of getModelData() of the
     *     Model fitted by the returned Estimator.
     * @return An Estimator which wraps the nodes of this graph.
     */
    public Estimator<?, ?> buildEstimator(
            TableId[] estimatorInputs,
            TableId[] modelInputs,
            TableId[] outputs,
            TableId[] inputModelData,
            TableId[] outputModelData) {...}

    /**
     * Wraps nodes of the graph into an AlgoOperator.
     *
     * <p>When the returned AlgoOperator runs, the sequence of operations recorded by the {@code
     * addAlgoOperator(...)} and {@code addEstimator(...)} would be executed as specified in the
     * Java doc of the corresponding methods.
     *
     * @param inputs A list of TableIds which represents inputs to transform() of the returned
     *     AlgoOperator.
     * @param outputs A list of TableIds which represents outputs of transform() of the returned
     *     AlgoOperator.
     * @return An AlgoOperator which wraps the nodes of this graph.
     */
    public AlgoOperator<?> buildAlgoOperator(TableId[] inputs, TableId[] outputs) {...}

    /**
     * Wraps nodes of the graph into a Model.
     *
     * <p>When the returned Model runs, the sequence of operations recorded by the {@code
     * addAlgoOperator(...)} and {@code addEstimator(...)} would be executed as specified in the
     * Java doc of the corresponding methods.
     *
     * @param inputs A list of TableIds which represents inputs to transform() of the returned
     *     Model.
     * @param outputs A list of TableIds which represents outputs of transform() of the returned
     *     Model.
     * @return A Model which wraps the nodes of this graph.
     */
    public Model<?> buildModel(TableId[] inputs, TableId[] outputs) {...}

    /**
     * Wraps nodes of the graph into a Model.
     *
     * <p>When the returned Model runs, the sequence of operations recorded by the {@code
     * addAlgoOperator(...)}, {@code addEstimator(...)}, {@code setModelData(...)} and {@code
     *
 getModelData(...)} would be executed *as <p>1) Model::transform should takespecified in the givenJava inputsdoc and returnsof the given outputs.corresponding
     * methods.
     *
 <p>The transform method of the* returned@param Modelinputs should invoke the corresponding methods of the
     * internal stages as specified by the GraphBuilder.
     */A list of TableIds which represents inputs to transform() of the returned
    public Model<?> buildModel(TableId[] inputs, TableId[] outputs) {...}

*     Model.
     /**
     * Returns a Model instance with the following behavior:* @param outputs A list of TableIds which represents outputs of transform() of the returned
     *
     * <p>1) Model::transform should take the given inputs and returns the given outputs.
     *Model.
     * @param inputModelData A list of TableIds which represents inputs to setModelData() of the
     * <p>2) Model::setModelData should take the givenreturned inputModelDataModel.
     *
 @param outputModelData A list *of <p>3) Model::getModelData should return the given outputModelData.
     *TableIds which represents outputs of getModelData() of the
     * <p>The transform/setModelData/getModelData methods of the returned Model should invoke the.
     * corresponding methods@return ofA theModel internalwhich stageswraps asthe specifiednodes byof thethis GraphBuildergraph.
     */
    public Model<?> buildModel(
            TableId[] inputs,
            TableId[] outputs,
            TableId[] inputModelData,
            TableId[] outputModelData) {...}
}

...

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<AlgoOperator<?> stage1 = new TransformerA();
Stage<AlgoOperator<?> stage2 = new TransformerA();
Stage<Estimator<?> stage3 = new EstimatorB();
// Creates inputs and inputStates
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputsaddAlgoOperator(stage1, input1)[0];
TableId output2 = builder.getOutputsaddAlgoOperator(stage2, input2)[0];
TableId output3 = builder.getOutputsaddEstimator(stage3, output1, output2)[0];

// Specifies the ordered lists of inputs, outputs, input states and output states that will
// be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};

// Generates the Graph instance.
Estimator<?, ?> estimator = builder.buildEstimator(inputs, outputs);
// The fit method takes 2 tables which are mapped to input1 and input2.
Model<?> model = estimator.fit(...);
// The transform method takes 2 tables which are mapped to input1 and input2.
Table[] results = model.transform(...);

...

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<Estimator<?> stage1 = new EstimatorA();
Stage<AlgoOperator<?> stage2 = new TransformerB();
// Creates inputs
TableId estimatorInput1 = builder.createTableId();
TableId estimatorInput2 = builder.createTableId();
TableId transformerInput1 = builder.createTableId();

// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputsaddEstimator(stage1, new TableId[] {estimatorInput1, estimatorInput2}, new TableId[] {transformerInput1})[0];
TableId output2 = builder.getOutputsaddAlgoOperator(stage2, output1)[0];

// Specifies the ordered lists of estimator inputs, transformer inputs, outputs, input states and output states
// that will be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] estimatorInputs = new TableId[] {estimatorInput1, estimatorInput2};
TableId[] transformerInputs = new TableId[] {transformerInput1};
TableId[] outputs = new TableId[] {output2};
TableId[] inputModelData = new TableId[] {};
TableId[] outputModelData = new TableId[] {};

// Generates the Graph instance.
Estimator<?, ?> estimator = builder.buildEstimator(estimatorInputs, transformerInputs, outputs, inputModelData, outputModelData);
// The fit method takes 2 tables which are mapped to estimatorInput1 and estimatorInput2.
Model<?> model = estimator.fit(...);
// The transform method takes 1 table which is mapped to transformerInput1.
Table[] results = model.transform(...);

...