Goal: write a function in PostgreSQL SQL that takes as input an integer array whose each element is either 0
, 1
, or -1
and returns an array of the same length, where each element of the output array is the sum of all adjacent nonzero values in the input array having the same or lower index.
Example, this input:
{0,1,1,1,1,0,-1,-1,0}
should produce this result:
{0,1,2,3,4,0,-1,-2,0}
Here is my attempt at such a function:
CREATE FUNCTION runs(input int[], output int[] DEFAULT '{}')
RETURNS int[] AS $$
SELECT
CASE WHEN cardinality(input) = 0 THEN output
ELSE runs(input[2:],
array_append(output, CASE
WHEN input[1] = 0 THEN 0
ELSE output[cardinality(output)] input[1]
END)
)
END
$$ LANGUAGE SQL;
Which gives unexpected (to me) output:
# select runs('{0,1,1,1,1,0,-1,-1,-1,0}');
runs
----------------------------------------
{0,1,2,3,4,5,6,0,0,0,-1,-2,-3,-4,-5,0}
(1 row)
I'm using PostgreSQL 14.4. While I am ignorant of why there are more elements in the output array than the input, the cardinality()
in the recursive call seems to be causing it, as also does using array_length()
or array_upper()
in the same place.
Question: how can I write a function that gives me the output I want (and why is the function I wrote failing to do that)?
Bonus extra: For context, this input array is coming from array_agg()
invoked on a table column and the output will go back into a table using unnest()
. I'm converting to/from an array since I see no way to do it directly on the table, in particular because WITH RECURSIVE
forbids references to the recursive table in either an outer join or subquery. But if there's a way around using arrays (especially with a lack of tail-recursion optimization) that will answer the general question (But I am still very very curious why I'm seeing the extra elements in the output array).
CodePudding user response:
Everything indicates that you have found a reportable Postgres bug. The function should work properly, and a slight modification unexpectedly changes its behavior. Add SELECT;
right after $$
to get the function to run as expected, see Db<>fiddle.
A good alternative to a recursive solution is a simple iterative function. Handling arrays in PL/pgSQL is typically simpler and faster than recursion.
create or replace function loop_function(input int[])
returns int[] language plpgsql as $$
declare
val int;
tot int = 0;
res int[];
begin
foreach val in array input loop
if val = 0 then tot = 0;
else tot := tot val;
end if;
res := res || tot;
end loop;
return res;
end $$;
Test it in Db<>fiddle.
The OP wrote:
this input array is coming from
array_agg()
invoked on a table column and the output will go back into a table usingunnest()
.
You can calculate these cumulative sums directly in the table with the help of window functions.
select id, val, sum(val) over w
from (
select
id,
val,
case val
when 0 then 0
else sum((val = 0)::int) over w
end as series
from my_table
window w as (order by id)
) t
window w as (partition by series order by id)
order by id
Test it in Db<>fiddle.