You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 18 Next »

Current state: Not ready for discussion.

Discussion thread: To be added

JIRA: To be added

Released: Not released yet.

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Motivation

An important part of the Flink ML library infrastructure is the APIs to define, set and get parameters of machine learning stages (e.g. Transformer, Estimator, AlgoOperator). Currently Flink ML library provides these APIs through the public methods of WithParams (an interface), ParamInfo, ParamInfoFactory and Params (three classes). In this FLIP we propose to simplify the Flink ML library infrastructure by reducing the number of classes as well as the number of public methods on those classes, and still support all the expected use-cases.

The goal of this FLIP is to increase developer velocity by making Flink ML library easier to use.

Issues with the existing parameter-related interface and classes

In this section we first summarize the public methods of the existing parameter-related interface and classes, followed by explanation of why they could be simplified.

1) The ParamInfo class provides the definition of a parameter, including its name, type, default value and so on.

public class ParamInfo<V> {
    private final String name;
    private final String[] alias;
    private final String description;
    private final boolean isOptional;
    private final boolean hasDefaultValue;
    private final V defaultValue;
    private final ParamValidator<V> validator;
    private final Class<V> valueClass;

    ParamInfo(String name, String[] alias, String description, boolean isOptional, boolean hasDefaultValue, V defaultValue, ParamValidator<V> validator, Class<V> valueClass) {...}

    public String getName() {...}

    public String[] getAlias() {...}

    public String getDescription() {...}

    public boolean isOptional() {...}

    public boolean hasDefaultValue() {...}

    public V getDefaultValue() {...}

    public ParamValidator<V> getValidator() {...}

    public Class<V> getValueClass() {...}
}


2) The WithParam interface provides APIs to set and get parameter values.

public interface WithParams<T> {

    Params getParams();

    default <V> T set(ParamInfo<V> info, V value) {
        getParams().set(info, value);
        return (T) this;
    }

    default <V> V get(ParamInfo<V> info) {
        return getParams().get(info);
    }
}


3) The ParamInfoFactory class provides APIs to build a ParamInfo.

public class ParamInfoFactory {

    public static <V> ParamInfoBuilder<V> createParamInfo(String name, Class<V> valueClass) {...}

    public static class ParamInfoBuilder<V> {
        ParamInfoBuilder(String name, Class<V> valueClass) {...}

        public ParamInfoBuilder<V> setAlias(String[] alias) {...}

        public ParamInfoBuilder<V> setDescription(String description) {...}

        public ParamInfoBuilder<V> setOptional() {...}

        public ParamInfoBuilder<V> setRequired() {...}

        public ParamInfoBuilder<V> setHasDefaultValue(V defaultValue) {...}

        public ParamInfoBuilder<V> setValidator(ParamValidator<V> validator) {...}

        public ParamInfo<V> build() {...}
    }
}


4) The Params class stores the mapping from parameter to value. And it provides APIs to convert ParamInfo to/from json-formatted string.

public class Params implements Serializable, Cloneable {
    public int size() {...}

    public void clear() {...}

    public boolean isEmpty() {...}

    public <V> V get(ParamInfo<V> info) {...}

    public <V> Params set(ParamInfo<V> info, V value) {...}

    public <V> void remove(ParamInfo<V> info) {...}

    public <V> boolean contains(ParamInfo<V> info) {...}

    public String toJson() {...}

    public void loadJson(String json) {...}

    public static Params fromJson(String json) {...}

    public Params merge(Params otherParams) {...}

    public Params clone() {...}
}


Here are the redundancy issues with the existing APIs:

1) ParamInfo does not need to the "String[] alias" field.

Alias is typically needed only when we need to migrate parameter of an existing machine learning algorithm to use a new name. Since there is no machine learning algorithm in the Flink ML library yet, we do not have any use-case for this field.

Many other frameworks/libraries, such as Apache Spark ML and Apache Kafka, can support their users without having alias for its configs. It would be better to follow this pattern of avoiding changing parameter names, instead of supporting alias from the very beginning.

2) ParameterInfo does not need the "boolean hasDefaultValue" field and the "boolean isOptional" field.

As far as we can tell, there is no use-case that requires these two fields.

The parameter definition itself does not specify how its value should be used. It is up to the algorithm to decide how to use parameter value and whether the parameter value is optional. Thus the algorithm, not the parameter definition, should decide whether the parameter is optional.

