Home > OS >  How to convert all int dtypes to double simultanously on PySpark
How to convert all int dtypes to double simultanously on PySpark

Time:10-29

here's my dataset

DataFrame[column1: double, column2: double, column3: int, column4: int, column5: int, ... , column300: int]

What I want is

DataFrame[column1: double, column2: double, column3: double, column4: double, column5: double, ... , column300: double]

What I did dataset.withColumn("column3", datalabel.column3.cast(DoubleType()))

It is too manual, can you show me how to do that?

CodePudding user response:

You first need to filter out your int column types from your available schema.

Then in conjunction with reduce you can iterate through the DataFrame to cast them to your choice

reduce is a very important & useful functionality that can be utilise to navigate any iterative use case(s) within Spark in general

Data Preparation

df = pd.DataFrame({
        'id':[f'id{i}' for i in range(0,10)],
        'col1': [i for i in range(80,90)],
        'col2': [i for i in range(5,15)],
        'col3': [6,7,5,3,4,2,9,12,4,10]
    
})


sparkDF = sql.createDataFrame(df)

sparkDF.printSchema()

root
 |-- id: string (nullable = true)
 |-- col1: long (nullable = true)
 |-- col2: long (nullable = true)
 |-- col3: long (nullable = true)

Identification

sparkDF.dtypes

## [('id', 'string'), ('col1', 'bigint'), ('col2', 'bigint'), ('col3', 'bigint')]

long_double_list = [ col for col,dtyp in sparkDF.dtypes if dtyp == 'bigint' ]

long_double_list

## ['col1', 'col2', 'col3']

Reduce

sparkDF = reduce(lambda df,c: df.withColumn(c,F.col(c).cast(DoubleType()))
                ,long_double_list
                ,sparkDF
            )

sparkDF.printSchema()

root
 |-- id: string (nullable = true)
 |-- col1: double (nullable = true)
 |-- col2: double (nullable = true)
 |-- col3: double (nullable = true)

CodePudding user response:

You can use list comprehensions to construct the converted field list.

import pyspark.sql.functions as F
...
cols = [F.col(field[0]).cast('double') if field[1] == 'int' else F.col(field[0]) for field in df.dtypes]
df = df.select(cols)
df.printSchema()
  • Related