...
- register the module
- define the sql functions
- implement the functions in C++
- register the C++ header files
1. Register the module
Add the following line to the file called Modules.yml
under ./src/config/
yaml
Code Block | ||
---|---|---|
| ||
- name: hello_world |
and create two folders: ./src/ports/postgres/modules/hello_world
and ./src/modules/hello_world
. The names of the folders need to match the name of the module specified in Modules.yml
.
2. Define the SQL functions
Create file avg_var.sql_in
under folder ./src/ports/postgres/modules/hello_world
. Inside this file we define the aggregate function and other helper functions for computing mean and variance. The actual implementations of those functions will be in separate C++ files which we will describe in the next section.
...
We define the aggregate function avg_var
using built-in PostgreSQL command CREATE AGGREGATE
.
Code Block | ||
---|---|---|
| ||
DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.avg_var(DOUBLE PRECISION); CREATE AGGREGATE MADLIB_SCHEMA.avg_var(DOUBLE PRECISION) ( SFUNC=MADLIB_SCHEMA.avg_var_transition, STYPE=double precision[], FINALFUNC=MADLIB_SCHEMA.avg_var_final, m4_ifdef(`__POSTGRESQL__', `', `PREFUNC=MADLIB_SCHEMA.avg_var_merge_states,') INITCOND='{0, 0, 0}' ); |
We also define parameters passed to CREATE AGGREGATE
:
...
The transition, merge, and final functions are defined in the same file avg_var.sql_in
as the aggregate function. More details about those functions can be found in the PostgreSQL documentation.
3. Implement the functions in C++
Create the header and the source files, avg_var.hpp
and avg_var.cpp
, under the folder ./src/modules/hello_world
. In the header file we declare the transition, merge and final functions using the macro DECLARE_UDF(MODULE, NAME)
. For example, the transition function avg_var_transition
is declared as DECLARE_UDF(hello_world, avg_var_transition)
. The macro DECLARE_UDF
is defined in the file dbconnector.hpp
under ./src/ports/postgres/dbconnector
.
Under the hood, each of the three UDFs is declared as a subclass of dbconnector::postgres::UDF
. The behavior of those UDFs is solely determined by its member functionc++ AnyType
Code Block | ||
---|---|---|
| ||
AnyType run(AnyType &args); |
In other words, we only need to implement the following methods in the avg_var.cpp
file:c++ AnyType
Code Block | ||
---|---|---|
| ||
AnyType avg_var_transition::run(AnyType& args); AnyType avg_var_merge_states::run(AnyType& args); AnyType avg_var_final::run(AnyType& args); |
Here the AnyType
class works for both passing data from the DBMS to the C++ function, as well as returning values back from C++. Refer to TypeTraits_impl.hpp
for more details.#####
Transition function
c++ AnyType avgvartransition::run(AnyType& args) {// get current state value AvgVarTransitionState<MutableArrayHandle<double> > state = args[0]; // get current row value double x = args[1].getAs<double>(); double d = (x - state.avg);
// online update mean
state.avg += d / static_cast<double>(state.numRows + 1);
double new_d = (x - state.avg);
double a = static_cast<double>(state.numRows) / static_cast<double>(state.numRows + 1);
// online update variance
state.var = state.var * a + d * new_d / static_cast<double>(state.numRows + 1);
state.numRows ++;
return state;
}
Code Block | ||
---|---|---|
| ||
AnyType
avgvartransition::run(AnyType& args) {
// get current state value
AvgVarTransitionState<MutableArrayHandle<double> > state = args[0];
// get current row value
double x = args[1].getAs<double>();
double d = (x - state.avg);
// online update mean
state.avg += d / static_cast<double>(state.numRows + 1);
double new_d = (x - state.avg);
double a = static_cast<double>(state.numRows) / static_cast<double>(state.numRows + 1);
// online update variance
state.var = state.var * a + d * new_d / static_cast<double>(state.numRows + 1);
state.numRows ++;
return state;
} |
- there are two arguments for
avg_var_transition, as specified in
avg_var.sql_ - there are two arguments for
avgvartransition, as specified in
avgvar.sqlin. The first one is an array of SQL double type, corresponding to the current mean, variance, and number of rows traversed and the second one is a double representing the current tuple value.
- we will describe class
AvgVarTransitionStatelater. Basically it takes
args[0]`, a SQL double array, passes the data to the appropriate C++ types and stores them in the state
instance.
- both the mean and the variance are updated in an online manner to avoid accumulating large intermediate sum.
Merge function
AnyType
avg_var_merge_states::run(AnyType& args) {
AvgVarTransitionState<MutableArrayHandle<double> > stateLeft = args[0];
AvgVarTransitionState<ArrayHandle<double> > stateRight = args[1];
// Merge states together and return
stateLeft += stateRight;
return stateLeft;
}
- again, the arguments contained in
AnyType& args
are defined inavg_var.sql_in
. - the details are hidden in method of class
AvgVarTransitionState
which overloads the operator+=
Final function
AnyType
avg_var_final::run(AnyType& args) {
AvgVarTransitionState<MutableArrayHandle<double> > state = args[0];
// If we haven't seen any data, just return Null. This is the standard
// behavior of aggregate function on empty data sets (compare, e.g.,
// how PostgreSQL handles sum or avg on empty inputs)
if (state.numRows == 0)
return Null();
return state;
}
...