Source code for attributes

"""Deals with the attributes (variable parameters) of genes"""
#from __future__ import print_function
import copy
import sys
import warnings

from random import choice, gauss, random, uniform

from neat.config import ConfigParameter
from neat.six_util import iterkeys, iteritems

if sys.version_info.major > 2:
    unicode = str # pylint: disable=redefined-builtin

# TODO: There is probably a lot of room for simplification of these classes using metaprogramming.


[docs]class BaseAttribute(object): """Superclass for the type-specialized attribute subclasses, used by genes.""" def __init__(self, name, default_ok=False, **default_dict): self.name = name self._config_items = copy.deepcopy(self._config_items_init) self.default_ok = {} for n in iterkeys(self._config_items): self.default_ok[n] = False for n, default in iteritems(default_dict): self._config_items[n] = [self._config_items[n][0], default] if default_ok and (default is not None): self.default_ok[n] = True for n in iterkeys(self._config_items): setattr(self, n + "_name", self.config_item_name(n))
[docs] def config_item_name(self, config_item_base_name): return "{0}_{1}".format(self.name, config_item_base_name)
[docs] def get_config_params(self): return [ConfigParameter(self.config_item_name(n), self._config_items[n][0], default=self._config_items[n][1], default_ok=self.default_ok[n]) for n in iterkeys(self._config_items)]
[docs]class FloatAttribute(BaseAttribute): """ Class for numeric attributes, such as the response of a node or the weight of a connection. """ _config_items_init = {"init_mean": [float, None], "init_stdev": [float, None], "init_type": [str, 'gaussian'], "replace_rate": [float, None], "mutate_rate": [float, None], "mutate_power": [float, None], "max_value": [float, None], "min_value": [float, None]}
[docs] def clamp(self, value, config): min_value = getattr(config, self.min_value_name) max_value = getattr(config, self.max_value_name) return max(min(value, max_value), min_value)
[docs] def init_value(self, config): mean = getattr(config, self.init_mean_name) stdev = getattr(config, self.init_stdev_name) init_type = getattr(config, self.init_type_name).lower() if ('gauss' in init_type) or ('normal' in init_type): return self.clamp(gauss(mean, stdev), config) if 'uniform' in init_type: min_value = max(getattr(config, self.min_value_name), (mean-(2.0*stdev))) max_value = min(getattr(config, self.max_value_name), (mean+(2.0*stdev))) return uniform(min_value, max_value) raise RuntimeError("Unknown init_type {!r} for {!s}".format(getattr(config, self.init_type_name), self.init_type_name))
[docs] def mutate_value(self, value, config): # mutate_rate is usually no lower than replace_rate, # and frequently higher - so put first for efficiency mutate_rate = getattr(config, self.mutate_rate_name) r = random() if r < mutate_rate: try: mutate_power = getattr(config, self.mutate_power_name) except AttributeError: warnings.warn("Config {0!r} __dict__ is:\n\t".format(config) + "\n\t".join(dir(config))) raise return self.clamp(value + gauss(0.0, mutate_power), config) replace_rate = getattr(config, self.replace_rate_name) if r < replace_rate + mutate_rate: return self.init_value(config) return value
## def validate(self, config): # pragma: no cover ## pass
[docs]class BoolAttribute(BaseAttribute): """Class for boolean attributes such as whether a connection is enabled or not.""" _config_items_init = {"default": [str, None], "mutate_rate": [float, None], "rate_to_true_add": [float, 0.0], "rate_to_false_add": [float, 0.0]}
[docs] def init_value(self, config): default = str(getattr(config, self.default_name)).lower() if default in ('1', 'on', 'yes', 'true'): return True elif default in ('0', 'off', 'no', 'false'): return False elif default in ('random', 'none'): return bool(random() < 0.5) raise RuntimeError("Unknown default value {!r} for {!s}".format(default, self.name))
[docs] def mutate_value(self, value, config): mutate_rate = getattr(config, self.mutate_rate_name) if value: mutate_rate += getattr(config, self.rate_to_false_add_name) else: mutate_rate += getattr(config, self.rate_to_true_add_name) if mutate_rate > 0: r = random() if r < mutate_rate: # NOTE: we choose a random value here so that the mutation rate has the # same exact meaning as the rates given for the string and bool # attributes (the mutation operation *may* change the value but is not # guaranteed to do so). return random() < 0.5 return value
## def validate(self, config): # pragma: no cover ## pass
[docs]class StringAttribute(BaseAttribute): """ Class for string attributes (such as, previously, the aggregation function of a node) that are selected from a list of options. """ _config_items_init = {"default": [str, 'random'], "options": [list, None], "mutate_rate": [float, None]}
[docs] def init_value(self, config): default = getattr(config, self.default_name) if default.lower() in ('none','random'): options = getattr(config, self.options_name) return choice(options) return default
[docs] def mutate_value(self, value, config): mutate_rate = getattr(config, self.mutate_rate_name) if mutate_rate > 0: r = random() if r < mutate_rate: options = getattr(config, self.options_name) return choice(options) return value
## def validate(self, config): # pragma: no cover ## pass
[docs]class FuncAttribute(BaseAttribute): """ Handle attributes that may be simple strings or may be functions needing multiparameter handling. """ _config_items_init = copy.deepcopy( StringAttribute._config_items_init) # pylint: disable=protected-access ## def copy(self): ## return copy.deepcopy(self)
[docs] def init_value(self, config): default = getattr(config, self.default_name) if default in (None, 'random'): options = getattr(config, self.options_name) default = choice(options) if hasattr(default, 'init_value'): # pragma: no cover default.init_value(config) elif not isinstance(default, (str, unicode)): # put in test for raise RuntimeError("Unknown what to do with value {0!r} for {1!s}".format(default, self.name)) elif hasattr(config, 'multiparameterset'): multiparam = config.multiparameterset if multiparam.is_multiparameter(default, self.name): default = multiparam.init_multiparameter(default, self, config) return default
[docs] def mutate_value(self, value, config): mutate_rate = getattr(config, self.mutate_rate_name) if mutate_rate > 0: r = random() if r < mutate_rate: options = getattr(config, self.options_name) value = choice(options) if hasattr(value, 'mutate_value'): #print("Accessing mutate_value function of {!r}".format(value)) value.mutate_value(config) elif not isinstance(value, (str, unicode)): # put in test for raise RuntimeError("Unknown what to do with value {0!r} for {1!s}".format(value, self.name)) elif hasattr(config, 'multiparameterset'): multiparam = config.multiparameterset if multiparam.is_multiparameter(value, self.name): value = multiparam.init_multiparameter(value, self, config) return value