"hasDefaultValue" appears to be redundant because this information can be derived by checking whether the defaultValue == null.

3) ParamInfo does not need to have those getter methods. All its member fields could be declared as "public final" since there is no need to change those fields are ParamInfo is constructed.

4) ParamInfoFactory is un-necessary and we can just construct ParamInfo by calling the ParamInfo constructors.

5) Params is unnecessary after we remove the alias field from ParamInfo.

After we remove the alias field from ParamInfo, we can just replace Params with Map<ParamInfo<?>, Object>, which effectively contains the mapping from parameter definitions to parameter values.

Here are the usability issues with the existing APIs:

1) Params::loadJson() and Params::toJson(), which are used to save/load a stage, can not guarantee the same parameter values will be used, if the default value defined in the ParamInfo changes.

This behavior makes it hard to guarantee consistent accuracy/performance of an existing Transformer/Estimator.

2) The existing Params::get() and Params::set() implementations always convert the value from/to json-formatted string, which could incur unnecessary performance overhead.

It is possible that an algorithm will want to set and get parameters one or more times before saving the model to disk. Ideally this should not involve value serialization and deserialization overhead.

Public Interfaces

1) We propose to replace WithParams/ParamInfo/ParamInfoFactory/Params with Param and WithParams as shown below. And it could be shown that the proposed interface and classes address all the issues described above.

/**
 * Definition of a parameter, including name, class, description, default value and the validator.
 *
 * @param <T> The type of the parameter value
 */
public class Param<T> implements Serializable {
    public final String name;
    public final Class<T> clazz;
    public final String description;
    public final T defaultValue;
    public final ParamValidator<T> validator;

    public Param(String name, Class<T> clazz, String description, T defaultValue, ParamValidator<T> validator) {...}

    // Encodes the given object into a json-formatted string
    public String jsonEncode(T value) throws IOException {...}

    // Decodes the json-formatted string into an object of the given type.
    public T jsonDecode(String json) throws IOException {...}
}



/**
 * Interface for classes that take parameters. It provides APIs to set and get parameters.
 *
 * @param <T> The class type of WithParams implementation itself.
 */
@PublicEvolving
public interface WithParams<T> {
    /** Gets a param by name. */
    default <V> Param<V> getParam(String name) {...}

    /**
     * Sets the value of the given parameter in the user-defined map.
     *
     * @param param the parameter
     * @param value the value
     * @return the WithParams instance itself
     */
    default <V> T set(Param<V> param, V value) {...}

    /**
     * Gets the value of the given parameter. Returns the value from the user-defined map if set(...) has been
     * explicitly called to set value for this parameter. Otherwise, returns the default value from the definition of
     * this parameter.
     *
     * @param param the parameter
     * @param <V> the type of the parameter
     * @return the value of the parameter
     */
    default <V> V get(Param<V> param) {...}

    /**
     * Returns an immutable map that contains value for every parameter that meets one of the following conditions:
     * 1) set(...) has been called to set value for this parameter.
     * 2) The parameter is a field of this WithParams instance. This includes public, protected and private fields. And
     * this also includes fields inherited from its interfaces and super-classes.
     *
     * @return an immutable map of parameters and values.
     */
    default Map<Param<?>, Object> getParamMap() {...}

    /**
     * Returns a mutable map that can be used to set values for parameters. A subclass of this interface should override
     * this method if it wants to support users to set non-default parameter values.
     *
     * @return a mutable map of parameters and value overrides.
     */
    default Map<Param<?>, Object> getUserDefinedParamMap() {
        return null;
    }
}



2) We propose to add the following subclasses of Param<?> to simplify the creation of parameters with primitive-typed values (e.g. long, int, boolean). 

public class BooleanParam extends Param<Boolean> {
  ...
}

public class IntParam extends Param<Integer> {
  ...
}

public class LongParam extends Param<Long> {
  ...
}

public class FloatParam extends Param<Float> {
  ...
}

public class DoubleParam extends Param<Double> {
  ...
}

public class StringParam extends Param<String> {
  ...
}

public class IntArrayParam extends Param<Integer[]> {
  ...
}

public class LongArrayParam extends Param<Long[]> {
  ...
}

public class FloatArrayParam extends Param<Float[]> {
  ...
}

public class DoubleArrayParam extends Param<Double[]> {
  ...
}


3) We propose to add the following subclasses of ParamValidator<?> to simplify the creation of parameter validators with numerical values.

