Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

In this scenario, we need to allow the Estimator to be run on a different machine than the Transformer. So that Estimator could consume sufficient computation resource in a cluster while the Transformer could be deployed on edge devices.

5) Provide APIs to allow Estimator/Transformer to be efficiently saved/loaded even if state (e.g. model data) of Estimator/Transformer is more than 10s of GBs.

The existing PipelineStage::toJson basically requires developer of Transformer/Estimator to serialize all model data into an in-memory string, which could be very inefficient (or practically impossible) if the model data is very large (e.g 10s of GBs).


In addition to addressing the above use-cases, this FLIP also proposes a few more changes to simplify the class hierarchy and improve API usability. The existing Flink ML library has the following usability issues:
5
6) The Model interface (which currently simply extends the Transformer interface without adding any extra logic) does not provide any added value. The added class hierarchy complexity is not justified.

67) fit/transform API requires users to explicitly provide the TableEnvironment, where the TableEnvironment could be retrieved from the Table instance given to the fit/transform.

78) A Pipeline is both a Transformer and an Estimator. The experience of using Pipeline is therefore different from the experience of using Estimator (with the needFit API).

89) There is no API provided by the Estimator/Transformer interface to validate the schema consistency of a Pipeline. Users would have to instantiate Tables (with I/O logics) and run fit/transform to know whether the stages in the Pipeline are compatible with each other.

...

Here we list the additions and the changes to the Flink ML API.


The following changes are the most important changes proposed by this doc:

1) Updated Transformer/Estimator to take list of tables as inputs and return list of tables as output.

...

This change addresses the use-cases described in the motivation section, where a running Transformer needs to ingest the model state streams emitted by a Estimator, which could be running on a different machine., which could be running on a different machine.

5) Remove the methods PipelineStage::toJson and PipelineStage::loadJson. Add methods save(...) and load(...) to the Stage interface.


The following changes are relatively minor:

65) Removed TableEnvironment from the parameter list of fit/transform APIs.

This change simplifies the usage of fit/transform APIs.

67) Added PipelineModel and let Pipeline implement only the Estimator. Pipeline is no longer a Transformer.

This change makes the experience of using Pipeline consistent with the experience of using Estimator/Transformer, where a class is either an Estimator or a Transformer.

78) Removed Pipeline::appendStage from the Pipeline class.

This change makes the concept of Pipeline consistent with that of Graph/GraphBuilder. Neither Graph nor Pipeline provides the API to construct themselves.

89) Removed the Model interface.

...

The following code block shows the interface of Stage, Transformer, Estimator, Pipeline and PipelineModel after making the proposed changes listed above.


Code Block
languagejava
@PublicEvolving
interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable {
    /**
     * This method checks the compatibility between input schemas, stage's parameters and stage's
     * logic. It should raise an exception if there is any mismatch, e.g. the number of input
     * schemas is wrong, or if a required field is missing from a schema.
     *
     * <p>If there is no mismatch, the method derives and returns the output schemas from the input
     * schemas.
     *
     * <p>Note that the output schemas of a given Estimator instance should equal the output schemas
     * of the Transformer instance fitted by this Estimator instance, suppose the same list of input
     * schemas are used as inputs to the fit/transform methods respectively.
     *
     * @param schemas the list of schemas of the input tables.
     * @return the list of schemas of the output tables of schemas of the output tables.
     */
    TableSchema[] transformSchemas(TableSchema... schemas);

    /**
     * Saves this stage to the given path.
     */
    TableSchema[]void transformSchemas(TableSchema... schemassave(String path);

    /**
    Skipped */
 Loads this stage defaultfrom Stringthe toJson() {...}

given path.
    /** Skipped */
    default void loadJsonload(String json) {...}path);
}

@PublicEvolving
public interface Transformer<T extends Transformer<T>> extends Stage<T> {

    /**
     * Applies the Transformer on the given input tables, and returns the result tables.
     *
     * @param inputs a list of tables
     * @return a list of tables
     */
    Table[] transform(Table... inputs);

    /**
     * Uses the given list of tables to update internal states. This can be useful for e.g. online
     * learning where an Estimator fits an infinite stream of training samples and streams the model
     * diff data to this Transformer.
     *
     * <p>This method may be called at most once.
     *
     * @param inputs a list of tables
     */
    default void setStateStreams(Table... inputs) {
        throw new UnsupportedOperationException("this method is not implemented");
    }

    /**
     * Gets a list of tables representing changes of internal states of this Transformer. These
     * tables might come from the Estimator that instantiated this Transformer.
     *
     * @return a list of tables
     */
    default Table[] getStateStreams() {
        throw new UnsupportedOperationException("this method is not implemented");
    }
}

@PublicEvolving
public interface Estimator<E extends Estimator<E, M>, M extends Transformer<M>> extends Stage<E> {

    /**
     * Trains on the given inputs and produces a Transformer. If this Estimator may be used to
     * compose a Pipeline, the transform method of the returned Transformer should be able to accept
     * a list of tables of the same length and schemas as the fit method of this Estimator.
     *
     * @param inputs a list of tables
     * @return a Transformer
     */
    M fit(Table... inputs);
}

@PublicEvolving
public final class Pipeline implements Estimator<Pipeline, PipelineModel> {

    public Pipeline(List<Stage<?>> stages) {...}

    @Override
    public PipelineModel fit(Table... inputs) {...}

    /** Skipped a few methods, including the implementations of the Estimator APIs. */
}

@PublicEvolving
public final class PipelineModel implements Transformer<PipelineModel> {

    public PipelineModel(List<Transformer<?>> transformers) {...}

    /** Skipped a few methods, including the implementations of the Transformer APIs. */
}

...

Code Block
languagejava
void runTrainingOnClusterA(...) {
  // Creates the training stream from a Kafka topic.
  Table training_stream = ...;

  Estimator estimator = new EstimatorA(...);
  Transformer transformer = estimator.fit(training_stream);
  Table state_stream = transformer.getStateStreams()[0];
  String transformer_json = transformer.toJson();

  // Writes the state_stream to a Kafka topicA.Kafka topicA.
  state_stream.sinkTable(...);
  // WritesSaves transformer_json's state/metadata to a remote path.
  file.transformer.save(remote_path);

  // Executes the operators generated by the Estimator::fit(...), which reads from training_stream and writes to state_stream.
  env.execute()
}

...

Code Block
languagejava
void runInferenceOnWebServer(...) {

  // Reads the transformer_json from the same remote file written by the above code snippet.
  String transformer_json = ...;
  // Creates the state stream from Kafka topicA which is written by the above code snippet. 
  Table state_stream = ...;
  // Creates the input stream that needs inference.
  Table input_stream = ...;

  Transformer transformer = new Transformer(...);
  transformer.loadJsonload(transformerremote_jsonpath);
  transformer.setStateStreams(new Table[]{state_stream});
  Table output_stream = transformer.transform(input_stream);

  // Do something with the output_stream.

  // Executes the operators generated by the Transformer::transform(...), which reads from state_stream to update its parameters. 
  // It also does inference on input_stream and produces results to the output_stream.
  env.execute()
}

...