Source code for imblearn.utils.validation

"""Utilities for input validation"""

# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: MIT
import warnings
from collections import Counter
from numbers import Real, Integral

import numpy as np

from sklearn.neighbors.base import KNeighborsMixin
from sklearn.neighbors import NearestNeighbors
from sklearn.externals import six, joblib
from sklearn.utils import deprecated
from sklearn.utils.multiclass import type_of_target

from ..exceptions import raise_isinstance_error

SAMPLING_KIND = ('over-sampling', 'under-sampling', 'clean-sampling',
                 'ensemble')
TARGET_KIND = ('binary', 'multiclass')


[docs]def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): """Check the objects is consistent to be a NN. Several methods in imblearn relies on NN. Until version 0.4, these objects can be passed at initialisation as an integer or a KNeighborsMixin. After only KNeighborsMixin will be accepted. This utility allows for type checking and raise if the type is wrong. Parameters ---------- nn_name : str, The name associated to the object to raise an error if needed. nn_object : int or KNeighborsMixin, The object to be checked additional_neighbor : int, optional (default=0) Sometimes, some algorithm need an additional neighbors. Returns ------- nn_object : KNeighborsMixin The k-NN object. """ if isinstance(nn_object, Integral): return NearestNeighbors(n_neighbors=nn_object + additional_neighbor) elif isinstance(nn_object, KNeighborsMixin): return nn_object else: raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object)
def check_target_type(y): """Check the target types to be conform to the current samplers. The current samplers should be compatible with ``'binary'`` and ``'multiclass'`` targets only. Parameters ---------- y : ndarray, The array containing the target Returns ------- y : ndarray, The returned target. """ if type_of_target(y) not in TARGET_KIND: # FIXME: perfectly we should raise an error but the sklearn API does # not allow for it warnings.warn("'y' should be of types {} only. Got {} instead.".format( TARGET_KIND, type_of_target(y))) return y
[docs]def hash_X_y(X, y, n_samples=1000): """Compute hash of the input arrays. Parameters ---------- X : ndarray, shape (n_samples, n_features) The ``X`` array. y : ndarray, shape (n_samples) Returns ------- X_hash: str Hash identifier of the ``X`` matrix. y_hash: str Hash identifier of the ``y`` matrix. """ rng = np.random.RandomState(0) raw_idx = rng.randint(X.shape[0], size=n_samples) col_idx = rng.randint(X.shape[1], size=n_samples) return joblib.hash(X[raw_idx, col_idx]), joblib.hash(y[raw_idx])
def _ratio_all(y, sampling_type): """Returns ratio by targeting all classes.""" target_stats = Counter(y) if sampling_type == 'over-sampling': n_sample_majority = max(target_stats.values()) ratio = {key: n_sample_majority - value for (key, value) in target_stats.items()} elif (sampling_type == 'under-sampling' or sampling_type == 'clean-sampling'): n_sample_minority = min(target_stats.values()) ratio = {key: n_sample_minority for key in target_stats.keys()} return ratio def _ratio_majority(y, sampling_type): """Returns ratio by targeting the majority class only.""" if sampling_type == 'over-sampling': raise ValueError("'ratio'='majority' cannot be used with" " over-sampler.") elif (sampling_type == 'under-sampling' or sampling_type == 'clean-sampling'): target_stats = Counter(y) class_majority = max(target_stats, key=target_stats.get) n_sample_minority = min(target_stats.values()) ratio = {key: n_sample_minority for key in target_stats.keys() if key == class_majority} return ratio def _ratio_not_minority(y, sampling_type): """Returns ratio by targeting all classes but not the minority.""" target_stats = Counter(y) if sampling_type == 'over-sampling': n_sample_majority = max(target_stats.values()) class_minority = min(target_stats, key=target_stats.get) ratio = {key: n_sample_majority - value for (key, value) in target_stats.items() if key != class_minority} elif (sampling_type == 'under-sampling' or sampling_type == 'clean-sampling'): n_sample_minority = min(target_stats.values()) class_minority = min(target_stats, key=target_stats.get) ratio = {key: n_sample_minority for key in target_stats.keys() if key != class_minority} return ratio def _ratio_minority(y, sampling_type): """Returns ratio by targeting the minority class only.""" target_stats = Counter(y) if sampling_type == 'over-sampling': n_sample_majority = max(target_stats.values()) class_minority = min(target_stats, key=target_stats.get) ratio = {key: n_sample_majority - value for (key, value) in target_stats.items() if key == class_minority} elif (sampling_type == 'under-sampling' or sampling_type == 'clean-sampling'): raise ValueError("'ratio'='minority' cannot be used with" " under-sampler and clean-sampler.") return ratio def _ratio_auto(y, sampling_type): """Returns ratio auto for over-sampling and not-minority for under-sampling.""" if sampling_type == 'over-sampling': return _ratio_all(y, sampling_type) elif (sampling_type == 'under-sampling' or sampling_type == 'clean-sampling'): return _ratio_not_minority(y, sampling_type) def _ratio_dict(ratio, y, sampling_type): """Returns ratio by converting the dictionary depending of the sampling.""" target_stats = Counter(y) # check that all keys in ratio are also in y set_diff_ratio_target = set(ratio.keys()) - set(target_stats.keys()) if len(set_diff_ratio_target) > 0: raise ValueError("The {} target class is/are not present in the" " data.".format(set_diff_ratio_target)) ratio_ = {} if sampling_type == 'over-sampling': n_samples_majority = max(target_stats.values()) class_majority = max(target_stats, key=target_stats.get) for class_sample, n_samples in ratio.items(): if n_samples < target_stats[class_sample]: raise ValueError("With over-sampling methods, the number" " of samples in a class should be greater" " or equal to the original number of samples." " Originally, there is {} samples and {}" " samples are asked.".format( target_stats[class_sample], n_samples)) if n_samples > n_samples_majority: warnings.warn("After over-sampling, the number of samples ({})" " in class {} will be larger than the number of" " samples in the majority class (class #{} ->" " {})".format(n_samples, class_sample, class_majority, n_samples_majority)) ratio_[class_sample] = n_samples - target_stats[class_sample] elif sampling_type == 'under-sampling': for class_sample, n_samples in ratio.items(): if n_samples > target_stats[class_sample]: raise ValueError("With under-sampling methods, the number of" " samples in a class should be less or equal" " to the original number of samples." " Originally, there is {} samples and {}" " samples are asked.".format( target_stats[class_sample], n_samples)) ratio_[class_sample] = n_samples elif sampling_type == 'clean-sampling': # clean-sampling can be more permissive since those samplers do not # use samples for class_sample, n_samples in ratio.items(): ratio_[class_sample] = n_samples return ratio_ @deprecated("Use a float for 'ratio' is deprecated from version 0.2." " The support will be removed in 0.4. Use a dict, str," " or a callable instead.") def _ratio_float(ratio, y, sampling_type): """TODO: Deprecated in 0.2. Remove in 0.4.""" target_stats = Counter(y) if sampling_type == 'over-sampling': n_sample_majority = max(target_stats.values()) class_majority = max(target_stats, key=target_stats.get) ratio = {key: int(n_sample_majority * ratio - value) for (key, value) in target_stats.items() if key != class_majority} elif (sampling_type == 'under-sampling' or sampling_type == 'clean-sampling'): n_sample_minority = min(target_stats.values()) class_minority = min(target_stats, key=target_stats.get) ratio = {key: int(n_sample_minority / ratio) for (key, value) in target_stats.items() if key != class_minority} return ratio
[docs]def check_ratio(ratio, y, sampling_type): """Ratio validation for samplers. Checks ratio for consistent type and return a dictionary containing each targeted class with its corresponding number of pixel. Parameters ---------- ratio : str, dict or callable, Ratio to use for resampling the data set. - If ``str``, has to be one of: (i) ``'minority'``: resample the minority class; (ii) ``'majority'``: resample the majority class, (iii) ``'not minority'``: resample all classes apart of the minority class, (iv) ``'all'``: resample all classes, and (v) ``'auto'``: correspond to ``'all'`` with for over-sampling methods and ``'not minority'`` for under-sampling methods. The classes targeted will be over-sampled or under-sampled to achieve an equal number of sample with the majority or minority class. - If ``dict``, the keys correspond to the targeted classes. The values correspond to the desired number of samples. - If callable, function taking ``y`` and returns a ``dict``. The keys correspond to the targeted classes. The values correspond to the desired number of samples. y : ndarray, shape (n_samples,) The target array. sampling_type : str, The type of sampling. Can be either ``'over-sampling'`` or ``'under-sampling'``. Returns ------- ratio_converted : dict, The converted and validated ratio. Returns a dictionary with the key being the class target and the value being the desired number of samples. """ if sampling_type not in SAMPLING_KIND: raise ValueError("'sampling_type' should be one of {}. Got '{}'" " instead.".format(SAMPLING_KIND, sampling_type)) if np.unique(y).size <= 1: raise ValueError("The target 'y' needs to have more than 1 class." " Got {} class instead".format(np.unique(y).size)) if sampling_type == 'ensemble': return ratio if isinstance(ratio, six.string_types): if ratio not in RATIO_KIND.keys(): raise ValueError("When 'ratio' is a string, it needs to be one of" " {}. Got '{}' instead.".format(RATIO_KIND, ratio)) return RATIO_KIND[ratio](y, sampling_type) elif isinstance(ratio, dict): return _ratio_dict(ratio, y, sampling_type) elif isinstance(ratio, Real): if ratio <= 0 or ratio > 1: raise ValueError("When 'ratio' is a float, it should in the range" " (0, 1]. Got {} instead.".format(ratio)) return _ratio_float(ratio, y, sampling_type) elif callable(ratio): ratio_ = ratio(y) return _ratio_dict(ratio_, y, sampling_type)
RATIO_KIND = {'minority': _ratio_minority, 'majority': _ratio_majority, 'not minority': _ratio_not_minority, 'all': _ratio_all, 'auto': _ratio_auto}