Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

To give it a whole picture, the table below lists all key Java interfaces and the corresponding Python interfaces to be added. 

Interface type

Java Interface Name

Python Interface Name

Description

ML core interface

PipelineStage

PipelineStage

The base node of Pipeline

Transformer

Transformer

Native Python interface

JavaTransformer

Python wrappers for calling Java interface

Estimator

Estimator

Native Python interface

JavaEstimator

Python wrappers for calling Java interface

Model

Model

Native Python interface

JavaModel

Python wrappers for calling Java interface

ML Pipeline

Pipeline

Pipeline

Describes a ML workflow

ML environment

MLEnvironment

MLEnvironment

Stores the necessary context in Flink

MLEnvironmentFactory

MLEnvironmentFactory

Factory to get the MLEnvironment

Help interface

Params

Params

A container of parameters

ParamInfo

ParamInfo

Definition of a parameter

WithParams

WithParams

common interface to interact with classes with parameters

Support native Python Transformer/Estimator/Model

...

To support both these two cases, we provided two kinds of interfaces for Transformer/Estimator/Model. One for Java wrappers, the other for native Python APIs. 

Java Interface Name

Python Interface Name

Transformer

Transformer

JavaTransformer

Estimator

Estimator

JavaEstimator

Model

Model

JavaModel

Below, we will take Transformer as an example to show you what the interfaces would be like and how to implement these two kinds of interfaces.

class Transformer(PipelineStage):

    """

    A transformer is a PipelineStage that transforms an input Table to a result Table.

    """

    __metaclass__ = ABCMeta

    @abstractmethod

    def transform(self, table_env, table):

        raise NotImplementedError()


class JavaTransformer(Transformer):

    """

    Base class for :py:class:`Transformer`s that wrap Java implementations.

    Subclasses should ensure they have the transformer Java object available as j_obj.

    """

    def __init__(self, j_obj):

        super().__init__()

        self._j_obj = j_obj

    def transform(self, table_env, table):

        self._convert_params_to_java(self._j_obj)

        return Table(self._j_obj.transform(table_env._j_tenv, table._j_table))

To write native Python Transformers, you can extend the Transformer class and implement the transform() method. For Java wrappers, you can extend the JavaTransformer class and pass the java object to the constructor. In this case, you don’t have to implement the transform() method it has already been implemented.

...

ML core interface

PipelineStage

class PipelineStage(WithParams):

    """

    Base class for a stage in a pipeline.

    """

    def __init__(self, params=None):

        if params is None:

            self._params = Params()

        else:

            self._params = params

    def get_params(self):

        return self._params


Transformer

class Transformer(PipelineStage):

    """

    A transformer is a PipelineStage that transforms an input Table to a result Table.

    """

    __metaclass__ = ABCMeta

    @abstractmethod

    def transform(self, table_env, table):

        """

        Applies the transformer on the input table, and returns the result table.

        :param table_env: the table environment to which the input table is bound.

        :param table: the table to be transformed

        :returns: the transformed table

        """

        raise NotImplementedError()


JavaTransformer

class JavaTransformer(Transformer):

    """

    Base class for :py:class:`Transformer`s that wrap Java implementations.

    Subclasses should ensure they have the transformer Java object available as j_obj.

    """

    def __init__(self, j_obj):

        super().__init__()

        self._j_obj = j_obj

    def transform(self, table_env, table):

        """

        Applies the transformer on the input table, and returns the result table.

        :param table_env: the table environment to which the input table is bound.

        :param table: the table to be transformed

        :returns: the transformed table

        """

        self._convert_params_to_java(self._j_obj)

        return Table(self._j_obj.transform(table_env._j_tenv, table._j_table))


Estimator

class Estimator(PipelineStage):

    """

    Estimators are PipelineStages responsible for training and generating machine learning models.

    The implementations are expected to take an input table as training samples and generate a

    Model which fits these samples.

    """

    __metaclass__ = ABCMeta

    def fit(self, table_env, table):

        """

        Train and produce a Model which fits the records in the given Table.

        :param table_env: the table environment to which the input table is bound.

        :param table: the table with records to train the Model.

        :returns: a model trained to fit on the given Table.

        """

        raise NotImplementedError()


JavaEstimator

class JavaEstimator(Estimator):

    """

    Base class for :py:class:`Estimator`s that wrap Java implementations.

    Subclasses should ensure they have the estimator Java object available as j_obj.

    """

    def __init__(self, j_obj):

        super().__init__()

        self._j_obj = j_obj

    def fit(self, table_env, table):

        """

        Train and produce a Model which fits the records in the given Table.

        :param table_env: the table environment to which the input table is bound.

        :param table: the table with records to train the Model.

        :returns: a model trained to fit on the given Table.

        """

        self._convert_params_to_java(self._j_obj)

        return JavaModel(self._j_obj.fit(table_env._j_tenv, table._j_table))


