Skip to contents

This class provides a unified interface for visualizing machine learning models on both 1D and 2D tasks. It automatically detects the dimensionality and creates appropriate visualizations using ggplot2.

Value

Returns self invisibly.

Returns self invisibly.

Super class

vistool::Visualizer -> VisualizerModel

Public fields

task

(`mlr3::Task`)
Task used to train the model.

learner

(`mlr3::Learner`)
Learner used to train the model.

Methods

Inherited methods


Method new()

Creates a new instance of this [R6][R6::R6Class] class.

Usage

VisualizerModel$new(
  task,
  learner,
  x1_limits = NULL,
  x2_limits = NULL,
  padding = 0,
  n_points = 100L
)

Arguments

task

([mlr3::Task])
The task to train the model on.

learner

(`mlr3::Learner`)
The learner to train the model with.

x1_limits

(`numeric(2)`)
Limits for the first feature axis. For 1D tasks, this controls the x-axis range. If NULL, will be determined from task data.

x2_limits

(`numeric(2)`)
Limits for the second feature axis (2D tasks only). Ignored for 1D tasks. If NULL, will be determined from task data.

padding

(`numeric(1)`)
A margin that is added to x1limits and x2limits. The x1 margin is calculated by `max(x1lmits) - min(x1limits) * padding`.

padding

(`numeric(1)`)
A margin that is added to x1limits and x2limits. The x1 margin is calculated by `max(x1lmits) - min(x1limits) * padding`.

n_points

(`integer(1)`)
The number of generated point per dimension. Note that a grid of `npoints^2` values is generated and evaluated by `objective$eval(x)` to plot the surface.

n_points

(`integer(1)`)
The number of generated point per dimension. Note that a grid of `npoints^2` values is generated and evaluated by `objective$eval(x)` to plot the surface.


Method add_training_data()

Adds the training data to the plot.

Usage

VisualizerModel$add_training_data(
  color = "auto",
  size = NULL,
  shape = 19,
  alpha = NULL,
  show_labels = FALSE,
  label_size = NULL
)

Arguments

color

(`character(1)` or named `character`)
Color of the points. For classification tasks: - `character(1)`: A single color for all points (e.g., `"blue"`) or `"auto"` for automatic color assignment. - `named character`: A vector mapping class labels to colors (e.g., `c(pos = "red", neg = "blue")`). For regression tasks, only single colors are supported. Default is `"auto"`.

size

(`numeric(1)`)
Size of the points. If NULL, uses theme$point_size. Default is NULL.

shape

(`numeric(1)` or named `numeric`)
Shape of the points. For classification tasks, can be a named vector mapping class labels to shapes. Default is 19 (filled circle).

alpha

(`numeric(1)`)
Alpha transparency of the points. If NULL, uses theme$alpha. Default is NULL.

show_labels

(`logical(1)`)
Whether to show data point labels. Default is FALSE.

label_size

(`numeric(1)`)
Size of data point labels. If NULL, defaults to smaller text.


Method add_boundary()

Adds boundary lines/contours to the plot.

Usage

VisualizerModel$add_boundary(
  values = NULL,
  color = "black",
  linetype = "dashed",
  linewidth = NULL,
  alpha = NULL
)

Arguments

values

(`numeric()`)
Values at which to draw boundaries. For 1D: horizontal lines (y-values). For 2D: contour lines (z-values). If NULL, uses sensible defaults based on prediction type.

color

(`character(1)`)
Color of the boundary lines. Default is "black".

linetype

(`character(1)`)
Line type for boundaries. For 1D: ggplot2 linetypes. For 2D: contour line types. Default is "dashed".

linewidth

(`numeric(1)`)
Width of boundary lines. If NULL, uses theme$line_width.

alpha

(`numeric(1)`)
Alpha transparency of boundary lines. If NULL, uses theme$alpha.


Method plot()

Create and return the ggplot2 plot with model-specific layers.

Usage

VisualizerModel$plot(...)

Arguments

...

Additional arguments passed to the parent plot method.

Returns

A ggplot2 object.


Method clone()

The objects of this class are cloneable with this method.

Usage

VisualizerModel$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.