Source code for spam.plotting.multivariateGaussians

"""
Library of SPAM functions for plotting multivariate gaussians
Copyright (C) 2020 SPAM Contributors

This program is free software: you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the Free
Software Foundation, either version 3 of the License, or (at your option)
any later version.

This program is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
more details.

You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.
"""


import matplotlib.pyplot as plt
import numpy
from matplotlib import cm


[docs] def plotMultivariateGaussians(phi, mean, hessian, n=100): X = numpy.linspace(mean[0] - 0.1 * hessian[0][0], mean[0] + 0.1 * hessian[0][0], n) Y = numpy.linspace(mean[1] - 0.1 * hessian[1][1], mean[1] + 0.1 * hessian[1][1], n) X, Y = numpy.meshgrid(X, Y) gauss = numpy.zeros((n, n)) print(X) print(Y) for ny in range(n): y = Y[ny, 0] for nx in range(n): x = X[0, nx] h = numpy.array([x, y]) gauss[nx, ny] = float(phi) * numpy.exp(-0.5 * (numpy.dot(numpy.dot(h - mean, hessian), h - mean))) # print h, gauss[ nx, ny ] fig = plt.figure() ax = fig.gca(projection="3d") ax.plot_surface(X, Y, gauss, cmap=cm.coolwarm, linewidth=0, antialiased=False) plt.show()