import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
X, y = make_classification()
# Split the data.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.05, shuffle=True, random_state=0)
# Make things for the cross validation.
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
param_grid = {'max_depth': np.arange(3, 8)}
model = RandomForestClassifier(random_state=1)
# Create and train the cross validation.
clf = GridSearchCV(model, param_grid,
scoring='f1_weighted',
cv=cv, verbose=3)
clf.fit(X_train, y_train)
You're scaling the data, but tree-based methods like random forests do not need this step.
You are doing your own tuning loop, instead of using sklearn.model_selection.GridSearchCV. This is fine, but it can get quite fiddly (imagine wanting to step over another hyperparameter).
If you use GridSearchCV you don't need to do your own cross validation.
You're using accuracy for evaluation, which is usually not a great evaluation metric for multi-class classification. Weighted F1 is better.
If you're doing cross validation, you need to put the scaler in the CV loop (e.g. using a pipeline) because otherwise the scaler has seen the validation data... but you don't need a scaler for this learning algorithm so this point is moot.
I would probably do something like this:
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
X, y = make_classification()
# Split the data.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.05, shuffle=True, random_state=0)
# Make things for the cross validation.
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
param_grid = {'max_depth': np.arange(3, 8)}
model = RandomForestClassifier(random_state=1)
# Create and train the cross validation.
clf = GridSearchCV(model, param_grid,
scoring='f1_weighted',
cv=cv, verbose=3)
clf.fit(X_train, y_train)
Take a look at clf.cv_results_ for the scores etc, which you can plot if you want. By default GridSearchCV trains a final model on the best hyperparameters, so you can make predictions with clf.
Almost forgot... you asked about improving the model :) Here are some ideas:
The above will help you tune on more hyperparameters (eg max_features, n_estimators, and min_samples_leaf). But don't get too carried away with hyperparameter tuning.
You could try transforming some features (columns in X), or adding new ones.
Look for more data, eg more rows, higher quality labels, etc.
Address any issues with class imbalance.
Try a more sophisticated algorithm, like gradient boosted trees (there are models in sklearn, or take a look at xgboost).
发布评论
评论(1)
这对我来说很突出:
sklearn.model_selection.GridSearchCV
。这很好,但它可能会变得非常繁琐(想象一下想要跨过另一个超参数)。我可能会做这样的事情:
看看
clf.cv_results_
的分数等,如果你愿意的话你可以绘制出来。默认情况下,GridSearchCV
会根据最佳超参数训练最终模型,因此您可以使用clf
进行预测。差点忘了...您询问了有关改进模型的问题:)这里有一些想法:
max_features
、n_estimators
和min_samples_leaf
)。但不要对超参数调整太过于得意忘形。X
中的列),或添加新功能。sklearn
中有模型,或者查看xgboost
)。Here's what stands out to me:
sklearn.model_selection.GridSearchCV
. This is fine, but it can get quite fiddly (imagine wanting to step over another hyperparameter).GridSearchCV
you don't need to do your own cross validation.I would probably do something like this:
Take a look at
clf.cv_results_
for the scores etc, which you can plot if you want. By defaultGridSearchCV
trains a final model on the best hyperparameters, so you can make predictions withclf
.Almost forgot... you asked about improving the model :) Here are some ideas:
max_features
,n_estimators
, andmin_samples_leaf
). But don't get too carried away with hyperparameter tuning.X
), or adding new ones.sklearn
, or take a look atxgboost
).