Model

class Model(Transformer):

    """

    Abstract class for models that are fitted by estimators.

    A model is an ordinary Transformer except how it is created. While ordinary transformers 

    are defined by specifying the parameters directly, a model is usually generated by an Estimator

    when Estimator.fit(table_env, table) is invoked.

    """

    __metaclass__ = ABCMeta


JavaModel

class JavaModel(JavaTransformer, Model):

    """

    Base class for :py:class:`JavaTransformer`s that wrap Java implementations.

    Subclasses should ensure they have the model Java object available as j_obj.

    """

    def __init__(self, j_obj):

        super().__init__(j_obj)

ML Pipeline

class Pipeline(Estimator, Model):

    """

    A pipeline is a linear workflow which chains Estimators and Transformers to

    execute an algorithm.

    """

    def __init__(self, stages=None):

        super().__init__()

        self.stages = []

        if stages is not None:

            self.stages = stages

        self.last_estimator_index = -1

    def _need_fit(self):

        return self.last_estimator_index >= 0

    @staticmethod

    def _is_stage_need_fit(stage):

        return (isinstance(stage, Pipeline) and stage._need_fit()) or \

               ((not isinstance(stage, Pipeline)) and isinstance(stage, Estimator))

    def append_stage(self, stage):

        if self._is_stage_need_fit(stage):

            self.last_estimator_index = len(self.stages)

        elif not isinstance(stage, Transformer):

            raise RuntimeError("All PipelineStages should be Estimator or Transformer!")

        self.stages.append(stage)

        return self

    def fit(self, t_env, input):

        """

        Train the pipeline to fit on the records in the given Table.

        :param t_env: the table environment to which the input table is bound.

        :param input: the table with records to train the Pipeline.

        :returns: a pipeline with same stages as this Pipeline except all Estimators

        replaced with their corresponding Models.

        """

        transform_stages = []

        for i in range(0, len(self.stages)):

            s = self.stages[i]

            if i <= self.last_estimator_index:

                need_fit = self._is_stage_need_fit(s)

                if need_fit:

                    t = s.fit(t_env, input)

                else:

                    t = s

                transform_stages.append(t)

                input = t.transform(t_env, input)

            else:

                transform_stages.append(s)

        return Pipeline(transform_stages)

    def transform(self, t_env, input):

        """

        Generate a result table by applying all the stages in this pipeline to the input table in order.

        :param t_env: the table environment to which the input table is bound.

        :param input: the table to be transformed.

        :returns: a result table with all the stages applied to the input tables in order.

        """

        if self._need_fit():

            raise RuntimeError("Pipeline contains Estimator, need to fit first.")

        for s in self.stages:

            input = s.transform(t_env, input)

        return input

ML environment 

MLEnvironmentFactory

class MLEnvironmentFactory:

    """

    Factory to get the MLEnvironment using a MLEnvironmentId.

    """

    _lock = threading.RLock()

    _default_ml_environment_id = 0

    _next_id = 1

    _map = {}

    gateway = get_gateway()

    j_ml_env = gateway.jvm.MLEnvironmentFactory.getDefault()

    _default_ml_env = MLEnvironment(

        ExecutionEnvironment(j_ml_env.getExecutionEnvironment()),

        StreamExecutionEnvironment(j_ml_env.getStreamExecutionEnvironment()),

        BatchTableEnvironment(j_ml_env.getBatchTableEnvironment()),

        StreamTableEnvironment(j_ml_env.getStreamTableEnvironment()))

    _map[_default_ml_environment_id] = _default_ml_env

    @staticmethod

    def get(ml_env_id):

        """

        Get the MLEnvironment using a MLEnvironmentId.

        :param ml_env_id: the MLEnvironmentId

        :return: the MLEnvironment

        """

        with MLEnvironmentFactory._lock:

            if ml_env_id not in MLEnvironmentFactory._map:

                raise ValueError(

                    "Cannot find MLEnvironment for MLEnvironmentId %s. "

                    "Did you get the MLEnvironmentId by calling "

                    "get_new_ml_environment_id?" % ml_env_id)

            return MLEnvironmentFactory._map[ml_env_id]

    @staticmethod

    def get_default():

        """

        Get the MLEnvironment use the default MLEnvironmentId.

        :return: the default MLEnvironment.

        """

        with MLEnvironmentFactory._lock:

            return MLEnvironmentFactory._map[MLEnvironmentFactory._default_ml_environment_id]

    @staticmethod

    def get_new_ml_environment_id():

        """

        Create a unique MLEnvironment id and register a new MLEnvironment in the factory.

        :return: the MLEnvironment id.

        """

        with MLEnvironmentFactory._lock:

            return MLEnvironmentFactory.register_ml_environment(MLEnvironment())

    @staticmethod

    def register_ml_environment(ml_environment):

        """

        Register a new MLEnvironment to the factory and return a new MLEnvironment id.

        :param ml_environment: the MLEnvironment that will be stored in the factory.

        :return: the MLEnvironment id.

        """

        with MLEnvironmentFactory._lock:

            MLEnvironmentFactory._map[MLEnvironmentFactory._next_id] = ml_environment

            MLEnvironmentFactory._next_id += 1

            return MLEnvironmentFactory._next_id - 1

    @staticmethod

    def remove(ml_env_id):

        """

        Remove the MLEnvironment using the MLEnvironmentId.

        :param ml_env_id: the id.

        :return: the removed MLEnvironment

        """

        with MLEnvironmentFactory._lock:

            if ml_env_id is None:

                raise ValueError("The environment id cannot be null.")

            # Never remove the default MLEnvironment. Just return the default environment.

            if MLEnvironmentFactory._default_ml_env == ml_env_id:

                return MLEnvironmentFactory.get_default()

            else:

                return MLEnvironmentFactory._map.pop(ml_env_id)


