Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.


Page properties


Current state"Under Discussion"

Discussion thread: To be added

...

Jira
serverASF JIRA
serverId5aa69414-a9e9-3523-82ec-879b028fb15b
keyFLINK-23959

...

Releaseml-2.0.0


Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Table of Contents

[This FLIP proposal is a joint work between Dong Lin and Zhipeng Zhang]

Motivation and Use-cases

The existing Flink ML library allows users to compose an Estimator/Transformer/AlgoOperator from a pipeline (i.e. linear sequence) of Estimator/Transformer/AlgoOperator. Users only need to construct this Pipeline once and generate the corresponding PipelineModel, without having to explicitly construct the fitted PipelineModel as a linear sequence of stages. However, in the use-case that needs order to train a DAG of Estimator/Transformer/AlgoOperator and uses the trained model for inference, users currently needs need to separately build construct the DAG separatelytwice, once for the training logic and once for the inference logic. This experience is inferior to the cases supported by the Pipeline.experience of training and using a chain of Estimator/Transformer/AlgoOperator. In addition to requiring more work from users, this approach is more error prone because the DAG for the training logic may be inconsistent from the DAG for the inference logic.

In order to address the issues described aboveTo improve the user experience, we propose to add several helper classes that allow users to compose Estimator/Transformer/AlgoOperator from a DAG of Estimator/Transformer/AlgoOperator.

Public Interfaces

This FLIP proposes to add the Graph, GraphModel, GraphTransformer and GraphBuilder classesGraphBuilder, GraphNode and TableId classes. The following code block shows the public APIs of these classes.

1) Add the TableId class to represent the input/output of a stage.

This class is necessary in order to construct the DAG before we have the concrete Tables available. And this class overrides the equals/hashCode so that it can be used as the key of a hash map.

Code Block
languagejava
/**public class TableId {
 *   Aprivate Graphfinal int tableId;

    @Override
    public boolean equals(Object obj) {...}

    @Override
    public int hashCode() {...}
}


2) Add the GraphNode class.

This class contains the stage as well as the input/output of this stage in the form of TableId lists. A DAG can thus be represented as a list of GraphNodes.

Code Block
languagejava
public class GraphNode {
    public final Stage<?> stage;
    public final TableId[] estimatorInputs;
    public final TableId[] algoOpInputs;
    public final TableId[] outputs;
}


3) Add the Graph class to wrap a DAG of Estimator/Model/Transformer/AlgoOperaor into an Estimator.

Code Block
languagejava
/**
 * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be an
 * Estimator, Model, Transformer or AlgoOperator. When `Graph::fit` is called, the stages are
 * executed in a topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method
 * will be called on the input tables (from the input edges) to fit a Model. Then the Model will be
 * used to transform the input tables and produce output tables to the output edges. If a stage is
 * an AlgoOperator, its `AlgoOperator::transform` method will be called on the input tables and
 * produce output tables to the output edges. The GraphModel fitted from a Graph consists of the
 * fitted Models and AlgoOperators, corresponding to the Graph's stages.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, GraphModel> {
    public Graph(List<GraphNode> nodes, TableId[] estimatorInputIds, TableId[] algoOpInputs, TableId[] outputs, TableId[] inputModelData, TableId[] outputModelData) {...}

    @Override
    public GraphModel fit(Table... inputs) {...}

    @Override
    public void save(String path) throws IOException {...}

    @Override
    public static Graph load(StreamTableEnvironment tEnv, String path) throws IOException {...}
}


4) Add the GraphModel class to wrap a DAG of Estimator/Model/Transformer/AlgoOperaor into a Model.

Code Block
languagejava
/**
 * A GraphModel acts as a Model. A GraphModel consists of a DAG of stages, each of which could be an
 * Estimator, Model, Transformer or AlgoOperators. When `GraphModel::transform` is called, the
 * stages are executed in a topologically-sorted order. When a stage is executed, its
 * `AlgoOperator::transform` method will be called on the input tables (from the input edges) and
 * produce output tables to the output edges.
 */
