Home > Enterprise >  How to handle race condition with Coroutines in Kotlin?
How to handle race condition with Coroutines in Kotlin?

Time:06-01

I have a coroutine/flow problem that I'm trying to solve

I have this method getClosesRegion that's suppose to do the following:

  1. Attempt to connect to every region
  2. The first region to connect (I use launch to attempt to connect to all concurrently), should be returned and the rest of the region requests should be cancelled
  3. If all regions failed to connect OR after a 30 second timeout, throw an exception

That's currently what I have:

override suspend fun getClosestRegion(): Region {
        val regions = regionsRepository.getRegions()
        val firstSuccessResult = MutableSharedFlow<Region>(replay = 1)
        val scope = CoroutineScope(Dispatchers.IO)

        // Attempts to connect to every region until the first success
        scope.launch {
            regions.forEach { region ->
                launch {
                    val retrofitClient = buildRetrofitClient(region.backendUrl)
                    val regionAuthenticationAPI = retrofitClient.create(AuthenticationAPI::class.java)
                    val response = regionAuthenticationAPI.canConnect()
                    if (response.isSuccessful && scope.isActive) {
                        scope.cancel()
                        firstSuccessResult.emit(region)
                    }
                }
            }
        }

        val result = withTimeoutOrNull(TimeUnit.SECONDS.toMillis(30)) { firstSuccessResult.first() }
        if (result != null)
            return result
        throw Exception("Failed to connect to any region")
    }

Issues with current code:

  1. If 1 region was successfully connected, I expect that the of the requests will be cancelled (by scope.cancel()), but in reality other regions that have successfully connected AFTER the first one are also emitting value to the flow (scope.isActive returns true)
  2. I don't know how to handle the race condition of throw exception if all regions failed to connect or after 30 second timeout

Also I'm pretty new to kotlin Flow and Coroutines so I don't know if creating a flow is really necessary here

CodePudding user response:

You don't need to create a CoroutineScope and manage it from within a coroutine. You can use the coroutineScope function instead.

I of course didn't test any of the below, so please excuse syntax errors and omitted <types> that the compiler can't infer.

Here's how you might do it using a select clause, but I think it's kind of awkward:

override suspend fun getClosestRegion(): Region = coroutineScope {
    val regions = regionsRepository.getRegions()
    val result = select<Region?> {
        onTimeout(30.seconds) { null }
        for (region in regions) {
            launch {
                val retrofitClient = buildRetrofitClient(region.backendUrl)
                val regionAuthenticationAPI = retrofitClient.create(AuthenticationAPI::class.java)
                val result = regionAuthenticationAPI.canConnect()
                if (!it.isSuccessful) {
                    delay(30.seconds) // prevent this one from being selected
                }
            }.onJoin { region }
        }
    }
    coroutineContext.cancelChildren() // Cancel any remaining async jobs
    requireNotNull(result) { "Failed to connect to any region" }
}

Here's how you could do it with channelFlow:

override suspend fun getClosestRegion(): Region = coroutineScope {
    val regions = regionsRepository.getRegions()
    val flow = channelFlow {
        for (region in regions) {
            launch {
                val retrofitClient = buildRetrofitClient(region.backendUrl)
                val regionAuthenticationAPI = retrofitClient.create(AuthenticationAPI::class.java)
                val result = regionAuthenticationAPI.canConnect()
                if (result.isSuccessful) {
                    send(region)
                }
            }
        }
    }
    val result = withTimeoutOrNull(30.seconds) { 
        flow.firstOrNull()
    }
    coroutineContext.cancelChildren() // Cancel any remaining async jobs
    requireNotNull(result) { "Failed to connect to any region" }
}

I think your MutableSharedFlow technique could also work if you dropped the isActive check and used coroutineScope { } and cancelChildren() like I did above. But it seems awkward to create a shared flow that isn't shared by anything (it's only used by the same coroutine that created it).

