Source code for pymlp.mlp.NetworkVisualizer

import pylab
import numpy as np


[docs]class NetworkVisualizer(): """Visualizes the weights of a network.""" def __init__(self, rows, columns, colormap=pylab.cm.RdBu): pylab.ion() figureAndSubplots = pylab.subplots(columns, rows) self.figure = figureAndSubplots[0] self.subplots = np.array(figureAndSubplots[1:]).reshape((columns, rows)) self.data = [[None for i in xrange(rows)] for j in xrange(columns)] self.colormap = colormap
[docs] def setData(self, weights, column, row): self.data[row][column] = weights
[docs] def removeUnnecessaryPlots(self): """Removes the unused subfigures so there are no empty graphs. Helpful for modular networks.""" self._mapOverData(self._filter)
def _filter(self, i, j): if self.data[i][j] is None: pylab.delaxes(self.subplots[i, j]) # remove from member list self.subplots[i][j] = None
[docs] def visualize(self): self._mapOverData(self._plot) pylab.draw() # pylab.subplots_adjust()
def _plot(self, i, j): if self.data[i][j] is not None: self.subplots[i][j].cla() # clear old data self.subplots[i][j].pcolormesh(self.data[i][j], cmap=self.colormap) def _mapOverData(self, f): for i in xrange(len(self.data)): for j in xrange(len(self.data[0])): f(i, j)
[docs]def visualize(): """Sample function to show, how the visualizer is used.""" import time visualizer = NetworkVisualizer(3, 4) weights = np.array([[2, 3, 4], [-1, 3, -4], [0, 1, 5]]) visualizer.setData(weights, 0, 3) weights2 = np.array([[-1, 2, 3], [4, -5, 6]]) visualizer.setData(weights2, 1, 1) visualizer.setData(weights2 * 10, 2, 0) # test scaled values visualizer.removeUnnecessaryPlots() visualizer.visualize() time.sleep(1) # play a little to show, how it works weights3 = np.array([[0, 0, 0], [1, 2, 3], [-10, 2, 3], [1, 2, 3]]) for i in xrange(100): visualizer.setData(weights3, 0, 3) visualizer.visualize() weights3[0, 0] += 1 time.sleep(0.001) visualizer.setData(weights, 0, 3) visualizer.visualize() time.sleep(0.001) pylab.close()