Difference between transform, fit_transform, predict, and fit

Chanchala Gorale
2 min readJun 24, 2024

In machine learning, transform, fit_transform, predict, and fit are methods commonly used in preprocessing, model training, and inference. These methods are often seen in libraries such as scikit-learn. Here's a detailed explanation of each, including how they work in the backend:

1. fit

  • Purpose: The fit method is used to train a model or to calculate the parameters necessary for a transformation.
  • Usage: model.fit(X, y) or transformer.fit(X)
  • Backend Process:
  • For estimators (like linear regression, decision trees, etc.), fit involves finding the optimal model parameters that best map the input features X to the target variable y. This usually involves optimization algorithms like gradient descent.
  • For transformers (like StandardScaler, PCA, etc.), fit involves computing the necessary statistics (e.g., mean and variance for scaling, eigenvalues and eigenvectors for PCA) from the input data X.

2. transform

  • Purpose: The transform method applies the transformation defined by the fitted parameters to new data.
  • Usage: transformed_data = transformer.transform(X)
  • Backend Process:
  • The method uses the parameters calculated during fit to transform the input data X. For example, if using a StandardScaler, it would subtract the mean and divide by the standard deviation computed during fit.

3. fit_transform

  • Purpose: fit_transform is a convenience method that combines fit and transform into a single step.
  • Usage: transformed_data = transformer.fit_transform(X)
  • Backend Process:
  • This method first calls fit to calculate the necessary parameters and then immediately applies transform to the input data X. This is useful for preprocessing pipelines where you want to fit and transform the training data in one go.

4. predict

  • Purpose: The predict method is used to make predictions on new data after the model has been trained.
  • Usage: predictions = model.predict(X)
  • Backend Process:
  • The method uses the parameters learned during fit to predict the target variable for new input data X. For example, in a linear regression model, it would apply the learned weights to the input features to generate predictions.

Detailed Example

Let’s consider a simple workflow involving a standard scaler and a linear regression model:

StandardScaler (transformer)

  • fit: scaler.fit(X_train)
  • Computes the mean and standard deviation of X_train.
  • transform: X_train_scaled = scaler.transform(X_train)
  • Uses the computed mean and standard deviation to scale X_train.
  • fit_transform: X_train_scaled = scaler.fit_transform(X_train)
  • Computes the mean and standard deviation, then scales X_train in one step.

LinearRegression (estimator)

  • fit: model.fit(X_train_scaled, y_train)
  • Learns the weights and bias terms that minimize the cost function (e.g., least squares error) for the training data.
  • predict: predictions = model.predict(X_test_scaled)
  • Applies the learned weights and bias terms to the test data X_test_scaled to generate predictions.

Summary

  • fit: Learns parameters from the data.
  • transform: Applies the learned transformation to the data.
  • fit_transform: Combines fit and transform for efficiency.
  • predict: Uses a trained model to make predictions on new data.

Understanding these methods and their backend processes helps in effectively designing and implementing machine learning pipelines.

--

--

Chanchala Gorale
Chanchala Gorale

Written by Chanchala Gorale

Founder | Product Manager | Software Developer

No responses yet