public final class GraphModel implements Model<GraphModel> {

    public GraphModel(List<GraphNode> nodes, TableId[] inputIds, TableId[] outputIds, TableId[] inputModelData, TableId[] outputModelData) {...}

    @Override
    public Table[] transform(Table... inputTables) {...}

    @Override
    public void setModelData(Table... inputs) {...}

    @Override
    public Table[] getModelData() {...}

    @Override
    public void save(String path) throws IOException {...}

    public static GraphModel load(StreamTableEnvironment tEnv, String path) throws IOException {...}
}


5) Add the GraphBuilder class to build GraphModel or Graph from a DAG of stages.

Code Block
languagejava
/**
 * A GraphBuilder provides APIs to build Estimator/Model/AlgoOperator from a DAG of stages, each of
 * which could be an Estimator, Model, Transformer or AlgoOperator.
 */
@PublicEvolving
public final class GraphBuilder {     

    /**
     * Specifies the loose upper bound of the number of output tables that can be returned by the
     * Model::getModelData() and AlgoOperator::transform() methods, for any stage involved in this
     * Graph.
     *
     * <p>The default upper bound is 20.
     */
    public GraphBuilder setMaxOutputTableNum(int maxOutputLength) {...}

    /**
     * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of
     * tables between stages, as well as the input/output tables of the Graph/GraphModel generated
     * by this builder.
     *
     * @return A TableId.
     */
    public TableId createTableId() {...}

    /**
     * Adds an AlgoOperator in the graph.
     *
     * <p>When the graph runs as Estimator, the transform() of the given AlgoOperator would be
     * invoked with the given inputs. Then when the GraphModel fitted by this graph runs, the
     * transform() of the given AlgoOperator would be invoked with the given inputs.
     *
     * <p>When the graph runs as AlgoOperator or Model, the transform() of the given AlgoOperator
     * would be invoked with the given inputs.
     *
     * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables
     * outputted by transform(). This number could be configured using {@link
     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of
     * Tables outputted by transform().
     *
     * @param algoOp An AlgoOperator instance.
     * @param inputs A list of TableIds which represents inputs to transform() of the given
     *     AlgoOperator.
     * @return A list of TableIds which represents the outputs of transform() of the given
     *     AlgoOperator.
     */
    public TableId[] addAlgoOperator(AlgoOperator<?> algoOp, TableId... inputs) {...}

    /**
     * Adds an Estimator in the graph.
     *
     * <p>When the graph runs as Estimator, the fit() of the given Estimator would be invoked with
     * the given inputs. Then when the GraphModel fitted by this graph runs, the transform() of the
     * Model fitted by the given Estimator would be invoked with the given inputs.
     *
     * <p>When the graph runs as AlgoOperator or Model, the fit() of the given Estimator would be
     * invoked with the given inputs, then the transform() of the Model fitted by the given
     * Estimator would be invoked with the given inputs.
     *
     * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables
     * outputted by transform(). This number could be configured using {@link
     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of
     * Tables outputted by transform().
     *
     * @param estimator An Estimator instance.
     * @param inputs A list of TableIds which represents inputs to fit() of the given Estimator as
     *     well as inputs to transform() of the Model fitted by the given Estimator.
     * @return A list of TableIds which represents the outputs of transform() of the Model fitted by
     *     the given Estimator.
     */
    public TableId[] addEstimator(Estimator<?, ?> estimator, TableId... inputs) {...}

