Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.


...

Page properties


Discussion thread

...

https://lists.apache.org/thread/zgbyp5hjhsp2bs3t1txq2p1l5t3c08yt
Vote threadhttps://lists.apache.org/thread/pq1ot6wj1j87jxm4tqydl4vf6klqsy4l
JIRA

Jira
serverASF JIRA
serverId5aa69414-a9e9-3523-82ec-879b028fb15b
keyFLINK-24354

Releaseml-2.0.0


Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Table of Contents

[This FLIP proposal is a joint work between Dong Lin and Zhipeng Zhang]

Motivation

An important part of the Flink ML library infrastructure is the APIs to define, set and get parameters of machine learning stages (

JIRA: To be added

Released: 1.15

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Table of Contents

Motivation

An important part of the Flink ML library infrastructure is the APIs to define, set and get parameters of machine learning stages (e.g. Transformer, Estimator, AlgoOperator). Currently Flink ML library provides these APIs through the public methods of WithParams (an interface), ParamInfo, ParamInfoFactory and Params (three classes). In this FLIP we propose to simplify the Flink ML library infrastructure by reducing the number of classes as well as the number of public methods on those classes, and still support all the expected use-cases.

...

3) The ParamInfoFactory class provides APIs to build a ParamInfo.

Code Block
languagejava

public class ParamInfoFactory {

    public static <V> ParamInfoBuilder<V> createParamInfo(String name, Class<V> valueClass) {...}

    public static class ParamInfoBuilder<V> {
        ParamInfoBuilder(String name, Class<V> valueClass) {...}

        public ParamInfoBuilder<V> setAlias(String[] alias) {...}

        public ParamInfoBuilder<V> setDescription(String description) {...}

        public ParamInfoBuilder<V> setOptional() {...}

        public ParamInfoBuilder<V> setRequired() {...}

        public ParamInfoBuilder<V> setHasDefaultValue(V defaultValue) {...}

        public ParamInfoBuilder<V> setValidator(ParamValidator<V> validator) {...}

        public ParamInfo<V> build() {...}
    }
}

...

Here are the usability issues with the existing APIs:

1) Params::containsloadJson() would only recognizes parameters whose value has been explicitly set by and Params::settoJson(...).

This means that even if users want to use the default value specified in the parameter definition, users would need to explicitly call Params::set(paramInfo, paramInfo.defaultValue) in order for this parameter to be recognized by Params::contains and Params::toJson etc.

), which are used to save/load a stage, can not guarantee the same parameter values will be used, if the default value defined in the ParamInfo changes.

This behavior makes it hard to guarantee consistent accuracy/performance of an existing Transformer/Estimator.

2) The existing Params::get() and Params::2) The existing Params::get() and Params::set() implementations always convert the value from/to json-formatted string, which could incur unnecessary performance overhead.

It is possible that an algorithm will want to set and get parameters 1+ one or more times before saving the model to disk. Ideally this should not involve value serialization and deserialization overhead.

...

Code Block
languagejava
/**
 * Definition of a parameter, including name, class, description, default value and the validator.
 *
 * @param <T> The type of the parameter value
 */
public class Param<T> implements Serializable {
    public final String name;
    public final Class<?>Class<T> clazz;
    public final String description;
    public final ObjectT defaultValue;
    public final ParamValidatorParamValidator<T> validator;

    public Param(String name, Class<?>Class<T> clazz, String description, ObjectT defaultValue, ParamValidatorParamValidator<T> validator) {...}

    // Encodes the given object into a json-formatted string
    public String jsonEncode(ObjectT value) throws IOException {...}

    // Decodes the json-formatted string into an object of the given type.
    public T jsonDecode(String json) throws IOException {...}
}


...

Code Block
languagejava
/**
 * Interface for classes that take parameters. It provides APIs to set and get parameters.
 *
 * @param <T> The class type of WithParams implementation itself.
 */
