PySpark DataFrame: find array column’s index that’s closest to integer column’s value

I have a PySpark DataFrame that has an array column type and an integer column type. I want to figure out which array position that integer column is closest to in terms of value. See below:

df = spark.createDataFrame(
[
    (1, [5, 20, 100, 250], 2),  
    (2, [16, 53, 120, 180], 168),
    (3, [100, 200, 1000, 2500], 3500),
],
["id", "array_col", "int_col"]  
)

I want to create a new column that sees which array index in array_col that the int_col’s value is closest to, producing a new df like this:

| ID      | array_col              | int_col | closest_index |
| 1       | [5, 20, 100, 250]      | 2       | 0             |
| 2       | [16, 53, 120, 180]     | 168     | 3             |
| 3       | [100, 200, 1000, 2500] | 3501    | 3             |

I’ve tried doing something like this:

def find_nearest(value):
    res = bin_array[np.newaxis, :] - value.values[:, np.newaxis]
    ret_vals = [bin_array[np.argmin(np.abs(i))] for i in res]
    return pd.Series(ret_vals)

And then from there, using the array_position function to locate the index location, but no luck on a DataFrame. Any help would be much appreciated!

Answer

You could define a UDF that finds the nearest index and then use it for each row.

Here’s an example:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F


def find_nearest_index(array, value):
    return min(range(len(array)), key=lambda i: abs(array[i] - value))


if __name__ == "__main__":
    spark = SparkSession.builder.master("local").appName("Test").getOrCreate()
    df = spark.createDataFrame(
        [
            (1, [5, 20, 100, 250], 2),
            (2, [16, 53, 120, 180], 168),
            (3, [100, 200, 1000, 2500], 3500),
        ],
        ["id", "array_col", "int_col"],
    )
    nearest_index_udf = F.udf(lambda x, y: find_nearest_index(x, y))
    df = df.withColumn(
        "Nearest Index", nearest_index_udf(F.col("array_col"), F.col("int_col"))
    )

Gives result:

+---+--------------------+-------+-------------+
| id|           array_col|int_col|Nearest Index|
+---+--------------------+-------+-------------+
|  1|   [5, 20, 100, 250]|      2|            0|
|  2|  [16, 53, 120, 180]|    168|            3|
|  3|[100, 200, 1000, ...|   3500|            3|
+---+--------------------+-------+-------------+