Home > Software engineering >  rlang::hash cannot differentiate between arrow queries
rlang::hash cannot differentiate between arrow queries

Time:11-26

I use the memoise package to cache queries to an arrow dataset but I sometimes get mismatches/"collisions" in hashes and therefore the wrong values are returned.

I have isolated the problem and replicated it in the MWE below. The issue is that the rlang::hash() (which memoise uses) of an arrow query that first filters then summarises does not depend on the filter.

My question is: is this something that I can fix (because I used it wrongly) or is this a bug in the one of the packages (I am happy to create an issue), if so, should this be reported to arrow, rlang::hash(), or even R6?

MWE

For example, all three queries below have the same hash but they should be different (and looking at the results, the results obviously are...)

library(arrow)
library(dplyr)

ds_file <- file.path(tempdir(), "mtcars")

write_dataset(mtcars, ds_file)
ds <- open_dataset(ds_file)

# 1) Create three different queries =======

# Query 1 with mpg > 25 ----
query1 <- ds |> 
  filter(mpg > 25) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg))

# Query 2 with mpg > 0 ----
query2 <- ds |> 
  filter(mpg > 0) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg))

# Query 3 with filter on cyl ----
query3 <- ds |> 
  filter(cyl == 4) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg))


# 2) Lets compare the hashes: the main issue ======
rlang::hash(query1)
#> [1] "f505339fd65df6ef53728fcc4b0e55f7"
rlang::hash(query2)
#> [1] "f505339fd65df6ef53728fcc4b0e55f7"
rlang::hash(query3)
#> [1] "f505339fd65df6ef53728fcc4b0e55f7"
# ERROR HERE: they should be different as the queries are different!

# 3) Lets also compare the results: clearly different =====
query1 |> collect()
#> # A tibble: 2 × 3
#>      vs     n mean_mpg
#>   <dbl> <int>    <dbl>
#> 1     1     5     30.9
#> 2     0     1     26

query2 |> collect()
#> # A tibble: 2 × 3
#>      vs     n mean_mpg
#>   <dbl> <int>    <dbl>
#> 1     0    18     16.6
#> 2     1    14     24.6

query3 |> collect()
#> # A tibble: 2 × 3
#>      vs     n mean_mpg
#>   <dbl> <int>    <dbl>
#> 1     1    10     26.7
#> 2     0     1     26

Note that the same error happens when I use digest.

When I print the queries, they are printed as if they were identical... (I reported this bug here to arrow)

query1
#> FileSystemDataset (query)
#> vs: double
#> n: int32
#> mean_mpg: double
#> 
#> See $.data for the source Arrow object

query2
#> FileSystemDataset (query)
#> vs: double
#> n: int32
#> mean_mpg: double
#> 
#> See $.data for the source Arrow object

query3
#> FileSystemDataset (query)
#> vs: double
#> n: int32
#> mean_mpg: double
#> 
#> See $.data for the source Arrow object

but when I query the $.data argument of the query, I see that they are in fact different

query1$.data
#> FileSystemDataset (query)
#> mpg: double
#> vs: double
#> 
#> * Aggregations:
#> n: sum(1)
#> mean_mpg: mean(mpg)
#> * Filter: (mpg > 25)            #<=========
#> * Grouped by vs
#> See $.data for the source Arrow object

query2$.data
#> FileSystemDataset (query)
#> mpg: double
#> vs: double
#> 
#> * Aggregations:
#> n: sum(1)
#> mean_mpg: mean(mpg)
#> * Filter: (mpg > 0)             #<=========
#> * Grouped by vs
#> See $.data for the source Arrow object

query3$.data
#> FileSystemDataset (query)
#> mpg: double
#> vs: double
#> 
#> * Aggregations:
#> n: sum(1)
#> mean_mpg: mean(mpg)
#> * Filter: (cyl == 4)            #<=========
#> * Grouped by vs
#> See $.data for the source Arrow object

but again rlang::hash() cannot find a difference:

rlang::hash(query1$.data)
#> [1] "b7f743cd635f7dc06356b827a6974df8"
rlang::hash(query2$.data)
#> [1] "b7f743cd635f7dc06356b827a6974df8"
rlang::hash(query3$.data)
#> [1] "b7f743cd635f7dc06356b827a6974df8"

