如何根据 Pyspark 中聚合函数的条件按计数进行分组?

发布于 2025-01-13 12:21:57 字数 4410 浏览 0 评论 0原文

假设我构建以下示例数据集:

import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime

spark = SparkSession.builder\
    .config("spark.driver.memory", "10g")\
    .config('spark.sql.repl.eagerEval.enabled', True)\ # to display df in pretty HTML
    .getOrCreate()

df = spark.createDataFrame(
    [
        ("US", "US_SL_A", datetime(2022, 1, 1), 3.8),
        ("US", "US_SL_A", datetime(2022, 1, 2), 4.3),
        ("US", "US_SL_A", datetime(2022, 1, 3), 4.3),
        ("US", "US_SL_A", datetime(2022, 1, 4), 3.95),
        ("US", "US_SL_A", datetime(2022, 1, 5), 1.),
        ("US", "US_SL_B", datetime(2022, 1, 1), 4.3),
        ("US", "US_SL_B", datetime(2022, 1, 2), 3.8),
        ("US", "US_SL_B", datetime(2022, 1, 3), 9.),
        ("US", "US_SL_C", datetime(2022, 1, 1), 1.),
        ("ES", "ES_SL_A", datetime(2022, 1, 1), 4.2),
        ("ES", "ES_SL_A", datetime(2022, 1, 2), 1.),
        ("ES", "ES_SL_B", datetime(2022, 1, 1), 2.),
        ("FR", "FR_SL_A", datetime(2022, 1, 1), 2.),
    ],
    schema = ("country", "platform", "timestamp", "size")
)

>> df.show()
+-------+--------+-------------------+----+
|country|platform|          timestamp|size|
+-------+--------+-------------------+----+
|     US| US_SL_A|2022-01-01 00:00:00| 3.8|
|     US| US_SL_A|2022-01-02 00:00:00| 4.3|
|     US| US_SL_A|2022-01-03 00:00:00| 4.3|
|     US| US_SL_A|2022-01-04 00:00:00|3.95|
|     US| US_SL_A|2022-01-05 00:00:00| 1.0|
|     US| US_SL_B|2022-01-01 00:00:00| 4.3|
|     US| US_SL_B|2022-01-02 00:00:00| 3.8|
|     US| US_SL_B|2022-01-03 00:00:00| 9.0|
|     US| US_SL_C|2022-01-01 00:00:00| 1.0|
|     ES| ES_SL_A|2022-01-01 00:00:00| 4.2|
|     ES| ES_SL_A|2022-01-02 00:00:00| 1.0|
|     ES| ES_SL_B|2022-01-01 00:00:00| 2.0|
|     FR| FR_SL_A|2022-01-01 00:00:00| 2.0|
+-------+--------+-------------------+----+

我的目标是检测大小列中的异常值数量,但之前按国家/地区和平台进行分组。为此,我想使用四分位数范围作为标准;也就是说,我想计算所有那些值小于分位数 0.25 的 1.5 倍减去四分位数范围的大小。

我可以通过以下方式获得不同的分位数参数和每组所需的阈值:

