Home > Software design >  AttributeError: 'DataFrame' object has no attribute 'pivot'
AttributeError: 'DataFrame' object has no attribute 'pivot'

Time:09-22

I have PySpark dataframe:

user_id item_id last_watch_dt total_dur watched_pct
1 1 2021-05-11 4250 72
1 2 2021-05-11 80 99
2 3 2021-05-11 1000 80
2 4 2021-05-11 5000 40

I used this code:

df_new = df.pivot(index='user_id', columns='item_id', values='watched_pct')

To get this:

1 2 3 4
1 72 99 0 0
2 0 0 80 40

But I got an error:

AttributeError: 'DataFrame' object has no attribute 'pivot'

What did I do wrong?

CodePudding user response:

You can only do .pivot on objects having pivot attribute (method or property). You tried to do df.pivot, so it would only work if df had such attribute. You can inspect all the attributes of df (it's an object of pyspark.sql.DataFrame class) here. You see many attributes there, but none of them is called pivot. That's why you get an attribute error.

pivot is a method of pyspark.sql.GroupedData object. It means, in order to use it, you must somehow create pyspark.sql.GroupedData object from your pyspark.sql.DataFrame object. In your case, it's by using .groupBy():

df.groupBy("user_id").pivot("item_id")

This creates yet another pyspark.sql.GroupedData object. In order to make a dataframe out of it you would want to use one of the methods of GroupedData class. agg is the method that you need. Inside it, you will have to provide Spark's aggregation function which you will use for all the grouped elements (e.g. sum, first, etc.).

df.groupBy("user_id").pivot("item_id").agg(F.sum("watched_pct"))

Full example:

from pyspark.sql import functions as F

df = spark.createDataFrame(
    [(1, 1, '2021-05-11', 4250, 72),
     (1, 2, '2021-05-11', 80, 99),
     (2, 3, '2021-05-11', 1000, 80),
     (2, 4, '2021-05-11', 5000, 40)],
    ['user_id', 'item_id', 'last_watch_dt', 'total_dur', 'watched_pct'])

df = df.groupBy("user_id").pivot("item_id").agg(F.sum("watched_pct"))
df.show()
#  ------- ---- ---- ---- ---- 
# |user_id|   1|   2|   3|   4|
#  ------- ---- ---- ---- ---- 
# |      1|  72|  99|null|null|
# |      2|null|null|  80|  40|
#  ------- ---- ---- ---- ---- 

If you want to replace nulls with 0, use fillna of pyspark.sql.DataFrame class.

df = df.fillna(0)
df.show()
#  ------- --- --- --- --- 
# |user_id|  1|  2|  3|  4|
#  ------- --- --- --- --- 
# |      1| 72| 99|  0|  0|
# |      2|  0|  0| 80| 40|
#  ------- --- --- --- --- 
  • Related