    /**
     * Adds an Estimator in the graph.
     *
     * <p>When the graph runs as Estimator, the fit() of the given Estimator would be invoked with
     * estimatorInputs. Then when the GraphModel fitted by this graph runs, the transform() of the
     * Model fitted by the given Estimator would be invoked with modelInputs.
     *
     * <p>When the graph runs as AlgoOperator or Model, the fit() of the given Estimator would be
     * invoked with estimatorInputs, then the transform() of the Model fitted by the given Estimator
     * would be invoked with modelInputs.
     *
     * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables
     * outputted by transform(). This number could be configured using {@link
     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of
     * Tables outputted by transform().
     *
     * @param estimator An Estimator instance.
     * @param estimatorInputs A list of TableIds which represents inputs to fit() of the given
     *     Estimator.
     * @param modelInputs A list of TableIds which represents inputs to transform() of the Model
     *     fitted by the given Estimator.
     * @return A list of TableIds which represents the outputs of transform() of the Model fitted by
     *     the given Estimator.
     */
    public TableId[] addEstimator(
            Estimator<?, ?> estimator, TableId[] estimatorInputs, TableId[] modelInputs) {...}

    /**
     * When the graph runs as Estimator, it first generates a GraphModel that contains the Model
     * fitted by the given Estimator. Then when this GraphModel runs, the setModelData() of the
     * fitted Model would be invoked with the given inputs before its transform() is invoked.
     *
     * <p>When the graph runs as AlgoOperator or Model, the setModelData() of the Model fitted by
     * the given Estimator would be invoked with the given inputs before its transform() is invoked.
     *
     * @param estimator An Estimator instance.
     * @param inputs A list of TableIds which represents inputs to setModelData() of the Model
     *     fitted by the given Estimator.
     */
    public void setModelDataOnEstimator(Estimator<?, ?> estimator, TableId... inputs) {...}

    /**
     * When the graph runs as Estimator, the setModelData() of the given Model would be invoked with
     * the given inputs before its transform() is invoked. Then when the GraphModel fitted by this
     * graph runs, the setModelData() of the given Model would be invoked with the given inputs.
     *
     * <p>When the graph runs as AlgoOperator or Model, the setModelData() of the given Model would
     * be invoked with the given inputs before its transform() is invoked.
     *
     * @param model A Model instance.
     * @param inputs A list of TableIds which represents inputs to setModelData() of the given
     *     Model.
     */
    public void setModelDataOnModel(Model<?> model, TableId... inputs) {...}

    /**
     * When the graph runs as Estimator, it first generates a GraphModel that contains the Model
     * fitted by the given Estimator. Then when this GraphModel runs, the getModelData() of the
     * fitted Model would be invoked.
     *
     * <p>When the graph runs as AlgoOperator or Model, the getModelData() of the Model fitted by
     * the given Estimator would be invoked.
     *
     * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables
     * outputted by getModelData(). This number could be configured using {@link
     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of
     * Tables outputted by getModelData().
     *
     * @param estimator An Estimator instance.
     * @return A list of TableIds which represents the outputs of getModelData() of the Model fitted
     *     by the given Estimator.
     */
    public TableId[] getModelDataFromEstimator(Estimator<?, ?> estimator) {...}

    /**
     * When the graph runs as Estimator, the getModelData() of the given Model would be invoked.
     * Then when the GraphModel fitted by this graph runs, the getModelData() of the given Model
     * would be invoked.
     *
     * <p>When the graph runs as AlgoOperator or Model, the getModelData() of the given Model would
     * be invoked.
     *
     * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables
     * outputted by getModelData(). This number could be configured using {@link
     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of
     * Tables outputted by getModelData().
     *
     * @param model A Model instance.
     * @return A list of TableIds which represents the outputs of getModelData() of the given Model.
     */
    public TableId[] getModelDataFromModel(Model<?> model) {...}

