Source code for imblearn.utils.estimator_checks

"""Utils to check the samplers and compatibility with scikit-learn"""

# Adapated from scikit-learn
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: MIT

from __future__ import division

import sys
import traceback

from collections import Counter

import numpy as np

from sklearn.datasets import make_classification
from sklearn.utils.estimator_checks import _yield_all_checks \
    as sklearn_yield_all_checks, check_estimator \
    as sklearn_check_estimator, check_parameters_default_constructible
from sklearn.exceptions import NotFittedError
from sklearn.utils.testing import (assert_warns, assert_raises_regex,
                                   assert_true, set_random_state,
                                   assert_equal)

from imblearn.base import SamplerMixin
from imblearn.over_sampling.base import BaseOverSampler
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
from imblearn.ensemble.base import BaseEnsembleSampler


def _yield_sampler_checks(name, Estimator):
    yield check_target_type
    yield check_samplers_one_label
    yield check_samplers_no_fit_error
    yield check_samplers_X_consistancy_sample
    yield check_samplers_fit
    yield check_samplers_fit_sample
    yield check_samplers_ratio_fit_sample


def _yield_all_checks(name, Estimator):
    # make the checks from scikit-learn
    sklearn_yield_all_checks(name, Estimator)
    # trigger our checks if this is a SamplerMixin
    if issubclass(Estimator, SamplerMixin):
        for check in _yield_sampler_checks(name, Estimator):
            yield check
    # FIXME already present in scikit-learn 0.19
    yield check_dont_overwrite_parameters


