Companies can select a section of a Road. Sections are denoted by a start & end.
pyspark dataframe below:
-------------------- ---------- --------
|Road company |start(km) |end(km) |
-------------------- ---------- --------
|classA |1 |3 |
|classA |4 |7 |
|classA |10 |15 |
|classA |16 |20 |
|classB |1 |3 |
|classB |4 |7 |
|classB |10 |15 |
-------------------- ---------- --------
The classB company would pick the section of the road first. For classA entries, there should be overlap with classB. That is, classA Companies could not select a section of the road part that has been chosen by classB(company). The result should as below:
-------------------- ---------- --------
|Road company |start(km) |end(km) |
-------------------- ---------- --------
|classA |16 |20 |
|classB |1 |3 |
|classB |4 |7 |
|classB |10 |15 |
-------------------- ---------- --------
The distinct() function does not support separating the frame into several parts to apply the distinct operation. What should I do to implement that?
CodePudding user response:
If you could partially allocate the section of Road here's a different (very similar) strategy:
start="start(km)"
end="end(km)"
def emptyDFr():
schema = StructType([
StructField(start,IntegerType(),True),
StructField(end,IntegerType(),True),
StructField("Road company",StringType(),True),
StructField("ranged",IntegerType(),True)
])
return spark.createDataFrame(sc.emptyRDD(), schema)
def dummyData():
return sc.parallelize([["classA",1,3],["classA",4,7],["classA",8,15],["classA",16,20],["classB",1,3],["classB",4,7],["classB",8,17]]).toDF(['Road company','start(km)','end(km)'])
df = dummyData()
df.cache()
df_ordered = df.orderBy(when(col("Road company") == "classB", 1)
.when(col("Road company") == "classA", 2)
.when(col("Road company") == "classC", 3)
).select("Road company").distinct()
# create the sequence of kilometers that cover the 'start' to 'end'
ranged = df.withColumn("range", explode(sequence( col(start), col(end) )) )
whatsLeft = ranged.select( col("range") ).distinct()
result = emptyDFr()
#Only use collect() on small countable sets of data.
for company in df_ordered.collect():
taken = ranged.where(col("Road company") == lit(company[0]))\
.join(whatsLeft, ["range"])
whatsLeft = whatsLeft.subtract( taken.select( col("range") ) )
result = result.union( taken.select( col("range") ,col(start), col(end),col("Road company") ) )
#convert our result back to the 'original style' of records with starts and ends.
result.groupBy( start, end, "Road company").agg(count("ranged").alias("count") )\
#figure out math to see if you got everything you asked for.
.withColumn("Partial", ((col(end) lit(1)) - col(start)) != col("count"))\
.withColumn("Maths", ((col(end) lit(1)) - col(start))).show() #helps show why this works not requried.
CodePudding user response:
If you can can rely on the fact that sections will not ever overlap, you can solve this with the below logic. You could likely optimize it to rely on the "start(km)". But if you are talking more in-depth than that it might be more complicated.
from pyspark.sql.functions col, when
from pyspark.sql.types import *
def emptyDF():
schema = StructType([
StructField("start(km)",IntegerType(),True),
StructField("end(km)",IntegerType(),True),
StructField("Road company",StringType(),True)
])
return spark.createDataFrame(sc.emptyRDD(), schema)
def dummyData():
return sc.parallelize([["classA",1,3],["classA",4,7],["classA",8,15],["classA",16,20],["classB",1,3],["classB",4,7],["classB",8,15]]).toDF(['Road company','start(km)','end(km)'])
df = dummyData()
df.cache()
df_ordered = df.orderBy(when(col("Road company") == "classB", 1)
.when(col("Road company") == "classA", 2)
.when(col("Road company") == "classC", 3)
).select("Road company").distinct()
whatsLeft = df.select( col("start(km)") ,col("end(km)") ).distinct()
result = emptyDF()
#Only use collect() on small countable sets of data.
for company in df_ordered.collect():
taken = df.where(col("Road company") == lit(company[0]))\
.join(whatsLeft, ["start(km)" ,"end(km)"])
whatsLeft = whatsLeft.subtract( taken.drop( col("Road company") ) )
result = result.union( taken )
result.show()
--------- ------- ------------
|start(km)|end(km)|Road company|
--------- ------- ------------
| 1| 3| classB|
| 4| 7| classB|
| 8| 15| classB|
| 16| 20| classA|
--------- ------- ------------