    /**
     * Wraps nodes of the graph into an Estimator.
     *
     * <p>When the returned Estimator runs, and when the Model fitted by the returned Estimator
     * runs, the sequence of operations recorded by the {@code addAlgoOperator(...)}, {@code
     * addEstimator(...)}, {@code setModelData(...)} and {@code getModelData(...)} would be executed
     * as specified in the Java doc of the corresponding methods.
     *
     * @param inputs A list of TableIds which represents inputs to fit() of the returned Estimator
     *     as well as inputs to transform() of the Model fitted by the returned Estimator.
     * @param outputs A list of TableIds which represents outputs of transform() of the Model fitted
     *     by the returned Estimator.
     * @return An Estimator which wraps the nodes of this graphacts as an Estimator. A Graph consists of a DAG of stages, each of which could be
 * an Estimator, Transformer or AlgoOperator. When `Graph::fit` is called, the stages are executed in a
 * topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method will be
 * called on the input tables (from the input edges) to fit a model. Then the model, which is a
 * Transformer, will be used to transform the input tables to produce output tables to the output
 * edges. If a stage is a Transformer or AlgoOperator, its `AlgoOperator::transform` method will be called on the
 * input tables to produce output tables to the output edges. The fitted model from a Graph is a
 * GraphTransformer, which consists of fitted models and transformers, corresponding to the Graph's
 * stages.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, GraphTransformer> {
    public Graph(...) {...}

    @Override
    public GraphTransformer fit(Table... inputs) {...}

    /** Skipped a few methods, including the implementations of some Estimator APIs. */
}

/**
 * A GraphTransformer acts as a Transformer. A GraphTransformer consists of a DAG of Transformers or AlgoOperators. When
 * `GraphTransformer::transform` is called, the stages are executed in a topologically-sorted order. When
 * a stage is executed, its `AlgoOperator::transform` method will be called on the input tables (from
 * the input edges) to produce output tables to the output edges.
 */
public final class GraphTransformer implements Transformer<GraphTransformer> {
    /** Skipped a few methods, including the implementations of the Transformer APIs. */
}

/**
 * A GraphBuilder provides APIs to build Graph and GraphTransformer from a DAG of Estimator, Transformer and
 * AlgoOperator instances.
 */
@PublicEvolving
public final class GraphBuilder {
    private int maxOutputLength = 20;

    public GraphBuilder() {}

