如何根据 Pyspark 中聚合函数的条件按计数进行分组?
假设我构建以下示例数据集:
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime
spark = SparkSession.builder\
.config("spark.driver.memory", "10g")\
.config('spark.sql.repl.eagerEval.enabled', True)\ # to display df in pretty HTML
.getOrCreate()
df = spark.createDataFrame(
[
("US", "US_SL_A", datetime(2022, 1, 1), 3.8),
("US", "US_SL_A", datetime(2022, 1, 2), 4.3),
("US", "US_SL_A", datetime(2022, 1, 3), 4.3),
("US", "US_SL_A", datetime(2022, 1, 4), 3.95),
("US", "US_SL_A", datetime(2022, 1, 5), 1.),
("US", "US_SL_B", datetime(2022, 1, 1), 4.3),
("US", "US_SL_B", datetime(2022, 1, 2), 3.8),
("US", "US_SL_B", datetime(2022, 1, 3), 9.),
("US", "US_SL_C", datetime(2022, 1, 1), 1.),
("ES", "ES_SL_A", datetime(2022, 1, 1), 4.2),
("ES", "ES_SL_A", datetime(2022, 1, 2), 1.),
("ES", "ES_SL_B", datetime(2022, 1, 1), 2.),
("FR", "FR_SL_A", datetime(2022, 1, 1), 2.),
],
schema = ("country", "platform", "timestamp", "size")
)
>> df.show()
+-------+--------+-------------------+----+
|country|platform| timestamp|size|
+-------+--------+-------------------+----+
| US| US_SL_A|2022-01-01 00:00:00| 3.8|
| US| US_SL_A|2022-01-02 00:00:00| 4.3|
| US| US_SL_A|2022-01-03 00:00:00| 4.3|
| US| US_SL_A|2022-01-04 00:00:00|3.95|
| US| US_SL_A|2022-01-05 00:00:00| 1.0|
| US| US_SL_B|2022-01-01 00:00:00| 4.3|
| US| US_SL_B|2022-01-02 00:00:00| 3.8|
| US| US_SL_B|2022-01-03 00:00:00| 9.0|
| US| US_SL_C|2022-01-01 00:00:00| 1.0|
| ES| ES_SL_A|2022-01-01 00:00:00| 4.2|
| ES| ES_SL_A|2022-01-02 00:00:00| 1.0|
| ES| ES_SL_B|2022-01-01 00:00:00| 2.0|
| FR| FR_SL_A|2022-01-01 00:00:00| 2.0|
+-------+--------+-------------------+----+
我的目标是检测大小列中的异常值数量,但之前按国家/地区和平台进行分组。为此,我想使用四分位数范围作为标准;也就是说,我想计算所有那些值小于分位数 0.25 的 1.5 倍减去四分位数范围的大小。
我可以通过以下方式获得不同的分位数参数和每组所需的阈值:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.round(1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)), 2)
).alias("1.5xInterquartile"),
F.percentile_approx("size", 0.25).alias("q1"),
F.percentile_approx("size", 0.75).alias("q3"),
)\
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`"))\ # Q1 - 1.5*IQR
.show()
+-------+--------+-----------------+---+---+---------+
|country|platform|1.5xInterquartile| q1| q3|threshold|
+-------+--------+-----------------+---+---+---------+
| US| US_SL_A| 0.75|3.8|4.3| 3.05|
| US| US_SL_B| 7.8|3.8|9.0| -4.0|
| US| US_SL_C| 0.0|1.0|1.0| 1.0|
| ES| ES_SL_A| 4.8|1.0|4.2| -3.8|
| FR| FR_SL_A| 0.0|2.0|2.0| 2.0|
| ES| ES_SL_B| 0.0|2.0|2.0| 2.0|
+-------+--------+-----------------+---+---+---------+
但这并不是我想要得到的。我想要的是,不是按四分位数聚合,而是按每组满足低于异常值阈值条件的行数计数进行聚合。
所需的输出将是这样的:
+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
| US| US_SL_A| 1 |
| US| US_SL_B| 0 |
| US| US_SL_C| 0 |
| ES| ES_SL_A| 0 |
| FR| FR_SL_A| 0 |
| ES| ES_SL_B| 0 |
+-------+--------+----------+
这是因为只有 (US, US_SL_A)
组的一个值 (1.) 低于此类组的异常值阈值
这是我实现这一目标的尝试:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.count(
F.when(
F.col("size") < F.percentile_approx("size", 0.25) - 1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)),
True
)
)
).alias("n_outliers"),
)
但我得到一个错误,其中指出:
AnalysisException: It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;
Aggregate [country#0, platform#1], [country#0, platform#1, count(CASE WHEN (size#3 < (percentile_approx(size#3, 0.25, 10000, 0, 0) - ((percentile_approx(size#3, 0.75, 10000, 0, 0) - percentile_approx(size#3, 0.25, 10000, 0, 0)) * 1.5))) THEN true END) AS n_outliers#732L]
+- LogicalRDD [country#0, platform#1, timestamp#2, size#3], false
Suppose I build the following example dataset:
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime
spark = SparkSession.builder\
.config("spark.driver.memory", "10g")\
.config('spark.sql.repl.eagerEval.enabled', True)\ # to display df in pretty HTML
.getOrCreate()
df = spark.createDataFrame(
[
("US", "US_SL_A", datetime(2022, 1, 1), 3.8),
("US", "US_SL_A", datetime(2022, 1, 2), 4.3),
("US", "US_SL_A", datetime(2022, 1, 3), 4.3),
("US", "US_SL_A", datetime(2022, 1, 4), 3.95),
("US", "US_SL_A", datetime(2022, 1, 5), 1.),
("US", "US_SL_B", datetime(2022, 1, 1), 4.3),
("US", "US_SL_B", datetime(2022, 1, 2), 3.8),
("US", "US_SL_B", datetime(2022, 1, 3), 9.),
("US", "US_SL_C", datetime(2022, 1, 1), 1.),
("ES", "ES_SL_A", datetime(2022, 1, 1), 4.2),
("ES", "ES_SL_A", datetime(2022, 1, 2), 1.),
("ES", "ES_SL_B", datetime(2022, 1, 1), 2.),
("FR", "FR_SL_A", datetime(2022, 1, 1), 2.),
],
schema = ("country", "platform", "timestamp", "size")
)
>> df.show()
+-------+--------+-------------------+----+
|country|platform| timestamp|size|
+-------+--------+-------------------+----+
| US| US_SL_A|2022-01-01 00:00:00| 3.8|
| US| US_SL_A|2022-01-02 00:00:00| 4.3|
| US| US_SL_A|2022-01-03 00:00:00| 4.3|
| US| US_SL_A|2022-01-04 00:00:00|3.95|
| US| US_SL_A|2022-01-05 00:00:00| 1.0|
| US| US_SL_B|2022-01-01 00:00:00| 4.3|
| US| US_SL_B|2022-01-02 00:00:00| 3.8|
| US| US_SL_B|2022-01-03 00:00:00| 9.0|
| US| US_SL_C|2022-01-01 00:00:00| 1.0|
| ES| ES_SL_A|2022-01-01 00:00:00| 4.2|
| ES| ES_SL_A|2022-01-02 00:00:00| 1.0|
| ES| ES_SL_B|2022-01-01 00:00:00| 2.0|
| FR| FR_SL_A|2022-01-01 00:00:00| 2.0|
+-------+--------+-------------------+----+
My goal is to detect the number of outliers in the size column, but previously grouping by country and platform. For this I want to use the interquartile range as a criterion; that is, I want to count all those sizes whose value is less than 1.5 times the quantile 0.25 minus the interquartile range.
I can get the different quantile parameters and desired threshold per group by doing:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.round(1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)), 2)
).alias("1.5xInterquartile"),
F.percentile_approx("size", 0.25).alias("q1"),
F.percentile_approx("size", 0.75).alias("q3"),
)\
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`"))\ # Q1 - 1.5*IQR
.show()
+-------+--------+-----------------+---+---+---------+
|country|platform|1.5xInterquartile| q1| q3|threshold|
+-------+--------+-----------------+---+---+---------+
| US| US_SL_A| 0.75|3.8|4.3| 3.05|
| US| US_SL_B| 7.8|3.8|9.0| -4.0|
| US| US_SL_C| 0.0|1.0|1.0| 1.0|
| ES| ES_SL_A| 4.8|1.0|4.2| -3.8|
| FR| FR_SL_A| 0.0|2.0|2.0| 2.0|
| ES| ES_SL_B| 0.0|2.0|2.0| 2.0|
+-------+--------+-----------------+---+---+---------+
But this is not exactly what I want to get. What I would want is, instead of aggregating by interquartiles, to aggregate by a count of the number of rows per group that satisfy the condition of being below the outlier threshold.
Desired output would be something like this:
+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
| US| US_SL_A| 1 |
| US| US_SL_B| 0 |
| US| US_SL_C| 0 |
| ES| ES_SL_A| 0 |
| FR| FR_SL_A| 0 |
| ES| ES_SL_B| 0 |
+-------+--------+----------+
This is because only (US, US_SL_A)
group has one value (1.) below the outlier threshold for such a group
Here's my attempt to achieve that:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.count(
F.when(
F.col("size") < F.percentile_approx("size", 0.25) - 1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)),
True
)
)
).alias("n_outliers"),
)
But I get an error, which states:
AnalysisException: It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;
Aggregate [country#0, platform#1], [country#0, platform#1, count(CASE WHEN (size#3 < (percentile_approx(size#3, 0.25, 10000, 0, 0) - ((percentile_approx(size#3, 0.75, 10000, 0, 0) - percentile_approx(size#3, 0.25, 10000, 0, 0)) * 1.5))) THEN true END) AS n_outliers#732L]
+- LogicalRDD [country#0, platform#1, timestamp#2, size#3], false
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(2)
这里的关键是在聚合之前使用窗口函数
The key here is the use of windows functions previous to the aggregation
您的
count
和percentile_approx
都需要聚合,但看起来顶部的agg
并不处理这些。您可以尝试对所有聚合使用窗口函数,这将为每条记录添加 n_outliers 计数。然后,稍后您可以使用
distinct
来仅获取每组的 1 条记录。Your
count
andpercentile_approx
both need aggregation but looks like theagg
on top doesn't take care of those.You can try using window functions for all of aggregations which will add
n_outliers
count for each records. Then, later you can usedistinct
to get only the 1 record per group.