Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

The following changes are the most important changes proposed by this doc:

1) Added the MLFunc AlgoOperator class. MLFunc AlgoOperator class has the same interface as the existing Transformer (i.e. has the transform method).

...

Code Block
languagejava
/**
 * Base class for a stage in a Pipeline or Graph. The interface is only a concept, and does not have any actual
 * functionality. Its subclasses could be Estimator, Transformer or MLfuncAlgoOperator. No other classes should inherit this
 * interface directly.
 *
 * <p>Each stage is with parameters, and requires a public empty constructor for restoration.
 *
 * @param <T> The class type of the Stage implementation itself.
 * @see WithParams
 */
@PublicEvolving
interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable {
    /**
     * Saves this stage to the given path.
     */
    void save(String path);

    /**
     * Loads this stage from the given path.
     */
    void load(String path);
}

/**
 * A MLFuncAlgoOperator is a Stage that takes a list of tables as inputs and produces a list of
 * tables as results. It can be used to encode a generic multi-input multi-output machine learning function.
 *
 * @param <T> The class type of the MLFuncAlgoOperator implementation itself.
 */
@PublicEvolving
public interface MLFunc<TAlgoOperator<T extends MLFunc<T>>AlgoOperator<T>> extends Stage<T> {
    /**
     * Applies the MLFuncAlgoOperator on the given input tables, and returns the result tables.
     *
     * @param inputs a list of tables
     * @return a list of tables
     */
    Table[] transform(Table... inputs);
}

/**
 * A Transformer is a MLFuncAlgoOperator with additional support for state streams, which could be set by the Estimator that fitted
 * this Transformer. Unlike MLFuncAlgoOperator, a Transformer is typically associated with an Estimator.
 *
 * @param <T> The class type of the Transformer implementation itself.
 */
@PublicEvolving
public interface Transformer<T extends Transformer<T>> extends MLFunc<T>AlgoOperator<T> {
    /**
     * Uses the given list of tables to update internal states. This can be useful for e.g. online
     * learning where an Estimator fits an infinite stream of training samples and streams the model
     * diff data to this Transformer.
     *
     * <p>This method may be called at most once.
     *
     * @param inputs a list of tables
     */
    default void setStateStreams(Table... inputs) {
        throw new UnsupportedOperationException("this method is not implemented");
    }

    /**
     * Gets a list of tables representing changes of internal states of this Transformer. These
     * tables might come from the Estimator that instantiated this Transformer.
     *
     * @return a list of tables
     */
    default Table[] getStateStreams() {
        throw new UnsupportedOperationException("this method is not implemented");
    }
}

/**
 * An Estimator is a Stage that takes a list of tables as inputs and produces a Transformer.
 *
 * @param <E> class type of the Estimator implementation itself.
 * @param <M> class type of the Transformer this Estimator produces.
 */
@PublicEvolving
public interface Estimator<E extends Estimator<E, M>, M extends Transformer<M>> extends Stage<E> {
    /**
     * Trains on the given inputs and produces a Transformer.
     *
     * @param inputs a list of tables
     * @return a Transformer
     */
    M fit(Table... inputs);
}

/**
 * A Pipeline acts as an Estimator. It consists of an ordered list of stages, each of which could be
 * an Estimator, Transformer or MLFuncAlgoOperator.
 */
@PublicEvolving
public final class Pipeline implements Estimator<Pipeline, pipelineTransformer> {

    public Pipeline(List<Stage<?>> stages) {...}

    @Override
    public pipelineTransformer fit(Table... inputs) {...}

    /** Skipped a few methods, including the implementations of the Estimator APIs. */
}

/**
 * A pipelineTransformer acts as a Transformer. It consists of an ordered list of Transformers or MLFuncsAlgoOperators.
 */
@PublicEvolving
public final class pipelineTransformer implements Transformer<pipelineTransformer> {

    public pipelineTransformer(List<Transformer<?>> transformers) {...}

    /** Skipped a few methods, including the implementations of the Transformer APIs. */
}

...

Code Block
languagejava
/**
 * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be
 * an Estimator, Transformer or MLFuncAlgoOperator. When `Graph::fit` is called, the stages are executed in a
 * topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method will be
 * called on the input tables (from the input edges) to fit a model. Then the model, which is a
 * Transformer, will be used to transform the input tables to produce output tables to the output
 * edges. If a stage is a Transformer or MLFuncAlgoOperator, its `MLFunc`AlgoOperator::transform` method will be called on the
 * input tables to produce output tables to the output edges. The fitted model from a Graph is a
 * graphTransformer, which consists of fitted models and transformers, corresponding to the Graph's
 * stages.
 */
@PublicEvolving
public final class Graph implements Estimator<Graph, graphTransformer> {
    public Graph(...) {...}

    @Override
    public graphTransformer fit(Table... inputs) {...}

