Pyspark – Increment value based on previous row value

I have a dataframe which I need to fix some parcel dates based on previous date values. Here’s an example:

+-------------------+---------------------+-----------------------+----------+
|account            |contract             |contract_parcel        |  date    |
+-------------------+---------------------+-----------------------+----------+
|              92397|                    1|                      1|2020-12-07|
|              92397|                    1|                      2|      null|
|              92397|                    2|                      1|2020-12-07|
|              92397|                    2|                      2|      null|
|              92397|                    2|                      3|      null|
|              92397|                    2|                      4|      null|
|              92397|                    2|                      5|      null|
|              92397|                    2|                      6|      null|
|              92397|                    3|                      1|2021-01-04|
|              92397|                    3|                      2|2021-02-01|
+-------------------+---------------------+-----------------------+----------+

For each account there are multiple contracts with multiple parcels. For those where the date column is null I need to replicate the previous parcel value but adding a month to it and so on for all parcels. I tried using Window with lag and last functions but I cannot manage to update the date based on the previous value. I only managed to copy it.

I need an output like this one below:

+-------------------+---------------------+-----------------------+----------+
|account            |contract             |contract_parcel        |  date    |
+-------------------+---------------------+-----------------------+----------+
|              92397|                    1|                      1|2020-12-07|
|              92397|                    1|                      2|2021-01-07|
|              92397|                    2|                      1|2020-12-07|
|              92397|                    2|                      2|2021-01-07|
|              92397|                    2|                      3|2021-02-07|
|              92397|                    2|                      4|2021-03-07|
|              92397|                    2|                      5|2021-04-07|
|              92397|                    2|                      6|2021-05-07|
|              92397|                    3|                      1|2021-01-04|
|              92397|                    3|                      2|2021-02-01|
+-------------------+---------------------+-----------------------+----------+

I also tried by iterating through the dataframe but the performance was very poor.

Answer

The main idea that I did was first doing a cumulative sum to get the value of months to increase from the last not null date. After retrieving it, you can pass it into add_months function to replace the null value.

from pyspark.sql import Window
import pyspark.sql.functions as f


group_window = Window.partitionBy('account', 'contract').orderBy('contract_parcel')
add_month_window = Window.partitionBy('account', 'contract', 'group').orderBy('contract_parcel')

cumulative_df = (df
                 .withColumn('group',  f.sum((f.col('date').isNotNull()).cast('int')).over(group_window))  
                 .withColumn('add_month', f.sum(f.col('date').isNull().cast('int')).over(add_month_window)))
+-------+--------+---------------+----------+-----+---------+
|account|contract|contract_parcel|date      |group|add_month|
+-------+--------+---------------+----------+-----+---------+
|92397  |1       |1              |2020-12-07|1    |0        |
|92397  |1       |2              |null      |1    |1        |
|92397  |2       |1              |2020-12-07|1    |0        |
|92397  |2       |2              |null      |1    |1        |
|92397  |2       |3              |null      |1    |2        |
|92397  |2       |4              |null      |1    |3        |
|92397  |2       |5              |null      |1    |4        |
|92397  |2       |6              |null      |1    |5        |
|92397  |3       |1              |2021-01-04|1    |0        |
|92397  |3       |2              |2021-02-01|2    |0        |
+-------+--------+---------------+----------+-----+---------+

replace_df = (cumulative_df
              .withColumn('date', f.first('date').over(add_month_window))
              .withColumn('date', f.expr('add_months(`date`, `add_month`)')))
+-------+--------+---------------+----------+-----+---------+
|account|contract|contract_parcel|date      |group|add_month|
+-------+--------+---------------+----------+-----+---------+
|92397  |1       |1              |2020-12-07|1    |0        |
|92397  |1       |2              |2021-01-07|1    |1        |
|92397  |2       |1              |2020-12-07|1    |0        |
|92397  |2       |2              |2021-01-07|1    |1        |
|92397  |2       |3              |2021-02-07|1    |2        |
|92397  |2       |4              |2021-03-07|1    |3        |
|92397  |2       |5              |2021-04-07|1    |4        |
|92397  |2       |6              |2021-05-07|1    |5        |
|92397  |3       |1              |2021-01-04|1    |0        |
|92397  |3       |2              |2021-02-01|2    |0        |
+-------+--------+---------------+----------+-----+---------+

output_df = replace_df.drop('group', 'add_month')
output_df.show(truncate=False)
+-------+--------+---------------+----------+
|account|contract|contract_parcel|date      |
+-------+--------+---------------+----------+
|92397  |1       |1              |2020-12-07|
|92397  |1       |2              |2021-01-07|
|92397  |2       |1              |2020-12-07|
|92397  |2       |2              |2021-01-07|
|92397  |2       |3              |2021-02-07|
|92397  |2       |4              |2021-03-07|
|92397  |2       |5              |2021-04-07|
|92397  |2       |6              |2021-05-07|
|92397  |3       |1              |2021-01-04|
|92397  |3       |2              |2021-02-01|
+-------+--------+---------------+----------+