Build a Hierarchy table (id to parent relation) in Python spark (PySpark). (Not using GraphFrame)


My input table looks like below

enter image description here

From this table, I have to take each base_id and find out the parent hierarchy. My output table should be like below

enter image description here

Here, for each base_id, the parent (parent_1), grand parent (parent_2)...(parent_n) Has to be calculated. Note that different base_id can have different hierarchy depth and the maximum level of depth is unknown.

I've written a script for this that I'm attaching below. It's working but the problem with my code is - I'm running a for loop for each base_id to find out its hierarchy. This is not very optimized since the no of base_ids can be 1 million plus. And God knows how much time it will take to run against Million ids.

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
import pyspark.sql.functions as F
from functools import reduce
from pyspark.sql import DataFrame

data2 = [('479', None),
         ('977', '666'),
         ('666', '479'),
         ('555', '678'),
         ('678', '977'),

schema = StructType([
    StructField("base_id", StringType(), True),
    StructField("parent_id", StringType(), True)

base_df = spark.createDataFrame(data=data2, schema=schema)

def id_parent_relation():
    itself_df = (
            "base_id as parent_id",
            "'itself' as retion_type"

    related_df = base_df.filter(F.col("parent_id").isNotNull())

    df_list = []

    id_list = related_df.select("base_id").rdd.flatMap(lambda x: x).collect()
    # print(id_list)

    for base_id in id_list:
        checker = True
        iteration = 1
        new_df = related_df.filter(F.col("base_id") == base_id)
        while checker:
            if iteration == 1:
                new_df = new_df.withColumn("relationship_type", F.lit("parent_1"))

                new_df = (
                        .withColumnRenamed("base_id", "base_id_new")
                        .withColumnRenamed("parent_id", "parent_id_new")
                new_df = (
                        on=(related_df["base_id"] == new_df["parent_id_new"]),
                    .withColumn("relationship_type", F.lit("parent_{}".format(iteration)))

            if new_df.filter(F.col("parent_id").isNotNull()).rdd.isEmpty():
                checker = False

                iteration  = 1
    union_df = reduce(DataFrame.unionAll, df_list)

if __name__ == "__main__":

I was hoping to somehow avoid the loops so that the script can be run for million records.

Please help me in this and thanks in advance!!!

I am no PySpark expert, so please feel free to criticize my suggestion.
The join part should be fine, but not sure how the stacking step will perform with high number of ids.

# Data from original post
df0 = base_df.alias("df0")

# Keep tracks of joins number
i = 1
# Performs join as many times as necessary
while (df0.filter(~F.isnull(F.col("df"   str(i - 1)   ".parent_id"))).count() > 0):
    # Create new alias
    df = base_df.alias("df"   str(i))
    # Join
    df0 = df0.join(df, F.col("df"   str(i - 1)   ".parent_id") == F.col("df"   str(i)   ".base_id"), "left")
    i  = 1

# Delete singletons (479 in exemple)
df0 = df0.where(~F.isnull(F.col("df0.parent_id")))

# Create stack expression
stack_expr = ""
for col in range(1, i):
    stack_expr  = f" df0.base_id, df{col}.base_id, 'parent_{col}',"

# Remove last ,
stack_expr = stack_expr[:-1]

# Stack to display columns as rows
df0.selectExpr("stack("   str(i)   ", "   stack_expr   ") as (base_id, parent_id, relation_type)") \
    .where("parent_id is not null") \

To my knowledge you cannot resolve a graph (transitive closure) with a single query. You need to run join multiple times (number of loop cycles is log(max_depth)).

Here is a trivial example in spark examples showing how to resolve a graph iteratively with exit condition checking whether any new connection have been found.