If it helps, the query objects are R6 objects with class arrow_dplyr_query (see also its source code in apache/arrow)

Memoise use case

For completeness sake and to put the problem into perspective, I use the following to cache the results, which should return different values (see above) but doesn't!

library(arrow)
library(memoise)
library(dplyr)

ds_file <- file.path(tempdir(), "mtcars")

write_dataset(mtcars, ds_file)
ds <- open_dataset(ds_file)

collect_cached <- memoise::memoise(dplyr::collect,
                                   cache = cachem::cache_mem(logfile = stdout()))

# Query 1 with mpg > 25 ----
ds |> 
  filter(mpg > 25) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg)) |> 
  collect_cached()
#> [2022-11-25 09:16:28.586] cache_mem get: key "2edd901226498414056dcc54eaa49415"
#> [2022-11-25 09:16:28.586] cache_mem get: key "2edd901226498414056dcc54eaa49415" is missing
#> [2022-11-25 09:16:28.705] cache_mem set: key "2edd901226498414056dcc54eaa49415"
#> [2022-11-25 09:16:28.706] cache_mem prune
#> # A tibble: 2 × 3
#>      vs     n mean_mpg
#>   <dbl> <int>    <dbl>
#> 1     1     5     30.9
#> 2     0     1     26

# Query 2 with mpg > 0 ----
# this is wrongly matched to the first query and returns wrong results... 
ds |> 
  filter(mpg > 0) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg)) |> 
  collect_cached()
#> [2022-11-25 09:16:28.820] cache_mem get: key "2edd901226498414056dcc54eaa49415"
#> [2022-11-25 09:16:28.820] cache_mem get: key "2edd901226498414056dcc54eaa49415" found    #< ERROR HERE! as the hash is identical
#> # A tibble: 2 × 3
#>      vs     n mean_mpg
#>   <dbl> <int>    <dbl>
#> 1     1     5     30.9
#> 2     0     1     26

Note that we get the same result although the queries are different (yet their hashes are identical, hence this question).

CodePudding user response:

This is very much a hack ... but perhaps it'll be enough? I was able to find something unique-enough about the intermediate "query" that included its filter components by capturing the output from show_query, and using that as the hash= argument to memoise:

hashfun <- function(x) {
  x$x <- capture.output(show_query(x$x))
  rlang::hash(x)
}
collect_cached <- memoise::memoise(
  dplyr::collect,
  cache = cachem::cache_mem(logfile = stdout()),
  hash = hashfun)

ds |> 
  filter(mpg > 25) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg)) |> 
  collect_cached()
# [2022-11-25 08:14:56.596] cache_mem get: key "e6184e282e05875139e8afd2a071f329"
# [2022-11-25 08:14:56.596] cache_mem get: key "e6184e282e05875139e8afd2a071f329" is missing
# [2022-11-25 08:14:56.616] cache_mem set: key "e6184e282e05875139e8afd2a071f329"
# [2022-11-25 08:14:56.616] cache_mem prune
# # A tibble: 2 x 3
#      vs     n mean_mpg
#   <dbl> <int>    <dbl>
# 1     1     5     30.9
# 2     0     1     26  

#### different filter, should be a "miss"
ds |> 
  filter(mpg > 0) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg)) |> 
  collect_cached()
# [2022-11-25 08:15:06.745] cache_mem get: key "88312b31b29050ff029900f4dfc58a9f"
# [2022-11-25 08:15:06.745] cache_mem get: key "88312b31b29050ff029900f4dfc58a9f" is missing
# [2022-11-25 08:15:06.767] cache_mem set: key "88312b31b29050ff029900f4dfc58a9f"
# [2022-11-25 08:15:06.767] cache_mem prune
# # A tibble: 2 x 3
#      vs     n mean_mpg
#   <dbl> <int>    <dbl>
# 1     0    18     16.6
# 2     1    14     24.6

#### repeat of filter `mpg > 0`, should be a "hit"
ds |> 
  filter(mpg > 0) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg)) |> 
  collect_cached()
