I have PySpark dataframe:
user_id | item_id | last_watch_dt | total_dur | watched_pct |
---|---|---|---|---|
1 | 1 | 2021-05-11 | 4250 | 72 |
1 | 2 | 2021-05-11 | 80 | 99 |
2 | 3 | 2021-05-11 | 1000 | 80 |
2 | 4 | 2021-05-11 | 5000 | 40 |
I used this code:
df_new = df.pivot(index='user_id', columns='item_id', values='watched_pct')
To get this:
1 | 2 | 3 | 4 | |
---|---|---|---|---|
1 | 72 | 99 | 0 | 0 |
2 | 0 | 0 | 80 | 40 |
But I got an error:
AttributeError: 'DataFrame' object has no attribute 'pivot'
What did I do wrong?
CodePudding user response:
You can only do .pivot
on objects having pivot
attribute (method or property). You tried to do df.pivot
, so it would only work if df
had such attribute. You can inspect all the attributes of df (it's an object of pyspark.sql.DataFrame
class) here. You see many attributes there, but none of them is called pivot
. That's why you get an attribute error.
pivot
is a method of pyspark.sql.GroupedData
object. It means, in order to use it, you must somehow create pyspark.sql.GroupedData
object from your pyspark.sql.DataFrame
object. In your case, it's by using .groupBy()
:
df.groupBy("user_id").pivot("item_id")
This creates yet another pyspark.sql.GroupedData
object. In order to make a dataframe out of it you would want to use one of the methods of GroupedData
class. agg
is the method that you need. Inside it, you will have to provide Spark's aggregation function which you will use for all the grouped elements (e.g. sum
, first
, etc.).
df.groupBy("user_id").pivot("item_id").agg(F.sum("watched_pct"))
Full example:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(1, 1, '2021-05-11', 4250, 72),
(1, 2, '2021-05-11', 80, 99),
(2, 3, '2021-05-11', 1000, 80),
(2, 4, '2021-05-11', 5000, 40)],
['user_id', 'item_id', 'last_watch_dt', 'total_dur', 'watched_pct'])
df = df.groupBy("user_id").pivot("item_id").agg(F.sum("watched_pct"))
df.show()
# ------- ---- ---- ---- ----
# |user_id| 1| 2| 3| 4|
# ------- ---- ---- ---- ----
# | 1| 72| 99|null|null|
# | 2|null|null| 80| 40|
# ------- ---- ---- ---- ----
If you want to replace nulls with 0
, use fillna
of pyspark.sql.DataFrame
class.
df = df.fillna(0)
df.show()
# ------- --- --- --- ---
# |user_id| 1| 2| 3| 4|
# ------- --- --- --- ---
# | 1| 72| 99| 0| 0|
# | 2| 0| 0| 80| 40|
# ------- --- --- --- ---