>> df.groupBy(
    ["country", "platform"]
).agg(
    (
        F.round(1.5*(F.percentile_approx("size", 0.75) -  F.percentile_approx("size", 0.25)), 2)
    ).alias("1.5xInterquartile"),
    F.percentile_approx("size", 0.25).alias("q1"),
    F.percentile_approx("size", 0.75).alias("q3"),
)\
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`"))\ # Q1 - 1.5*IQR
.show()
+-------+--------+-----------------+---+---+---------+
|country|platform|1.5xInterquartile| q1| q3|threshold|
+-------+--------+-----------------+---+---+---------+
|     US| US_SL_A|             0.75|3.8|4.3|     3.05|
|     US| US_SL_B|              7.8|3.8|9.0|     -4.0|
|     US| US_SL_C|              0.0|1.0|1.0|      1.0|
|     ES| ES_SL_A|              4.8|1.0|4.2|     -3.8|
|     FR| FR_SL_A|              0.0|2.0|2.0|      2.0|
|     ES| ES_SL_B|              0.0|2.0|2.0|      2.0|
+-------+--------+-----------------+---+---+---------+

但这并不是我想要得到的。我想要的是,不是按四分位数聚合,而是按每组满足低于异常值阈值条件的行数计数进行聚合。

所需的输出将是这样的:

+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
|     US| US_SL_A|    1     |
|     US| US_SL_B|    0     |
|     US| US_SL_C|    0     |
|     ES| ES_SL_A|    0     |
|     FR| FR_SL_A|    0     |
|     ES| ES_SL_B|    0     |
+-------+--------+----------+

这是因为只有 (US, US_SL_A) 组的一个值 (1.) 低于此类组的异常值阈值

这是我实现这一目标的尝试:

>> df.groupBy(
    ["country", "platform"]
).agg(
    (
        F.count(
            F.when(
                F.col("size") < F.percentile_approx("size", 0.25) - 1.5*(F.percentile_approx("size", 0.75) -  F.percentile_approx("size", 0.25)),
                True
            )
        )
    ).alias("n_outliers"),
)

但我得到一个错误,其中指出:

AnalysisException: It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;
Aggregate [country#0, platform#1], [country#0, platform#1, count(CASE WHEN (size#3 < (percentile_approx(size#3, 0.25, 10000, 0, 0) - ((percentile_approx(size#3, 0.75, 10000, 0, 0) - percentile_approx(size#3, 0.25, 10000, 0, 0)) * 1.5))) THEN true END) AS n_outliers#732L]
+- LogicalRDD [country#0, platform#1, timestamp#2, size#3], false

Suppose I build the following example dataset:

import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime

spark = SparkSession.builder\
    .config("spark.driver.memory", "10g")\
    .config('spark.sql.repl.eagerEval.enabled', True)\ # to display df in pretty HTML
    .getOrCreate()

df = spark.createDataFrame(
    [
        ("US", "US_SL_A", datetime(2022, 1, 1), 3.8),
        ("US", "US_SL_A", datetime(2022, 1, 2), 4.3),
        ("US", "US_SL_A", datetime(2022, 1, 3), 4.3),
        ("US", "US_SL_A", datetime(2022, 1, 4), 3.95),
        ("US", "US_SL_A", datetime(2022, 1, 5), 1.),
        ("US", "US_SL_B", datetime(2022, 1, 1), 4.3),
        ("US", "US_SL_B", datetime(2022, 1, 2), 3.8),
        ("US", "US_SL_B", datetime(2022, 1, 3), 9.),
        ("US", "US_SL_C", datetime(2022, 1, 1), 1.),
        ("ES", "ES_SL_A", datetime(2022, 1, 1), 4.2),
        ("ES", "ES_SL_A", datetime(2022, 1, 2), 1.),
        ("ES", "ES_SL_B", datetime(2022, 1, 1), 2.),
        ("FR", "FR_SL_A", datetime(2022, 1, 1), 2.),
    ],
    schema = ("country", "platform", "timestamp", "size")
)

>> df.show()
+-------+--------+-------------------+----+
|country|platform|          timestamp|size|
+-------+--------+-------------------+----+
|     US| US_SL_A|2022-01-01 00:00:00| 3.8|
|     US| US_SL_A|2022-01-02 00:00:00| 4.3|
|     US| US_SL_A|2022-01-03 00:00:00| 4.3|
|     US| US_SL_A|2022-01-04 00:00:00|3.95|
|     US| US_SL_A|2022-01-05 00:00:00| 1.0|
|     US| US_SL_B|2022-01-01 00:00:00| 4.3|
|     US| US_SL_B|2022-01-02 00:00:00| 3.8|
|     US| US_SL_B|2022-01-03 00:00:00| 9.0|
|     US| US_SL_C|2022-01-01 00:00:00| 1.0|
|     ES| ES_SL_A|2022-01-01 00:00:00| 4.2|
|     ES| ES_SL_A|2022-01-02 00:00:00| 1.0|
|     ES| ES_SL_B|2022-01-01 00:00:00| 2.0|
|     FR| FR_SL_A|2022-01-01 00:00:00| 2.0|
+-------+--------+-------------------+----+

My goal is to detect the number of outliers in the size column, but previously grouping by country and platform. For this I want to use the interquartile range as a criterion; that is, I want to count all those sizes whose value is less than 1.5 times the quantile 0.25 minus the interquartile range.

I can get the different quantile parameters and desired threshold per group by doing:

>> df.groupBy(
    ["country", "platform"]
).agg(
    (
        F.round(1.5*(F.percentile_approx("size", 0.75) -  F.percentile_approx("size", 0.25)), 2)
    ).alias("1.5xInterquartile"),
    F.percentile_approx("size", 0.25).alias("q1"),
    F.percentile_approx("size", 0.75).alias("q3"),
)\
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`"))\ # Q1 - 1.5*IQR
.show()
+-------+--------+-----------------+---+---+---------+
|country|platform|1.5xInterquartile| q1| q3|threshold|
+-------+--------+-----------------+---+---+---------+
|     US| US_SL_A|             0.75|3.8|4.3|     3.05|
|     US| US_SL_B|              7.8|3.8|9.0|     -4.0|
|     US| US_SL_C|              0.0|1.0|1.0|      1.0|
|     ES| ES_SL_A|              4.8|1.0|4.2|     -3.8|
|     FR| FR_SL_A|              0.0|2.0|2.0|      2.0|
|     ES| ES_SL_B|              0.0|2.0|2.0|      2.0|
+-------+--------+-----------------+---+---+---------+