[docs]def check_estimator(Estimator): """Check if estimator adheres to scikit-learn conventions and imbalanced-learn This estimator will run an extensive test-suite for input validation, shapes, etc. Additional tests samplers if the Estimator inherits from the corresponding mixin from imblearn.base Parameters ---------- Estimator : class Class to check. Estimator is a class object (not an instance). """ name = Estimator.__name__ # test scikit-learn compatibility sklearn_check_estimator(Estimator) check_parameters_default_constructible(name, Estimator) for check in _yield_all_checks(name, Estimator): check(name, Estimator)
def check_target_type(name, Estimator): X = np.random.random((20, 2)) y = np.linspace(0, 1, 20) estimator = Estimator() set_random_state(estimator) assert_warns(UserWarning, estimator.fit, X, y) def check_multiclass_warning(name, Estimator): X = np.random.random((20, 2)) y = np.array([0] * 3 + [1] * 2 + [2] * 15) estimator = Estimator() set_random_state(estimator) assert_warns(UserWarning, estimator.fit, X, y) def multioutput_estimator_convert_y_2d(name, y): # Estimators in mono_output_task_error raise ValueError if y is of 1-D # Convert into a 2-D y for those estimators. if "MultiTask" in name: return np.reshape(y, (-1, 1)) return y def is_public_parameter(attr): return not (attr.startswith('_') or attr.endswith('_')) def check_dont_overwrite_parameters(name, Estimator): # check that fit method only changes or sets private attributes if hasattr(Estimator.__init__, "deprecated_original"): # to not check deprecated classes return rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(20, 3)) y = X[:, 0].astype(np.int) y = multioutput_estimator_convert_y_2d(name, y) estimator = Estimator() if hasattr(estimator, "n_components"): estimator.n_components = 1 if hasattr(estimator, "n_clusters"): estimator.n_clusters = 1 set_random_state(estimator, 1) dict_before_fit = estimator.__dict__.copy() estimator.fit(X, y) dict_after_fit = estimator.__dict__ public_keys_after_fit = [key for key in dict_after_fit.keys() if is_public_parameter(key)] attrs_added_by_fit = [key for key in public_keys_after_fit if key not in dict_before_fit.keys()] # check that fit doesn't add any public attribute assert_true(not attrs_added_by_fit, ('Estimator adds public attribute(s) during' ' the fit method.' ' Estimators are only allowed to add private attributes' ' either started with _ or ended' ' with _ but %s added' % ', '.join(attrs_added_by_fit))) # check that fit doesn't change any public attribute attrs_changed_by_fit = [key for key in public_keys_after_fit if (dict_before_fit[key] is not dict_after_fit[key])] assert_true(not attrs_changed_by_fit, ('Estimator changes public attribute(s) during' ' the fit method. Estimators are only allowed' ' to change attributes started' ' or ended with _, but' ' %s changed' % ', '.join(attrs_changed_by_fit))) def check_samplers_one_label(name, Sampler): error_string_fit = "Sampler can't balance when only one class is present." sampler = Sampler() X = np.random.random((20, 2)) y = np.zeros(20) try: sampler.fit(X, y) except ValueError as e: if 'class' not in repr(e): print(error_string_fit, Sampler, e) traceback.print_exc(file=sys.stdout) raise e else: return except Exception as exc: print(error_string_fit, traceback, exc) traceback.print_exc(file=sys.stdout) raise exc def check_samplers_no_fit_error(name, Sampler): sampler = Sampler() X = np.random.random((20, 2)) y = np.array([1] * 5 + [0] * 15) assert_raises_regex(NotFittedError, "instance is not fitted yet.", sampler.sample, X, y) def check_samplers_X_consistancy_sample(name, Sampler): sampler = Sampler() X = np.random.random((30, 2)) y = np.array([1] * 20 + [0] * 10) sampler.fit(X, y) X_different = np.random.random((40, 2)) y_different = y = np.array([1] * 25 + [0] * 15) assert_raises_regex(RuntimeError, "X and y need to be same array earlier", sampler.sample, X_different, y_different) def check_samplers_fit(name, Sampler): sampler = Sampler() X = np.random.random((30, 2)) y = np.array([1] * 20 + [0] * 10) sampler.fit(X, y) assert_true(hasattr(sampler, 'ratio_')) def check_samplers_fit_sample(name, Sampler): sampler = Sampler(random_state=0) X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) target_stats = Counter(y) X_res, y_res = sampler.fit_sample(X, y) if isinstance(sampler, BaseOverSampler): target_stats_res = Counter(y_res) n_samples = max(target_stats.values()) assert_true(all(value >= n_samples for value in Counter(y_res).values())) elif isinstance(sampler, BaseUnderSampler): n_samples = min(target_stats.values()) assert_true(all(value == n_samples for value in Counter(y_res).values())) elif isinstance(sampler, BaseCleaningSampler): target_stats_res = Counter(y_res) class_minority = min(target_stats, key=target_stats.get) assert_true( all(target_stats[class_sample] > target_stats_res[class_sample] for class_sample in target_stats.keys() if class_sample != class_minority)) elif isinstance(sampler, BaseEnsembleSampler): y_ensemble = y_res[0] n_samples = min(target_stats.values()) assert_true(all(value == n_samples for value in Counter(y_ensemble).values())) def check_samplers_ratio_fit_sample(name, Sampler): # in this test we will force all samplers to not change the class 1 X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) target_stats = Counter(y) sampler = Sampler(random_state=0) if isinstance(sampler, BaseOverSampler): ratio = {2: 498, 0: 498} sampler.set_params(ratio=ratio) X_res, y_res = sampler.fit_sample(X, y) assert_equal(target_stats[1], Counter(y_res)[1]) elif isinstance(sampler, BaseUnderSampler): ratio = {2: 201, 0: 201} sampler.set_params(ratio=ratio) X_res, y_res = sampler.fit_sample(X, y) assert_equal(target_stats[1], Counter(y_res)[1]) elif isinstance(sampler, BaseCleaningSampler): ratio = {2: 201, 0: 201} sampler.set_params(ratio=ratio) X_res, y_res = sampler.fit_sample(X, y) assert_equal(target_stats[1], Counter(y_res)[1]) elif isinstance(sampler, BaseEnsembleSampler): ratio = {2: 201, 0: 201} sampler.set_params(ratio=ratio) X_res, y_res = sampler.fit_sample(X, y) y_ensemble = y_res[0] assert_equal(target_stats[1], Counter(y_ensemble)[1])