I have a series of values and a probability I want each those values sampled. Is there a PySpark method to sample from that distribution for each row? I know how to hard-code with a random number generator, but I want this method to be flexible for any number of assignment values and probabilities:
assignment_values = ["foo", "buzz", "boo"]
value_probabilities = [0.3, 0.3, 0.4]
Hard-coded method with random number generator:
from pyspark.sql import Row
data = [
{"person": 1, "company": "5g"},
{"person": 2, "company": "9s"},
{"person": 3, "company": "1m"},
{"person": 4, "company": "3l"},
{"person": 5, "company": "2k"},
{"person": 6, "company": "7c"},
{"person": 7, "company": "3m"},
{"person": 8, "company": "2p"},
{"person": 9, "company": "4s"},
{"person": 10, "company": "8y"},
]
df = spark.createDataFrame(Row(**x) for x in data)
(
df
.withColumn("rand", F.rand())
.withColumn(
"assignment",
F.when(F.col("rand") < F.lit(0.3), "foo")
.when(F.col("rand") < F.lit(0.6), "buzz")
.otherwise("boo")
)
.show()
)
------- ------ ------------------- ----------
|company|person| rand|assignment|
------- ------ ------------------- ----------
| 5g| 1| 0.8020603266148111| boo|
| 9s| 2| 0.1297179045352752| foo|
| 1m| 3|0.05170251723736685| foo|
| 3l| 4|0.07978240998283603| foo|
| 2k| 5| 0.5931269297050258| buzz|
| 7c| 6|0.44673560271164037| buzz|
| 3m| 7| 0.1398027427612647| foo|
| 2p| 8| 0.8281404801171598| boo|
| 4s| 9|0.15568513681001817| foo|
| 8y| 10| 0.6173220502731542| boo|
------- ------ ------------------- ----------
CodePudding user response:
I think randomSplit
may serve you. It randomly makes several dataframes out of your one nd puts them all into a list.
df.randomSplit([0.3, 0.3, 0.4])
You can also provide seed to it.
You can join the dfs back together using reduce
from pyspark.sql import functions as F
from functools import reduce
df = spark.createDataFrame(
[(1, "5g"),
(2, "9s"),
(3, "1m"),
(4, "3l"),
(5, "2k"),
(6, "7c"),
(7, "3m"),
(8, "2p"),
(9, "4s"),
(10, "8y")],
['person', 'company'])
assignment_values = ["foo", "buzz", "boo"]
value_probabilities = [0.3, 0.3, 0.4]
dfs = df.randomSplit(value_probabilities, 5)
dfs = [df.withColumn('assignment', F.lit(assignment_values[i])) for i, df in enumerate(dfs)]
df = reduce(lambda a, b: a.union(b), dfs)
df.show()
# ------ ------- ----------
# |person|company|assignment|
# ------ ------- ----------
# | 1| 5g| foo|
# | 2| 9s| foo|
# | 6| 7c| foo|
# | 4| 3l| buzz|
# | 5| 2k| buzz|
# | 8| 2p| buzz|
# | 3| 1m| boo|
# | 7| 3m| boo|
# | 9| 4s| boo|
# | 10| 8y| boo|
# ------ ------- ----------