You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 2 Next »

Status

Current state["Under Discussion"]

Discussion threadhere (<- link to http://apache-flink-mailing-list-archive.1008284.n3.nabble.com/DISCUSS-FLIP-139-General-Python-User-Defined-Aggregate-Function-on-Table-API-td44139.html)

JIRAhere (<- link to https://issues.apache.org/jira/browse/FLINK-XXXX)

Released: <Flink Version>

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Motivation

FLIP-58 has already introduced the stateless Python UDF and has already been supported in the previous releases. However the stateful Python UDF, i.e. User-Defined Aggregate Function is not supported in PyFlink yet. In this FLIP, we will introduce the Python user-defined aggregate function for PyFlink Table API.

Goals

  1. Support running Python UDAF on streaming/batch mode with blink planer
  2. Support accessing the value state/list state/map state remotely from python side.
  3. Support FILTER and DISTINCT keyword on Python UDAF
  4. Support DataView(ListView and MapView)


Non-Goals

  1. Support the optimization like hash agg, local/global, minibatch, and so on.
  2. Support to use Python UDAF in group window
  3. Support to use Python UDAF in over window
  4. Support mixed use of Java UDAF and Python UDAF
  5. Support Python user defined table aggregate function


NOTE: Although these features will not be supported in the first version, it doesn’t mean that they cannot be supported according to the design proposed in this FLIP. 

Architecture

Executing Python user-defined aggregate functions is similar to executing the stateless functions described in FLIP-58. The key difference is that the aggregate functions need to access the state, which can be implemented via the state service of current design.

Public Interfaces

AggregateFunction

To define an aggregate function, users need to implement such an interface:

class AggregateFunction(UserDefinedFunction):

   """

   Base interface for user-defined aggregate function.

   .. versionadded:: 1.12.0

   """

   def open(self, function_context: FunctionContext):

       pass

   

   def get_value(self, accumulator):

       pass

   def create_accumulator(self):

       pass

   def accumulate(self, accumulator, *args):

       pass

   def retract(self, accumulator, *args):

       pass

   def merge(self, accumulator, accumulators):

       pass

   def reset_accumulator(self, accumulator):

       pass

   def close(self):

       pass

   def get_result_type(self):

       pass

   def get_accumulator_type(self):

       pass


The semantics of each method above is aligned with the AggregateFunction class of Java Edition. According to actual usage, not all methods need to be implemented. e.g. the "merge" method is optional to implement if the local-global optimization is not enabled.

Examples:

class CountAggregateFunction(AggregateFunction):

   def get_value(self, accumulator):

       return accumulator[0]

   def create_accumulator(self):

       return [0]

   def accumulate(self, accumulator, *args):

       accumulator[0] = accumulator[0] + 1

   def retract(self, accumulator, *args):

       accumulator[0] = accumulator[0] - 1

   def merge(self, accumulator, accumulators):

       for other_acc in accumulators:

           accumulator[0] = accumulator[0] + other_acc

   def reset_accumulator(self, accumulator):

       accumulator[0] = 0

   def get_accumulator_type(self):

       return DataTypes.ARRAY(DataTypes.BIGINT())

   def get_result_type(self):

       return DataTypes.BIGINT()


t_env.register_function("my_count", CountAggregateFunction())

t = t_env.from_elements([(1, 'Hi', 'Hello'),

                                         (3, 'Hi', 'hi'),

                                         (3, 'Hi2', 'hi'),

                                         (3, 'Hi', 'hi'),

                                         (2, 'Hi', 'Hello')], ['a', 'b', 'c'])

t.group_by("c").select("a.my_count, c as b")

DataView

DataView is an important feature in the Java UDAF. The interfaces are as follows:

T = TypeVar('T')

K = TypeVar('K')

V = TypeVar('V')

class ListView(Generic[T]):

   def get(self) -> Iterable[T]:

       pass

   def add(self, value: T) -> None:

       pass

   def add_all(self, values: List[T]) -> None:

       pass

   def clear(self) -> None:

       pass

   def __eq__(self, other: Any) -> bool:

       pass

   def __hash__(self) -> int:

       pass

   def __iter__(self) -> Iterator[T]:

       return iter(self.get())


class MapView(Generic[K, V]):

   def get(self, key: K) -> V:

       pass

   def put(self, key: K, value: V) -> None:

       pass

   def put_all(self, dict_value: Dict[K, V]) -> None:

       pass

   def remove(self, key: K) -> None:

       pass

   def contains(self, key: K) -> bool:

       pass

   def items(self) -> Iterable[Tuple[K, V]]:

       pass

   def keys(self) -> Iterable[K]:

       pass

   def values(self) -> Iterable[V]:

       pass

   def is_empty(self) -> bool:

       pass

   def clear(self) -> None:

       pass

   def __eq__(self, other: Any) -> bool:

       pass

   def __hash__(self) -> int:

       pass

   def __getitem__(self, key: K) -> V:

       return self.get(key)

   def __setitem__(self, key: K, value: V) -> None:

       self.put(key, value)

   def __delitem__(self, key: K) -> None:

       self.remove(key)

   def __contains__(self, key: K) -> bool:

       return self.contains(key)

   def __iter__(self) -> Iterator[K]:

       return iter(self.keys())


Note that the ListView does not have the "remove" method, which is used to remove a single element from the ListState. This operation is too expensive for Python State. It is equivalent to get + modify + clear + add_all.

As PyFlink does not support user-defined type yet, the way to support DataView in PyFlink is different from Java. Users need to declare a special type field for MapView/ListView in the accumulator type:

# MapView:

DataTypes.MAP_VIEW(DataTypes.BIGINT(), DataTypes.BIGINT())

# ListView:

DataTypes.LIST_VIEW(DataTypes.BIGINT())


The example is as follows:

class CountDistinctAggregateFunction(AggregateFunction):

   def get_value(self, accumulator):

       return accumulator["count"]

   def create_accumulator(self):

       return Row(count=0, map=MapView())

   def accumulate(self, accumulator, *args):

       if args[0] in accumulator["map"]:

           accumulator["map"][args[0]] += 1

       else:

           accumulator["count"] += 1

           accumulator["map"][args[0]] = 1

   def retract(self, accumulator, *args):

       if args[0] in accumulator["map"]:

           accumulator["map"][args[0]] -= 1

           if accumulator["map"][args[0]] <= 0:

               del accumulator["map"][args[0]]

               accumulator["count"] -= 1

   def merge(self, accumulator, accumulators):

       pass

   def reset_accumulator(self, accumulator):

       accumulator["count"] = 0

       accumulator["map"].clear()

   def get_accumulator_type(self):

       return DataTypes.ROW([

           DataTypes.FIELD("count", DataTypes.BIGINT()),

           DataTypes.FIELD("map", DataTypes.MAP_VIEW(DataTypes.BIGINT(), DataTypes.BIGINT())

       ])

   def get_result_type(self):

       return DataTypes.BIGINT()


When running on streaming mode, the data views in accumulators will access the Flink state if needed. When running on batch mode, the data views just store the data in memory.

Proposed Design

Introduced Rules For Python UDAF

As Python functions should be executed in separate Python worker, we need to introduce the following rules to translate the logical nodes which contain Python UDAF calls to special PyFlink UDAF physical nodes to make sure the plan is executable.

  • StreamExecPythonGroupAggregateRule
  • BatchExecPythonGroupAggregateRule


Unlike UDF and UDTF, UDAFs would not chain with other UDFs, so we don’t need to introduce merge rules and split rules for them. Just one physical conversion rule is enough.

StreamExecPythonGroupAggregateRule

This rule will convert the logical aggregate node which contains Python UDAFs to the special PyFlink physical node which used to execute Python UDAFs. We do not plan to execute Java agg and Python agg in one operator. If the logic node has both Java UDAF and Python UDAF, an exception will be thrown. 

The physical plan will be adjusted as follows:

BatchExecPythonGroupAggregateRule

Similar to the StreamExecPythonGroupAggregateRule, it will replace the BatchExecXXXAggregate to BatchExecPythonAggregate, e.g. replace the sort aggregate:

UDAF Execution

Java Side

Just like other Python UDF implementations, at Java side there would be an operator class and a runner class for each execution mode. The following diagram shows how the operator transfers data to and receives data from the Python process:

For streaming execution, we need to implement a StateRequestHandler additionally to support accessing state remotely, and in our plan the simple group aggregate will be supported in the first version. The optimization, e.g. minibatch and local/global, can be introduced later if needed. For batch execution, as the sort agg is a general solution for all agg functions, we will support the sort agg in the first version. The hash agg and other optimization can be introduced later if needed.

Watermark

Just like what we did when supporting the Python stateless functions, the watermarks from upstream would be held and sent to downstream after the current bundle finished.

Checkpointing

Compared to the Python stateless functions, not only all the data in the buffer will be flushed, but the cached state in Python side will also be flushed to Java side. 

Thread Safe of The StateRequestHandler

The StateRequestHandler is a gRPC service, which means it is running in a separate thread of the main thread. Before accessing the state, we need to set the current key manually in StateRequestHandler. However, as we know that, the framework will set the current key before calling processElement for each record. To ensure the thread safety of setting the current key, we will override the "setKeyContextElement1" method and "setCurrentkey" method of the Python aggregate operator with an empty method, so that the current key will be managed only by the StateRequestHandle.

Python Side

For streaming mode, the work flow at Python side is similar to the Java Aggregation, i.e. the GroupAggFunction + AggsHandler structure:

The GroupAggFunction and AggsHandler would access the state via Pyflink State API, which is implemented based on the Beam GrpcStateService.

For batch mode, the structure is similar to stream. But in batch mode we don’t need to consider state. Besides, all the data for a single key will be sent to the Python process together for sort agg and so the implementation will be very simple and straightforward.

Pyflink State API

The execution of Python UDAF depends on flink state, so we need to support access to flink state on the python side. 

Interfaces

Currently the state is used internally. Considering that this part can be reused in the PyFlink Stateful DataStream API, we abstract the reusable part into interfaces:

T = TypeVar('T')

K = TypeVar('K')

V = TypeVar('V')


class State(abc.ABC):

   @abstractmethod

   def clear(self) -> None:

       pass


class ValueState(State, Generic[T]):

   @abstractmethod

   def value(self) -> T:

       pass

   @abstractmethod

   def update(self, value: T) -> None:

       pass


class ListState(State, Generic[T]):

   @abstractmethod

   def get(self) -> Iterable[T]:

       pass

   @abstractmethod

   def add(self, value: T) -> None:

       pass

   @abstractmethod

   def update(self, values: List[T]) -> None:

       pass

   @abstractmethod

   def add_all(self, values: List[T]) -> None:

       pass

   def __iter__(self) -> Iterator[T]:

       return iter(self.get())


class MapState(State, Generic[K, V]):

   @abstractmethod

   def get(self, key: K) -> V:

       pass

   @abstractmethod

   def put(self, key: K, value: V) -> None:

       pass

   @abstractmethod

   def put_all(self, dict_value: Dict[K, V]) -> None:

       pass

   @abstractmethod

   def remove(self, key: K) -> None:

       pass

   @abstractmethod

   def contains(self, key: K) -> bool:

       pass

   @abstractmethod

   def items(self) -> Iterable[Tuple[K, V]]:

       pass

   @abstractmethod

   def keys(self) -> Iterable[K]:

       pass

   @abstractmethod

   def values(self) -> Iterable[V]:

       pass

   @abstractmethod

   def is_empty(self) -> bool:

       pass

   def __getitem__(self, key: K) -> V:

       return self.get(key)

   def __setitem__(self, key: K, value: V) -> None:

       self.put(key, value)

   def __delitem__(self, key: K) -> None:

       self.remove(key)

   def __contains__(self, key: K) -> bool:

       return self.contains(key)

   def __iter__(self) -> Iterator[K]:

       return iter(self.keys())


Currently we only introduce ValueState, ListState, MapState for PyFlink UDAF. Other kinds of state will be introduced when proposing the PyFlink Stateful DataStream API.

The StateDescriptors are not introduced here, because the most methods in StateDescriptor interfaces would not be used in PyFlink UDAF. It is more appropriate to propose them in the design document of Flink DataStream Stateful API. For PyFlink UDAF the states will be created directly.

Implementation

Beam portability framework has a state channel to transmit the state request from Python side to Java side. It only supports BagState (similar to Flink's ListState) or other states based on BagState for now, but its protocol design still allows us to implement MapState.

We can use ListState to simulate ValueState, i.e ValueState.value() = next(ListState.get()), ValueState.update(value) = ListState.update([value]). And ListState can also be used to simulate ReducingState and AggregatingState in future. So only supporting ListState and MapState at protocol level is enough.

Cache Strategy


It is too expensive to have a network communication for every operation of the states. We need to cache read and write operations to State. There are two kinds of cache: read-cache and write-cache. Different states have different caching strategies:



Read/Iterate

Write/Append

Remove/Clear

ListState

Request the entire list state of a key at first time when executing an iterate operation, then store it in read-cache.

Only support appending data to the end of the list or replacing the entire list. The appended/updated data will be cached in write-cache. The cached data will be flushed to Java side when triggering flushing.

Don’t support removing single element of the list. 

The clear operation will clear all the cache of current state, and itself will be cached in write-cache. The clear request will be sent to to Java side when finishing the bundle. Its priority is higher than append/update requests.

MapState

Support request data by map key. It will try to read data from cache firstly. The requested data from remote JVM will be stored in read-cache. 

If the whole state is cached in Python side, just iterate it locally. Otherwise it will request remote data and combine it with the local cache.

Support writing data by map key. The data will be cached in write-cache. The cached data will be flushed to Java side when triggering flushing.

Support removing data by map key. The remove operation will be applied to the write-cache firstly. If the data to remove is not in the write-cache, the remove operation will be cached in write-cache and sent to to Java side when triggering flushing.

The clear operation will clear all the cache of current state, and itself will be cached in write-cache. The clear request will be sent to to Java side when triggering flushing. Its priority is higher than any other requests.


The principle of the strategies above is: cache everything as much as possible. If the cache size is too large, the least recently used cache can be cleared/flushed. Note that the cache referenced by any iterators should never be cleared.

"FILTER" Keyword and "DISTINCT" Keyword

According to the introduced rule, we just replace the GroupAggregate operator with PythonGroupAggregate operator. So the implementation of "FILTER" keyword and "DISTINCT" keyword is just the same as Java’s implementation. The AggsHandler will check the filter arg for every agg call first before calling the "accumulate" method of them, and will prepare a MapView object as the distinct multiset for every distinct agg call. For each agg the accumulate logic is:

Built-in Python Aggs

Currently, we do not plan to support running both Java aggs and Python aggs in one operator. But the mixed use of Python UDAF and built-in agg would be a common scenario, so we need to implement the python version built-in aggs. All the Java built-in aggs will be replaced with corresponding Python built-in aggs when converting a logical aggregate node to a python aggregate node.

Compatibility, Deprecation, and Migration Plan

This FLIP is a new feature and so there is no compatible issue with previous versions.

Implementation Plan

  1. Support ValueState and Basic Python UDAF(without dataview support, distinct keyword and filter keyword support) on blink stream planner 
  2. Support Basic Python UDAF(without dataview support, distinct keyword and filter keyword support) on blink batch planner
  3. Support ListState and MapState
  4. Support DataView
  5. Support FILTER keyword and DISTINCT keyword
  6. Support mixed use with built-in aggs


  • No labels