Source code for sklearn.pipeline

"""
The :mod:`sklearn.pipeline` module implements utilities to build a composite
estimator, as a chain of transforms and estimators.
"""
# Author: Edouard Duchesnay
#         Gael Varoquaux
#         Virgile Fritsch
#         Alexandre Gramfort
#         Lars Buitinck
# License: BSD

from collections import defaultdict
from warnings import warn
from abc import ABCMeta, abstractmethod

import numpy as np
from scipy import sparse

from .base import BaseEstimator, TransformerMixin
from .externals.joblib import Parallel, delayed
from .externals import six
from .utils import tosequence
from .utils.metaestimators import if_delegate_has_method

__all__ = ['Pipeline', 'FeatureUnion']


class _BasePipeline(six.with_metaclass(ABCMeta, BaseEstimator)):
    """Handles parameter management for classifiers composed of named steps.
    """

    @abstractmethod
    def __init__(self):
        pass

    def _replace_step(self, steps_attr, name, new_val):
        # assumes `name` is a valid step name
        new_steps = getattr(self, steps_attr)[:]
        for i, (step_name, _) in enumerate(new_steps):
            if step_name == name:
                new_steps[i] = (name, new_val)
                break
        setattr(self, steps_attr, new_steps)

    def _get_params(self, steps_attr, deep=True):
        out = super(_BasePipeline, self).get_params(deep=False)
        if not deep:
            return out
        steps = getattr(self, steps_attr)
        out.update(steps)
        for name, estimator in steps:
            if estimator is None:
                continue
            for key, value in six.iteritems(estimator.get_params(deep=True)):
                out['%s__%s' % (name, key)] = value
        return out

    def _set_params(self, steps_attr, **params):
        # Ensure strict ordering of parameter setting:
        # 1. All steps
        if steps_attr in params:
            setattr(self, steps_attr, params.pop(steps_attr))
        # 2. Step replacement
        step_names, _ = zip(*getattr(self, steps_attr))
        for name in list(six.iterkeys(params)):
            if '__' not in name and name in step_names:
                self._replace_step(steps_attr, name, params.pop(name))
        # 3. Step parameters and other initilisation arguments
        super(_BasePipeline, self).set_params(**params)
        return self

    def _validate_names(self, names):
        if len(set(names)) != len(names):
            raise ValueError('Names provided are not unique: '
                             '{0!r}'.format(list(names)))
        invalid_names = set(names).intersection(self.get_params(deep=False))
        if invalid_names:
            raise ValueError('Step names conflict with constructor arguments: '
                             '{0!r}'.format(sorted(invalid_names)))
        invalid_names = [name for name in names if '__' in name]
        if invalid_names:
            raise ValueError('Step names must not contain __: got '
                             '{0!r}'.format(invalid_names))


class Pipeline(_BasePipeline):
    """Pipeline of transforms with a final estimator.

    Sequentially apply a list of transforms and a final estimator.
    Intermediate steps of the pipeline must be 'transforms', that is, they
    must implement fit and transform methods.
    The final estimator only needs to implement fit.

    The purpose of the pipeline is to assemble several steps that can be
    cross-validated together while setting different parameters.
    For this, it enables setting parameters of the various steps using their
    names and the parameter name separated by a '__', as in the example below.
    A step's estimator may be replaced entirely by setting the parameter
    with its name to another estimator, or a transformer removed by setting
    to None.

    Read more in the :ref:`User Guide <pipeline>`.

    Parameters
    ----------
    steps : list
        List of (name, transform) tuples (implementing fit/transform) that are
        chained, in the order in which they are chained, with the last object
        an estimator.

    Attributes
    ----------
    named_steps : dict
        Read-only attribute to access any step parameter by user given name.
        Keys are step names and values are steps parameters.

    Examples
    --------
    >>> from sklearn import svm
    >>> from sklearn.datasets import samples_generator
    >>> from sklearn.feature_selection import SelectKBest
    >>> from sklearn.feature_selection import f_regression
    >>> from sklearn.pipeline import Pipeline
    >>> # generate some data to play with
    >>> X, y = samples_generator.make_classification(
    ...     n_informative=5, n_redundant=0, random_state=42)
    >>> # ANOVA SVM-C
    >>> anova_filter = SelectKBest(f_regression, k=5)
    >>> clf = svm.SVC(kernel='linear')
    >>> anova_svm = Pipeline([('anova', anova_filter), ('svc', clf)])
    >>> # You can set the parameters using the names issued
    >>> # For instance, fit using a k of 10 in the SelectKBest
    >>> # and a parameter 'C' of the svm
    >>> anova_svm.set_params(anova__k=10, svc__C=.1).fit(X, y)
    ...                                              # doctest: +ELLIPSIS
    Pipeline(steps=[...])
    >>> prediction = anova_svm.predict(X)
    >>> anova_svm.score(X, y)                        # doctest: +ELLIPSIS
    0.829...
    >>> # getting the selected features chosen by anova_filter
    >>> anova_svm.named_steps['anova'].get_support()
    ... # doctest: +NORMALIZE_WHITESPACE
    array([False, False,  True,  True, False, False, True,  True, False,
           True,  False,  True,  True, False, True,  False, True, True,
           False, False], dtype=bool)
    """

    # BaseEstimator interface

    def __init__(self, steps):
        # shallow copy of steps
        self.steps = tosequence(steps)
        self._validate_steps()

