Home > OS >  Choose from multinomial distribution
Choose from multinomial distribution

Time:10-21

I have a series of values and a probability I want each those values sampled. Is there a PySpark method to sample from that distribution for each row? I know how to hard-code with a random number generator, but I want this method to be flexible for any number of assignment values and probabilities:

assignment_values = ["foo", "buzz", "boo"]
value_probabilities = [0.3, 0.3, 0.4]

Hard-coded method with random number generator:

from pyspark.sql import Row

data = [
    {"person": 1, "company": "5g"},
    {"person": 2, "company": "9s"},
    {"person": 3, "company": "1m"},
    {"person": 4, "company": "3l"},
    {"person": 5, "company": "2k"},
    {"person": 6, "company": "7c"},
    {"person": 7, "company": "3m"},
    {"person": 8, "company": "2p"},
    {"person": 9, "company": "4s"},
    {"person": 10, "company": "8y"},
]
df = spark.createDataFrame(Row(**x) for x in data)

(
    df
    .withColumn("rand", F.rand())
    .withColumn(
        "assignment", 
        F.when(F.col("rand") < F.lit(0.3), "foo")
        .when(F.col("rand") < F.lit(0.6), "buzz")
        .otherwise("boo")
    )
    .show()
)
 ------- ------ ------------------- ---------- 
|company|person|               rand|assignment|
 ------- ------ ------------------- ---------- 
|     5g|     1| 0.8020603266148111|       boo|
|     9s|     2| 0.1297179045352752|       foo|
|     1m|     3|0.05170251723736685|       foo|
|     3l|     4|0.07978240998283603|       foo|
|     2k|     5| 0.5931269297050258|      buzz|
|     7c|     6|0.44673560271164037|      buzz|
|     3m|     7| 0.1398027427612647|       foo|
|     2p|     8| 0.8281404801171598|       boo|
|     4s|     9|0.15568513681001817|       foo|
|     8y|    10| 0.6173220502731542|       boo|
 ------- ------ ------------------- ---------- 

CodePudding user response:

I think randomSplit may serve you. It randomly makes several dataframes out of your one nd puts them all into a list.

df.randomSplit([0.3, 0.3, 0.4])

You can also provide seed to it.

You can join the dfs back together using reduce

from pyspark.sql import functions as F
from functools import reduce

df = spark.createDataFrame(
    [(1, "5g"),
     (2, "9s"),
     (3, "1m"),
     (4, "3l"),
     (5, "2k"),
     (6, "7c"),
     (7, "3m"),
     (8, "2p"),
     (9, "4s"),
     (10, "8y")],
    ['person', 'company'])

assignment_values = ["foo", "buzz", "boo"]
value_probabilities = [0.3, 0.3, 0.4]

dfs = df.randomSplit(value_probabilities, 5)
dfs = [df.withColumn('assignment', F.lit(assignment_values[i])) for i, df in enumerate(dfs)]
df = reduce(lambda a, b: a.union(b), dfs)

df.show()
#  ------ ------- ---------- 
# |person|company|assignment|
#  ------ ------- ---------- 
# |     1|     5g|       foo|
# |     2|     9s|       foo|
# |     6|     7c|       foo|
# |     4|     3l|      buzz|
# |     5|     2k|      buzz|
# |     8|     2p|      buzz|
# |     3|     1m|       boo|
# |     7|     3m|       boo|
# |     9|     4s|       boo|
# |    10|     8y|       boo|
#  ------ ------- ---------- 
  • Related