I'm working with this covid data set provided by usafacts below. Each column represent a single date with the header in format yyyy-MM-dd
.
I am sending a link to my blog. You may find there other interesting Spark topics that interest you: https://bigdata-etl.com/articles/big-data/apache-spark/
CodePudding user response:
Your approach makes sense, except that if you keep Mondays and Sundays and use their difference to get the number of cases per week, you are missing what happens between Sundays and Mondays. You are just considering 6 days and not seven. You need to compute the difference between each Monday for instance and the previous one.
One way to go at it is to:
- Only keep Mondays (drop the other columns)
- Compute the differences between each consecutive Mondays
- Optionally explode the dataset (one line per week) so that it is easier to use and so that you have a stable schema even if more weeks are added to the dataset. But the first 2 steps may be enough for you.
from pyspark.sql import functions as f
import datetime
# A function that tells if a date is a monday
def is_monday(date):
[year, month, day] = date.split("-")
return datetime.date(int(year), int(month), int(day)).weekday() == 0
# reading the data
covid = (
spark.read
.option("header", true)
.option("inferSchema", true)
.csv("covid_confirmed_usafacts.csv")
)
dates = covid.columns[4:]
mondays = [ d for d in dates if is_monday(d) ]
cols = covid.columns[0:4]
# for each monday, we compute the difference with the previous one.
# For the first monday, we compute the diff with the first date we have.
# You may remove that first part if you don't need it.
diffs = (
[f.col(mondays[0]) - f.col(dates[0])]
[f.col(mondays[d]) - f.col(mondays[d-1]) for d in range(1, len(mondays))]
)
# simply naming the week column with by its monday
named_diffs = [ diffs[d].alias(mondays[d]) for d in range(len(mondays)) ]
result_1 = covid.select(cols named_diffs)
# Step 3: exploding the dataframe
result_2 = covid\
.withColumn("s", f.explode(f.array(*structs)))\
.drop(*dates)\
.select(cols ['s.*'])
result_2.printSchema()
root
|-- countyFIPS: integer (nullable = true)
|-- County Name: string (nullable = true)
|-- State: string (nullable = true)
|-- StateFIPS: integer (nullable = true)
|-- date: string (nullable = false)
|-- value: integer (nullable = true)