[docs] def get_params(self, deep=True): """Get parameters for this estimator. Parameters ---------- deep: boolean, optional If True, will return the parameters for this estimator and contained subobjects that are estimators. Returns ------- params : mapping of string to any Parameter names mapped to their values. """ return self._get_params('steps', deep=deep)
[docs] def set_params(self, **kwargs): """Set the parameters of this estimator. Valid parameter keys can be listed with ``get_params()``. Returns ------- self """ self._set_params('steps', **kwargs) return self
def _validate_steps(self): names, estimators = zip(*self.steps) # validate names self._validate_names(names) # validate estimators transformers = estimators[:-1] estimator = estimators[-1] for t in transformers: if t is None: continue if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not hasattr(t, "transform")): raise TypeError("All intermediate steps should be " "transformers and implement fit and transform." " '%s' (type %s) doesn't" % (t, type(t))) # We allow last estimator to be None as an identity transformation if estimator is not None and not hasattr(estimator, "fit"): raise TypeError("Last step of Pipeline should implement fit. " "'%s' (type %s) doesn't" % (estimator, type(estimator))) @property def _estimator_type(self): return self.steps[-1][1]._estimator_type @property def named_steps(self): return dict(self.steps) @property def _final_estimator(self): return self.steps[-1][1] # Estimator interface def _fit(self, X, y=None, **fit_params): self._validate_steps() fit_params_steps = dict((name, {}) for name, step in self.steps if step is not None) for pname, pval in six.iteritems(fit_params): step, param = pname.split('__', 1) fit_params_steps[step][param] = pval Xt = X for name, transform in self.steps[:-1]: if transform is None: pass elif hasattr(transform, "fit_transform"): Xt = transform.fit_transform(Xt, y, **fit_params_steps[name]) else: Xt = transform.fit(Xt, y, **fit_params_steps[name]) \ .transform(Xt) if self._final_estimator is None: return Xt, {} return Xt, fit_params_steps[self.steps[-1][0]] def fit(self, X, y=None, **fit_params): """Fit the model Fit all the transforms one after the other and transform the data, then fit the transformed data using the final estimator. Parameters ---------- X : iterable Training data. Must fulfill input requirements of first step of the pipeline. y : iterable, default=None Training targets. Must fulfill label requirements for all steps of the pipeline. **fit_params : dict of string -> object Parameters passed to the ``fit`` method of each step, where each parameter name is prefixed such that parameter ``p`` for step ``s`` has key ``s__p``. Returns ------- self : Pipeline This estimator """ Xt, fit_params = self._fit(X, y, **fit_params) if self._final_estimator is not None: self._final_estimator.fit(Xt, y, **fit_params) return self def fit_transform(self, X, y=None, **fit_params): """Fit the model and transform with the final estimator Fits all the transforms one after the other and transforms the data, then uses fit_transform on transformed data with the final estimator. Parameters ---------- X : iterable Training data. Must fulfill input requirements of first step of the pipeline. y : iterable, default=None Training targets. Must fulfill label requirements for all steps of the pipeline. **fit_params : dict of string -> object Parameters passed to the ``fit`` method of each step, where each parameter name is prefixed such that parameter ``p`` for step ``s`` has key ``s__p``. Returns ------- Xt : array-like, shape = [n_samples, n_transformed_features] Transformed samples """ last_step = self._final_estimator Xt, fit_params = self._fit(X, y, **fit_params) if hasattr(last_step, 'fit_transform'): return last_step.fit_transform(Xt, y, **fit_params) elif last_step is None: return Xt else: return last_step.fit(Xt, y, **fit_params).transform(Xt) @if_delegate_has_method(delegate='_final_estimator') def predict(self, X): """Apply transforms to the data, and predict with the final estimator Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. Returns ------- y_pred : array-like """ Xt = X for name, transform in self.steps[:-1]: if transform is not None: Xt = transform.transform(Xt) return self.steps[-1][-1].predict(Xt) @if_delegate_has_method(delegate='_final_estimator') def fit_predict(self, X, y=None, **fit_params): """Applies fit_predict of last step in pipeline after transforms. Applies fit_transforms of a pipeline to the data, followed by the fit_predict method of the final estimator in the pipeline. Valid only if the final estimator implements fit_predict. Parameters ---------- X : iterable Training data. Must fulfill input requirements of first step of the pipeline. y : iterable, default=None Training targets. Must fulfill label requirements for all steps of the pipeline. **fit_params : dict of string -> object Parameters passed to the ``fit`` method of each step, where each parameter name is prefixed such that parameter ``p`` for step ``s`` has key ``s__p``. Returns ------- y_pred : array-like """ Xt, fit_params = self._fit(X, y, **fit_params) return self.steps[-1][-1].fit_predict(Xt, y, **fit_params) @if_delegate_has_method(delegate='_final_estimator') def predict_proba(self, X): """Apply transforms, and predict_proba of the final estimator Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. Returns ------- y_proba : array-like, shape = [n_samples, n_classes] """ Xt = X for name, transform in self.steps[:-1]: if transform is not None: Xt = transform.transform(Xt) return self.steps[-1][-1].predict_proba(Xt) @if_delegate_has_method(delegate='_final_estimator') def decision_function(self, X): """Apply transforms, and decision_function of the final estimator Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. Returns ------- y_score : array-like, shape = [n_samples, n_classes] """ Xt = X for name, transform in self.steps[:-1]: if transform is not None: Xt = transform.transform(Xt) return self.steps[-1][-1].decision_function(Xt) @if_delegate_has_method(delegate='_final_estimator') def predict_log_proba(self, X): """Apply transforms, and predict_log_proba of the final estimator Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. Returns ------- y_score : array-like, shape = [n_samples, n_classes] """ Xt = X for name, transform in self.steps[:-1]: if transform is not None: Xt = transform.transform(Xt) return self.steps[-1][-1].predict_log_proba(Xt) @property def transform(self): """Apply transforms, and transform with the final estimator This also works where final estimator is ``None``: all prior transformations are applied. Parameters ---------- X : iterable Data to transform. Must fulfill input requirements of first step of the pipeline. Returns ------- Xt : array-like, shape = [n_samples, n_transformed_features] """ # _final_estimator is None or has transform, otherwise attribute error if self._final_estimator is not None: self._final_estimator.transform return self._transform def _transform(self, X): Xt = X for name, transform in self.steps: if transform is not None: Xt = transform.transform(Xt) return Xt @property def inverse_transform(self): """Apply inverse transformations in reverse order All estimators in the pipeline must support ``inverse_transform``. Parameters ---------- Xt : array-like, shape = [n_samples, n_transformed_features] Data samples, where ``n_samples`` is the number of samples and ``n_features`` is the number of features. Must fulfill input requirements of last step of pipeline's ``inverse_transform`` method. Returns ------- Xt : array-like, shape = [n_samples, n_features] """ # raise AttributeError if necessary for hasattr behaviour for name, transform in self.steps: if transform is not None: transform.inverse_transform return self._inverse_transform def _inverse_transform(self, X): if hasattr(X, 'ndim') and X.ndim == 1: warn("From version 0.19, a 1d X will not be reshaped in" " pipeline.inverse_transform any more.", FutureWarning) X = X[None, :] Xt = X for name, transform in self.steps[::-1]: if transform is not None: Xt = transform.inverse_transform(Xt) return Xt @if_delegate_has_method(delegate='_final_estimator') def score(self, X, y=None): """Apply transforms, and score with the final estimator Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. y : iterable, default=None Targets used for scoring. Must fulfill label requirements for all steps of the pipeline. Returns ------- score : float """ Xt = X for name, transform in self.steps[:-1]: if transform is not None: Xt = transform.transform(Xt) return self.steps[-1][-1].score(Xt, y) @property def classes_(self): return self.steps[-1][-1].classes_ @property def _pairwise(self): # check if first estimator expects pairwise input return getattr(self.steps[0][1], '_pairwise', False) def _name_estimators(estimators): """Generate names for estimators.""" names = [type(estimator).__name__.lower() for estimator in estimators] namecount = defaultdict(int) for est, name in zip(estimators, names): namecount[name] += 1 for k, v in list(six.iteritems(namecount)): if v == 1: del namecount[k] for i in reversed(range(len(estimators))): name = names[i] if name in namecount: names[i] += "-%d" % namecount[name] namecount[name] -= 1 return list(zip(names, estimators)) def make_pipeline(*steps): """Construct a Pipeline from the given estimators. This is a shorthand for the Pipeline constructor; it does not require, and does not permit, naming the estimators. Instead, their names will be set to the lowercase of their types automatically. Examples -------- >>> from sklearn.naive_bayes import GaussianNB >>> from sklearn.preprocessing import StandardScaler >>> make_pipeline(StandardScaler(), GaussianNB(priors=None)) ... # doctest: +NORMALIZE_WHITESPACE Pipeline(steps=[('standardscaler', StandardScaler(copy=True, with_mean=True, with_std=True)), ('gaussiannb', GaussianNB(priors=None))]) Returns ------- p : Pipeline """ return Pipeline(_name_estimators(steps)) def _fit_one_transformer(transformer, X, y): return transformer.fit(X, y) def _transform_one(transformer, name, weight, X): res = transformer.transform(X) # if we have a weight for this transformer, multiply output if weight is None: return res return res * weight def _fit_transform_one(transformer, name, weight, X, y, **fit_params): if hasattr(transformer, 'fit_transform'): res = transformer.fit_transform(X, y, **fit_params) else: res = transformer.fit(X, y, **fit_params).transform(X) # if we have a weight for this transformer, multiply output if weight is None: return res, transformer return res * weight, transformer class FeatureUnion(_BasePipeline, TransformerMixin): """Concatenates results of multiple transformer objects. This estimator applies a list of transformer objects in parallel to the input data, then concatenates the results. This is useful to combine several feature extraction mechanisms into a single transformer. Parameters of the transformers may be set using its name and the parameter name separated by a '__'. A transformer may be replaced entirely by setting the parameter with its name to another transformer, or removed by setting to ``None``. Read more in the :ref:`User Guide <feature_union>`. Parameters ---------- transformer_list: list of (string, transformer) tuples List of transformer objects to be applied to the data. The first half of each tuple is the name of the transformer. n_jobs: int, optional Number of jobs to run in parallel (default 1). transformer_weights: dict, optional Multiplicative weights for features per transformer. Keys are transformer names, values the weights. """ def __init__(self, transformer_list, n_jobs=1, transformer_weights=None): self.transformer_list = tosequence(transformer_list) self.n_jobs = n_jobs self.transformer_weights = transformer_weights self._validate_transformers() def get_params(self, deep=True): """Get parameters for this estimator. Parameters ---------- deep: boolean, optional If True, will return the parameters for this estimator and contained subobjects that are estimators. Returns ------- params : mapping of string to any Parameter names mapped to their values. """ return self._get_params('transformer_list', deep=deep) def set_params(self, **kwargs): """Set the parameters of this estimator. Valid parameter keys can be listed with ``get_params()``. Returns ------- self """ self._set_params('transformer_list', **kwargs) return self def _validate_transformers(self): names, transformers = zip(*self.transformer_list) # validate names self._validate_names(names) # validate estimators for t in transformers: if t is None: continue if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not hasattr(t, "transform")): raise TypeError("All estimators should implement fit and " "transform. '%s' (type %s) doesn't" % (t, type(t))) def _iter(self): """Generate (name, est, weight) tuples excluding None transformers """ get_weight = (self.transformer_weights or {}).get return ((name, trans, get_weight(name)) for name, trans in self.transformer_list if trans is not None) def get_feature_names(self): """Get feature names from all transformers. Returns ------- feature_names : list of strings Names of the features produced by transform. """ feature_names = [] for name, trans, weight in self._iter(): if not hasattr(trans, 'get_feature_names'): raise AttributeError("Transformer %s (type %s) does not " "provide get_feature_names." % (str(name), type(trans).__name__)) feature_names.extend([name + "__" + f for f in trans.get_feature_names()]) return feature_names def fit(self, X, y=None): """Fit all transformers using X. Parameters ---------- X : iterable or array-like, depending on transformers Input data, used to fit transformers. y : array-like, shape (n_samples, ...), optional Targets for supervised learning. Returns ------- self : FeatureUnion This estimator """ self._validate_transformers() transformers = Parallel(n_jobs=self.n_jobs)( delayed(_fit_one_transformer)(trans, X, y) for _, trans, _ in self._iter()) self._update_transformer_list(transformers) return self def fit_transform(self, X, y=None, **fit_params): """Fit all transformers, transform the data and concatenate results. Parameters ---------- X : iterable or array-like, depending on transformers Input data to be transformed. y : array-like, shape (n_samples, ...), optional Targets for supervised learning. Returns ------- X_t : array-like or sparse matrix, shape (n_samples, sum_n_components) hstack of results of transformers. sum_n_components is the sum of n_components (output dimension) over transformers. """ self._validate_transformers() result = Parallel(n_jobs=self.n_jobs)( delayed(_fit_transform_one)(trans, name, weight, X, y, **fit_params) for name, trans, weight in self._iter()) if not result: # All transformers are None return np.zeros((X.shape[0], 0)) Xs, transformers = zip(*result) self._update_transformer_list(transformers) if any(sparse.issparse(f) for f in Xs): Xs = sparse.hstack(Xs).tocsr() else: Xs = np.hstack(Xs) return Xs def transform(self, X): """Transform X separately by each transformer, concatenate results. Parameters ---------- X : iterable or array-like, depending on transformers Input data to be transformed. Returns ------- X_t : array-like or sparse matrix, shape (n_samples, sum_n_components) hstack of results of transformers. sum_n_components is the sum of n_components (output dimension) over transformers. """ Xs = Parallel(n_jobs=self.n_jobs)( delayed(_transform_one)(trans, name, weight, X) for name, trans, weight in self._iter()) if not Xs: # All transformers are None return np.zeros((X.shape[0], 0)) if any(sparse.issparse(f) for f in Xs): Xs = sparse.hstack(Xs).tocsr() else: Xs = np.hstack(Xs) return Xs def _update_transformer_list(self, transformers): transformers = iter(transformers) self.transformer_list[:] = [ (name, None if old is None else next(transformers)) for name, old in self.transformer_list ] # XXX it would be nice to have a keyword-only n_jobs argument to this function, # but that's not allowed in Python 2.x. def make_union(*transformers): """Construct a FeatureUnion from the given transformers. This is a shorthand for the FeatureUnion constructor; it does not require, and does not permit, naming the transformers. Instead, they will be given names automatically based on their types. It also does not allow weighting. Examples -------- >>> from sklearn.decomposition import PCA, TruncatedSVD >>> make_union(PCA(), TruncatedSVD()) # doctest: +NORMALIZE_WHITESPACE FeatureUnion(n_jobs=1, transformer_list=[('pca', PCA(copy=True, iterated_power='auto', n_components=None, random_state=None, svd_solver='auto', tol=0.0, whiten=False)), ('truncatedsvd', TruncatedSVD(algorithm='randomized', n_components=2, n_iter=5, random_state=None, tol=0.0))], transformer_weights=None) Returns ------- f : FeatureUnion """ return FeatureUnion(_name_estimators(transformers))