"""
Has the built-in aggregation functions, methods for using them,
and methods for adding new user-defined ones.
"""
from __future__ import division
import math
import sys
import warnings
from operator import mul
from pprint import saferepr
from neat.multiparameter import MultiParameterSet
from neat.multiparameter import BadFunctionError as InvalidAggregationFunction # pylint: disable=unused-import
from neat.math_util import mean, median2, tmean
if sys.version_info[0] > 2:
from functools import reduce
[docs]def product_aggregation(x): # note: `x` is a list or other iterable
return reduce(mul, x, 1.0)
[docs]def sum_aggregation(x):
return sum(x)
[docs]def max_aggregation(x):
return max(x)
[docs]def min_aggregation(x):
return min(x)
[docs]def maxabs_aggregation(x):
return max(x, key=abs)
[docs]def mean_aggregation(x):
return mean(x)
[docs]def tmean_aggregation(x):
return tmean(x,trim=0.25)
def _check_value_range(a, min_val, max_val, caller, var_name): # TEST NEEDED!
if not min_val <= a <= max_val:
raise ValueError(
"{0} for {1}_aggregation must be between {2:n} and {3:n}, not {4!r}".format(
var_name, caller, min_val, max_val, a))
[docs]def maxabs_mean_aggregation(x, a):
_check_value_range(a, 0.0, 1.0, 'maxabs_mean', 'a')
return ((1.0-a)*mean(x))+(a*maxabs_aggregation(x))
[docs]def multiparam_tmean_aggregation(x, a):
_check_value_range(a, 0.0, 0.5, 'multiparam_tmean', 'a')
return tmean(x,trim=a)
[docs]def maxabs_tmean_aggregation(x, a):
_check_value_range(a, -1.0, 1.0, 'maxabs_tmean', 'a')
if a >= 0.0:
return maxabs_mean_aggregation(x, a=a)
return multiparam_tmean_aggregation(x, a=abs(a/2))
[docs]def sum_product_aggregation(x, a):
_check_value_range(a, 0.0, 1.0, 'sum_product', 'a')
return ((1.0-a)*product_aggregation(x))+(a*sum(x))
[docs]def sum_mean_aggregation(x, a):
_check_value_range(a, 0.0, 1.0, 'sum_mean', 'a')
input_list = list(map(float,x))
num_input = len(input_list)
if num_input == 1:
return input_list[0]
mult = (1+(a*(num_input - 1)))/num_input
return sum(input_list)*mult
[docs]def product_mean_aggregation(x, a, use_median):
_check_value_range(a, 0.0, 1.0, 'product_mean', 'a')
if not isinstance(use_median, bool):
raise TypeError(
"Type of use_median must be bool, not {0!s} (for {1!r})".format(
type(use_median), use_median))
input_list = list(map(float,x))
num_input = len(input_list)
if num_input == 1:
return input_list[0]
power = (1+(a*(num_input - 1)))/num_input
tmp_product = product_aggregation(input_list)
transformed_product = math.pow(abs(tmp_product), power)
if use_median:
return math.copysign(transformed_product, median2(input_list))
return math.copysign(transformed_product, tmp_product)
[docs]def sum_product_mean_aggregation(x, a, b, use_median):
_check_value_range(a, 0.0, 1.0, 'sum_product_mean', 'a')
_check_value_range(b, 0.0, 1.0, 'sum_product_mean', 'b')
return ((b*sum_mean_aggregation(x, a))+
((1.0-b)*product_mean_aggregation(x, a, use_median)))
[docs]class AggregationFunctionSet(object):
"""Contains aggregation functions and methods to add and retrieve them."""
def __init__(self, multiparameterset=None):
if multiparameterset is None:
warn_string = ("Aggregation init called without multiparameterset:" +
" may cause multiple instances of it")
multiparameterset = MultiParameterSet('aggregation')
warnings.warn(warn_string)
self.multiparameterset = multiparameterset
self.add('product', product_aggregation)
self.add('sum', sum_aggregation)
self.add('max', max_aggregation)
self.add('min', min_aggregation)
self.add('maxabs', maxabs_aggregation)
self.add('median', median_aggregation)
self.add('mean', mean_aggregation)
self.add('tmean', tmean_aggregation)
self.add('multiparam_tmean', multiparam_tmean_aggregation,
a={'min_value':0.0, 'max_value':0.5})
self.add('max_median_min', max_median_min_aggregation,
a={'min_value':-1.0, 'max_value':1.0})
self.add('maxabs_mean', maxabs_mean_aggregation,
a={'min_value':0.0, 'max_value':1.0})
self.add('maxabs_tmean', maxabs_tmean_aggregation,
a={'min_value':-1.0, 'max_value':1.0})
self.add('sum_mean', sum_mean_aggregation,
a={'min_value':0.0, 'max_value':1.0})
self.add('product_mean', product_mean_aggregation,
a={'min_value':0.0, 'max_value':1.0},
use_median={'param_type': 'bool'})
self.add('sum_product', sum_product_aggregation,
a={'min_value':0.0, 'max_value':1.0})
self.add('sum_product_mean', sum_product_mean_aggregation,
a={'min_value':0.0, 'max_value':1.0},
b={'min_value':0.0, 'max_value':1.0},
use_median={'param_type': 'bool'})
[docs] def add(self, name, function, **kwargs):
self.multiparameterset.add_func(name, function, 'aggregation', **kwargs)
[docs] def get(self, name):
return self.multiparameterset.get_func(name, 'aggregation')
[docs] def __getitem__(self, index):
warnings.warn(
"Use get, not indexing ([{!s}]), for aggregation functions".format(saferepr(index)),
DeprecationWarning)
return self.get(index)
[docs] def is_valid(self, name):
return self.multiparameterset.is_valid_func(name, 'aggregation')