CodePudding user response:

  1. If 1 region was successfully connected, I expect that the of the requests will be cancelled (by scope.cancel()), but in reality other regions that have successfully connected AFTER the first one are also emitting value to the flow (scope.isActive returns true)

To quote the documentation...

Coroutine cancellation is cooperative. A coroutine code has to cooperate to be cancellable.

Once your client is initiated, you can't cancel it - the client has be able to interrupt what it's doing. That probably isn't happening inside of Retrofit.

I'll presume that it's not a problem that you're sending more requests than you need - otherwise you won't be able to make simultaneous requests.


  1. I don't know how to handle the race condition of throw exception if all regions failed to connect or after 30 second timeout

As I understand there are three situations

  1. There's one successful response - other responses should be ignored
  2. All responses are unsuccessful - an error should be thrown
  3. All responses take longer than 30 seconds - again, throw an error

Additionally I don't want to keep track of how many requests are active/failed/successful. That requires shared state, and is complicated and brittle. Instead, I want to use parent-child relationships to manage this.

Timeout

The timeout is already handled by withTimeoutOrNull() - easy enough!

First success

Selects could be useful here, and I see @Tenfour04 has provided that answer. I'll give an alternative.

Using suspendCancellableCoroutine() provides a way to

  1. return as soon as there's a success - resume(...)
  2. throw an error when all requests fail - resumeWithException
suspend fun getClosestRegion(
  regions: List<Region>
): Region = withTimeoutOrNull(10.seconds) {

  // don't give the supervisor a parent, because if one response is successful
  // the parent will be await the cancellation of the other children
  val supervisorJob = SupervisorJob()

  // suspend the current coroutine. We'll use cont to continue when 
  // there's a definite outcome
  suspendCancellableCoroutine<Region> { cont ->

    launch(supervisorJob) {
      regions
        .map { region ->
          // note: use async instead of launch so we can do awaitAll()
          // to track when all tasks have completed, but none have resumed
          async(supervisorJob) {

            coroutineContext.job.invokeOnCompletion {
              log("cancelling async job for $region")
            }

            val retrofitClient = buildRetrofitClient(region)
            val response = retrofitClient.connect()
            
            // if there's a success, then try to complete the supervisor.
            // complete() prevents multiple jobs from continuing the suspended
            // coroutine
            if (response.isSuccess && supervisorJob.complete()) {
              log("got success for $region - resuming")
              // happy flow - we can return
              cont.resume(region)
            }
          }
        }.awaitAll()

      // uh-oh, nothing was a success
      if (supervisorJob.complete()) {
        log("no successful regions - throwing exception & resuming")
        cont.resumeWithException(Exception("no region response was successful"))
      }
    }
  }
} ?: error("Timeout error - unable to get region")

examples

all responses are successful

If all tasks are successful, then it takes the shortest amount of time to return

getClosestRegion(
  List(5) {
    Region("attempt1-region$it", success = true)
  }
)

...

log("result for all success: $regionSuccess, time $time")
 got success for Region(name=attempt1-region1, success=true, delay=2s) - resuming
 cancelling async job for Region(name=attempt1-region3, success=true, delay=2s)
 result for all success: Region(name=attempt1-region1, success=true, delay=2s), time 2.131312600s
 cancelling async job for Region(name=attempt1-region1, success=true, delay=2s)

all responses fail

When all responses fail, it should take the only as long as the maximum timeout.

getClosestRegion(
  List(5) {
    Region("attempt2-region$it", success = false)
  }
)

...

