Status
Current state: Accepted
Discussion thread: https://lists.apache.org/thread.html/r6729f351fb1bc13a93754c199d5fee1051cc8146e22374737c578779%40%3Cdev.flink.apache.org%3E
Voting thread: https://lists.apache.org/thread/2087m6t1d58lw3xngtwhws6xbd9fm30r
JIRA:
Released: Not released yet.
Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).
[This FLIP proposal is a joint work between Dong Lin and Zhipeng Zhang]
Motivation and Use-cases
The existing Flink ML library allows users to compose an Estimator/Transformer from a pipeline (i.e. linear sequence) of Estimator/Transformer, and each Estimator/Transformer has one input and one output.
The following use-cases are not supported yet. And we would like to address these use-cases with the changes proposed in this FLIP.
1) Express an Estimator/Transformer that has multiple inputs/outputs.
For example, some graph embedding algorithms (e.g., MetaPath2Vec) need to take two tables as inputs. These two tables represent nodes labels and edges of the graph respectively. This logic can be expressed as an Estimator with 2 input tables.
And some workflow may need to split 1 table into 2 tables, and use these tables for training and validation respectively. This logic can be expressed by a Transformer with 1 input table and 2 output tables.
2) Express a generic machine learning computation logic that does not have the "transformation" semantic.
We believe most machine learning engineers associate the name "Transformer" with the "transformation" semantic, where the a record in the output typically corresponds to one record in the input. Thus it is counter-intuitive to use Transformer to encode aggregation logic, where a record in the output corresponds to an arbitrary number of records in the input.
Therefore we need to have a class with a name different from "Transformer" to encode generic multi-input multi-output computation logic.
3) Online learning where a long-running Model instance needs to be continuously updated by the latest model data generated by another long-running Estimator instance.
In this scenario, we need to allow the Estimator to be run on a different machine than the Model, so that the Estimator could consume sufficient computation resource in a cluster while the Model could be deployed on edge devices.
4) Provide APIs to allow Estimator/Model to be efficiently saved/loaded even if state (e.g. model data) of Estimator/Model is more than 10s of GBs.
The existing PipelineStage::toJson basically requires developer of Estimator/Model to serialize all model data into an in-memory string, which could be very inefficient (or practically impossible) if the model data is very large (e.g 10s of GBs).
In addition to addressing the above use-cases, this FLIP also proposes a few more changes to simplify the class hierarchy and improve API usability. The existing Flink ML library has the following usability issues:
5) fit/transform API requires users to explicitly provide the TableEnvironment, where the TableEnvironment could be retrieved from the Table instance given to the fit/transform.
6) A Pipeline is currently both a Transformer and an Estimator. The experience of using Pipeline is inconsistent from the experience of using Estimator (with the needFit API).
Public Interfaces
This FLIP proposes quite a few changes and additions to the existing Flink ML APIs. We first describe the proposed API additions and changes, followed by the API code of interfaces and classes after making the proposed changes.
API additions and changes
Here we list the additions and the changes to the Flink ML API.
The following changes are the most important changes proposed by this doc:
1) Added the AlgoOperator class. AlgoOperator class has the same interface as the existing Transformer (i.e. provides the transform(...) API).
This change address the need to encode a generic multi-input multi-output machine learning function.
2) Updated fit/transform methods to take list of tables as inputs and return list of tables as output.
This change addresses the use-cases described in the motivation section, e.g. a graph embedding Estimator needs to take 2 tables as inputs.
3) Added setModelData and getModelData to the Model interface.
This change addresses the use-cases described in the motivation section, where a long-running Model instance needs to ingest the model state streams emitted by an Estimator, which could be running on a different machine.
4) Removed the methods PipelineStage::toJson and PipelineStage::loadJson. Added methods save(...) and load(...) to the Stage interface.
This change addresses the need to efficiently save/load a Model instance even if its model data is very large.
The following changes are relatively minor:
5) Removed TableEnvironment from the parameter list of fit/transform APIs.
This change simplifies the usage of fit/transform APIs.
6) Added pipelineModel and let Pipeline implement only the Estimator. Pipeline is no longer a Transformer.
This change makes the experience of using Pipeline consistent with the experience of using Estimator.
7) Removed Pipeline::appendStage from the Pipeline class.
8) Renamed PipelineStage to Stage and add the PublicEvolving tag to the Stage interface.
Interfaces and classes after the proposed API changes
The following code block shows the interface of Stage, Estimator, Model, Transformer and AlgoOperator, after making the changes listed above.
/** * Base class for a node in a Pipeline or Graph. The interface is only a concept, and does not have * any actual functionality. Its subclasses could be Estimator, Model, Transformer or AlgoOperator. * No other classes should inherit this interface directly. * * <p>Each stage is with parameters, and requires a public empty constructor for restoration. * * @param <T> The class type of the Stage implementation itself. */ @PublicEvolving public interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable { /** * Saves this stage to the given path. */ void save(String path) throws IOException; // NOTE: every Stage subclass should implement a static method with signature "static T // load(StreamExecutionEnvironment env, String path)", where T refers to the concrete // subclass. This static method should instantiate a new stage instance based on the data // read from the given path. } /** * An AlgoOperator takes a list of tables as inputs and produces a list of tables as results. It can * be used to encode generic multi-input multi-output computation logic. * * @param <T> The class type of the AlgoOperator implementation itself. */ @PublicEvolving public interface AlgoOperator<T extends AlgoOperator<T>> extends Stage<T> { /** * Applies the AlgoOperator on the given input tables and returns the result tables. * * @param inputs a list of tables * @return a list of tables */ Table[] transform(Table... inputs); } /** * A Transformer is an AlgoOperator with the semantic difference that it encodes the Transformation * logic, such that a record in the output typically corresponds to one record in the input. In * contrast, an AlgoOperator is a better fit to express aggregation logic where a record in the * output could be computed from an arbitrary number of records in the input. * * @param <T> The class type of the Transformer implementation itself. */ @PublicEvolving public interface Transformer<T extends Transformer<T>> extends AlgoOperator<T> {} /** * A Model is typically generated by invoking {@link Estimator#fit(Table...)}. A Model is a * Transformer with the extra APIs to set and get model data. * * @param <T> The class type of the Model implementation itself. */ @PublicEvolving public interface Model<T extends Model<T>> extends Transformer<T> { /** * Sets model data using the given list of tables. Each table could be an unbounded stream of * model data changes. * * @param inputs a list of tables */ default T setModelData(Table... inputs) { throw new UnsupportedOperationException("this operation is not supported"); } /** * Gets a list of tables representing the model data. Each table could be an unbounded stream of * model data changes. * * @return a list of tables */ default Table[] getModelData() { throw new UnsupportedOperationException("this operation is not supported"); } } /** * Estimators are responsible for training and generating Models. * * @param <E> class type of the Estimator implementation itself. * @param <M> class type of the Model this Estimator produces. */ @PublicEvolving public interface Estimator<E extends Estimator<E, M>, M extends Model<M>> extends Stage<E> { /** * Trains on the given inputs and produces a Model. * * @param inputs a list of tables * @return a Model */ M fit(Table... inputs); } /** * A PipelineModel acts as a Model. It consists of an ordered list of stages, each of which could be * a Model, Transformer or AlgoOperator. */ @PublicEvolving public final class PipelineModel implements Model<PipelineModel> { public PipelineModel(List<Stage<?>> stages) {...} public static PipelineModel load(String path) throws IOException {...} /** Skipped a few methods, including the implementations of the Model APIs. */ } /** * A Pipeline acts as an Estimator. It consists of an ordered list of stages, each of which could be * an Estimator, Model, Transformer or AlgoOperator. */ @PublicEvolving public final class Pipeline implements Estimator<Pipeline, PipelineModel> { public Pipeline(List<Stage<?>> stages) {...} public static Pipeline load(String path) throws IOException { /** Skipped a few methods, including the implementations of the Estimator APIs. */ }
Example Usage
In this section we provide examples code snippets to demonstrate how we can use the APIs proposed in this FLIP to address the use-cases in the motivation section.
Online learning by running Transformer and Estimator concurrently on different machines
Here is an online learning scenario:
- We have an unbounded stream of tagged data that can be used for training.
- We have an algorithm that can be trained using this unbounded stream of data. This algorithm (with its latest states/parameters) can be used to do inference. And the accuracy of the algorithm increases with the increasing amount of training data it has seen.
- We would like to train this algorithm using the unbounded data stream on clusterA. And uses this algorithm with the update-to-date states/parameters to do inference on 10 different web servers.
In order to address this use-case, we can write the training and inference logics of this algorithm into an EstimatorA class and a ModelA class with the following API behaviors:
- EstimatorA::fit takes a table as input and returns an instance of ModelA. Before fit() returns this ModelA, it calls ModelA.setModelData(model_data), where the model_data represents the stream of algorithm parameters changes emitted by EstimatorA.
- ModelA::setModelData(...) takes a table as input. Its implementation reads the data from this table to continuously update its algorithm parameters.
- ModelA::getModelData(...) returns the same table instance that has been provided via ModelA::setModelData(...).
- ModelA::transform takes a table as input and returns a table. The returned table represents the inference results.
Here are the code snippets that address this use-case by using the proposed APIs.
First run the following code on clusterA:
void runTrainingOnClusterA(...) { // Creates the training stream from a Kafka topic. Table training_stream = ...; Estimator estimator = new EstimatorA(...); Model model = estimator.fit(training_stream); Table model_data = model.getModelData()[0]; // This method writes the data from the given Table to a Kafka topic. writeToKafka(model_data, "topicA"); // Saves model's state/metadata to a remote path. model.save(remote_path); // Executes the operators generated by the Estimator::fit(...), which reads from training_stream and writes to model_data. env.execute() }
Then run the following code on each web server:
void runInferenceOnWebServer(...) { // This method creates a Table from a Kafka topic, where the data in this topic is generated by the above code snippet. Table model_data = readFromKafka("topicA"); // Creates the input stream that needs inference. Table input_stream = ...; Model model = ModelA::load(remote_path); model.setModelData(new Table[]{model_data}); Table output_stream = model.transform(input_stream)[0]; // Do something with the output_stream. // Executes the operators generated by the Model::transform(...), which reads from model_data to update its parameters. // It also does inference on input_stream and produces results to the output_stream. env.execute() }
Compatibility, Deprecation, and Migration Plan
The changes proposed in this FLIP is backward incompatible with the existing APIs. We propose to change the APIs directly without deprecation period. And we will help some existing Flink ecosystem open source projects to migrate and use the proposed APIs.
Note that there is no implementation of Estimator/Transformer (excluding test-only implementations) in the existing Flink codebase. So no work is needed to migrate the existing Flink codebase.
To our best knowledge, the only open source project that uses the Flink ML API is https://github.com/alibaba/Alink. We will work together with Alink developers to migrate the existing code to use the proposed API. Furthermore, we will migrate Alink's Estimator/Transformer implementation to the Flink ML library codebase as much as possible.
Test Plan
We will provide unit tests to validate the proposed changes.
Rejected Alternatives
There is no rejected alternatives to be listed here yet.