@PublicEvolving
public interface WithParams<T> { 
    /**
 Gets a param by name. */
    default Param<?> getParam(String name) {...}

* Gets the parameter by its name.
     /**
     * Sets@param the value of the givenname The parameter in the user-defined mapname.
     *
 @param <V> The class *type @paramof param the parameter value.
     * @param@return value the valueThe parameter.
     * @return the WithParams instance itself/
    default */
<V>    default T set(Param<?> param, Object valueParam<V> getParam(String name) {...}

    /**
     * GetsSets the value of the given parameter.
 Returns the value from the user-defined map if set(...) has been    *
     * @param param The parameter.
     * explicitly@param calledvalue toThe setparameter value.
 for this parameter. Otherwise, returns* the@return defaultThe valueWithParams frominstance theitself.
 definition of
   */
  * this parameter.
 @SuppressWarnings("unchecked")
    default <V> *
T set(Param<V> param, V  * @param param the parametervalue) {...}

    /**
     * @paramGets <V> the typevalue of the parameter.
     *
     * @param param The parameter.
     * @return @param <V> The class type of the parameter value of the parameter.
     * @return The parameter value.
     */
    @SuppressWarnings("unchecked")
    default <V> V get(Param<V> param) {...}

    /**
     * Returns ana immutablemap mapwhich thatshould containscontain value for every parameter that meets one of the following
  conditions:   * conditions.
     *
     * 1<p>1) set(...) has been called to set value for this parameter.
     *
     * 2<p>2) The parameter is a public final field of this WithParams instance. This includes public, protected and private fields. And
     * this also includes fields inherited from its interfaces and super-classes.
     *
     * @return<p>The ansubclass immutablewhich mapimplements ofthis parameters and values.interface could meet this requirement by returning a
     */
    default Map<Param<?>, Object> getParamMap() {...}

    /** member field of the given map type, after having initialized this member field using the
     * Returns a mutable map that can be used to set values for parameters. A subclass of this interface should override{@link ParamUtils#initializeMapWithDefaultValues(Map, WithParams)} method.
     *
     * this@return methodA ifmap itwhich wantsmaps toparameter supportdefinition users to set non-default parameter valuesvalue.
     */
    Map<Param<?>, * @return a mutable map of parameters and value overrides.
     */
    default Map<Param<?>, Object> getUserDefinedParamMap() {
        return null;
    }
}


Object> getParamMap();
}


2) We propose to add the following subclasses of Param<?> to simplify the creation of parameters with primitive-typed values (e.g. long, int, boolean). 

Code Block
languagejava

public class BooleanParam extends Param<Boolean> {
  ...
}

public class IntParam extends Param<Integer> {
  ...
}

public class LongParam extends Param<Long> {
  ...
}

public class FloatParam extends Param<Float> {
  ...
}

public class DoubleParam extends Param<Double> {
  ...
}

public class StringParam extends Param<String> {
  ...
}

public class IntArrayParam extends Param<Integer[]> {
  ...
}

public class LongArrayParam extends Param<Long[]> {
  ...
}

public class FloatArrayParam extends Param<Float[]> {
  ...
}

public class DoubleArrayParam extends Param<Double[]> {
  ...
}

...

3) We propose to add the following subclasses of ParamValidator<?> to simplify the creation of parameter validators with numerical util method to facilitate the initialization of parameter map with default parameter values.

Code Block
languagejava
/**
 *Utility Factory methods for commonreading and validationwriting functionsstages. */
public class ParamUtils {
    /**
     * Updates the paramMap with default values of all public final Param-typed fields of the given
     * instance. A parameter's value will not be updated if this parameter is already found in the
     * map.
     *
     * <p>Note: This method should be called after all public final Param-typed fields of the given
     * instance have been defined. A good choice is to call this method in the constructor of the
     * given instance.
     */
    public static void initializeMapWithDefaultValues(Map<Param<?>, Object> paramMap, WithParams<?> instance) {...}
}  


4) We propose to add the following subclasses of ParamValidator<?> to simplify the creation of parameter validators with numerical values.

Code Block
languagejava
/**
 * Factory methods for common validation functions. The numerical methods only support Int, Long,
 * Float, and Double.
 */
public class ParamValidators {
    // Always return true.
    public static <T> ParamValidator<T> alwaysTrue() {...}

    // Check if the parameter value is greater than lowerBound.
    public static <T> ParamValidator<T> gt(double lowerBound) {...}

    // Check if the parameter value is greater than or equal to lowerBound.The numerical methods only support Int, Long,
 * Float, and Double.
 */
public class ParamValidators {
    public static final ParamValidator<?> ALWAYS_TRUE = (Object value) -> true;

    // Check if the parameter value is greater than lowerBound.
    public static <T> ParamValidator<T> gt(double lowerBound) {...}

    // Check if the parameter value is greater than or equal to lowerBound.
    public static <T> ParamValidator<T> gtEq(double lowerBound) {...}

    // Check if the parameter value is less than upperBound.
    public static <T> ParamValidator<T> lt(double upperBound) {...}

    // Check if the parameter value is less than or equal to upperBound.
    public static <T> ParamValidator<T> ltEq(double upperBound) {...}

