You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 4 Next »

Status

Current stateUnder Discussion

Discussion threadhttp://apache-flink-mailing-list-archive.1008284.n3.nabble.com/DISCUSS-FLIP-153-Support-state-access-in-Python-DataStream-API-tt47127.html

JIRAhere (<- link to https://issues.apache.org/jira/browse/FLINK-XXXX)

Released: TBD

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).


Motivation

In FLIP-130, we have already supported Python DataStream stateless API so that users are able to perform some basic data transformations. To implement more complex data processing, we need to provide state access support. In this doc, I would propose to add support of state access for Python DataStream API to support stateful operations on a KeyedStream.

Goals

  1. Provide state access APIs in KeyedStream.
  2. Support ValueState, ListState, MapState, ReducingState and AggregatingState.

Background

PyFlink leverages the Beam portability framework to start a Python process to execute user defined Python functions.

Beam’s portability framework has provided the following building blocks for state access:

  1. State API between the Runner and the SDK harness which could be used for state access in the Python user-defined function. (Note: The state communication uses a separate channel from the data channel.
  2. It has defined five kinds of states in the proto message in the State API (Refer to StateKey for more details) and three types of operations for state access in the State API: Get / Append / Clear:

    State Type

    Usecase

    Runner

    Remote references (and other Runner specific extensions)

    IterableSideInput

    Side input of iterable

    MultimapSideInput

    Side input of of values of map

    MultimapKeyedSideInput

    Side input of of keys of map

    BagUserState

    User state with primary key (value state, bag state, combining value state)

    Among them, only BagUserStage is dedicated for state access in Beam. The others are used for data access in Beam. 

  3. Building on the proto message of BagUserState, it has supported four kinds of user-facing API in Beam’s Python SDK harness: BagRuntimeState, SetRuntimeState, ReadModifyWriteRuntimeState and CombiningValueRuntimeState.

It could be seen that there are two layers: the proto message layer and the user-facing API layer:

  1. In the proto message layer, it has defined 5 types of proto messages.
  2. Each proto message could only represent a single kind of state in the underlying execution engine and it’s up to the operator to decide which kind of state one kind of proto message mapped to.
  3. In the user-facing API (State defined in Python function) layer, the Python SDK harness could expose different kinds of user-facing API even with the same underlying proto message.

BagUserStage will be mapped to ListState in Flink, to support the other kinds of state, such as MapState which could not be simulated using ListState, we will make use of the other kinds of proto messages even if they are not designed to be used for state access in Beam. This could work as it’s up to the operator to decide which kind of state it’s mapped to and will be described in the following sections.


Public Interfaces

We will introduce a series of States and StateDescriptors for the state types supported.

PS: Note that ValueState/ListState/MapState have already been introduced (not exposed to users yet) in FLIP-139 when supporting Python UDAF. We just need to introduce the StateDescriptors for them.

ValueState

We will introduce ValueState and ValueStateDescriptor to let users use value state in Python DataStream API. The interfaces are as following:

import abc

from pyflink.common.typeinfo import TypeInformation


class StateDescriptor(abc.ABC):

   def __init__(self, name: str):

       self.name = name

class ValueStateDescriptor(StateDescriptor):

   def __init__(self,
                      name: str,

                       value_type_info: TypeInformation):

       super(ValueStateDescriptor, self).__init__(name)

       self.value_type_info = value_type_info


class State(ABC):

   @abstractmethod

   def clear(self) -> None:

       pass


class ValueState(State, Generic[T]):

   @abstractmethod

   def value(self) -> T:

       pass

   @abstractmethod

   def update(self, value: T) -> None:

       pass

ListState

class ListStateDescriptor(StateDescriptor):

   def __init__(self,
                      name,
                      elem_type_info: TypeInformation):

       super(ListStateDescriptor, self).__init__(name)

       self.elem_type_info = elem_type_info

class ListState(State, Generic[T]):

   @abstractmethod

   def get(self) -> Iterable[T]:

       pass

   @abstractmethod

   def add(self, value: T) -> None:

       pass

   @abstractmethod

   def update(self, values: List[T]) -> None:

       pass

   @abstractmethod

   def add_all(self, values: List[T]) -> None:

       pass

   def __iter__(self) -> Iterator[T]:

       return iter(self.get())

MapState

class MapStateDescriptor(StateDescriptor):

   def __init__(self,
                      name: str,

                       key_type_info: TypeInformation,
                      value_type_info: TypeInformation):

       super(MapStateDescriptor, self).__init__(name)

       self.key_type_info = key_type_info

       self.value_type_info = value_type_info


class MapState(State, Generic[K, V]):

   @abstractmethod

   def get(self, key: K) -> V:

       pass

   @abstractmethod

   def put(self, key: K, value: V) -> None:

       pass

   @abstractmethod

   def put_all(self, dict_value: Dict[K, V]) -> None:

       pass

   @abstractmethod

   def remove(self, key: K) -> None:

       pass

   @abstractmethod

   def contains(self, key: K) -> bool:

       pass

   @abstractmethod

   def items(self) -> Iterable[Tuple[K, V]]:

       pass

   @abstractmethod

   def keys(self) -> Iterable[K]:

       pass

   @abstractmethod

   def values(self) -> Iterable[V]:

       pass

   @abstractmethod

   def is_empty(self) -> bool:

       pass

   def __getitem__(self, key: K) -> V:

       return self.get(key)

   def __setitem__(self, key: K, value: V) -> None:

       self.put(key, value)

   def __delitem__(self, key: K) -> None:

       self.remove(key)

   def __contains__(self, key: K) -> bool:

       return self.contains(key)

   def __iter__(self) -> Iterator[K]:

       return iter(self.keys())

ReducingState

class ReduceFunction(Function, Generic[T]):

  

   def reduce(self, first: T, second: T) -> T:

       pass


class ReducingStateDescriptor(StateDescriptor):

  def __init__(self,

               name: str,

               reduce_function: ReduceFunction,

               elem_type_info: TypeInformation):

      super(ReducingStateDescriptor, self).__init__(name)

      self.reduce_function = reduce_function

      self.elem_type_info = elem_type_info

class ReducingState(State, Generic[T]):

   def get(self) -> T:

       pass

   def add(self, value: T):

       pass

AggregatingState

class AggregateFunction(Function, Generic[ACC, IN, OUT]):

   def create_accumulator(self) -> ACC:

       pass

   def add(self, acc: ACC, value: IN) -> ACC:

       pass

   def get_result(self, acc: ACC) -> OUT:

       pass

   def merge(self, acc1: ACC, acc2: ACC) -> ACC:

       pass


class AggregatingStateDescriptor(StateDescriptor):

  def __init__(self,

               name: str,

               aggregate_function: AggregateFunction,

               acc_type_info: TypeInformation):

      super(AggregatingStateDescriptor, self).__init__(name)

      self.aggregate_function = aggregate_function

      self.acc_type_info = acc_type_info


class AggregatingState(State, Generic[T]):

   def get(self) -> T:

       pass

   def add(self, value: T):

       pass

RuntimeContext

RuntimeContext contains information about the context in which functions are executed. The following methods will be added in RuntimeContext to allow creating state.

class RuntimeContext(object):

      def get_state(self, state_descriptor: ValueStateDescriptor) -> ValueState:

            pass

     def get_list_state(self, state_descriptor: ListStateDescriptor) -> ListState:

           pass

     def get_map_state(self, state_descriptor: MapStateDescriptor) -> MapState:
          pass

     def get_reducing_state(self, state_descriptor: ReducingStateDescriptor) -> ReducingState:
          pass

     def get_aggregating_state(self, state_descriptor: AggregatingStateDescriptor) -> AggregatingState:
          pass

Proposed Design

Architecture

The overall architecture will be as following:



1) At the Python side, when users access state, a StateRequest(Get/Append/Clear) will be sent to the Java operator via the state channel

