Skip to content

utils.draw.plot_scatter

Function ยท Source

mdnc.utils.draw.plot_scatter(
    gen,
    xlabel=None, ylabel='value', x_log=None, y_log=False,
    figure_size=(6, 5.5), legend_loc=None, legend_col=None,
    fig=None, ax=None
)

Plot a scatter graph for multiple data groups. Each group is given by:

  • a 1D data arrays (x, y coordinates),
  • or a 2D array (N, 2), the second axis represents the x, y corrdincates.

Arguments

Requries

Argument Type Description
gen object A generator callable object (function), each yield returns a sample. It allows users to provide an extra kwargs dict for each iteration (see Examples). For each iteration, it returns 2 1D arrays or a 2D array.
xlabel str The x axis label.
ylabel str The y axis label.
x_log bool A flag. Whether to convert the x axis into the logarithmic format.
y_log bool A flag. Whether to convert the y axis into the logarithmic format.
figure_size (float, float) A tuple with two values representing the (width, height) of the output figure. The unit is inch.
legend_loc str or
int or
(float, float)
The localtion of the legend, see matplotlib.pyplot.legend to view details. (The legend only works when passing label to each iteration).
legend_col int The number of columns of the legend, see matplotlib.pyplot.legend to view details. (The legend only works when passing label to each iteration).
fig object A matplotlib figure instance. If not given, would use plt.gcf() for instead.
ax object A matplotlib subplot instance. If not given, would use plt.gca() for instead.

Examples

Example
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np
import matplotlib.pyplot as plt
import mdnc

@mdnc.utils.draw.setFigure(style='seaborn-darkgrid', font_size=16)
def test_scatter():
    def func_gen():
        size = 100
        for i in range(3):
            center = -4.0 + 4.0 * np.random.rand(2)
            scale = 0.5 + 2.0 * np.random.rand(2)
            x1 = np.random.normal(loc=center[0], scale=scale[0], size=size)
            x2 = np.random.normal(loc=center[1], scale=scale[1], size=size)
            yield np.power(10, x1), np.power(10, x2), {'label': r'$x_{' + str(i + 1) + r'}$'}

    mdnc.utils.draw.plot_scatter(func_gen(), x_log=True, y_log=True,
                                 xlabel='Metric 1', ylabel='Metric 2')
    plt.show()

test_scatter()


Last update: March 14, 2021

Comments