from inspect import signature
from types import FunctionType
import numpy as np
from pyqtgraph import IsocurveItem
import data_slicer.utilities as util
[docs]class ModelError(Exception) :
""" Base class for :class:`Model <data_slicer.model.Model>` related
errors.
"""
pass
[docs]class UndefinedModelError(ModelError) :
""" Error raised when an operation would require the *model* attribute of
:class: `Model <data_slicer.model.Model>` but it is not found.
"""
pass
[docs]class UndefinedAxisError(ModelError) :
""" Error raised when an operation would require all the axes in the
*axes* attribute of :class: `Model <data_slicer.model.Model>` but at least
one of them is not found.
"""
pass
[docs]class Model() :
"""
General object that allows calculating a model over some input axes. It
also provides functionalities to extract different slices from the
calculated data.
"""
MIN_AXIS_LENGTH = 100
def __init__(self, model=None) :
"""
**Parameters**
===== =================================================================
model callable; a python function representing the model.
===== =================================================================
.. seealso::
:meth:`set_model <data_slicer.model.Model.set_model>`
"""
self.data = None
if model is not None :
self.set_model(model)
else :
self.model = None
def __repr__(self) :
base = '<Model: {}>'
try :
self._check_if_model_defined()
except UndefinedModelError :
return base.format('undefined')
inner = 'n_args={}, n_kwargs={}'.format(self.n_args, self.n_kwargs)
return base.format(inner)
[docs] def set_model(self, model) :
""" Specify the function that represents the model. This should be a
function with the call signature::
model(axis1, axis2, ..., axisN, kwarg1=kwarg1_default, ...,
kwargN=kwargN_default, **kwargs)
I.e. all the positional arguments (*axis1* to *axisN*) correspond to
the required input variables. The keyword arguments (*kwarg1* to
*kwargN*) as well as further unspecified *kwargs* can be used for the
model parameters.
Information about the number of arguments is obtained through
introspection.
**Parameters**
===== =================================================================
model callable; a python function representing the model.
===== =================================================================
"""
# Check if a function was supplied
if not isinstance(model, FunctionType) :
raise TypeError('*model* has to be a function.')
self.model = model
# Get information about the supplied function
sig = signature(model)
# Count the number of different arguments
n_args = 0
n_kwargs = 0
has_var_kwargs = False
for name in sig.parameters :
param = sig.parameters[name]
kind = param.kind
if kind == param.POSITIONAL_OR_KEYWORD : n_args += 1
elif kind == param.KEYWORD_ONLY : n_kwargs += 1
elif kind == param.VAR_KEYWORD : has_var_kwargs = True
self.n_args = n_args
self.n_kwargs = n_kwargs
self.has_var_kwargs = has_var_kwargs
# Prepare a container for the axes
self.axes = n_args * [None]
def _check_if_model_defined(self) :
""" Raise an appropriate error if no model function is found. """
if self.model is None :
raise UndefinedModelError('No model has been defined. Use the '
'set_model() attribute.')
def _check_if_axes_defined(self) :
""" Raise an appropriate error if at least one of the necessary axes
has not been properliy defined.
"""
need_to_raise = False
if 'axes' in self.__dict__ :
for axis in self.axes :
if axis is None : need_to_raise = True
else :
need_to_raise = True
if need_to_raise :
message = ('At least one out of {} axes has not been defined. '
'Use the set_axis() attribute.')
raise UndefinedAxisError(message.format(self.n_args))
[docs] def set_axes(self, axes) :
""" Set all axes (inputs for the model). The axes supplied here can
either be 1 dimensional array-like objects representing the actual
values at which the model should be evaluated or simply the start and
stop values for the range in which the model will be evaluated.
If the length of any axis is smaller than Model.MIN_AXIS_LENGTH or if
just start and stop values are given, linearly space values between
start and stop will be given.
**Parameters**
==== ==================================================================
axes list of len(self.n_args); an error will be thrown if the number
of supplied axes does not match what is necessary for self.model.
==== ==================================================================
.. seealso::
:meth:`set_axis <data_slicer.model.Model.set_axis>` to set just
one specific axis.
"""
if len(axes) != self.n_args :
message = ('The number of supplied axes ({}) does not match the '
'number of required axes ({}).')
raise ValueError(message.format(len(axes), self.n_args))
for i, axis in enumerate(axes) :
self.set_axis(axis, dim=i)
[docs] def set_axis(self, axis, dim=0) :
""" Set the axis (input to the model) at position *dim*. See
documentation of :meth:`set_axes <data_slicer.model.Model.set_axes>`
for more details.
**Parameters**
==== ==================================================================
axis 1d array-like; the values at which the model should be
evaluated along this dimension. If the length is smaller than
Model.MIN_AXIS_LENGTH, linearly spaced values between the first
and last value in *axis* will be created.
==== ==================================================================
.. seealso::
:meth:`set_axes <data_slicer.model.Model.set_axes>` to set all
axes at once.
"""
self._check_if_model_defined()
if len(axis) < self.MIN_AXIS_LENGTH :
axis = np.linspace(axis[0], axis[-1], self.MIN_AXIS_LENGTH)
self.axes[dim] = axis
[docs] def get_axes_dims(self) :
""" Return the length of all axes that are defined. """
# If a model is defined, the axes will at least have been initialized
self._check_if_model_defined()
lengths = []
for axis in self.axes :
if axis is None :
lengths.append(None)
else :
lengths.append(len(axis))
return lengths
[docs] def calculate_model_data(self, axes=None, **kwargs) :
"""
Evaluate the given model function :meth:`model
<data_slicer.model.Model.model>` at every point in the hypervolume
defined by the *axes*. This means that every possible combination of
coordinates of all *axes* is created, and the model evaluated at
every such point.
**Parameters**
====== ================================================================
axes if specified, this is passed on to :meth:`set_axes
<data_slicer.model.Model.set_axes>`. Otherwise the previously
set axes will be used.
kwargs all keyword arguments are passed to the model function.
====== ================================================================
"""
self._check_if_model_defined()
if axes is None :
self._check_if_axes_defined()
else :
self.set_axes(axes)
self.meshes = np.meshgrid(*self.axes)
data = self.model(*self.meshes, **kwargs)
self.data = data
return data
[docs] def make_slice(self, dim, index, integrate=0, silent=False) :
""" Return a slice out of the model data. If the data has not yet
been calculated, try to do it first.
This wraps :func:`make_slice <data_slicer.utilities.make_slice>`.
Confer respective documentation for information on the arguments.
"""
try :
data = self.data
except AttributeError :
data = self.calculate_model_data()
return util.make_slice(data, dim, index, integrate, silent)
[docs] def get_isocurve(self, level, pen=dict(color='r', width=2), **kwargs) :
"""
.. warning::
Only possible for 2D models (i.e. self.n_args==2).
Return an isocurve (:class:`IsocurveItem
<pyqtgraph.graphicsItems.IsocurveItem>` of the model data at the
selected *level*.
This uses pyqtgraph's icocurve function, which is based on the
marching squares algorithm.
**Parameters**
===== =================================================================
level float; value at which the isocurve is generated.
pen arguments for the visual properties of the isocurve. Can be
anything which is valid for :func:`mkPen <pyqtgraph.mkPen>`.
===== =================================================================
**Returns**
============ ==========================================================
isocurveItem :class:`IsocurveItem
<pyqtgraph.graphicsItems.IsocurveItem>`
============ ==========================================================
"""
# Check for correct dimensionality
self._check_if_model_defined()
if self.n_args != 2 :
raise ModelError('Isocurves can only be generated for 2D models, '
'but n_args is {}'.format(self.n_args))
# Check if data has been calculated
if self.data is None :
self.calculate_model_data()
return IsocurveItem(data=self.data, level=level, pen=pen, **kwargs)
[docs] def get_values_around(self, value, eps) :
"""
.. warning::
unfinished
"""
mask = np.where(np.abs(self.data - value) < eps)
points = self.data[mask]
return [mesh[mask] for mesh in self.meshes], points
# Testing
if __name__ == '__main__' :
my_model = Model()
func1 = lambda x,y : x**2 + y**2
my_model.set_model(func1)
my_model.set_axes([[0, 100], [0, 150]])
data = my_model.calculate_model_data()
meshes, points = my_model.get_values_around(50, 1)