...
Therefore, this design doc proposes to add the Graph/GraphModelgraphTransformer/GraphBuilder classes to provide the following capability:
...
This change addresses the use-cases described in the motivation section, e.g. a graph embedding Estimator needs to take 2 tables as inputs.
3) Added Graph, GraphModel graphTransformer and GraphBuilder.
This change addresses the use-cases described in the motivation section, where we need to compose an Estimator from a DAG of Estimator/Transformer. Note that the Graph/GraphBuilder supports Estimator class whose input schemas are different from its fitted Transformer.
...
This change simplifies the usage of fit/transform APIs.
7) Added PipelineModel pipelineTransformer and let Pipeline implement only the Estimator. Pipeline is no longer a Transformer.
...
9) Removed the Model interface. And renamed PipelineModel to PipelineTransformer.
This change simplifies the class hierarchy by removing a redundant class. It follows the philosophy of only adding complexity when we have explicit use-case for it.
...
The following code block shows the interface of Stage, Transformer, Estimator, Pipeline and PipelineModel and pipelineTransformer after making the changes listed above.
...
Code Block | ||
---|---|---|
| ||
/** * Base class for a stage in a Pipeline or Graph. The interface is only a concept, and does not have any actual * functionality. Its subclasses could be Estimator, Transformer or MLfunc. No other classes should inherit this * interface directly. * * <p>Each stage is with parameters, and requires a public empty constructor for restoration. * * @param <T> The class type of the Stage implementation itself. * @see WithParams */ @PublicEvolving interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable { /** * Saves this stage to the given path. */ void save(String path); /** * Loads this stage from the given path. */ void load(String path); } /** * A MLFunc is a Stage that takes a list of tables as inputs and produces a list of * tables as results. It can be used to encode a generic multi-input multi-output machine learning function. * * @param <T> The class type of the MLFunc implementation itself. */ @PublicEvolving public interface MLFunc<T extends MLFunc<T>> extends Stage<T> { /** * Applies the MLFunc on the given input tables, and returns the result tables. * * @param inputs a list of tables * @return a list of tables */ Table[] transform(Table... inputs); } /** * A Transformer is a MLFunc with additional support for state streams, which could be set by the Estimator that fitted * this Transformer. Unlike MLFunc, a Transformer is typically associated with an Estimator. * * @param <T> The class type of the Transformer implementation itself. */ @PublicEvolving public interface Transformer<T extends Transformer<T>> extends MLFunc<T> { /** * Uses the given list of tables to update internal states. This can be useful for e.g. online * learning where an Estimator fits an infinite stream of training samples and streams the model * diff data to this Transformer. * * <p>This method may be called at most once. * * @param inputs a list of tables */ default void setStateStreams(Table... inputs) { throw new UnsupportedOperationException("this method is not implemented"); } /** * Gets a list of tables representing changes of internal states of this Transformer. These * tables might come from the Estimator that instantiated this Transformer. * * @return a list of tables */ default Table[] getStateStreams() { throw new UnsupportedOperationException("this method is not implemented"); } } /** * An Estimator is a Stage that takes a list of tables as inputs and produces a Transformer. * * @param <E> class type of the Estimator implementation itself. * @param <M> class type of the Transformer this Estimator produces. */ @PublicEvolving public interface Estimator<E extends Estimator<E, M>, M extends Transformer<M>> extends Stage<E> { /** * Trains on the given inputs and produces a Transformer. * * @param inputs a list of tables * @return a Transformer */ M fit(Table... inputs); } /** * A Pipeline acts as an Estimator. It consists of an ordered list of stages, each of which could be * an Estimator, Transformer or MLFunc. */ @PublicEvolving public final class Pipeline implements Estimator<Pipeline, PipelineModel>pipelineTransformer> { public Pipeline(List<Stage<?>> stages) {...} @Override public PipelineModelpipelineTransformer fit(Table... inputs) {...} /** Skipped a few methods, including the implementations of the Estimator APIs. */ } /** * A PipelineModelpipelineTransformer acts as a Transformer. It consists of an ordered list of Transformers or MLFuncs. */ @PublicEvolving public final class PipelineModelpipelineTransformer implements Transformer<PipelineModel>Transformer<pipelineTransformer> { public PipelineModelpipelineTransformer(List<Transformer<?>> transformers) {...} /** Skipped a few methods, including the implementations of the Transformer APIs. */ } |
...
The following code block shows the interface of Graph, GraphModel graphTransformer and GraphBuilder that we propose to add.
Code Block | ||
---|---|---|
| ||
/** * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be * an Estimator, Transformer or MLFunc. When `Graph::fit` is called, the stages are executed in a * topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method will be * called on the input tables (from the input edges) to fit a model. Then the model, which is a * Transformer, will be used to transform the input tables to produce output tables to the output * edges. If a stage is a Transformer or MLFunc, its `MLFunc::transform` method will be called on the * input tables to produce output tables to the output edges. The fitted model from a Graph is a * GraphModelgraphTransformer, which consists of fitted models and transformers, corresponding to the Graph's * stages. */ @PublicEvolving public final class Graph implements Estimator<Graph, GraphModel>graphTransformer> { public Graph(...) {...} @Override public GraphModelgraphTransformer fit(Table... inputs) {...} /** Skipped a few methods, including the implementations of some Estimator APIs. */ } /** * A GraphModelgraphTransformer acts as a Transformer. A GraphModelgraphTransformer consists of a DAG of Transformers or MLFuncs. When * `GraphModel`graphTransformer::transform` is called, the stages are executed in a topologically-sorted order. When * a stage is executed, its `MLFunc::transform` method will be called on the input tables (from * the input edges) to produce output tables to the output edges. */ public final class GraphModelgraphTransformer implements Transformer<GraphModel>Transformer<graphTransformer> { /** Skipped a few methods, including the implementations of the Transformer APIs. */ } /** * A GraphBuilder provides APIs to build Graph and GraphModelgraphTransformer from a DAG of Estimator, Transformer and MLFunc instances. */ @PublicEvolving public final class GraphBuilder { /** * Specifies the upper bound (could be loose) of the number of output tables that can be * returned by the Transformer::getStateStreams and Transformer::transform methods, for any * stage involved in this Graph. * * <p>The default upper bound is 20. */ public GraphBuilder setMaxOutputLength(int maxOutputLength) {...} /** * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of * tables between stages, as well as the input/output tables of the Graph/GraphModelgraphTransformer generated * by this builder. */ public TableId createTableId() {...} /** * If the stage is an Estimator, both its fit method and the transform method of its fitted Transformer would be * invoked with the given inputs when the graph runs. * * <p>If this stage is a Transformer or MLfunc, its transform method would be invoked with the given * inputs when the graph runs. * * Returns a list of TableIds, which represents outputs of the Transformer::transform invocation. */ public TableId[] getOutputs(Stage<?> stage, TableId... inputs) {...} /** * If this stage is an Estimator, its fit method would be invoked with estimatorInputs, and the transform method * of its fitted Transformer would be invoked with transformerInputs, when the graph runs. * * <p>This method throws Exception if the stage is a Transformer or MLFunc. * * This method is useful when the state is an Estimator AND the Estimator::fit needs to take a different list of * Tables from the Transformer::transform of the fitted Transformer. * * Returns a list of TableIds, which represents outputs of the Transformer::transform invocation. */ public TableId[] getOutputs(Stage<?> stage, TableId[] estimatorInputs, TableId[] transformerInputs) {...} /** * The GraphModelgraphTransformer::setStateStreams should invoke the setStateStreams of the corresponding stage * with the corresponding inputs. */ void setStateStreams(Stage<?> stage, TableId... inputs) {...} /** * The GraphModelgraphTransformer::getStateStreams should invoke the getStateStreams of the corresponding stage. * * <p>Returns a list of TableIds, which represents outputs of the getStateStreams invocation. */ TableId[] getStateStreams(Stage<?> stage) {...} /** * Returns a Graph instance which the following API specification: * * 1) Graph::fit should take inputs and returns a GraphModelgraphTransformer with the following specification. * * 2) GraphModelgraphTransformer::transform should take inputs and return outputs. * * 3) GraphModelgraphTransformer::setStateStreams should take inputStates. * * 4) GraphModelgraphTransformer::getStateStreams should return outputStates. * * The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal stages in the order specified by the DAG of stages. */ Graph build(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...} /** * Returns a Graph instance which the following API specification: * * 1) Graph::fit should take estimatorInputs and returns a GraphModelgraphTransformer with the following specification. * * 2) GraphModelgraphTransformer::transform should take transformerInputs and return outputs. * * 3) GraphModelgraphTransformer::setStateStreams should take inputStates. * * 4) GraphModelgraphTransformer::getStateStreams should return outputStates. * * The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal stages in the order specified by the DAG of stages. * * This method is useful when the Graph::fit needs to take a different list of Tables from the GraphModelgraphTransformer::transform of the fitted GraphModelgraphTransformer. */ Graph build(TableId[] estimatorInputs, TableId[] transformerInputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...} /** * Returns a GraphModelgraphTransformer instance which the following API specification: * * 1) GraphModelgraphTransformer::transform should take inputs and returns outputs. * * 2) GraphModelgraphTransformer::setStateStreams should take inputStates. * * 3) GraphModelgraphTransformer::getStateStreams should return outputStates. * * The transform/setStateStreams/getStateStreams should invoke the APIs of the internal stages in the order specified by the DAG of stages. * * This method throws Exception if any stage of this graph is an Estimator. */ GraphModelgraphTransformer buildModelbuildTransformer(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...} // The TableId is necessary to pass the inputs/outputs of various API calls across the // Graph/GraphModelgraphTransformer stags. static class TableId {} } |
...
- The method takes 2 input tables. The 1st input table is given to a TransformerA instance. And the 2nd input table is given to another TransformerA instance.
- An EstimatorB instance fits the output tables of these two TransformerA instances and generates a new TransformerB instance.
- Returns a GraphModel graphTransformer instance which contains 2 TransformerA instance and 1 TransformerB instance, connected using the same DAG as shown above.
...
Code Block | ||
---|---|---|
| ||
GraphBuilder builder = new GraphBuilder(); // Creates nodes Stage<?> stage1 = new TransformerA(); Stage<?> stage2 = new TransformerA(); Stage<?> stage3 = new EstimatorB(); // Creates inputs and inputStates TableId input1 = builder.createTableId(); TableId input2 = builder.createTableId(); // Feeds inputs to nodes and gets outputs. TableId output1 = builder.getOutputs(stage1, input1)[0]; TableId output2 = builder.getOutputs(stage2, input2)[0]; TableId output3 = builder.getOutputs(stage3, output1, output2)[0]; // Specifies the ordered lists of inputs, outputs, input states and output states that will // be used as the inputs/outputs of the corresponding Graph and GraphModelgraphTransformer APIs. TableId[] inputs = new TableId[] {input1, input2}; TableId[] outputs = new TableId[] {output3}; // Generates the Graph instance. Graph graph = builder.build(inputs, outputs, new TableId[]{}, new TableId[]{}); // The fit method takes 2 tables which are mapped to input1 and input2. GraphModelgraphTransformer modeltransformer = graph.fit(...); // The transform method takes 2 tables which are mapped to input1 and input2. Table[] results = modeltransformer.transform(...); |
Online learning by running Transformer and Estimator concurrently on different machines
...
- The method takes 2 input tables. Both tables are given to EstimatorA::fit.
- EstimatorA fits the input tables and generates a TransformerA instance. The TransformerA instance takes 1 table input, which is different from the 2 tables given to the EstimatorA.
- Returns a GraphModel graphTransformer instance which contains a TransformerA instance and a TransformerB instance, which are connected as a chain.
The fitted GraphModel graphTransformer is represented by the following DAG:
Notes:
- The fitted GraphModel graphTransformer takes only 1 table as input whereas the Graph takes 2 tables as inputs.
- The proposed APIs also support composing an Estimator from a DAG of Estimator/Transformer whose input schemas are different from its fitted Transformer.
...
Code Block | ||
---|---|---|
| ||
GraphBuilder builder = new GraphBuilder(); // Creates nodes Stage<?> stage1 = new EstimatorA(); Stage<?> stage2 = new TransformerB(); // Creates inputs TableId estimatorInput1 = builder.createTableId(); TableId estimatorInput2 = builder.createTableId(); TableId transformerInput1 = builder.createTableId(); // Feeds inputs to nodes and gets outputs. TableId output1 = builder.getOutputs(stage1, new TableId[] {estimatorInput1, estimatorInput2}, new TableId[] {transformerInput1})[0]; TableId output2 = builder.getOutputs(stage2, output1)[0]; // Specifies the ordered lists of estimator inputs, transformer inputs, outputs, input states and output states // that will be used as the inputs/outputs of the corresponding Graph and GraphModelgraphTransformer APIs. TableId[] estimatorInputs = new TableId[] {estimatorInput1, estimatorInput2}; TableId[] transformerInputs = new TableId[] {transformerInput1}; TableId[] outputs = new TableId[] {output2}; TableId[] inputStates = new TableId[] {}; TableId[] outputStates = new TableId[] {}; // Generates the Graph instance. Graph graph = builder.build(estimatorInputs, transformerInputs, outputs, inputStates, outputStates); // The fit method takes 2 tables which are mapped to estimatorInput1 and estimatorInput2. GraphModelgraphTransformer modeltransformer = graph.fit(...); // The transform method takes 1 table which is mapped to transformerInput1. Table[] results = modeltransformer.transform(...); |
Compatibility, Deprecation, and Migration Plan
...