Source code for nn.feed_forward

"""A straightforward feed-forward neural network NEAT implementation."""
from neat.graphs import feed_forward_layers
from neat.six_util import itervalues


[docs]class FeedForwardNetwork(object): def __init__(self, inputs, outputs, node_evals): self.input_nodes = inputs self.output_nodes = outputs self.node_evals = node_evals self.values = dict((key, 0.0) for key in inputs + outputs)
[docs] def activate(self, inputs): if len(self.input_nodes) != len(inputs): raise RuntimeError( "Expected {0:n} inputs, got {1:n}".format(len(self.input_nodes), len(inputs))) for k, v in zip(self.input_nodes, inputs): self.values[k] = v for node, act_func, agg_func, bias, response, links in self.node_evals: node_inputs = [] for i, w in links: node_inputs.append(self.values[i] * w) s = agg_func(node_inputs) self.values[node] = act_func(bias + response * s) return [self.values[i] for i in self.output_nodes]
[docs] def reset(self): """Method in case using instead of recurrent due to genome phenotype being feed-forward""" pass
[docs] @staticmethod def create(genome, config): """ Receives a genome and returns its phenotype (a FeedForwardNetwork). """ # Gather expressed connections. connections = [cg.key for cg in itervalues(genome.connections) if cg.enabled] layers = feed_forward_layers(config.genome_config.input_keys, config.genome_config.output_keys, connections) node_evals = [] for layer in layers: for node in layer: inputs = [] node_expr = [] # currently unused for conn_key in connections: inode, onode = conn_key if onode == node: cg = genome.connections[conn_key] inputs.append((inode, cg.weight)) node_expr.append("v[{}] * {:.7e}".format(inode, cg.weight)) ng = genome.nodes[node] aggregation_function = config.genome_config.aggregation_defs.get(ng.aggregation) activation_function = config.genome_config.activation_defs.get(ng.activation) node_evals.append((node, activation_function, aggregation_function, ng.bias, ng.response, inputs)) return FeedForwardNetwork(config.genome_config.input_keys, config.genome_config.output_keys, node_evals)