...
To give it a whole picture, the table below lists all key Java interfaces and the corresponding Python interfaces to be added.
Interface type | Java Interface Name | Python Interface Name | Description |
ML core interface | PipelineStage | PipelineStage | The base node of Pipeline |
Transformer | Transformer | Native Python interface | |
JavaTransformer | Python wrappers for calling Java interface | ||
Estimator | Estimator | Native Python interface | |
JavaEstimator | Python wrappers for calling Java interface | ||
Model | Model | Native Python interface | |
JavaModel | Python wrappers for calling Java interface | ||
ML Pipeline | Pipeline | Pipeline | Describes a ML workflow |
ML environment | MLEnvironment | MLEnvironment | Stores the necessary context in Flink |
MLEnvironmentFactory | MLEnvironmentFactory | Factory to get the MLEnvironment | |
Help interface | Params | Params | A container of parameters |
ParamInfo | ParamInfo | Definition of a parameter | |
WithParams | WithParams | common interface to interact with classes with parameters |
Support native Python Transformer/Estimator/Model
...
To support both these two cases, we provided two kinds of interfaces for Transformer/Estimator/Model. One for Java wrappers, the other for native Python APIs.
Java Interface Name | Python Interface Name |
Transformer | Transformer |
JavaTransformer | |
Estimator | Estimator |
JavaEstimator | |
Model | Model |
JavaModel |
Below, we will take Transformer as an example to show you what the interfaces would be like and how to implement these two kinds of interfaces.
class Transformer(PipelineStage): """ A transformer is a PipelineStage that transforms an input Table to a result Table. """ __metaclass__ = ABCMeta @abstractmethod def transform(self, table_env, table): raise NotImplementedError() class JavaTransformer(Transformer): """ Base class for :py:class:`Transformer`s that wrap Java implementations. Subclasses should ensure they have the transformer Java object available as j_obj. """ def __init__(self, j_obj): super().__init__() self._j_obj = j_obj def transform(self, table_env, table): self._convert_params_to_java(self._j_obj) return Table(self._j_obj.transform(table_env._j_tenv, table._j_table)) |
To write native Python Transformers, you can extend the Transformer class and implement the transform() method. For Java wrappers, you can extend the JavaTransformer class and pass the java object to the constructor. In this case, you don’t have to implement the transform() method it has already been implemented.
...
ML core interface
PipelineStage
class PipelineStage(WithParams): """ Base class for a stage in a pipeline. """ def __init__(self, params=None): if params is None: self._params = Params() else: self._params = params def get_params(self): return self._params |
Transformer
class Transformer(PipelineStage): """ A transformer is a PipelineStage that transforms an input Table to a result Table. """ __metaclass__ = ABCMeta @abstractmethod def transform(self, table_env, table): """ Applies the transformer on the input table, and returns the result table. :param table_env: the table environment to which the input table is bound. :param table: the table to be transformed :returns: the transformed table """ raise NotImplementedError() |
JavaTransformer
class JavaTransformer(Transformer): """ Base class for :py:class:`Transformer`s that wrap Java implementations. Subclasses should ensure they have the transformer Java object available as j_obj. """ def __init__(self, j_obj): super().__init__() self._j_obj = j_obj def transform(self, table_env, table): """ Applies the transformer on the input table, and returns the result table. :param table_env: the table environment to which the input table is bound. :param table: the table to be transformed :returns: the transformed table """ self._convert_params_to_java(self._j_obj) return Table(self._j_obj.transform(table_env._j_tenv, table._j_table)) |
Estimator
class Estimator(PipelineStage): """ Estimators are PipelineStages responsible for training and generating machine learning models. The implementations are expected to take an input table as training samples and generate a Model which fits these samples. """ __metaclass__ = ABCMeta def fit(self, table_env, table): """ Train and produce a Model which fits the records in the given Table. :param table_env: the table environment to which the input table is bound. :param table: the table with records to train the Model. :returns: a model trained to fit on the given Table. """ raise NotImplementedError() |
JavaEstimator
class JavaEstimator(Estimator): """ Base class for :py:class:`Estimator`s that wrap Java implementations. Subclasses should ensure they have the estimator Java object available as j_obj. """ def __init__(self, j_obj): super().__init__() self._j_obj = j_obj def fit(self, table_env, table): """ Train and produce a Model which fits the records in the given Table. :param table_env: the table environment to which the input table is bound. :param table: the table with records to train the Model. :returns: a model trained to fit on the given Table. """ self._convert_params_to_java(self._j_obj) return JavaModel(self._j_obj.fit(table_env._j_tenv, table._j_table)) |
Model
class Model(Transformer): """ Abstract class for models that are fitted by estimators. A model is an ordinary Transformer except how it is created. While ordinary transformers are defined by specifying the parameters directly, a model is usually generated by an Estimator when Estimator.fit(table_env, table) is invoked. """ __metaclass__ = ABCMeta |
JavaModel
class JavaModel(JavaTransformer, Model): """ Base class for :py:class:`JavaTransformer`s that wrap Java implementations. Subclasses should ensure they have the model Java object available as j_obj. """ def __init__(self, j_obj): super().__init__(j_obj) |
ML Pipeline
class Pipeline(Estimator, Model): """ A pipeline is a linear workflow which chains Estimators and Transformers to execute an algorithm. """ def __init__(self, stages=None): super().__init__() self.stages = [] if stages is not None: self.stages = stages self.last_estimator_index = -1 def _need_fit(self): return self.last_estimator_index >= 0 @staticmethod def _is_stage_need_fit(stage): return (isinstance(stage, Pipeline) and stage._need_fit()) or \ ((not isinstance(stage, Pipeline)) and isinstance(stage, Estimator)) def append_stage(self, stage): if self._is_stage_need_fit(stage): self.last_estimator_index = len(self.stages) elif not isinstance(stage, Transformer): raise RuntimeError("All PipelineStages should be Estimator or Transformer!") self.stages.append(stage) return self def fit(self, t_env, input): """ Train the pipeline to fit on the records in the given Table. :param t_env: the table environment to which the input table is bound. :param input: the table with records to train the Pipeline. :returns: a pipeline with same stages as this Pipeline except all Estimators replaced with their corresponding Models. """ transform_stages = [] for i in range(0, len(self.stages)): s = self.stages[i] if i <= self.last_estimator_index: need_fit = self._is_stage_need_fit(s) if need_fit: t = s.fit(t_env, input) else: t = s transform_stages.append(t) input = t.transform(t_env, input) else: transform_stages.append(s) return Pipeline(transform_stages) def transform(self, t_env, input): """ Generate a result table by applying all the stages in this pipeline to the input table in order. :param t_env: the table environment to which the input table is bound. :param input: the table to be transformed. :returns: a result table with all the stages applied to the input tables in order. """ if self._need_fit(): raise RuntimeError("Pipeline contains Estimator, need to fit first.") for s in self.stages: input = s.transform(t_env, input) return input |
ML environment
MLEnvironmentFactory
class MLEnvironmentFactory: """ Factory to get the MLEnvironment using a MLEnvironmentId. """ _lock = threading.RLock() _default_ml_environment_id = 0 _next_id = 1 _map = {} gateway = get_gateway() j_ml_env = gateway.jvm.MLEnvironmentFactory.getDefault() _default_ml_env = MLEnvironment( ExecutionEnvironment(j_ml_env.getExecutionEnvironment()), StreamExecutionEnvironment(j_ml_env.getStreamExecutionEnvironment()), BatchTableEnvironment(j_ml_env.getBatchTableEnvironment()), StreamTableEnvironment(j_ml_env.getStreamTableEnvironment())) _map[_default_ml_environment_id] = _default_ml_env @staticmethod def get(ml_env_id): """ Get the MLEnvironment using a MLEnvironmentId. :param ml_env_id: the MLEnvironmentId :return: the MLEnvironment """ with MLEnvironmentFactory._lock: if ml_env_id not in MLEnvironmentFactory._map: raise ValueError( "Cannot find MLEnvironment for MLEnvironmentId %s. " "Did you get the MLEnvironmentId by calling " "get_new_ml_environment_id?" % ml_env_id) return MLEnvironmentFactory._map[ml_env_id] @staticmethod def get_default(): """ Get the MLEnvironment use the default MLEnvironmentId. :return: the default MLEnvironment. """ with MLEnvironmentFactory._lock: return MLEnvironmentFactory._map[MLEnvironmentFactory._default_ml_environment_id] @staticmethod def get_new_ml_environment_id(): """ Create a unique MLEnvironment id and register a new MLEnvironment in the factory. :return: the MLEnvironment id. """ with MLEnvironmentFactory._lock: return MLEnvironmentFactory.register_ml_environment(MLEnvironment()) @staticmethod def register_ml_environment(ml_environment): """ Register a new MLEnvironment to the factory and return a new MLEnvironment id. :param ml_environment: the MLEnvironment that will be stored in the factory. :return: the MLEnvironment id. """ with MLEnvironmentFactory._lock: MLEnvironmentFactory._map[MLEnvironmentFactory._next_id] = ml_environment MLEnvironmentFactory._next_id += 1 return MLEnvironmentFactory._next_id - 1 @staticmethod def remove(ml_env_id): """ Remove the MLEnvironment using the MLEnvironmentId. :param ml_env_id: the id. :return: the removed MLEnvironment """ with MLEnvironmentFactory._lock: if ml_env_id is None: raise ValueError("The environment id cannot be null.") # Never remove the default MLEnvironment. Just return the default environment. if MLEnvironmentFactory._default_ml_env == ml_env_id: return MLEnvironmentFactory.get_default() else: return MLEnvironmentFactory._map.pop(ml_env_id) |
MLEnvironment
class MLEnvironment(object): """ The MLEnvironment stores the necessary context in Flink. Each MLEnvironment will be associated with a unique ID. The operations associated with the same MLEnvironment ID will share the same Flink job context. Both MLEnvironment ID and MLEnvironment can only be retrieved from MLEnvironmentFactory. """ def __init__(self, exe_env=None, stream_exe_env=None, batch_tab_env=None, stream_tab_env=None): self._exe_env = exe_env self._stream_exe_env = stream_exe_env self._batch_tab_env = batch_tab_env self._stream_tab_env = stream_tab_env def get_execution_environment(self): if self._exe_env is None: self._exe_env = ExecutionEnvironment.get_execution_environment() return self._exe_env def get_stream_execution_environment(self): if self._stream_exe_env is None: self._stream_exe_env = StreamExecutionEnvironment.get_execution_environment() return self._stream_exe_env def get_batch_table_environment(self): if self._batch_tab_env is None: self._batch_tab_env = BatchTableEnvironment.create(ExecutionEnvironment.get_execution_environment()) return self._batch_tab_env def get_stream_table_environment(self): if self._stream_tab_env is None: self._stream_tab_env = StreamTableEnvironment.create(StreamExecutionEnvironment.get_execution_environment()) return self._stream_tab_env |
Params interface
Params
class Params(object): """ The map-like container class for parameter. This class is provided to unify the interaction with parameters. """ def __init__(self): self._paramMap = {} def set(self, k, v): self._paramMap[k] = v def get(self, k): return self._paramMap[k] |
ParamInfo
class ParamInfo(object): """ Definition of a parameter, including name, description, type_converter and so on. """ def __init__(self, name, description, type_converter=None): self.name = str(name) self.description = str(description) self.type_converter = TypeConverters.identity if type_converter is None else type_converter |
WithParams
class WithParams(object): """ Parameters are widely used in machine learning realm. This class defines a common interface to interact with classes with parameters. """ def get_params(self): pass def set(self, k, v): self.get_params().set(k, v) return self def get(self, k): return self.get_params().get(k) def _set(self, **kwargs): """ Sets user-supplied params. """ for param, value in kwargs.items(): p = getattr(self, param) if value is not None: try: value = p.type_converter(value) except TypeError as e: raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e)) self.get_params().set(p, value) return self |
Example
trainTable = t_env.from_path('traningSource') servingTable = t_env.from_path('servingSource') # transformer va = VectorAssembler(selected_cols=["a", "b"], output_col="features") # estimator kmeans = KMeans()\ .set_vector_col("features")\ .set_k(2)\ .set_reserved_cols(["a", "b"])\ .set_prediction_col("prediction_result")\ .set_max_iter(100) # pipeline pipeline = Pipeline().append_stage(va).append_stage(kmeans) pipeline\ .fit(t_env, trainTable)\ .transform(t_env, servingTable)\ .insert_into('mySink') t_env.execute('KmeansTest') |
Implementation Plan
- Align interface for MLEnvironment and MLEnvironmentFactory
- Add support for Python Translator/Estimator/Model
- Add support for Translator/Estimator/Model Java Wrappers.