    /** Skipped a few methods, including the implementations of some Estimator APIs. */
}

/**
 * A graphTransformer acts as a Transformer. A graphTransformer consists of a DAG of Transformers or MLFuncsAlgoOperators. When
 * `graphTransformer::transform` is called, the stages are executed in a topologically-sorted order. When
 * a stage is executed, its `MLFunc`AlgoOperator::transform` method will be called on the input tables (from
 * the input edges) to produce output tables to the output edges.
 */
public final class graphTransformer implements Transformer<graphTransformer> {
    /** Skipped a few methods, including the implementations of the Transformer APIs. */
}

/**
 * A GraphBuilder provides APIs to build Graph and graphTransformer from a DAG of Estimator, Transformer and MLFuncAlgoOperator instances.
 */
@PublicEvolving
public final class GraphBuilder {
    /**
     * Specifies the upper bound (could be loose) of the number of output tables that can be
     * returned by the Transformer::getStateStreams and Transformer::transform methods, for any
     * stage involved in this Graph.
     *
     * <p>The default upper bound is 20.
     */
    public GraphBuilder setMaxOutputLength(int maxOutputLength) {...}

    /**
     * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of
     * tables between stages, as well as the input/output tables of the Graph/graphTransformer generated
     * by this builder.
     */
    public TableId createTableId() {...}

    /**
     * If the stage is an Estimator, both its fit method and the transform method of its fitted Transformer would be
     * invoked with the given inputs when the graph runs.
     *
     * <p>If this stage is a Transformer or MLfuncAlgoOperator, its transform method would be invoked with the given
     * inputs when the graph runs.
     *
     * Returns a list of TableIds, which represents outputs of the Transformer::transform invocation.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId... inputs) {...}

    /**
     * If this stage is an Estimator, its fit method would be invoked with estimatorInputs, and the transform method
     * of its fitted Transformer would be invoked with transformerInputs, when the graph runs.
     *
     * <p>This method throws Exception if the stage is a Transformer or MLFuncAlgoOperator.
     *
     * This method is useful when the state is an Estimator AND the Estimator::fit needs to take a different list of
     * Tables from the Transformer::transform of the fitted Transformer.
     *
     * Returns a list of TableIds, which represents outputs of the Transformer::transform invocation.
     */
    public TableId[] getOutputs(Stage<?> stage, TableId[] estimatorInputs, TableId[] transformerInputs) {...}

    /**
     * The graphTransformer::setStateStreams should invoke the setStateStreams of the corresponding stage
     * with the corresponding inputs.
     */
    void setStateStreams(Stage<?> stage, TableId... inputs) {...}

    /**
     * The graphTransformer::getStateStreams should invoke the getStateStreams of the corresponding stage.
     *
     * <p>Returns a list of TableIds, which represents outputs of the getStateStreams invocation.
     */
    TableId[] getStateStreams(Stage<?> stage) {...}

    /**
     * Returns a Graph instance which the following API specification:
     *
     * 1) Graph::fit should take inputs and returns a graphTransformer with the following specification.
     *
     * 2) graphTransformer::transform should take inputs and return outputs.
     *
     * 3) graphTransformer::setStateStreams should take inputStates.
     *
     * 4) graphTransformer::getStateStreams should return outputStates.
     *
     * The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal stages in the order specified by the DAG of stages.
     */
    Graph build(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /**
     * Returns a Graph instance which the following API specification:
     *
     * 1) Graph::fit should take estimatorInputs and returns a graphTransformer with the following specification.
     *
     * 2) graphTransformer::transform should take transformerInputs and return outputs.
     *
     * 3) graphTransformer::setStateStreams should take inputStates.
     *
     * 4) graphTransformer::getStateStreams should return outputStates.
     *
     * The fit/transform/setStateStreams/getStateStreams should invoke the APIs of the internal stages in the order specified by the DAG of stages.
     *
     * This method is useful when the Graph::fit needs to take a different list of Tables from the graphTransformer::transform of the fitted graphTransformer.
     */
    Graph build(TableId[] estimatorInputs, TableId[] transformerInputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    /**
     * Returns a graphTransformer instance which the following API specification:
     *
     * 1) graphTransformer::transform should take inputs and returns outputs.
     *
     * 2) graphTransformer::setStateStreams should take inputStates.
     *
     * 3) graphTransformer::getStateStreams should return outputStates.
     *
     * The transform/setStateStreams/getStateStreams should invoke the APIs of the internal  stages in the order specified by the DAG of stages.
     *
     * This method throws Exception if any stage of this graph is an Estimator.
     */
    graphTransformer buildTransformer(TableId[] inputs, TableId[] outputs, TableId[] inputStates, TableId[] outputStates) {...}

    // The TableId is necessary to pass the inputs/outputs of various API calls across the
    // Graph/graphTransformer stags.
    static class TableId {}

}

...