# .   > 
# [2022-11-25 08:15:24.825] cache_mem get: key "88312b31b29050ff029900f4dfc58a9f"
# [2022-11-25 08:15:24.825] cache_mem get: key "88312b31b29050ff029900f4dfc58a9f" found
# # A tibble: 2 x 3
#      vs     n mean_mpg
#   <dbl> <int>    <dbl>
# 1     0    18     16.6
# 2     1    14     24.6

The object passed to hashfun is a list, where the first argument appears to be a checksum or salt of a sort (we'll ignore it), and all remaining arguments (named or otherwise) are determined by the formals of the cached function. In our case, since we're caching collect, it accepts x= (which we see) and ...= (which we don't):

debugonce(hashfun)
ds |> 
  filter(mpg > 0) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg)) |> 
  collect_cached()
# debugging in: encl$`_hash`(c(encl$`_f_hash`, args, lapply(encl$`_additional`, 
#     function(x) eval(x[[2L]], environment(x)))))
# debug at #1: {
#     x$x <- capture.output(show_query(x$x))
#     rlang::hash(x)
# }

x
# [[1]]
# [1] "1e4b92a7ebe8b4bcb1afbd44c9a72a72"
# 
# $x
# FileSystemDataset (query)
# vs: double
# n: int32
# mean_mpg: double
# 
# See $.data for the source Arrow object

show_query(x$x)
# ExecPlan with 6 nodes:
# 5:SinkNode{}
#   4:ProjectNode{projection=[vs, n, mean_mpg]}
#     3:GroupByNode{keys=["vs"], aggregates=[
#       hash_sum(n, {skip_nulls=true, min_count=1}),
#       hash_mean(mean_mpg, {skip_nulls=false, min_count=0}),
#     ]}
#       2:ProjectNode{projection=["n": 1, "mean_mpg": mpg, vs]}
#         1:FilterNode{filter=(mpg > 0)}
#           0:SourceNode{}

Just replacing x$x with the return from show_query(x$x) didn't seem to work since there appear to be things only in the printed form that are not readily available to rlang::hash, so I chose capture.output.

CodePudding user response:

Edit: See comments this does not work!

I have modified your MWE below to call $.data or explain on the query which both do not evaluate the query (pretty sure for .data, 100% for explain) but seem to change the R6 enough to create unique hashes. As you have opened an issue for the underlying problem over at GitHub, this should be a pretty simple workaround with no performance hit.

library(arrow)
library(dplyr)

ds_file <- file.path(tempdir(), "mtcars")

write_dataset(mtcars, ds_file)
ds <- open_dataset(ds_file)

# 1) Create three different queries =======

# Query 1 with mpg > 25 ----
query1 <- ds |> 
  filter(mpg > 25) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg))

# Query 2 with mpg > 0 ----
query2 <- ds |> 
  filter(mpg > 0) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg))

# Query 3 with filter on cyl ----
query3 <- ds |> 
  filter(cyl == 4) |> 
  group_by(vs) |> 
  summarise(n = n(), mean_mpg = mean(mpg))

query1$.data
#> FileSystemDataset (query)
#> mpg: double
#> vs: double
#> 
#> * Aggregations:
#> n: sum(1)
#> mean_mpg: mean(mpg)
#> * Filter: (mpg > 25)
#> * Grouped by vs
#> See $.data for the source Arrow object
explain(query2)
#> ExecPlan with 6 nodes:
#> 5:SinkNode{}
#>   4:ProjectNode{projection=[vs, n, mean_mpg]}
#>     3:GroupByNode{keys=["vs"], aggregates=[
#>      hash_sum(n, {skip_nulls=true, min_count=1}),
#>      hash_mean(mean_mpg, {skip_nulls=false, min_count=0}),
#>     ]}
#>       2:ProjectNode{projection=["n": 1, "mean_mpg": mpg, vs]}
#>         1:FilterNode{filter=(mpg > 0)}
#>           0:SourceNode{}
# 2) Lets compare the hashes: the main issue ======
rlang::hash(query1)
#> [1] "8bbf29208ccbc95fc1bc46f2f2dfe10d"
rlang::hash(query2)
#> [1] "ae5c80b8ed0cc884df40926f3a985b27"
rlang::hash(query3)
#> [1] "3826d824e4c9be046ac5f09dcb60959d"
  • Related