Home > OS >  Passing Arguments into a thread in Scala
Passing Arguments into a thread in Scala

Time:09-07

I am learning Scala and as an exercise I am transforming some python (PySpark) code to Scala (spark/Scala) code. Everything was going ok until I started dealing with scala threads. So, Do you now how can I re write the following code to scala?

Thank You in Advance!

def load_tables(table_name, spark):
    source_path = f"s3://data/tables/{table_name}"
    table = spark.read.format("csv").load(source_path)
    table.createOrReplaceTempView(table_name)

def read_initial_tables(spark):
    threads  = []
    tables = ["table1", "table2", "table3"]
    for table in tables:
        t = threading.Thread(target=load_tables, args=(table, spark))
        threads.append(t)
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

CodePudding user response:

...passing arguments into threads...

Scala uses the Java standard libraries, and starting a thread in Java is a little bit different from starting a thread in Python. The main difference is, in Python you can choose any target (i.e., any function or callable object) for the thread's top-level, and you can pass in any args that you like. But when you start a Java thread, the top-level function must be a no-argument method named run() that belongs to an object that implements java.lang.Runnable.

Your Python thread's top-level function is load_tables(table, spark). So, what you need in your Scala program is a thread whose top-level function is a run() function that calls load_tables(table, spark).

I don't actually know Scala, but maybe the example on this web page will steer you in the right direction: https://alvinalexander.com/scala/how-to-create-java-thread-runnable-in-scala/

Basically, I think all you have to do is follow his example, and put your load_tables(table, spark) call in the place where his example says, "your custom behavior here."

CodePudding user response:

Solomon is right. I could not describe it better. Taking advantage of the syntactic sugar Scala provides over Java, your Python code is not longer in Scala:

  def load_tables(table_name: String, spark: SparkSession): Runnable = () => {
    val source_path = s"s3://data/tables/$table_name"
    val table = spark.read.format("csv").load(source_path)
    table.createOrReplaceTempView(table_name)
  }

  def read_initial_tables(spark: SparkSession): Unit = {
    val tables = List("table1", "table2", "table3")
    val threads = for {
      table <- tables
    } yield new Thread(load_tables(table, spark))
    for (thread <- threads)
      thread.start()
    for (thread <- threads)
      thread.join()
  }

You might ask where is the run() method, Solomon was talking about. Actually, the empty parentheses () after the = sign the load_tables starts with, represent the no-argument parameter list that is passed to the run method, while the body of the run method is the block of code between curly braces after the => sign. So a call to load_tables actually returns a new Runnable instance.

This is called a Single Abstract Method which is just a syntactic sugar that gives the impression that load_tables looks callable as in Python, but it's not actually. Only it's return type is, because it returns a Runnable object. This short version is only achievable because Runnable is a Functional Interface.

I'm not a specialist in Spark, so I'm not sure if this is the idiomatic way to code in Scala with Spark, but it's a good starting point to go from here.

CodePudding user response:

Maybe not really what you are looking for but it could be interesting. Scala has some very convenient stuff for parallelization of collections with the method .par:

val parallelizedList = List(1, 2, 3, 4).par
parallelizedList.foreach(i => println(i)) // this is executed in parallel, not sequentially
// output:
// 2
// 4
// 1
// 3

So you can use this syntax with spark to read multiple tables in parallel:

def loadTable(tableName: String, spark: SparkSession): Unit = {
    val sourcePath = f"s3://data/tables/$tableName"
    val table = spark.read.format("csv").load(sourcePath)
    table.createOrReplaceTempView(tableName)
}

val tableNames = List("table1", "table2", "table3")
tableNames.par.foreach(name => loadTable(name, spark))
  • Related