Skip to content

utils.draw.plot_training_records

Function ยท Source

mdnc.utils.draw.plot_training_records(
    gen,
    xlabel=None, ylabel='value', x_mark_num=None, y_log=False,
    figure_size=(6, 5.5), legend_loc=None, legend_col=None,
    fig=None, ax=None
)

Plot a training curve graph for multiple data groups. Each group is given by:

  • 4 1D arrays, representing the x axis of training metrics, the trainining metric values, the x axis of validation metrics, the validation metric values respectively.
  • or 2 2D arrays. Both of them have a shape of (N, 2). The two arrays represents the x axis and training metrics, the x axis and validation metric values respectively.
  • or 2 1D arrays. In this case, the validation data is not provided. The two arrays represents the x axis of training metrics, the trainining metric values repspectively.
  • or a 4D array. The 4 columns represents the x axis of training metrics, the trainining metric values, the x axis of validation metrics, the validation metric values respectively.
  • or a 2D array. In this case, the validation data is not provided. The two columns represents the x axis of training metrics, the trainining metric values repspectively.
  • or a 1D array. In this case, the validation data is not provided. The data represnets the training metrics. The x axis would be generated automatically.

Arguments

Requries

Argument Type Description
gen object A generator callable object (function), each yield returns a sample. It allows users to provide an extra kwargs dict for each iteration (see Examples). For each iteration, it returns 4 1D arrays, or 2 2D arrays, or 2 1D arrays, or a 4D array, or a 2D array, or a 1D array.
xlabel str The x axis label.
ylabel str The y axis label.
x_mark_num int The number of markers for the x axis.
y_log bool A flag. Whether to convert the y axis into the logarithmic format.
figure_size (float, float) A tuple with two values representing the (width, height) of the output figure. The unit is inch.
legend_loc str or
int or
(float, float)
The localtion of the legend, see matplotlib.pyplot.legend to view details. (The legend only works when passing label to each iteration).
legend_col int The number of columns of the legend, see matplotlib.pyplot.legend to view details. (The legend only works when passing label to each iteration).
fig object A matplotlib figure instance. If not given, would use plt.gcf() for instead.
ax object A matplotlib subplot instance. If not given, would use plt.gca() for instead.

Examples

Example
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
import matplotlib.pyplot as plt
import mdnc

@mdnc.utils.draw.setFigure(style='Solarize_Light2', font_size=14)
def test_training_records():
    def func_gen_batch():
        size = 100
        x = np.arange(start=0, stop=size)
        for i in range(3):
            begin = 1 + 99.0 * np.random.rand()
            end = 2 + 10 * np.random.rand()
            v = begin * np.exp((np.square((x - size) / size) - 1.0) * end)
            yield x, v, {'label': r'$x_{' + str(i + 1) + r'}$'}

    def func_gen_epoch():
        size = 10
        x = np.arange(start=0, stop=size)
        for i in range(3):
            begin = 1 + 99.0 * np.random.rand()
            end = 2 + 10 * np.random.rand()
            v = begin * np.exp((np.square((x - size) / size) - 1.0) * end)
            val_v = begin * np.exp((np.square((x - size) / size) - 1.0) * (end - 1))
            data = np.stack([x, v, x, val_v], axis=0)
            yield data, {'label': r'$x_{' + str(i + 1) + r'}$'}

    mdnc.utils.draw.plot_training_records(func_gen_batch(), y_log=True, x_mark_num=10,
                                          xlabel='Step', ylabel=r'Batch $\mathcal{L}$')
    plt.show()
    mdnc.utils.draw.plot_training_records(func_gen_epoch(), y_log=True, x_mark_num=10,
                                          xlabel='Step', ylabel=r'Epoch $\mathcal{L}$')
    plt.show()

test_training_records()


Last update: March 14, 2021

Comments