Status
Current state: ["Under Discussion"]
Discussion thread: here (<- link to http://apache-flink-mailing-list-archive.1008284.n3.nabble.com/DISCUSS-FLIP-139-General-Python-User-Defined-Aggregate-Function-on-Table-API-td44139.html)
JIRA: here (<- link to https://issues.apache.org/jira/browse/FLINK-XXXX)
Released: <Flink Version>
Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).
Motivation
FLIP-58 has already introduced the stateless Python UDF and has already been supported in the previous releases. However the stateful Python UDF, i.e. User-Defined Aggregate Function is not supported in PyFlink yet. In this FLIP, we will introduce the Python user-defined aggregate function for PyFlink Table API.
Goals
- Support running Python UDAF on streaming/batch mode with blink planer
- Support accessing the value state/list state/map state remotely from python side.
- Support FILTER and DISTINCT keyword on Python UDAF
- Support DataView(ListView and MapView)
Non-Goals
- Support the optimization like hash agg, local/global, minibatch, and so on.
- Support to use Python UDAF in group window
- Support to use Python UDAF in over window
- Support mixed use of Java UDAF and Python UDAF
- Support Python user defined table aggregate function
NOTE: Although these features will not be supported in the first version, it doesn’t mean that they cannot be supported according to the design proposed in this FLIP.
Architecture
Executing Python user-defined aggregate functions is similar to executing the stateless functions described in FLIP-58. The key difference is that the aggregate functions need to access the state, which can be implemented via the state service of current design.
Public Interfaces
AggregateFunction
To define an aggregate function, users need to implement such an interface:
class AggregateFunction(UserDefinedFunction): """ Base interface for user-defined aggregate function. .. versionadded:: 1.12.0 """ def open(self, function_context: FunctionContext): pass
def get_value(self, accumulator): pass def create_accumulator(self): pass def accumulate(self, accumulator, *args): pass def retract(self, accumulator, *args): pass def merge(self, accumulator, accumulators): pass def reset_accumulator(self, accumulator): pass def close(self): pass def get_result_type(self): pass def get_accumulator_type(self): pass |
The semantics of each method above is aligned with the AggregateFunction class of Java Edition. According to actual usage, not all methods need to be implemented. e.g. the "merge" method is optional to implement if the local-global optimization is not enabled.
Examples:
class CountAggregateFunction(AggregateFunction): def get_value(self, accumulator): return accumulator[0] def create_accumulator(self): return [0] def accumulate(self, accumulator, *args): accumulator[0] = accumulator[0] + 1 def retract(self, accumulator, *args): accumulator[0] = accumulator[0] - 1 def merge(self, accumulator, accumulators): for other_acc in accumulators: accumulator[0] = accumulator[0] + other_acc def reset_accumulator(self, accumulator): accumulator[0] = 0 def get_accumulator_type(self): return DataTypes.ARRAY(DataTypes.BIGINT()) def get_result_type(self): return DataTypes.BIGINT() t_env.register_function("my_count", CountAggregateFunction()) t = t_env.from_elements([(1, 'Hi', 'Hello'), (3, 'Hi', 'hi'), (3, 'Hi2', 'hi'), (3, 'Hi', 'hi'), (2, 'Hi', 'Hello')], ['a', 'b', 'c']) t.group_by("c").select("a.my_count, c as b") |
DataView
DataView is an important feature in the Java UDAF. The interfaces are as follows:
T = TypeVar('T') K = TypeVar('K') V = TypeVar('V') class ListView(Generic[T]): def get(self) -> Iterable[T]: pass def add(self, value: T) -> None: pass def add_all(self, values: List[T]) -> None: pass def clear(self) -> None: pass def __eq__(self, other: Any) -> bool: pass def __hash__(self) -> int: pass def __iter__(self) -> Iterator[T]: return iter(self.get()) class MapView(Generic[K, V]): def get(self, key: K) -> V: pass def put(self, key: K, value: V) -> None: pass def put_all(self, dict_value: Dict[K, V]) -> None: pass def remove(self, key: K) -> None: pass def contains(self, key: K) -> bool: pass def items(self) -> Iterable[Tuple[K, V]]: pass def keys(self) -> Iterable[K]: pass def values(self) -> Iterable[V]: pass def is_empty(self) -> bool: pass def clear(self) -> None: pass def __eq__(self, other: Any) -> bool: pass def __hash__(self) -> int: pass def __getitem__(self, key: K) -> V: return self.get(key) def __setitem__(self, key: K, value: V) -> None: self.put(key, value) def __delitem__(self, key: K) -> None: self.remove(key) def __contains__(self, key: K) -> bool: return self.contains(key) def __iter__(self) -> Iterator[K]: return iter(self.keys()) |
Note that the ListView does not have the "remove" method, which is used to remove a single element from the ListState. This operation is too expensive for Python State. It is equivalent to get + modify + clear + add_all.
As PyFlink does not support user-defined type yet, the way to support DataView in PyFlink is different from Java. Users need to declare a special type field for MapView/ListView in the accumulator type:
# MapView: DataTypes.MAP_VIEW(DataTypes.BIGINT(), DataTypes.BIGINT()) # ListView: DataTypes.LIST_VIEW(DataTypes.BIGINT()) |
The example is as follows:
class CountDistinctAggregateFunction(AggregateFunction): def get_value(self, accumulator): return accumulator["count"] def create_accumulator(self): return Row(count=0, map=MapView()) def accumulate(self, accumulator, *args): if args[0] in accumulator["map"]: accumulator["map"][args[0]] += 1 else: accumulator["count"] += 1 accumulator["map"][args[0]] = 1 def retract(self, accumulator, *args): if args[0] in accumulator["map"]: accumulator["map"][args[0]] -= 1 if accumulator["map"][args[0]] <= 0: del accumulator["map"][args[0]] accumulator["count"] -= 1 def merge(self, accumulator, accumulators): pass def reset_accumulator(self, accumulator): accumulator["count"] = 0 accumulator["map"].clear() def get_accumulator_type(self): return DataTypes.ROW([ DataTypes.FIELD("count", DataTypes.BIGINT()), DataTypes.FIELD("map", DataTypes.MAP_VIEW(DataTypes.BIGINT(), DataTypes.BIGINT()) ]) def get_result_type(self): return DataTypes.BIGINT() |
When running on streaming mode, the data views in accumulators will access the Flink state if needed. When running on batch mode, the data views just store the data in memory.
Proposed Design
Introduced Rules For Python UDAF
As Python functions should be executed in separate Python worker, we need to introduce the following rules to translate the logical nodes which contain Python UDAF calls to special PyFlink UDAF physical nodes to make sure the plan is executable.
- StreamExecPythonGroupAggregateRule
- BatchExecPythonGroupAggregateRule
Unlike UDF and UDTF, UDAFs would not chain with other UDFs, so we don’t need to introduce merge rules and split rules for them. Just one physical conversion rule is enough.
StreamExecPythonGroupAggregateRule
This rule will convert the logical aggregate node which contains Python UDAFs to the special PyFlink physical node which used to execute Python UDAFs. We do not plan to execute Java agg and Python agg in one operator. If the logic node has both Java UDAF and Python UDAF, an exception will be thrown.
The physical plan will be adjusted as follows:
BatchExecPythonGroupAggregateRule
Similar to the StreamExecPythonGroupAggregateRule, it will replace the BatchExecXXXAggregate to BatchExecPythonAggregate, e.g. replace the sort aggregate:
UDAF Execution
Java Side
Just like other Python UDF implementations, at Java side there would be an operator class and a runner class for each execution mode. The following diagram shows how the operator transfers data to and receives data from the Python process:
For streaming execution, we need to implement a StateRequestHandler additionally to support accessing state remotely, and in our plan the simple group aggregate will be supported in the first version. The optimization, e.g. minibatch and local/global, can be introduced later if needed. For batch execution, as the sort agg is a general solution for all agg functions, we will support the sort agg in the first version. The hash agg and other optimization can be introduced later if needed.
Watermark
Just like what we did when supporting the Python stateless functions, the watermarks from upstream would be held and sent to downstream after the current bundle finished.
Checkpointing
Compared to the Python stateless functions, not only all the data in the buffer will be flushed, but the cached state in Python side will also be flushed to Java side.
Thread Safe of The StateRequestHandler
The StateRequestHandler is a gRPC service, which means it is running in a separate thread of the main thread. Before accessing the state, we need to set the current key manually in StateRequestHandler. However, as we know that, the framework will set the current key before calling processElement for each record. To ensure the thread safety of setting the current key, we will override the "setKeyContextElement1" method and "setCurrentkey" method of the Python aggregate operator with an empty method, so that the current key will be managed only by the StateRequestHandle.
Python Side
For streaming mode, the work flow at Python side is similar to the Java Aggregation, i.e. the GroupAggFunction + AggsHandler structure:
The GroupAggFunction and AggsHandler would access the state via Pyflink State API, which is implemented based on the Beam GrpcStateService.
For batch mode, the structure is similar to stream. But in batch mode we don’t need to consider state. Besides, all the data for a single key will be sent to the Python process together for sort agg and so the implementation will be very simple and straightforward.
Pyflink State API
The execution of Python UDAF depends on flink state, so we need to support access to flink state on the python side.
Interfaces
Currently the state is used internally. Considering that this part can be reused in the PyFlink Stateful DataStream API, we abstract the reusable part into interfaces:
T = TypeVar('T') K = TypeVar('K') V = TypeVar('V') class State(abc.ABC): @abstractmethod def clear(self) -> None: pass class ValueState(State, Generic[T]): @abstractmethod def value(self) -> T: pass @abstractmethod def update(self, value: T) -> None: pass class ListState(State, Generic[T]): @abstractmethod def get(self) -> Iterable[T]: pass @abstractmethod def add(self, value: T) -> None: pass @abstractmethod def update(self, values: List[T]) -> None: pass @abstractmethod def add_all(self, values: List[T]) -> None: pass def __iter__(self) -> Iterator[T]: return iter(self.get()) class MapState(State, Generic[K, V]): @abstractmethod def get(self, key: K) -> V: pass @abstractmethod def put(self, key: K, value: V) -> None: pass @abstractmethod def put_all(self, dict_value: Dict[K, V]) -> None: pass @abstractmethod def remove(self, key: K) -> None: pass @abstractmethod def contains(self, key: K) -> bool: pass @abstractmethod def items(self) -> Iterable[Tuple[K, V]]: pass @abstractmethod def keys(self) -> Iterable[K]: pass @abstractmethod def values(self) -> Iterable[V]: pass @abstractmethod def is_empty(self) -> bool: pass def __getitem__(self, key: K) -> V: return self.get(key) def __setitem__(self, key: K, value: V) -> None: self.put(key, value) def __delitem__(self, key: K) -> None: self.remove(key) def __contains__(self, key: K) -> bool: return self.contains(key) def __iter__(self) -> Iterator[K]: return iter(self.keys()) |
Currently we only introduce ValueState, ListState, MapState for PyFlink UDAF. Other kinds of state will be introduced when proposing the PyFlink Stateful DataStream API.
The StateDescriptors are not introduced here, because the most methods in StateDescriptor interfaces would not be used in PyFlink UDAF. It is more appropriate to propose them in the design document of Flink DataStream Stateful API. For PyFlink UDAF the states will be created directly.
Implementation
Beam portability framework has a state channel to transmit the state request from Python side to Java side. It only supports BagState (similar to Flink's ListState) or other states based on BagState for now, but its protocol design still allows us to implement MapState.
We can use ListState to simulate ValueState, i.e ValueState.value() = next(ListState.get()), ValueState.update(value) = ListState.update([value]). And ListState can also be used to simulate ReducingState and AggregatingState in future. So only supporting ListState and MapState at protocol level is enough.
Cache Strategy
It is too expensive to have a network communication for every operation of the states. We need to cache read and write operations to State. There are two kinds of cache: read-cache and write-cache. Different states have different caching strategies:
Read/Iterate | Write/Append | Remove/Clear | |
ListState | Request the entire list state of a key at first time when executing an iterate operation, then store it in read-cache. | Only support appending data to the end of the list or replacing the entire list. The appended/updated data will be cached in write-cache. The cached data will be flushed to Java side when triggering flushing. | Don’t support removing single element of the list. The clear operation will clear all the cache of current state, and itself will be cached in write-cache. The clear request will be sent to to Java side when finishing the bundle. Its priority is higher than append/update requests. |
MapState | Support request data by map key. It will try to read data from cache firstly. The requested data from remote JVM will be stored in read-cache. If the whole state is cached in Python side, just iterate it locally. Otherwise it will request remote data and combine it with the local cache. | Support writing data by map key. The data will be cached in write-cache. The cached data will be flushed to Java side when triggering flushing. | Support removing data by map key. The remove operation will be applied to the write-cache firstly. If the data to remove is not in the write-cache, the remove operation will be cached in write-cache and sent to to Java side when triggering flushing. The clear operation will clear all the cache of current state, and itself will be cached in write-cache. The clear request will be sent to to Java side when triggering flushing. Its priority is higher than any other requests. |
The principle of the strategies above is: cache everything as much as possible. If the cache size is too large, the least recently used cache can be cleared/flushed. Note that the cache referenced by any iterators should never be cleared.
"FILTER" Keyword and "DISTINCT" Keyword
According to the introduced rule, we just replace the GroupAggregate operator with PythonGroupAggregate operator. So the implementation of "FILTER" keyword and "DISTINCT" keyword is just the same as Java’s implementation. The AggsHandler will check the filter arg for every agg call first before calling the "accumulate" method of them, and will prepare a MapView object as the distinct multiset for every distinct agg call. For each agg the accumulate logic is:
Built-in Python Aggs
Currently, we do not plan to support running both Java aggs and Python aggs in one operator. But the mixed use of Python UDAF and built-in agg would be a common scenario, so we need to implement the python version built-in aggs. All the Java built-in aggs will be replaced with corresponding Python built-in aggs when converting a logical aggregate node to a python aggregate node.
Compatibility, Deprecation, and Migration Plan
This FLIP is a new feature and so there is no compatible issue with previous versions.
Implementation Plan
- Support ValueState and Basic Python UDAF(without dataview support, distinct keyword and filter keyword support) on blink stream planner
- Support Basic Python UDAF(without dataview support, distinct keyword and filter keyword support) on blink batch planner
- Support ListState and MapState
- Support DataView
- Support FILTER keyword and DISTINCT keyword
- Support mixed use with built-in aggs