"""
A module for :any:`Snapshot` objects used to visualize the protocol during or after
the simulation has run.
:any:`Snapshot` is a base class for snapshot objects that get are updated by :any:`Simulation`.
:any:`Plotter` is a subclass of :any:`Snapshot` that creates a matplotlib figure and axis.
It also gives the option for a state_map function which maps states to the categories which
will show up in the plot.
:any:`StatePlotter` is a subclass of :any:`Plotter` that creates a barplot of the counts
in categories.
:any:`HistoryPlotter` is a subclass of :any:`Plotter` that creates a lineplot of the counts
in categories over time.
"""
from typing import Optional, Callable, Hashable, Any
from natsort import natsorted
import numpy as np
import pandas as pd # type: ignore
from tqdm import tqdm
State = Hashable
[docs]class Snapshot:
"""Base class for snapshot objects.
Attributes:
simulation: The :any:`Simulation` object that initialized and will update the
:any:`Snapshot`.
This attribute gets set when the :any:`Simulation` object calls
:any:`add_snapshot`.
update_time: How many seconds will elapse between calls to update while
in the :any:`Simulation.run` method of :any:`simulation`.
time: The time at the current snapshot. Changes when :any:`Snapshot.update` is called.
config: The configuration array at the current snapshot. Changes when
:any:`Snapshot.update` is called.
"""
def __init__(self) -> None:
"""Init constructor for the base class.
Parameters can be passed in here, and any attributes that can be defined
without the parent :any:`Simulation` object can be instantiated here, such as
:any:`update_time`.
"""
self.simulation = None
self.update_time = 0.1
self.time = None
self.config = None
[docs] def initialize(self) -> None:
"""Method which is called once during :any:`add_snapshot`.
Any initialization that requires accessing the data in :any:`simulation`
should go here.
"""
if self.simulation is None:
raise ValueError('self.simulation is None, cannot call self.initialize until using sim.add_snapshot')
[docs] def update(self, index: Optional[int] = None) -> None:
"""Method which is called while :any:`Snapshot.simulation` is running.
Args:
index: An optional integer index. If present, the snapshot will use the
data from configuration :any:`configs` ``[index]`` and time
:any:`times` ``[index]``. Otherwise, the snapshot will use the current
configuration :any:`config_array` and current time.
"""
if self.simulation is None:
raise ValueError('self.simulation is None, cannot call self.update until using sim.add_snapshot')
if index is not None:
self.time = self.simulation.times[index]
self.config = self.simulation.configs[index]
else:
self.time = self.simulation.time
self.config = self.simulation.config_array
[docs]class TimeUpdate(Snapshot):
"""Simple :any:`Snapshot` that prints the current time in the :any:`Simulation`.
When calling :any:`Simulation.run`, if :any:`snapshots` is empty, then
this object will get added to provide a basic progress update.
"""
def __init__(self, time_bound: Optional[float] = None, update_time: float = 0.2) -> None:
self.pbar = tqdm(total=time_bound, position=0, leave=False, unit=' time simulated')
self.update_time = update_time
[docs] def initialize(self) -> None:
self.start_time = self.simulation.time
[docs] def update(self, index: Optional[int] = None) -> None:
super().update(index)
new_n = round(self.time - self.start_time, 3)
self.pbar.update(new_n - self.pbar.n)
[docs]class Plotter(Snapshot):
"""Base class for a :any:`Snapshot` which will make a plot.
Gives the option to map states to categories, for an easy way to visualize
relevant subsets of the states rather than the whole state set.
These require an interactive matplotlib backend to work.
Attributes:
fig: The matplotlib figure that is created.
ax: The matplotlib axis object that is created. Modifying properties
of this object is the most direct way to modify the plot.
yscale: The scale used for the yaxis, passed into ax.set_yscale.
state_map: A function mapping states to categories, which acts as a filter
to view a subset of the states or just one field of the states.
categories: A list which holds the set ``{state_map(state)}`` for all states
in :any:`state_list`.
sort_by:
_matrix: A (# states)x(# categories) matrix such that for the configuration
array (indexed by states), ``matrix * config`` gives an array
of counts of categories. Used internally to get counts of categories.
"""
def __init__(self, state_map: Optional[Callable[[State], Any]]=None, update_time=0.5, yscale='linear',
sort_by: str = 'categories') -> None:
"""Initializes the :any:`Plotter`.
Args:
state_map: An optional function mapping states to categories.
yscale: The scale used for the yaxis, passed into ax.set_yscale.
Defaults to 'linear'.
"""
self._matrix = None
self.state_map = state_map
self.update_time = update_time
self.yscale = yscale
self.sort_by = sort_by
def _add_state_map(self, state_map):
"""An internal function called to update :any:`categories` and `_matrix`."""
self.categories = []
for state in self.simulation.state_list:
if state_map(state) is not None and state_map(state) not in self.categories:
self.categories.append(state_map(state))
self.categories = natsorted(self.categories, key=lambda x: repr(x))
categories_dict = {j: i for i, j in enumerate(self.categories)}
self._matrix = np.zeros((len(self.simulation.state_list), len(self.categories)), dtype=np.int64)
for i, state in enumerate(self.simulation.state_list):
m = state_map(state)
if m is not None:
self._matrix[i, categories_dict[m]] += 1
[docs] def initialize(self) -> None:
"""Initializes the plotter by creating a fig and ax."""
# Only do matplotlib import when necessary
super().initialize()
from matplotlib import pyplot as plt
self.fig, self.ax = plt.subplots()
if self.state_map is not None:
self._add_state_map(self.state_map)
else:
self.categories = self.simulation.state_list
[docs]class StatePlotter(Plotter):
""":any:`Plotter` which produces a barplot of counts."""
[docs] def initialize(self) -> None:
"""Initializes the barplot.
If :any:`state_map` gets changed, call :any:`initialize` to update the barplot to
show the new set :any:`categories`.
"""
super().initialize()
import seaborn as sns
self.ax = sns.barplot(x=[str(c) for c in self.categories], y=np.zeros(len(self.categories)))
# rotate the x-axis labels if any of the label strings have more than 2 characters
if max([len(str(c)) for c in self.categories]) > 2:
for tick in self.ax.get_xticklabels():
tick.set_rotation(90)
self.ax.set_yscale(self.yscale)
if self.yscale in ['symlog', 'log']:
self.ax.set_ylim(0, 2 * self.simulation.simulator.n)
else:
self.ax.set_ylim(0, self.simulation.simulator.n)
[docs] def update(self, index: Optional[int] = None) -> None:
"""Update the heights of all bars in the plot."""
super().update(index)
if self._matrix is not None:
heights = np.matmul(self.config, self._matrix)
else:
heights = self.config
for i, rect in enumerate(self.ax.patches):
rect.set_height(heights[i])
self.ax.set_title(f'Time {self.time: .3f}')
self.fig.tight_layout()
self.fig.canvas.draw()
[docs]class HistoryPlotter(Plotter):
"""Plotter which produces a lineplot of counts over time."""
[docs] def update(self, index: Optional[int] = None) -> None:
"""Make a new history plot."""
super().update(index)
self.ax.clear()
if self._matrix is not None:
df = pd.DataFrame(data=np.matmul(self.simulation.history.to_numpy(), self._matrix),
columns=self.categories,
index=self.simulation.history.index)
else:
df = self.simulation.history
df.plot(ax=self.ax)
self.ax.set_yscale(self.yscale)
if self.yscale in ['symlog', 'log']:
self.ax.set_ylim(0, 2 * self.simulation.simulator.n)
else:
self.ax.set_ylim(0, 1.1 * self.simulation.simulator.n)
# rotate the x labels if they are time units
if self.simulation.time_units:
for tick in self.ax.get_xticklabels():
tick.set_rotation(45)
self.fig.tight_layout()
self.fig.canvas.draw()