    /**
     * Specifies the upper bound (could be loose) of the number of output tables that can be
     * returned by the Transformer::getStateStreams and Transformer::transform methods, for any
     * stage involved in this Graph.
     *
     * <p>The default upper bound is 20.
     */
    public GraphBuilder setMaxOutputLength(int maxOutputLengthEstimator<?, ?> buildEstimator(TableId[] inputs, TableId[] outputs) {...}

    /**
     * Wraps nodes of the graph into an Estimator.
     *
     * <p>When the returned Estimator runs, and *when Createsthe aModel TableIdfitted associatedby withthe thisreturned GraphBuilder.Estimator
 It can be used to* specifyruns, the passingsequence of
 operations recorded by the * tables between stages, as well as the input/output tables of the Graph/GraphTransformer generated
     * by this builder.
     */
    public TableId createTableId() {...}

    /**{@code addAlgoOperator(...)}, {@code
     * addEstimator(...)}, {@code setModelData(...)} and {@code getModelData(...)} would be executed
     * as specified in the Java doc of the corresponding methods.
     *
     * If the stage is an Estimator, both its fit method and the transform method of its fitted@param inputs A list of TableIds which represents inputs to fit() of the returned Estimator
     * Transformer would be invoked withas thewell givenas inputs when to transform() of the graphModel runs.
fitted by the returned  *Estimator.
     * <p>If@param outputs thisA stagelist isof aTableIds Transformerwhich orrepresents AlgoOperator,outputs itsof transform() methodof wouldthe be invoked with the givenModel fitted
     * inputs when the graph runs.
by the returned   *Estimator.
     * @param <p>ReturnsinputModelData aA list of TableIds, which represents outputs inputs to setModelData() of the Transformer::transform
     * invocation.
    Model */
fitted by the  public TableId[] getOutputs(Stage<?> stage, TableId... inputs) {...}

    /**returned Estimator.
     * If@param thisoutputModelData stage is an Estimator, its fit method would be invoked with estimatorInputs, andA list of TableIds which represents outputs of getModelData() of the
     * transform  method of itsModel fitted Transformerby wouldthe bereturned invokedEstimator.
 with transformerInputs, when the
 * @return An Estimator *which graphwraps runs.
the nodes of this  *graph.
     * <p>This/
 method throws Exception ifpublic the stage is a Transformer or AlgoOperator.
Estimator<?, ?> buildEstimator(
         *
   TableId[] inputs,
 *  <p>This method is useful when the state is an Estimator AND the Estimator::fit needs to take
TableId[] outputs,
          * a different list of Tables from the Transformer::transform of the fitted Transformer.
     *TableId[] inputModelData,
            TableId[] outputModelData) {...}

     /**
 <p>Returns a list of TableIds, which* representsWraps outputsnodes of the Transformer::transform graph into an Estimator.
     * invocation.
     */
 <p>When the returned public TableId[] getOutputs(Stage<?> stage, TableId[] estimatorInputs, TableId[] transformerInputs) {...}

    /**Estimator runs, and when the Model fitted by the returned Estimator
     * The GraphTransformer::setStateStreams should invokeruns, the sequence of operations recorded by the setStateStreams method of the stage with
     * the given inputs.{@code addAlgoOperator(...)}, {@code
     * addEstimator(...)}, {@code setModelData(...)} and {@code getModelData(...)} would be executed
     */
 as specified in the void setStateStreams(Stage<?> stage, TableId... inputs) {...}

Java doc of the corresponding methods.
     /**
     * The GraphTransformer::getStateStreams should invoke the getStateStreams method @param estimatorInputs A list of TableIds which represents inputs to fit() of the stage.returned
     *     Estimator.
     * @param <p>ReturnsmodelInputs aA list of TableIds, which represents theinputs outputsto oftransform() getStateStreams of the stage.Model
     */
    TableId[] getStateStreams(Stage<?> stage) {...}

    /** fitted by the returned Estimator.
     * @param outputs ReturnsA anlist Estimatorof instanceTableIds which therepresents followingoutputs API specification:
     *of transform() of the Model fitted
     * <p>1) Estimator::fit should take inputs and returns a Transformer with the following
     * specification.
     *    by the returned Estimator.
     * @param inputModelData A list of TableIds which represents inputs to setModelData() of the
     *   <p>2) Transformer::transform shouldModel takefitted inputsby andthe returnreturned outputsEstimator.
     *
 @param outputModelData A list *of <p>TheTableIds fit/transformwhich shouldrepresents invokeoutputs theof APIsgetModelData() of the internal stages in the order specified
     *     Model fitted by the DAGreturned of stagesEstimator.
     */
 @return An Estimator Estimator buildEstimator(TableId[] inputs, TableId[] outputs) {...}

which wraps the nodes of this graph.
     */**
     * Returns an Estimator instance which the following API specification:
public Estimator<?, ?> buildEstimator(
            TableId[] *estimatorInputs,
     * <p>1) Estimator::fit should take inputs and returns a Transformer with the following specification.
TableId[] modelInputs,
         *
     * <p>2) Transformer::transform should take inputs and return outputs.
TableId[] outputs,
            TableId[] *inputModelData,
     * <p>3) Transformer::setStateStreams should take inputStates.
       TableId[] outputModelData) {...}

     /**
     * <p>4) Transformer::getStateStreams should return outputStatesWraps nodes of the graph into an AlgoOperator.
     *
     * <p>The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal
     * stages in the order specified by the DAG of stages.
     */
    Estimator buildEstimator(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /*<p>When the returned AlgoOperator runs, the sequence of operations recorded by the {@code
     * addAlgoOperator(...)} and {@code addEstimator(...)} would be executed as specified in the
     * Java doc of the corresponding methods.
     *
     * @param Returnsinputs A anlist Estimatorof instanceTableIds which the following API specification:represents inputs to transform() of the returned
     *
     AlgoOperator.
 * <p>1) Estimator::fit should take* estimatorInputs@param andoutputs returnsA alist Transformerof withTableIds thewhich followingrepresents specification.
outputs of transform() of the *returned
     * <p>2) Transformer::transform should take transformerInputs and return outputsAlgoOperator.
     *
 @return An AlgoOperator which *wraps <p>3) Transformer::setStateStreams should take inputStatesthe nodes of this graph.
     */
    public * <p>4) Transformer::getStateStreams should return outputStates.AlgoOperator<?> buildAlgoOperator(TableId[] inputs, TableId[] outputs) {...}

     /**
     * <p>The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal Wraps nodes of the graph into a Model.
     *
     * stages in<p>When the orderreturned specifiedModel byruns, the DAGsequence of stages.operations recorded by the {@code
     *
 addAlgoOperator(...)} and {@code  * <p>This method is useful when the Estimator::fit needs to take a different list of Tables fromaddEstimator(...)} would be executed as specified in the
     * Java doc of the corresponding methods.
     *
     * @param inputs the Transformer::transformA list of theTableIds which fittedrepresents Transformer.
inputs to transform() of the */returned
    Estimator buildEstimator(
*     Model.
     * @param TableId[] estimatorInputs,
            TableId[] transformerInputs,outputs A list of TableIds which represents outputs of transform() of the returned
     *       TableId[] outputs,
Model.
     * @return A Model which wraps the nodes of TableId[] inputStates,this graph.
     */
    public Model<?> buildModel(TableId[] inputs, TableId[] outputStatesoutputs) {...}

    /**
     * ReturnsWraps anodes Transformerof instance which the followinggraph APIinto specification:
     *a Model.
     * <p>1) Transformer::transform should take inputs and returns outputs.
     *
   <p>When the *returned <p>2) Transformer::setStateStreams should take inputStates.
     *Model runs, the sequence of operations recorded by the {@code
     * <p>3) Transformer::getStateStreams should return outputStates.addAlgoOperator(...)}, {@code addEstimator(...)}, {@code setModelData(...)} and {@code
     *
 getModelData(...)} would   * <p>The fit/transform/setStateStreams/getStateStreams should invoke the APIsbe executed as specified in the Java doc of the internalcorresponding
     * stagesmethods.
 in the order specified by*
 the DAG of stages.
 * @param inputs A *
list of TableIds which  * <p>This method throws Exception if any stage of this DAG is an Estimatorrepresents inputs to transform() of the returned
     *     Model.
     */
    Transformer buildTransformer(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /** @param outputs A list of TableIds which represents outputs of transform() of the returned
     *     Model.
     * Returns@param aninputModelData AlgoOperatorA instancelist whichof theTableIds followingwhich APIrepresents specification:
inputs to setModelData() of  *the
     * <p>1) AlgoOperator::transform should take inputs and returnsreturned outputsModel.
     *
 @param outputModelData A list *of <p>TheTableIds fit/transformwhich shouldrepresents invokeoutputs theof APIsgetModelData() of the
 internal stages in the order* specified by the DAG ofreturned stagesModel.
     *
 @return A Model which  * <p>This method throws Exception if any stage of this DAG is an Estimator.
wraps the nodes of this graph.
     */
    public Model<?> buildModel(
       */
    AlgoOperator buildAlgoOperator(TableId[] inputs, TableId[] outputs) {...}


          // The TableId is necessary to pass the inputs/outputs of various API calls across the[] outputs,
            TableId[] inputModelData,
    // Graph/GraphTransformer stages.
    static class TableId[] outputModelData) {...}
}


Example Usage

In this section we provide examples code snippets to demonstrate how we can use the APIs proposed in this FLIP to address the use-cases in the motivation section.

Composing an Estimator from a DAG of Estimator/Transformer

Suppose we have the following Transformer and Estimator classes:

  • TransformerA whose transform(...) takes 1 input table and has 1 output table.
  • TransformerB ModelB whose transform(...) takes 2 input tables and has 1 output table.
  • EstimatorB whose fit(...) takes 2 input tables and returns an instance of TransformerBof ModelB.

And we want to compose an Estimator (e.g. Graph) from the following DAG of Transformer/Estimator.

...

  • The method takes 2 input tables. The 1st input table is given to a TransformerA instance. And the 2nd input table is given to another TransformerA instance.
  • An EstimatorB instance fits the output tables of these two TransformerA instances and generates a new TransformerB ModelB instance.
  • Returns a GraphTransformer GraphModel instance which contains 2 TransformerA instance and 1 TransformerB ModelB instance, connected using the same DAG as shown above.

...

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<AlgoOperator<?> stage1 = new TransformerA();
Stage<AlgoOperator<?> stage2 = new TransformerA();
Stage<Estimator<?> stage3 = new EstimatorB();
// Creates inputs and inputStates
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputsaddAlgoOperator(stage1, input1)[0];
TableId output2 = builder.getOutputsaddAlgoOperator(stage2, input2)[0];
TableId output3 = builder.getOutputsaddEstimator(stage3, output1, output2)[0];

// Specifies the ordered lists of inputs, outputs, input states and output states that will
// be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};

// Generates the Graph instance.
Graph graphEstimator<?, ?> estimator = builder.buildbuildEstimator(inputs, outputs, new TableId[]{}, new TableId[]{});
// The fit method takes 2 tables which are mapped to input1 and input2.
GraphTransformerModel<?> transformermodel = graphestimator.fit(...);
// The transform method takes 2 tables which are mapped to input1 and input2.
Table[] results = transformermodel.transform(...);

Compose an Estimator from a chain of Estimator/Transformer whose input schemas are different from its fitted Transformer 

Suppose we have the following Estimator and Transformer classes where an Estimator's input schemas could be different from the input schema of its fitted Transformer:

...

  • The method takes 2 input tables. Both tables are given to EstimatorA::fit.
  • EstimatorA fits the input tables and generates a TransformerA instance. The TransformerA instance takes 1 table input, which is different from the 2 tables given to the EstimatorA.
  • Returns a GraphTransformer GraphModel instance which contains a TransformerA instance and a TransformerB instance, which are connected as a chain.

The fitted GraphTransformer is represented by the following DAG:

Image Removed

Notes:

...

  • TransformerB instance, which are connected as a chain.

The fitted GraphModel is represented by the following DAG:

Image Added

...

Here is the code snippet that addresses this use-case by using the proposed APIs:

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<Estimator<?> stage1 = new EstimatorA();
Stage<AlgoOperator<?> stage2 = new TransformerB();
// Creates inputs
TableId estimatorInput1 = builder.createTableId();
TableId estimatorInput2 = builder.createTableId();
TableId transformerInput1 = builder.createTableId();

// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputsaddEstimator(stage1, new TableId[] {estimatorInput1, estimatorInput2}, new TableId[] {transformerInput1})[0];
TableId output2 = builder.getOutputsaddAlgoOperator(stage2, output1)[0];

// Specifies the ordered lists of estimator inputs, transformer inputs, outputs, input states and output states
// that will be used as the inputs/outputs of the corresponding Graph and GraphTransformer APIs.
TableId[] estimatorInputs = new TableId[] {estimatorInput1, estimatorInput2};
TableId[] transformerInputs = new TableId[] {transformerInput1};
TableId[] outputs = new TableId[] {output2};
TableId[] inputStatesinputModelData = new TableId[] {};
TableId[] outputStatesoutputModelData = new TableId[] {};

// Generates the Graph instance.
Graph graphEstimator<?, ?> estimator = builder.buildbuildEstimator(estimatorInputs, transformerInputs, outputs, inputStatesinputModelData, outputStatesoutputModelData);
// The fit method takes 2 tables which are mapped to estimatorInput1 and estimatorInput2.
GraphTransformerModel<?> transformermodel = graphestimator.fit(...);
// The transform method takes 1 table which is mapped to transformerInput1.
Table[] results = transformermodel.transform(...);

Compatibility, Deprecation, and Migration Plan

This FLIP does not remove or modify any existing APIs. There is no backward incompatible change, deprecation or migration plan.

Test Plan

We will provide unit tests to validate the proposed changes.

Rejected Alternatives

There is no rejected alternatives to be listed here yet.

...