    /**
     * Check if the parameter value is in the range from lowerBound to upperBound.
     *
     * @param lowerInclusive if true, range includes value = lowerBound
     * @param upperInclusive if true, range includes value = upperBound
     */
    public static <T> ParamValidator<T> inRangegtEq(double lowerBound, double upperBound, boolean lowerInclusive, boolean upperInclusive) {...}

    // Check if the parameter value is inless the range [lowerBound, upperBound]than upperBound.
    public static <T> ParamValidator<T> lt(double upperBound) {...}

    // Check if the parameter value is less than or equal to upperBound.
    public static <T> ParamValidator<T> inRangeltEq(double lowerBound, double upperBound) {...}
}

Example Usage

In the following we provide an example code snippet that shows how to define, set and get parameter values with various types.

More specifically, the code snippet covers the following feature:

1) How to define parameters of various primitive types (e.g. long, int array, string).

2) How to define a parameter as a static field in an interface and accesses this parameter.

3) How to define a  parameter as a non-static field in a class and accesses this parameter.

4) How to access a parameter by a Param<?> variable.

5) How to access a parameter by its name.

Code Block
languagejava
// An example interface that provides pre-defined parameters.
public interface MyParams<T> extends WithParams<T> {
    Param<Boolean> BOOLEAN_PARAM = new BooleanParam("booleanParam", "Description", false);

    Param<Integer> INT_PARAM = new IntParam("intParam", "Description", 1, ParamValidators.lt(100));

    Param<Long> LONG_PARAM = new LongParam("longParam", "Description", 2L, ParamValidators.lt(100));

    Param<Integer[]> INT_ARRAY_PARAM = new IntArrayParam("intArrayParam", "Description", new Integer[] {3, 4});

    Param<String[]> STRING_ARRAY_PARAM = new StringArrayParam("stringArrayParam", "Description", new String[] {"5", "6"});
}

// An example stage class that defines its own parameters and also inherits parameters from MyParams.
public static class MyStage implements Stage<MyStage>, MyParams<MyStage> {
    private final Map<Param<?>, Object> paramMap = new HashMap<>();

    Param<Integer> extraIntParam = new IntParam("extraIntParam", "Description", 100, ParamValidator.ALWAYS_TRUE);

    public MyStage() {}

    @Override
    public Map<Param<?>, Object> getUserDefinedParamMap() {
        return paramMap;
    }

    // Skipped implementation of save() and load().
}


public static void main(String[] args) {
    MyStage stage = new MyStage();

    // Gets the value of a parameter defined in the MyParams interface without first setting its value.
    Long[] longArrayValue = stage.get(MyParams.LONG_ARRAY_PARAM);

    // Sets and gets value of a parameter defined in the MyParams interface.
    stage.set(MyParams.INT_PARAM, 1);
    Integer intValue = stage.get(MyParams.INT_PARAM);

    // Sets and gets value of a parameter defined in the MyStage class.
    stage.set(stage.extraIntParam, 2);
    Integer extraIntValue = stage.get(stage.extraIntParam);

    // Sets and gets value of a parameter identified by its name string.
    Param<?> longParam = stage.getParam("longParam");
    stage.set(longParam, 3L);
    Long longValue = (Long) stage.get(stage.getParam("longParam"));

}

Compatibility, Deprecation, and Migration Plan

The changes proposed in this FLIP is backward incompatible with the existing APIs of WithParams/ParamInfo/ParamInfoFactory/Params. We propose to change the APIs directly without deprecation period.

Since there is no implementation of Estimator/Transformer (excluding test-only implementations) in the existing Flink codebase, no work is needed to migrate the existing Flink codebase.

Test Plan

We will provide unit tests to validate the proposed changes.

Rejected Alternatives


    /**
     * Check if the parameter value is in the range from lowerBound to upperBound.
     *
     * @param lowerInclusive if true, range includes value = lowerBound
     * @param upperInclusive if true, range includes value = upperBound
     */
    public static <T> ParamValidator<T> inRange(double lowerBound, double upperBound, boolean lowerInclusive, boolean upperInclusive) {...}

    // Check if the parameter value is in the range [lowerBound, upperBound].
    public static <T> ParamValidator<T> inRange(double lowerBound, double upperBound) {...}

    // Check if the parameter value is in the array of allowed values.
    public static <T> ParamValidator<T> inArray(T... allowed) {...}

    // Check if the parameter value is not null.
    public static <T> ParamValidator<T> notNull() {...}
}

Proposed Changes

We make the following notes regarding the implementation and the usage of the proposed interfaces:

1) With the proposed interface, algorithm developers can define multiple parameters by calling e.g. Param<Boolean> BOOLEAN_PARAM = new BooleanParam(...). And in order to have WithParams::getParamMap() return those parameter values, regardless of whether algorithm users have explicitly called WithParams::set(...) to set parameter values, algorithm developer should make sure to initialize the paramMap in the constructor.

This can be achieved by calling the util method ParamUtils.initializeMapWithDefaultValues(paramMap, withParamsInstance).


2) The initializeMapWithDefaultValues(paramMap, withParamsInstance) method will use Java reflection to enumerate all public final fields of withParamInstance, find those fields assignable from the Param class, and update the given paramMap with default value for those Param fields that are not already found in the paramMap.

In order for this to work correctly, initializeMapWithDefaultValues(...) should be called after all public final Param-typed fields of the given WithParams instance have been defined. A good choice is to call this method in the constructor of the WithParams instance.


Example Usage

In the following we provide an example code snippet that shows how to define, set and get parameter values with various types.

More specifically, the code snippet covers the following feature:

1) How to define parameters of various primitive types (e.g. long, int array, string).

2) How to define a parameter as a static field in an interface and accesses this parameter.

3) How to define a  parameter as a non-static field in a class and accesses this parameter.

4) How to access a parameter by a Param<?> variable.

5) How to access a parameter by its name.


Code Block
languagejava
// An example interface that provides pre-defined parameters.
public interface MyParams<T> extends WithParams<T> {
    Param<Boolean> BOOLEAN_PARAM = new BooleanParam("booleanParam", "Description", false);

    Param<Integer> INT_PARAM = new IntParam("intParam", "Description", 1, ParamValidators.lt(100));

    Param<Long> LONG_PARAM = new LongParam("longParam", "Description", 2L, ParamValidators.lt(100));

    Param<Integer[]> INT_ARRAY_PARAM = new IntArrayParam("intArrayParam", "Description", new Integer[] {3, 4});

    Param<String[]> STRING_ARRAY_PARAM = new StringArrayParam("stringArrayParam", "Description", new String[] {"5", "6"});
}

// An example stage class that defines its own parameters and also inherits parameters from MyParams.
public static class MyStage implements Stage<MyStage>, MyParams<MyStage> {
    private final Map<Param<?>, Object> paramMap = new HashMap<>();

    public final Param<Integer> extraIntParam = new IntParam("extraIntParam", "Description", 100, ParamValidator.ALWAYS_TRUE);

    public MyStage() {
        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
    }

    // Skipped implementation of save() and load().
}


public static void main(String[] args) {
    MyStage stage = new MyStage();

    // Gets the value of a parameter defined in the MyParams interface without first setting its value.
    Long[] longArrayValue = stage.get(MyParams.LONG_ARRAY_PARAM);

    // Sets and gets value of a parameter defined in the MyParams interface.
    stage.set(MyParams.INT_PARAM, 1);
    Integer intValue = stage.get(MyParams.INT_PARAM);

    // Sets and gets value of a parameter defined in the MyStage class.
    stage.set(stage.extraIntParam, 2);
    Integer extraIntValue = stage.get(stage.extraIntParam);

    // Sets and gets value of a parameter identified by its name string.
    Param<?> longParam = stage.getParam("longParam");
    stage.set(longParam, 3L);
    Long longValue = (Long) stage.get(stage.getParam("longParam"));
}


Compatibility, Deprecation, and Migration Plan

The changes proposed in this FLIP is backward incompatible with the existing APIs of WithParams/ParamInfo/ParamInfoFactory/Params. We propose to change the APIs directly without deprecation period.

Since there is no implementation of Estimator/Transformer (excluding test-only implementations) in the existing Flink codebase, no work is needed to migrate the existing Flink codebase.

Test Plan

We will provide unit tests to validate the proposed changes.

Rejected Alternatives


1) Add the method "Map<Param<?>, Object> getInternalParmMap()" in in the WithParams interface. And do not add the initializeMapWithDefaultValues(...) util method.

In comparison to the proposed approach, this alternative approach makes the life a bit easier for the algorithm developer. And algorithm developer would not need to write code to invoke initializeMapWithDefaultValues(). The algorithm developer just needs to override getInternalParmMap() to return a member field of type Map<Param<?>, Object>. 

The downside of this proposed approach is that the algorithm users will see the getInternalParmMap() API in the WithParams interface that is never useful to them. The existence of getInternalParmMap() and getParamMap() on the same interface could be confusing to algorithm users.

We choose not to use this approach because we believe there will be much more algorithm users than algorithm developers. And it is more important to optimize the algorithm users' experienceThere is no rejected alternatives to be listed here yet.