MLEnvironment

class MLEnvironment(object):

    """

    The MLEnvironment stores the necessary context in Flink. Each MLEnvironment 

    will be associated with a unique ID. The operations associated with the same 

    MLEnvironment ID will share the same Flink job context. Both MLEnvironment 

    ID and MLEnvironment can only be retrieved from MLEnvironmentFactory.

    """

    def __init__(self, exe_env=None, stream_exe_env=None, batch_tab_env=None, stream_tab_env=None):

        self._exe_env = exe_env

        self._stream_exe_env = stream_exe_env

        self._batch_tab_env = batch_tab_env

        self._stream_tab_env = stream_tab_env

    def get_execution_environment(self):

        if self._exe_env is None:

            self._exe_env = ExecutionEnvironment.get_execution_environment()

        return self._exe_env

    def get_stream_execution_environment(self):

        if self._stream_exe_env is None:

            self._stream_exe_env = StreamExecutionEnvironment.get_execution_environment()

        return self._stream_exe_env

    def get_batch_table_environment(self):

        if self._batch_tab_env is None:

            self._batch_tab_env = BatchTableEnvironment.create(ExecutionEnvironment.get_execution_environment())

        return self._batch_tab_env

    def get_stream_table_environment(self):

        if self._stream_tab_env is None:

            self._stream_tab_env = StreamTableEnvironment.create(StreamExecutionEnvironment.get_execution_environment())

        return self._stream_tab_env


Params interface

Params

class Params(object):

    """

    The map-like container class for parameter. This class is provided to unify the interaction with

    parameters.

    """

    def __init__(self):

        self._paramMap = {}

    def set(self, k, v):

        self._paramMap[k] = v

    def get(self, k):

        return self._paramMap[k]


ParamInfo

class ParamInfo(object):

    """

    Definition of a parameter, including name, description, type_converter and so on.

    """

    def __init__(self, name, description, type_converter=None):

        self.name = str(name)

        self.description = str(description)

        self.type_converter = TypeConverters.identity if type_converter is None else type_converter


WithParams

class WithParams(object):

    """

    Parameters are widely used in machine learning realm. This class defines a common interface to

    interact with classes with parameters.

    """

    def get_params(self):

        pass

    def set(self, k, v):

        self.get_params().set(k, v)

        return self

    def get(self, k):

        return self.get_params().get(k)

    def _set(self, **kwargs):

        """

        Sets user-supplied params.

        """

        for param, value in kwargs.items():

            p = getattr(self, param)

            if value is not None:

                try:

                    value = p.type_converter(value)

                except TypeError as e:

                    raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))

            self.get_params().set(p, value)

        return self

Example

trainTable = t_env.from_path('traningSource')

servingTable = t_env.from_path('servingSource')

# transformer

va = VectorAssembler(selected_cols=["a", "b"], output_col="features")

# estimator

kmeans = KMeans()\

    .set_vector_col("features")\

    .set_k(2)\

    .set_reserved_cols(["a", "b"])\

    .set_prediction_col("prediction_result")\

    .set_max_iter(100)

# pipeline

pipeline = Pipeline().append_stage(va).append_stage(kmeans)

pipeline\

    .fit(t_env, trainTable)\

    .transform(t_env, servingTable)\

    .insert_into('mySink')

t_env.execute('KmeansTest') 


Implementation Plan

  • Align interface for MLEnvironment and MLEnvironmentFactory
  • Add support for Python Translator/Estimator/Model
  • Add support for Translator/Estimator/Model Java Wrappers.