2) At the Java operator, upon receiving a StateRequest, the operator will read/write the state backend according to the type of the StateRequest. It will also return StateRespose(it holds the value of the state for write requests) to the Python worker. 

3) StateRequestHandler will be created in the Java operator to process the state access requests from the Python side. It’s a callback which will be called by Beam’s portability framework upon receiving a state request. It’s called in a separate callback thread instead of the main thread in the Java operator. 

Current key in state backend

When an operator receives an element, it will extract the key of the element and set it to the underlying state backend. However, as it processes the data and executes the Python functions asynchronously in PyFlink, it will process the next element without waiting for the result of the previous element. As a result, it may happen that the received state request represents a different element than the current element being processed. The keys of them may be different.

To solve this problem, we need to manage the current key of the state backend ourselves:

  • Override the setCurrentKey() in AbstractStreamOperator to not set the current key of the underlying state backend
  • Calling StateBackend.setCurrentKey() before processing the state request in StateRequestHandler

Checkpoint

As all the state write operations will be delegated to the underlying state backend of the Java operator, we just need to make sure that all the state mutation requests are sent back to the Java operator before checkpoint, then the checkpoint will just work as it’s.

How to handle each kind of state 

PyFlink State Type

Beam proto message

Flink State Type

Remark

ValueState

BagUserState


ListState


Could be seen as a list with only single element

ListState

BagUserState

ListState

Match perfectly

MapState

MultimapSideInput

MapState

1)  There is still no built-in grpc message support in Beam for MapState, see discussion for details

2)  We will propose a discussion/design in Beam to add MapState support

3) Before that, we could reuse the gRPC message MultimapSideInput for MapState support.

ReducingState

BagUserState

ListState

There is no way to provide a Java ReduceFunction and so it will not be mapped to the Java ReducingState directly. We will simulate ReducingState via the Java ListState.

AggregatingState

BagUserState

ListState

It’s not possible to provide a Java AggregateFunction and so it will not be mapped to the Java AggregatingState directly. We will simulate AggregatingState via ListState.

Example

The following example shows a simple example on how to use state access in Python DataStream.

  • It firstly defines a DataStream from row elements consisting of three fields (String, Int, Int)
  • It then transforms it into a KeyedStream keyed by the first field of each record and performs a deduplication according to the key.

class DuplicateFilter(FlatMapFunction):

  def open(self, runtime_context):

    descriptor = ValueStateDescriptor("seen", Types.BOOLEAN)

    self.value_state = runtime_context.get_state(descriptor)

  def flat_map(self, value):
    if not self.value_state.value():

        self.value_state.update(True)

        yield value

ds = self.env.from_collection([Row(‘hello’, 1, 2), Row(‘hello’,  1, 3), Row(‘flink’, 2, 1)])

ds.key_by(lambda x: x[0], key_type_info=Types.INT())

    .flat_map(DuplicateFilter(), output_type=Types.STRING())

    .add_sink(MySink())

Implementation Plan

  1. Support ValueState in Python DataStream API
  2. Support ListState in Python DataStream API
  3. Support MapState in Python DataStream API
  4. Support ReducingState in Python DataStream API
  5. Support AggregatingState in Python DataStream API
  • No labels