log("failure: $allFailEx, time $time")
[DefaultDispatcher-worker-6 @all-fail#6] cancelling async job for Region(name=attempt2-region4, success=false, delay=1s)
[DefaultDispatcher-worker-4 @all-fail#4] cancelling async job for Region(name=attempt2-region2, success=false, delay=4s)
[DefaultDispatcher-worker-3 @all-fail#3] cancelling async job for Region(name=attempt2-region1, success=false, delay=4s)
[DefaultDispatcher-worker-6 @all-fail#5] cancelling async job for Region(name=attempt2-region3, success=false, delay=4s)
[DefaultDispatcher-worker-6 @all-fail#2] cancelling async job for Region(name=attempt2-region0, success=false, delay=5s)
[DefaultDispatcher-worker-6 @all-fail#1] no successful regions - throwing exception resuming
[DefaultDispatcher-worker-6 @all-fail#1] failure: java.lang.Exception: no region response was successful, time 5.225431500s

all responses timeout

And if all responses take longer than the timeout (I reduced it to 10 seconds in my example), then an exception will be thrown.

getClosestRegion(
  List(5) {
    Region("attempt3-region$it", false, 100.seconds)
  }
)

...

log("timeout: $timeoutEx, time $time")
[kotlinx.coroutines.DefaultExecutor] timeout: java.lang.IllegalStateException: Timeout error - unable to get region, time 10.070052700s

Full demo code

import kotlin.coroutines.*
import kotlin.random.*
import kotlin.time.Duration.Companion.seconds
import kotlin.time.*
import kotlinx.coroutines.*


suspend fun main() {
  System.getProperties().setProperty("kotlinx.coroutines.debug", "")

  withContext(CoroutineName("all-success")) {
    val (regionSuccess, time) = measureTimedValue {
      getClosestRegion(
        List(5) {
          Region("attempt1-region$it", true)
        }
      )
    }
    log("result for all success: $regionSuccess, time $time")
  }

  log("\n------\n")

  withContext(CoroutineName("all-fail")) {
    val (allFailEx, time) = measureTimedValue {
      try {
        getClosestRegion(
          List(5) {
            Region("attempt2-region$it", false)
          }
        )
      } catch (exception: Exception) {
        exception
      }
    }
    log("failure: $allFailEx, time $time")
  }

  log("\n------\n")

  withContext(CoroutineName("timeout")) {
    val (timeoutEx, time) = measureTimedValue {
      try {
        getClosestRegion(
          List(5) {
            Region("attempt3-region$it", false, 100.seconds)
          }
        )
      } catch (exception: Exception) {
        exception
      }
    }
    log("timeout: $timeoutEx, time $time")
  }
}


suspend fun getClosestRegion(
  regions: List<Region>
): Region = withTimeoutOrNull(10.seconds) {

  val supervisorJob = SupervisorJob()

  suspendCancellableCoroutine<Region> { cont ->

    launch(supervisorJob) {
      regions
        .map { region ->
          async(supervisorJob) {

            coroutineContext.job.invokeOnCompletion {
              log("cancelling async job for $region")
            }

            val retrofitClient = buildRetrofitClient(region)
            val response = retrofitClient.connect()
            if (response.isSuccess && supervisorJob.complete()) {
              log("got success for $region - resuming")
              cont.resume(region)
            }
          }
        }.awaitAll()

      // uh-oh, nothing was a success
      if (supervisorJob.complete()) {
        log("no successful regions - throwing exception resuming")
        cont.resumeWithException(Exception("no region response was successful"))
      }
    }
  }
} ?: error("Timeout error - unable to get region")


////////////////////////////////////////////////////////////////////////////////////////////////////


data class Region(
  val name: String,
  val success: Boolean,
  val delay: Duration = Random(name.hashCode()).nextInt(1..5).seconds,
) {
  val backendUrl = "http://localhost/$name"
}


fun buildRetrofitClient(region: Region) = RetrofitClient(region)


class RetrofitClient(private val region: Region) {

  suspend fun connect(): ClientResponse {
    delay(region.delay)
    return ClientResponse(region.backendUrl, region.success)
  }
}

data class ClientResponse(
  val url: String,
  val isSuccess: Boolean,
)

fun log(msg: String) = println("[${Thread.currentThread().name}] $msg")
  • Related