Check whether a value is found within a group in a PySpark dataframe

Let’s say I have the following df

df = spark.createDataFrame([
  ("a", "apple"),
  ("a", "pear"),
  ("b", "pear"),
  ("c", "carrot"),
  ("c", "apple"),
], ["id", "fruit"])

+---+-------+
| id|  fruit|
+---+-------+
|  a|  apple|
|  a|   pear|
|  b|   pear|
|  c| carrot|
|  c|  apple| 
+---+-------+

I now want to create a boolean flag which is TRUE for each id that has at least one column with "pear" in the fruit column fruit.

The desired output would look like that:

+---+-------+------+
| id|  fruit|  flag|
+---+-------+------+
|  a|  apple|  True|
|  a|   pear|  True|
|  b|   pear|  True|
|  c| carrot| False|
|  c|  apple| False|
+---+-------+------+

for pandas I found a solution with groupby().transform() here, but I don’t understand how to translate that into PySpark.

Answer

Use max window function:

df.selectExpr("*", "max(fruit = 'pear') over (partition by id) as flag").show()

+---+------+-----+
| id| fruit| flag|
+---+------+-----+
|  c|carrot|false|
|  c| apple|false|
|  b|  pear| true|
|  a| apple| true|
|  a|  pear| true|
+---+------+-----+

If you need to check multiple fruits, you can use in operator. For instance to check carrot and apple:

df.selectExpr("*", "max(fruit in ('carrot', 'apple')) over (partition by id) as flag").show()
+---+------+-----+
| id| fruit| flag|
+---+------+-----+
|  c|carrot| true|
|  c| apple| true|
|  b|  pear|false|
|  a| apple| true|
|  a|  pear| true|
+---+------+-----+

If you prefer python syntax:

from pyspark.sql.window import Window
import pyspark.sql.functions as f

df.select("*", 
  f.max(
    f.col('fruit').isin(['carrot', 'apple'])
  ).over(Window.partitionBy('id')).alias('flag')
).show()
+---+------+-----+
| id| fruit| flag|
+---+------+-----+
|  c|carrot| true|
|  c| apple| true|
|  b|  pear|false|
|  a| apple| true|
|  a|  pear| true|
+---+------+-----+