Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Therefore, this design doc proposes to add the Graph/GraphModelgraphTransformer/GraphBuilder classes to provide the following capability:

...

This change addresses the use-cases described in the motivation section, e.g. a graph embedding Estimator needs to take 2 tables as inputs.

3) Added Graph, GraphModel graphTransformer and GraphBuilder.

This change addresses the use-cases described in the motivation section, where we need to compose an Estimator from a DAG of Estimator/Transformer. Note that the Graph/GraphBuilder supports Estimator class whose input schemas are different from its fitted Transformer.

...

This change simplifies the usage of fit/transform APIs.

7) Added PipelineModel pipelineTransformer and let Pipeline implement only the Estimator. Pipeline is no longer a Transformer.

...

9) Removed the Model interface. And renamed PipelineModel to PipelineTransformer.

This change simplifies the class hierarchy by removing a redundant class. It follows the philosophy of only adding complexity when we have explicit use-case for it.

...

The following code block shows the interface of Stage, Transformer, Estimator, Pipeline and PipelineModel and pipelineTransformer after making the changes listed above.

...

Code Block
languagejava
/**
 * Base class for a stage in a Pipeline or Graph. The interface is only a concept, and does not have any actual
 * functionality. Its subclasses could be Estimator, Transformer or MLfunc. No other classes should inherit this
 * interface directly.
 *
 * <p>Each stage is with parameters, and requires a public empty constructor for restoration.
 *
 * @param <T> The class type of the Stage implementation itself.
 * @see WithParams
 */
@PublicEvolving
interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable {
    /**
     * Saves this stage to the given path.
     */
    void save(String path);

    /**
     * Loads this stage from the given path.
     */
    void load(String path);
}

/**
 * A MLFunc is a Stage that takes a list of tables as inputs and produces a list of
 * tables as results. It can be used to encode a generic multi-input multi-output machine learning function.
 *
 * @param <T> The class type of the MLFunc implementation itself.
 */
@PublicEvolving
public interface MLFunc<T extends MLFunc<T>> extends Stage<T> {
    /**
     * Applies the MLFunc on the given input tables, and returns the result tables.
     *
     * @param inputs a list of tables
     * @return a list of tables
     */
    Table[] transform(Table... inputs);
}

/**
 * A Transformer is a MLFunc with additional support for state streams, which could be set by the Estimator that fitted
 * this Transformer. Unlike MLFunc, a Transformer is typically associated with an Estimator.
 *
 * @param <T> The class type of the Transformer implementation itself.
 */
@PublicEvolving
public interface Transformer<T extends Transformer<T>> extends MLFunc<T> {
    /**
     * Uses the given list of tables to update internal states. This can be useful for e.g. online
     * learning where an Estimator fits an infinite stream of training samples and streams the model
     * diff data to this Transformer.
     *
     * <p>This method may be called at most once.
     *
     * @param inputs a list of tables
     */
    default void setStateStreams(Table... inputs) {
        throw new UnsupportedOperationException("this method is not implemented");
    }

    /**
     * Gets a list of tables representing changes of internal states of this Transformer. These
     * tables might come from the Estimator that instantiated this Transformer.
     *
     * @return a list of tables
     */
    default Table[] getStateStreams() {
        throw new UnsupportedOperationException("this method is not implemented");
    }
}

/**
 * An Estimator is a Stage that takes a list of tables as inputs and produces a Transformer.
 *
 * @param <E> class type of the Estimator implementation itself.
 * @param <M> class type of the Transformer this Estimator produces.
 */
@PublicEvolving
public interface Estimator<E extends Estimator<E, M>, M extends Transformer<M>> extends Stage<E> {
    /**
     * Trains on the given inputs and produces a Transformer.
     *
     * @param inputs a list of tables
     * @return a Transformer
     */
    M fit(Table... inputs);
}

/**
 * A Pipeline acts as an Estimator. It consists of an ordered list of stages, each of which could be
 * an Estimator, Transformer or MLFunc.
 */
@PublicEvolving
public final class Pipeline implements Estimator<Pipeline, PipelineModel>pipelineTransformer> {

    public Pipeline(List<Stage<?>> stages) {...}

    @Override
    public PipelineModelpipelineTransformer fit(Table... inputs) {...}

    /** Skipped a few methods, including the implementations of the Estimator APIs. */
}

/**
 * A PipelineModelpipelineTransformer acts as a Transformer. It consists of an ordered list of Transformers or MLFuncs.
 */
@PublicEvolving
public final class PipelineModelpipelineTransformer implements Transformer<PipelineModel>Transformer<pipelineTransformer> {

    public PipelineModelpipelineTransformer(List<Transformer<?>> transformers) {...}

