Spark XGBoost4J如何在每个回合中打印出火车和验证损失

发布于 2025-01-17 16:39:54 字数 1342 浏览 1 评论 0 原文

我正在使用 Spark xgboost4j 0.80 和 scala,想知道如何打印每轮的训练和验证损失(例如将 num_round 设置为 100,在训练期间打印 100 轮的训练和验证损失)。我之前在非分布式版本 xgboost 上做过这样的 this

输入图片这里的描述

虽然我不知道如何在 xgboost4j 中实现这一点,但我发现了一个 示例代码,而在我创建评估集并在 XGBoostClassifier 的参数映射中指定它之后,标准输出中没有任何更改。我在网上搜索了很多,但没有找到类似的问题和解决方案。

我的代码是这样的:

val params = Map("eta"       -> 0.1f,
        "lambda"      -> 0.1f,
        "max_depth"   -> 6,
        "objective"   -> "binary:logistic",
        "num_workers" -> 10
        "num_round"   -> 100,
        // "num_early_stopping_rounds" -> 10,
        // "maximize_evaluation_metrics" -> false,
        // "eval_metric" -> "logloss",
        "eval_sets" -> Map("eval" -> eval)
      )
val xgb = new XGBoostClassifier(params).
              setFeaturesCol("features").
              setLabelCol("label")

val model = xgb.fit(train)

提前感谢您的帮助。

I am using spark xgboost4j 0.80 with scala, wonder how to print out the train and validation loss in each round (for example set num_round as 100, print out train and validation loss in 100 rounds during training). I have made it before on the non distributed version xgboost like this:

enter image description here

While I don't figure out how to achieve this in xgboost4j, I found an example code while there's no changes in the stdout after I create the evaluation set and specify it in the XGBoostClassifier's parameters map. I have searched a lot on the Internet while don't find similar questions and solutions.

My code is like:

val params = Map("eta"       -> 0.1f,
        "lambda"      -> 0.1f,
        "max_depth"   -> 6,
        "objective"   -> "binary:logistic",
        "num_workers" -> 10
        "num_round"   -> 100,
        // "num_early_stopping_rounds" -> 10,
        // "maximize_evaluation_metrics" -> false,
        // "eval_metric" -> "logloss",
        "eval_sets" -> Map("eval" -> eval)
      )
val xgb = new XGBoostClassifier(params).
              setFeaturesCol("features").
              setLabelCol("label")

val model = xgb.fit(train)

Thanks in advance for your help.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文