Extract values from an array that sum to a certain value pyspark

0

I have a dataframe that has an array with doubles as values. Within the array, 1 or a sum of the numbers equals a certain target value, and I want to extract the values that either equal the value or can be summed to equal the value. I'd like to be able to do this in PySpark.

| Array                  | Target    | NewArray         |
| -----------------------|-----------|------------------|
| [0.0001,2.5,3.0,0.0031]| 0.0032    | [0.0001,0.0031]  |
| [2.5,1.0,0.5,3.0]      | 3.0       | [2.5, 0.5, 3.0]  |
| [1.0,1.0,1.5,1.0]      | 4.5       | [1.0,1.0,1.5,1.0]|
arrays extract pyspark sum
2021-11-23 19:39:03
1

1

You can encapsulate the logic as an udf and create NewArray based on this. I have borrowed the logic for identifying the elements of array summing to a target value from here.


from pyspark.sql.types import ArrayType, DoubleType
from pyspark.sql.functions import udf
from decimal import Decimal

data = [([0.0001,2.5,3.0,0.0031], 0.0032),
([2.5, 1.0, 0.5, 3.0], 3.0),
([1.0, 1.0, 1.5, 1.0], 4.5), 
([], 1.0),
(None, 1.0),
([1.0,2.0], None),]


df = spark.createDataFrame(data, ("Array", "Target", ))


@udf(returnType=ArrayType(DoubleType()))
def find_values_summing_to_target(array, target):
    def subset_sum(numbers, target, partial, result):
        s = sum(partial)
        # check if the partial sum is equals to target
        if s == target: 
            result.extend(partial)
        if s >= target:
            return  # if we reach the number why bother to continue
    
        for i in range(len(numbers)):
            n = numbers[i]
            remaining = numbers[i+1:]
            subset_sum(remaining, target, partial + [n], result)
    result = []
    if array is not None and target is not None:
        array = [Decimal(str(a)) for a in array]
        subset_sum(array, Decimal(str(target)), [], result)
        result = [float(r) for r in result]
    return result

df.withColumn("NewArray", find_values_summing_to_target("Array", "Target")).show(200, False)

Output

+--------------------------+------+--------------------+
|Array                     |Target|NewArray            |
+--------------------------+------+--------------------+
|[1.0E-4, 2.5, 3.0, 0.0031]|0.0032|[1.0E-4, 0.0031]    |
|[2.5, 1.0, 0.5, 3.0]      |3.0   |[2.5, 0.5, 3.0]     |
|[1.0, 1.0, 1.5, 1.0]      |4.5   |[1.0, 1.0, 1.5, 1.0]|
|[]                        |1.0   |[]                  |
|null                      |1.0   |[]                  |
|[1.0, 2.0]                |null  |[]                  |
+--------------------------+------+--------------------+
2021-11-29 17:22:52

Thanks for your help, it's definitely putting me on the right track. However I'm having trouble at this point: if s >= target: return I get an error when left in: TypeError: '>=' not supported between instances of 'int' and 'NoneType'. When I take this out it runs, but it does not return all of the values that sum to the target, only shows when 1 of the values is equal to the target by itself.
Alex Triece

Additionally, the issue could be that the decimals I'm using are much smaller (in the .0031 and .0001 range). I noticed when I substituted the example data with decimals like this it returned empty arrays. Any thoughts on that?
Alex Triece

For the first issue, I think you have None values in target column. For this I will update the answers to return an empty array if this happens.
Nithish

You were absolutely right about that first issue. Changed the na's to 0 and it works fine. However, it doesn't read the smaller decimals. I'm ok with 0's in the target column, so no need to spend too much time on that issue, unless you want to for others' sake.
Alex Triece

The code in the answer is now na or null safe. For the precision I would need an example, I tried for smaller ranges too 6 decimal digits and it still works. An example would help replicate.
Nithish

Just changed the top example to show what I'm looking at, really just the first row. When I plug this in, I get correct results for everything except the top row.
Alex Triece

The problem is due to floating point precision error, in Python 0.0001 + 0.0031 is 0.0031999999999999997 stackoverflow.com/questions/11950819/python-math-is-wrong/…, I have updated the answer to support precision arithmetic to support your usecase.
Nithish

Thanks, that helps. However, it throws an error with the Decimal() function. Is there something that needs to be imported for that to be recognized?
Alex Triece

In other languages

This page is in other languages

Русский
..................................................................................................................
Italiano
..................................................................................................................
Polski
..................................................................................................................
Română
..................................................................................................................
한국어
..................................................................................................................
हिन्दी
..................................................................................................................
Français
..................................................................................................................
Türk
..................................................................................................................
Česk
..................................................................................................................
Português
..................................................................................................................
ไทย
..................................................................................................................
中文
..................................................................................................................
Español
..................................................................................................................
Slovenský
..................................................................................................................