    /** Skipped a few methods, including the implementations of the Transformer APIs. */
}

...

The following code block shows the interface of Graph, GraphModel graphTransformer and GraphBuilder that we propose to add.

Code Block
languagejava
/**
 * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be
 * an Estimator, Transformer or MLFunc. When `Graph::fit` is called, the stages are executed in a
 * topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method will be
 * called on the input tables (from the input edges) to fit a model. Then the model, which is a
 * Transformer, will be used to transform the input tables to produce output tables to the output
 * edges. If a stage is a Transformer or MLFunc, its `MLFunc::transform` method will be called on the
 * input tables to produce output tables to the output edges. The fitted model from a Graph is a
 * GraphModelgraphTransformer, which consists of fitted models and transformers, corresponding to the Graph's
 * stages.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, GraphModel>graphTransformer> {
    public Graph(...) {...}

    @Override
    public GraphModelgraphTransformer fit(Table... inputs) {...}

    /** Skipped a few methods, including the implementations of some Estimator APIs. */
}

/**
 * A GraphModelgraphTransformer acts as a Transformer. A GraphModelgraphTransformer consists of a DAG of Transformers or MLFuncs. When
 * `GraphModel`graphTransformer::transform` is called, the stages are executed in a topologically-sorted order. When
 * a stage is executed, its `MLFunc::transform` method will be called on the input tables (from
 * the input edges) to produce output tables to the output edges.
 */
public final class GraphModelgraphTransformer implements Transformer<GraphModel>Transformer<graphTransformer> {
    /** Skipped a few methods, including the implementations of the Transformer APIs. */
}

/**
 * A GraphBuilder provides APIs to build Graph and GraphModelgraphTransformer from a DAG of Estimator, Transformer and MLFunc instances.
 */
@PublicEvolving
public final class GraphBuilder {
    /**
     * Specifies the upper bound (could be loose) of the number of output tables that can be
     * returned by the Transformer::getStateStreams and Transformer::transform methods, for any
     * stage involved in this Graph.
     *
     * <p>The default upper bound is 20.
     */
    public GraphBuilder setMaxOutputLength(int maxOutputLength) {...}

    /**
     * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of
     * tables between stages, as well as the input/output tables of the Graph/GraphModelgraphTransformer generated
     * by this builder.
     */
    public TableId createTableId() {...}

    /**
     * If the stage is an Estimator, both its fit method and the transform method of its fitted Transformer would be
     * invoked with the given inputs when the graph runs.
     *
     * <p>If this stage is a Transformer or MLfunc, its transform method would be invoked with the given
     * inputs when the graph runs.
     *
     * Returns a list of TableIds, which represents outputs of the Transformer::transform invocation.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId... inputs) {...}

    /**
     * If this stage is an Estimator, its fit method would be invoked with estimatorInputs, and the transform method
     * of its fitted Transformer would be invoked with transformerInputs, when the graph runs.
     *
     * <p>This method throws Exception if the stage is a Transformer or MLFunc.
     *
     * This method is useful when the state is an Estimator AND the Estimator::fit needs to take a different list of
     * Tables from the Transformer::transform of the fitted Transformer.
     *
     * Returns a list of TableIds, which represents outputs of the Transformer::transform invocation.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId[] estimatorInputs, TableId[] transformerInputs) {...}

    /**
     * The GraphModelgraphTransformer::setStateStreams should invoke the setStateStreams of the corresponding stage
     * with the corresponding inputs.
     */
    void setStateStreams(Stage<?> stage, TableId... inputs) {...}

    /**
     * The GraphModelgraphTransformer::getStateStreams should invoke the getStateStreams of the corresponding stage.
     *
     * <p>Returns a list of TableIds, which represents outputs of the getStateStreams invocation.
     */
    TableId[] getStateStreams(Stage<?> stage) {...}