/**
 * Factory methods for common validation functions. The numerical methods only support Int, Long,
 * Float, and Double.
 */
public class ParamValidators {
    // Always return true.
    public static <T> ParamValidator<T> alwaysTrue() {...}

    // Check if the parameter value is greater than lowerBound.
    public static <T> ParamValidator<T> gt(double lowerBound) {...}

    // Check if the parameter value is greater than or equal to lowerBound.
    public static <T> ParamValidator<T> gtEq(double lowerBound) {...}

    // Check if the parameter value is less than upperBound.
    public static <T> ParamValidator<T> lt(double upperBound) {...}

    // Check if the parameter value is less than or equal to upperBound.
    public static <T> ParamValidator<T> ltEq(double upperBound) {...}

    /**
     * Check if the parameter value is in the range from lowerBound to upperBound.
     *
     * @param lowerInclusive if true, range includes value = lowerBound
     * @param upperInclusive if true, range includes value = upperBound
     */
    public static <T> ParamValidator<T> inRange(double lowerBound, double upperBound, boolean lowerInclusive, boolean upperInclusive) {...}

    // Check if the parameter value is in the range [lowerBound, upperBound].
    public static <T> ParamValidator<T> inRange(double lowerBound, double upperBound) {...}
}


Example Usage

In the following we provide an example code snippet that shows how to define, set and get parameter values with various types.

More specifically, the code snippet covers the following feature:

1) How to define parameters of various primitive types (e.g. long, int array, string).

2) How to define a parameter as a static field in an interface and accesses this parameter.

3) How to define a  parameter as a non-static field in a class and accesses this parameter.

4) How to access a parameter by a Param<?> variable.

5) How to access a parameter by its name.


// An example interface that provides pre-defined parameters.
public interface MyParams<T> extends WithParams<T> {
    Param<Boolean> BOOLEAN_PARAM = new BooleanParam("booleanParam", "Description", false);

    Param<Integer> INT_PARAM = new IntParam("intParam", "Description", 1, ParamValidators.lt(100));

    Param<Long> LONG_PARAM = new LongParam("longParam", "Description", 2L, ParamValidators.lt(100));

    Param<Integer[]> INT_ARRAY_PARAM = new IntArrayParam("intArrayParam", "Description", new Integer[] {3, 4});

    Param<String[]> STRING_ARRAY_PARAM = new StringArrayParam("stringArrayParam", "Description", new String[] {"5", "6"});
}

// An example stage class that defines its own parameters and also inherits parameters from MyParams.
public static class MyStage implements Stage<MyStage>, MyParams<MyStage> {
    private final Map<Param<?>, Object> paramMap = new HashMap<>();

    Param<Integer> extraIntParam = new IntParam("extraIntParam", "Description", 100, ParamValidator.ALWAYS_TRUE);

    public MyStage() {}

    @Override
    public Map<Param<?>, Object> getUserDefinedParamMap() {
        return paramMap;
    }

    // Skipped implementation of save() and load().
}


public static void main(String[] args) {
    MyStage stage = new MyStage();

    // Gets the value of a parameter defined in the MyParams interface without first setting its value.
    Long[] longArrayValue = stage.get(MyParams.LONG_ARRAY_PARAM);

    // Sets and gets value of a parameter defined in the MyParams interface.
    stage.set(MyParams.INT_PARAM, 1);
    Integer intValue = stage.get(MyParams.INT_PARAM);

    // Sets and gets value of a parameter defined in the MyStage class.
    stage.set(stage.extraIntParam, 2);
    Integer extraIntValue = stage.get(stage.extraIntParam);

    // Sets and gets value of a parameter identified by its name string.
    Param<?> longParam = stage.getParam("longParam");
    stage.set(longParam, 3L);
    Long longValue = (Long) stage.get(stage.getParam("longParam"));
}


Compatibility, Deprecation, and Migration Plan

The changes proposed in this FLIP is backward incompatible with the existing APIs of WithParams/ParamInfo/ParamInfoFactory/Params. We propose to change the APIs directly without deprecation period.

Since there is no implementation of Estimator/Transformer (excluding test-only implementations) in the existing Flink codebase, no work is needed to migrate the existing Flink codebase.

Test Plan

We will provide unit tests to validate the proposed changes.

Rejected Alternatives

There is no rejected alternatives to be listed here yet.


  • No labels