Source code for aggregations

"""
Has the built-in aggregation functions, methods for using them,
and methods for adding new user-defined ones.
"""
from __future__ import division

import math
import sys
import warnings

from operator import mul
from pprint import saferepr

from neat.multiparameter import MultiParameterSet
from neat.multiparameter import BadFunctionError as InvalidAggregationFunction # pylint: disable=unused-import
from neat.math_util import mean, median2, tmean

if sys.version_info[0] > 2:
    from functools import reduce

[docs]def product_aggregation(x): # note: `x` is a list or other iterable return reduce(mul, x, 1.0)
[docs]def sum_aggregation(x): return sum(x)
[docs]def max_aggregation(x): return max(x)
[docs]def min_aggregation(x): return min(x)
[docs]def maxabs_aggregation(x): return max(x, key=abs)
[docs]def median_aggregation(x): return median2(x)
[docs]def mean_aggregation(x): return mean(x)
[docs]def tmean_aggregation(x): return tmean(x,trim=0.25)
def _check_value_range(a, min_val, max_val, caller, var_name): # TEST NEEDED! if not min_val <= a <= max_val: raise ValueError( "{0} for {1}_aggregation must be between {2:n} and {3:n}, not {4!r}".format( var_name, caller, min_val, max_val, a))
[docs]def maxabs_mean_aggregation(x, a): _check_value_range(a, 0.0, 1.0, 'maxabs_mean', 'a') return ((1.0-a)*mean(x))+(a*maxabs_aggregation(x))
[docs]def multiparam_tmean_aggregation(x, a): _check_value_range(a, 0.0, 0.5, 'multiparam_tmean', 'a') return tmean(x,trim=a)
[docs]def maxabs_tmean_aggregation(x, a): _check_value_range(a, -1.0, 1.0, 'maxabs_tmean', 'a') if a >= 0.0: return maxabs_mean_aggregation(x, a=a) return multiparam_tmean_aggregation(x, a=abs(a/2))
[docs]def sum_product_aggregation(x, a): _check_value_range(a, 0.0, 1.0, 'sum_product', 'a') return ((1.0-a)*product_aggregation(x))+(a*sum(x))
[docs]def max_median_min_aggregation(x, a): _check_value_range(a, -1.0, 1.0, 'max_median_min', 'a') median_weight = 1.0-abs(a) if a > 0.0: return (((1.0-median_weight)*max(x))+ (median_weight*median2(x))) if a < 0.0: return (((1.0-median_weight)*min(x))+ (median_weight*median2(x))) return median2(x)
[docs]def sum_mean_aggregation(x, a): _check_value_range(a, 0.0, 1.0, 'sum_mean', 'a') input_list = list(map(float,x)) num_input = len(input_list) if num_input == 1: return input_list[0] mult = (1+(a*(num_input - 1)))/num_input return sum(input_list)*mult
[docs]def product_mean_aggregation(x, a, use_median): _check_value_range(a, 0.0, 1.0, 'product_mean', 'a') if not isinstance(use_median, bool): raise TypeError( "Type of use_median must be bool, not {0!s} (for {1!r})".format( type(use_median), use_median)) input_list = list(map(float,x)) num_input = len(input_list) if num_input == 1: return input_list[0] power = (1+(a*(num_input - 1)))/num_input tmp_product = product_aggregation(input_list) transformed_product = math.pow(abs(tmp_product), power) if use_median: return math.copysign(transformed_product, median2(input_list)) return math.copysign(transformed_product, tmp_product)
[docs]def sum_product_mean_aggregation(x, a, b, use_median): _check_value_range(a, 0.0, 1.0, 'sum_product_mean', 'a') _check_value_range(b, 0.0, 1.0, 'sum_product_mean', 'b') return ((b*sum_mean_aggregation(x, a))+ ((1.0-b)*product_mean_aggregation(x, a, use_median)))
[docs]class AggregationFunctionSet(object): """Contains aggregation functions and methods to add and retrieve them.""" def __init__(self, multiparameterset=None): if multiparameterset is None: warn_string = ("Aggregation init called without multiparameterset:" + " may cause multiple instances of it") multiparameterset = MultiParameterSet('aggregation') warnings.warn(warn_string) self.multiparameterset = multiparameterset self.add('product', product_aggregation) self.add('sum', sum_aggregation) self.add('max', max_aggregation) self.add('min', min_aggregation) self.add('maxabs', maxabs_aggregation) self.add('median', median_aggregation) self.add('mean', mean_aggregation) self.add('tmean', tmean_aggregation) self.add('multiparam_tmean', multiparam_tmean_aggregation, a={'min_value':0.0, 'max_value':0.5}) self.add('max_median_min', max_median_min_aggregation, a={'min_value':-1.0, 'max_value':1.0}) self.add('maxabs_mean', maxabs_mean_aggregation, a={'min_value':0.0, 'max_value':1.0}) self.add('maxabs_tmean', maxabs_tmean_aggregation, a={'min_value':-1.0, 'max_value':1.0}) self.add('sum_mean', sum_mean_aggregation, a={'min_value':0.0, 'max_value':1.0}) self.add('product_mean', product_mean_aggregation, a={'min_value':0.0, 'max_value':1.0}, use_median={'param_type': 'bool'}) self.add('sum_product', sum_product_aggregation, a={'min_value':0.0, 'max_value':1.0}) self.add('sum_product_mean', sum_product_mean_aggregation, a={'min_value':0.0, 'max_value':1.0}, b={'min_value':0.0, 'max_value':1.0}, use_median={'param_type': 'bool'})
[docs] def add(self, name, function, **kwargs): self.multiparameterset.add_func(name, function, 'aggregation', **kwargs)
[docs] def get(self, name): return self.multiparameterset.get_func(name, 'aggregation')
[docs] def __getitem__(self, index): warnings.warn( "Use get, not indexing ([{!s}]), for aggregation functions".format(saferepr(index)), DeprecationWarning) return self.get(index)
[docs] def is_valid(self, name): return self.multiparameterset.is_valid_func(name, 'aggregation')