Source code for sklearn_nominal.sklearn.tree_base

from math import floor

import numpy as np

from sklearn_nominal.backend.core import ColumnType, Dataset
from sklearn_nominal.tree.pruning import PruneCriteria

from .. import shared, tree


[docs] class BaseTree: """Base class for decision trees with native nominal attribute support. This class serves as a shared foundation for both classification and regression trees, managing the hyperparameters that control tree growth, node splitting, and pruning strategies. Architectural Context --------------------- `BaseTree` coordinates the high-level tree-building process. It acts as a configuration hub that translates scikit-learn compatible parameters into the specialized components required by the internal backend trainers. Specifically, it handles: 1. **Splitter Construction**: Mapping numerical and nominal features to appropriate `ColumnError` scorers. 2. **Pruning Configuration**: Translating user-defined constraints (like `max_depth` or `min_samples_leaf`) into a `PruneCriteria` object. 3. **Attribute Penalization**: Implementing strategies like Gain Ratio to compensate for the bias of many-valued nominal attributes. Examples -------- >>> from sklearn_nominal.sklearn.tree_base import BaseTree >>> class MyTree(BaseTree): ... def __init__(self, **kwargs): ... super().__init__(**kwargs) ... def fit(self, X, y): ... # implementation that uses self.build_splitter() ... # and self.build_prune_criteria() ... pass Parameters ---------- criterion : str, default="" The function to measure the quality of a split. Supported criteria are "gini" for Gini impurity, and "entropy" or "gain_ratio" for Shannon information gain. splitter : str or int, default="best" The strategy used to choose the split at each numeric node. - "best": evaluates all possible split points. - int: limits the maximum number of splits to consider per node. max_depth : int, optional The maximum depth of the tree. If None, nodes are expanded until all leaves are pure or contain fewer than `min_samples_split` samples. min_samples_split : int or float, default=2 The minimum number of samples required to split an internal node. - If int, treated as a constant count. - If float, treated as a fraction: `ceil(min_samples_split * n_samples)`. min_samples_leaf : int or float, default=1 The minimum number of samples required to be at a leaf node. - If int, treated as a constant count. - If float, treated as a fraction: `ceil(min_samples_leaf * n_samples)`. min_error_decrease : float, default=0.0 A node will be split if the split induces a decrease of the error greater than or equal to this value. Notes ----- The weighted error decrease is calculated as: ``Δerror = error - Σ_i (N_i / N) * error_i`` where ``N`` is the total samples, ``N_i`` is the samples in branch ``i``, and ``error_i`` is the error in that branch. References ---------- .. [1] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification and Regression Trees", Wadsworth, Belmont, CA, 1984. .. [2] T. Hastie, R. Tibshirani and J. Friedman. "Elements of Statistical Learning", Springer, 2009. """ def __init__( self, criterion="", splitter="best", max_depth: int | None = None, min_samples_split: int | float = 2, min_samples_leaf=1, min_error_decrease=0.0, attribute_penalization_importance=1.0, nominal_split="multi", ): """Initializes the BaseTree with growth and splitting parameters. Parameters ---------- criterion : str Splitting quality measure. splitter : str or int Numeric node splitting strategy. max_depth : int, optional Maximum tree depth. min_samples_split : int or float Minimum samples per internal node. min_samples_leaf : int or float Minimum samples per leaf node. min_error_decrease : float Minimum required error reduction for splitting. nominal_split : str, default="multi" The strategy used to split nominal attributes. - "multi": creates one branch for each unique value. - "binary": creates a binary (One-vs-Rest) split for each node. """ self.criterion = criterion self.splitter = splitter self.max_depth = max_depth self.min_samples_leaf = min_samples_leaf self.min_samples_split = min_samples_split self.min_error_decrease = min_error_decrease self.nominal_split = nominal_split self.attribute_penalization_importance = attribute_penalization_importance
[docs] def build_attribute_penalizer(self): """Determines the penalization strategy for multi-valued attributes. This is used to implement "gain_ratio", which penalizes nominal attributes with many levels to prevent overfitting. Returns ------- sklearn_nominal.shared.ColumnPenalization The penalization strategy (e.g., `GainRatioPenalization` or `NoPenalization`). """ if self.criterion == "gain_ratio": return shared.DivisionInfoPenalization(self.attribute_penalization_importance) else: return shared.NoPenalization()
[docs] def build_splitter(self, e: shared.TargetError, p: shared.ColumnPenalization): """Constructs specialized split scorers for different column types. This method maps the general `splitter` and `criterion` parameters into concrete `ColumnError` implementations for both Numeric and Nominal features. Detailed Method Explanation --------------------------- For Numeric columns, it creates a `NumericColumnError` which can be limited by `max_evals` if the `splitter` parameter is an integer. For Nominal columns, it creates a `NominalColumnError` (multi-way split) or `BinaryNominalColumnError` (binary split) depending on the `nominal_split` parameter. Parameters ---------- e : sklearn_nominal.shared.TargetError The target error function (e.g., Gini, Entropy, or MSE). p : sklearn_nominal.shared.ColumnPenalization The column-level penalization strategy. Returns ------- dict A dictionary mapping `ColumnType` to its corresponding `ColumnError` scorer. Raises ------ ValueError If the `splitter` value is neither "best" nor an integer, or if `nominal_split` is invalid. """ if self.splitter == "best": max_evals = np.iinfo(np.int64).max elif isinstance(self.splitter, int): max_evals = self.splitter else: raise ValueError(f"Invalid value '{self.splitter}' for splitter; expected integer or 'best'") if self.nominal_split == "multi": nominal_scorer = shared.NominalColumnError(e, p) elif self.nominal_split == "binary": nominal_scorer = shared.BinaryNominalColumnError(e, p) else: raise ValueError(f"Invalid value '{self.nominal_split}' for nominal_split; expected 'multi' or 'binary'") scorers = { ColumnType.Numeric: shared.NumericColumnError(e, p, max_evals=max_evals), ColumnType.Nominal: nominal_scorer, } return scorers
[docs] def build_prune_criteria(self, d: Dataset) -> PruneCriteria: """Translates tree constraints into internal pruning criteria. This method converts user-facing parameters (which can be counts or fractions) into the absolute integer values required by the tree builders. Detailed Method Explanation --------------------------- - **Fractional conversion**: If `min_samples_leaf` or `min_samples_split` are floats, they are multiplied by the total number of samples in dataset `d` and floored. - **Constraint packaging**: All constraints are encapsulated into a `PruneCriteria` object which is passed to the recursive tree trainer. Parameters ---------- d : Dataset The dataset used to calculate relative sample counts. Returns ------- PruneCriteria The consolidated criteria used for pruning. """ min_samples_leaf = self.min_samples_leaf if isinstance(min_samples_leaf, float): min_samples_leaf = int(floor(d.n * min_samples_leaf)) min_samples_split = self.min_samples_split if isinstance(min_samples_split, float): min_samples_split = int(floor(d.n * min_samples_split)) return PruneCriteria( max_height=self.max_depth, min_samples_leaf=min_samples_leaf, min_error_decrease=self.min_error_decrease, min_samples_split=min_samples_split, )
[docs] def pretty_print(self, class_names=None): """Returns a human-readable string representation of the tree. Parameters ---------- class_names : list of str, optional The names of the classes for display. Returns ------- str The pretty-printed tree structure. """ return self.model_.pretty_print(class_names=class_names)
[docs] def export_dot(self, class_names=None, title=""): """Exports the tree as a Graphviz dot string. Parameters ---------- class_names : list of str, optional The names of the classes for display. title : str, default="" The title for the graph. Returns ------- str The tree in Graphviz dot format. """ return tree.export_dot(self.model_, title=title, class_names=class_names)
[docs] def export_dot_file(self, filepath, class_names=None, title=""): """Exports the tree as a Graphviz dot file. Parameters ---------- filepath : str The path to the file to save. class_names : list of str, optional The names of the classes for display. title : str, default="" The title for the graph. """ tree.export_dot_file(self.model_, filepath, title=title, class_names=class_names)
[docs] def export_image(self, filepath, class_names=None, title=""): """Exports the tree as an image file. This requires Graphviz to be installed on the system. Parameters ---------- filepath : str The path to the image file (e.g., "tree.png"). class_names : list of str, optional The names of the classes for display. title : str, default="" The title for the graph. """ tree.export_image(self.model_, filepath, title=title, class_names=class_names)
[docs] def display(self, class_names=None, title=""): """Displays the tree using the default system viewer or notebook output. Parameters ---------- class_names : list of str, optional The names of the classes for display. title : str, default="" The title for the graph. Returns ------- Any The image object for display in interactive environments. """ return tree.display(self.model_, title=title, class_names=class_names)