But this is not exactly what I want to get. What I would want is, instead of aggregating by interquartiles, to aggregate by a count of the number of rows per group that satisfy the condition of being below the outlier threshold.

Desired output would be something like this:

+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
|     US| US_SL_A|    1     |
|     US| US_SL_B|    0     |
|     US| US_SL_C|    0     |
|     ES| ES_SL_A|    0     |
|     FR| FR_SL_A|    0     |
|     ES| ES_SL_B|    0     |
+-------+--------+----------+

This is because only (US, US_SL_A) group has one value (1.) below the outlier threshold for such a group

Here's my attempt to achieve that:

>> df.groupBy(
    ["country", "platform"]
).agg(
    (
        F.count(
            F.when(
                F.col("size") < F.percentile_approx("size", 0.25) - 1.5*(F.percentile_approx("size", 0.75) -  F.percentile_approx("size", 0.25)),
                True
            )
        )
    ).alias("n_outliers"),
)

But I get an error, which states:

AnalysisException: It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;
Aggregate [country#0, platform#1], [country#0, platform#1, count(CASE WHEN (size#3 < (percentile_approx(size#3, 0.25, 10000, 0, 0) - ((percentile_approx(size#3, 0.75, 10000, 0, 0) - percentile_approx(size#3, 0.25, 10000, 0, 0)) * 1.5))) THEN true END) AS n_outliers#732L]
+- LogicalRDD [country#0, platform#1, timestamp#2, size#3], false

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(2

远昼 2025-01-20 12:21:57

这里的关键是在聚合之前使用窗口函数

import pyspark.sql.window as W

w = W.Window.partitionBy(["country", "platform"])

(df
 .withColumn("1.5xInterquartile", F.round(1.5*(F.percentile_approx("size", 0.75).over(w) -  F.percentile_approx("size", 0.25).over(w)), 2))
 .withColumn("q1",F.percentile_approx("size", 0.25).over(w))
 .withColumn("q3",F.percentile_approx("size", 0.75).over(w))
 .withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`")) # Q1 - 1.5*IQR
 .groupBy(["country", "platform"])
 .agg(F.count(F.when(F.col("size") < F.col("q1") - 1.5*(F.col("q3") - F.col("q1")), 1)).alias("n_outliers"))
 .show()
)

+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
|     ES| ES_SL_A|         0|
|     ES| ES_SL_B|         0|
|     FR| FR_SL_A|         0|
|     US| US_SL_A|         1|
|     US| US_SL_B|         0|
|     US| US_SL_C|         0|
+-------+--------+----------+

The key here is the use of windows functions previous to the aggregation

import pyspark.sql.window as W

w = W.Window.partitionBy(["country", "platform"])

(df
 .withColumn("1.5xInterquartile", F.round(1.5*(F.percentile_approx("size", 0.75).over(w) -  F.percentile_approx("size", 0.25).over(w)), 2))
 .withColumn("q1",F.percentile_approx("size", 0.25).over(w))
 .withColumn("q3",F.percentile_approx("size", 0.75).over(w))
 .withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`")) # Q1 - 1.5*IQR
 .groupBy(["country", "platform"])
 .agg(F.count(F.when(F.col("size") < F.col("q1") - 1.5*(F.col("q3") - F.col("q1")), 1)).alias("n_outliers"))
 .show()
)

+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
|     ES| ES_SL_A|         0|
|     ES| ES_SL_B|         0|
|     FR| FR_SL_A|         0|
|     US| US_SL_A|         1|
|     US| US_SL_B|         0|
|     US| US_SL_C|         0|
+-------+--------+----------+
丑丑阿 2025-01-20 12:21:57

您的 countpercentile_approx 都需要聚合,但看起来顶部的 agg 并不处理这些。

您可以尝试对所有聚合使用窗口函数,这将为每条记录添加 n_outliers 计数。然后,稍后您可以使用 distinct 来仅获取每组的 1 条记录。

w = Window.partitionBy("country", "platform")

df = (df.withColumn('n_outliers', 
         F.count(F.when(
             F.col("size") < (F.percentile_approx("size", 0.25).over(w) - 1.5*(F.percentile_approx("size", 0.75).over(w) -  F.percentile_approx("size", 0.25).over(w))),
             1
         )).over(w))
     .select('country', 'platform', 'n_outliers')
     .distinct())

Your count and percentile_approx both need aggregation but looks like the agg on top doesn't take care of those.

You can try using window functions for all of aggregations which will add n_outliers count for each records. Then, later you can use distinct to get only the 1 record per group.

w = Window.partitionBy("country", "platform")

df = (df.withColumn('n_outliers', 
         F.count(F.when(
             F.col("size") < (F.percentile_approx("size", 0.25).over(w) - 1.5*(F.percentile_approx("size", 0.75).over(w) -  F.percentile_approx("size", 0.25).over(w))),
             1
         )).over(w))
     .select('country', 'platform', 'n_outliers')
     .distinct())
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文