获取XGBoost分类器受过训练的BigQuery的列的名称

发布于 2025-01-22 05:31:56 字数 1087 浏览 2 评论 0原文

我正在大查询中培训XGBoost分类器。该模型经过训练,然后将BST(已保存的模型)文件导入到用于绘图的Python笔记本上。我想绘制模型中存在的树木,以了解如何预测。 当我绘制模型时,我将获得下面给出的结果:

“模型可视化”

我这样做:

import xgboost as xgb
bst = xgb.Booster(model_file='model.bst')
fig, ax = plt.subplots(figsize=(30, 30))
xgb.plot_tree(bst, num_trees=4, ax=ax)
plt.show()

我已经知道,列的名称像f182一样被掩盖了182nd的功能,即模型是受过训练。我想为这些树创建一个映射,并使用用于训练模型的实际列名称。用于训练模型的查询如下:

CREATE OR REPLACE MODEL `d1.boost_clf1`
OPTIONS(
    MODEL_TYPE='BOOSTED_TREE_CLASSIFIER',
    INPUT_LABEL_COLS=['y'],
    DATA_SPLIT_METHOD='CUSTOM',
    DATA_SPLIT_COL='isTrain',
    AUTO_CLASS_WEIGHTS=TRUE,
    EARLY_STOP=TRUE,
    L2_REG = 0.3,
    ENABLE_GLOBAL_EXPLAIN = TRUE
) AS
SELECT
    * except(isTrain, x1,x2,x3_timestamp,x4_timestamp, y)
    ,isTrain = 1 as isTrain
FROM d1.t1_preprocessed;

我尝试打印bst.feature_names,但这没有打印任何东西。

寻找一种用实际列名称绘制XGBoost树木树的方法的任何帮助,都将受到高度赞赏。谢谢!

I am training XGBoost Classifier on Big Query. The model is trained fine and then the bst (saved model) file is imported to a python notebook for plotting. I want to plot the trees present in the model to get an idea of how it is predicted.
When I plot the model, I get the results that are given below:

model visualization

I am doing it like this:

import xgboost as xgb
bst = xgb.Booster(model_file='model.bst')
fig, ax = plt.subplots(figsize=(30, 30))
xgb.plot_tree(bst, num_trees=4, ax=ax)
plt.show()

I have come to know that the column names are masked like f182 stands for the 182nd feature that the model was trained on. I would like to create a mapping for these trees, with the actual column names that were used for training the model. The query used to train the model is given below:

CREATE OR REPLACE MODEL `d1.boost_clf1`
OPTIONS(
    MODEL_TYPE='BOOSTED_TREE_CLASSIFIER',
    INPUT_LABEL_COLS=['y'],
    DATA_SPLIT_METHOD='CUSTOM',
    DATA_SPLIT_COL='isTrain',
    AUTO_CLASS_WEIGHTS=TRUE,
    EARLY_STOP=TRUE,
    L2_REG = 0.3,
    ENABLE_GLOBAL_EXPLAIN = TRUE
) AS
SELECT
    * except(isTrain, x1,x2,x3_timestamp,x4_timestamp, y)
    ,isTrain = 1 as isTrain
FROM d1.t1_preprocessed;

I have tried to print bst.feature_names but that doesn't print anything.

Any help in finding a way to plot the trees of XGBoost with actual column names would be highly appreciated. Thanks!

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文