    /**
     * Returns a Graph instance which the following API specification:
     *
     * 1) Graph::fit should take inputs and returns a GraphModelgraphTransformer with the following specification.
     *
     * 2) GraphModelgraphTransformer::transform should take inputs and return outputs.
     *
     * 3) GraphModelgraphTransformer::setStateStreams should take inputStates.
     *
     * 4) GraphModelgraphTransformer::getStateStreams should return outputStates.
     *
     * The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal stages in the order specified by the DAG of stages.
     */
    Graph build(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /**
     * Returns a Graph instance which the following API specification:
     *
     * 1) Graph::fit should take estimatorInputs and returns a GraphModelgraphTransformer with the following specification.
     *
     * 2) GraphModelgraphTransformer::transform should take transformerInputs and return outputs.
     *
     * 3) GraphModelgraphTransformer::setStateStreams should take inputStates.
     *
     * 4) GraphModelgraphTransformer::getStateStreams should return outputStates.
     *
     * The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal stages in the order specified by the DAG of stages.
     *
     * This method is useful when the Graph::fit needs to take a different list of Tables from the GraphModelgraphTransformer::transform of the fitted GraphModelgraphTransformer.
     */
    Graph build(TableId[] estimatorInputs, TableId[] transformerInputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /**
     * Returns a GraphModelgraphTransformer instance which the following API specification:
     *
     * 1) GraphModelgraphTransformer::transform should take inputs and returns outputs.
     *
     * 2) GraphModelgraphTransformer::setStateStreams should take inputStates.
     *
     * 3) GraphModelgraphTransformer::getStateStreams should return outputStates.
     *
     * The transform/setStateStreams/getStateStreams should invoke the APIs of the internal  stages in the order specified by the DAG of stages.
     *
     * This method throws Exception if any stage of this graph is an Estimator.
     */
    GraphModelgraphTransformer buildModelbuildTransformer(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    // The TableId is necessary to pass the inputs/outputs of various API calls across the
    // Graph/GraphModelgraphTransformer stags.
    static class TableId {}

}

...

  • The method takes 2 input tables. The 1st input table is given to a TransformerA instance. And the 2nd input table is given to another TransformerA instance.
  • An EstimatorB instance fits the output tables of these two TransformerA instances and generates a new TransformerB instance.
  • Returns a GraphModel graphTransformer instance which contains 2 TransformerA instance and 1 TransformerB instance, connected using the same DAG as shown above.

...

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new TransformerA();
Stage<?> stage2 = new TransformerA();
Stage<?> stage3 = new EstimatorB();
// Creates inputs and inputStates
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, input1)[0];
TableId output2 = builder.getOutputs(stage2, input2)[0];
TableId output3 = builder.getOutputs(stage3, output1, output2)[0];

// Specifies the ordered lists of inputs, outputs, input states and output states that will
// be used as the inputs/outputs of the corresponding Graph and GraphModelgraphTransformer APIs.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};

// Generates the Graph instance.
Graph graph = builder.build(inputs, outputs, new TableId[]{}, new TableId[]{});
// The fit method takes 2 tables which are mapped to input1 and input2.
GraphModelgraphTransformer modeltransformer = graph.fit(...);
// The transform method takes 2 tables which are mapped to input1 and input2.
Table[] results = modeltransformer.transform(...);

Online learning by running Transformer and Estimator concurrently on different machines

...

  • The method takes 2 input tables. Both tables are given to EstimatorA::fit.
  • EstimatorA fits the input tables and generates a TransformerA instance. The TransformerA instance takes 1 table input, which is different from the 2 tables given to the EstimatorA.
  • Returns a GraphModel graphTransformer instance which contains a TransformerA instance and a TransformerB instance, which are connected as a chain.

The fitted GraphModel graphTransformer is represented by the following DAG:

Notes:

  • The fitted GraphModel graphTransformer takes only 1 table as input whereas the Graph takes 2 tables as inputs.
  • The proposed APIs also support composing an Estimator from a DAG of Estimator/Transformer whose input schemas are different from its fitted Transformer. 

...

Code Block
languagejava
GraphBuilder builder = new GraphBuilder();

// Creates nodes
Stage<?> stage1 = new EstimatorA();
Stage<?> stage2 = new TransformerB();
// Creates inputs
TableId estimatorInput1 = builder.createTableId();
TableId estimatorInput2 = builder.createTableId();
TableId transformerInput1 = builder.createTableId();

// Feeds inputs to nodes and gets outputs.
TableId output1 = builder.getOutputs(stage1, new TableId[] {estimatorInput1, estimatorInput2}, new TableId[] {transformerInput1})[0];
TableId output2 = builder.getOutputs(stage2, output1)[0];

// Specifies the ordered lists of estimator inputs, transformer inputs, outputs, input states and output states
// that will be used as the inputs/outputs of the corresponding Graph and GraphModelgraphTransformer APIs.
TableId[] estimatorInputs = new TableId[] {estimatorInput1, estimatorInput2};
TableId[] transformerInputs = new TableId[] {transformerInput1};
TableId[] outputs = new TableId[] {output2};
TableId[] inputStates = new TableId[] {};
TableId[] outputStates = new TableId[] {};

// Generates the Graph instance.
Graph graph = builder.build(estimatorInputs, transformerInputs, outputs, inputStates, outputStates);
// The fit method takes 2 tables which are mapped to estimatorInput1 and estimatorInput2.
GraphModelgraphTransformer modeltransformer = graph.fit(...);
// The transform method takes 1 table which is mapped to transformerInput1.
Table[] results = modeltransformer.transform(...);

Compatibility, Deprecation, and Migration Plan

...