I have a multiclass classification TensorFlow model imported into GCP BigQuery. When you make predictions, the output is the probabilities which is a type FLOAT (the probabilities) and a mode REPEATED. What is the best way to get the index of the max value using SQL in BigQuery?
CodePudding user response:
If you want to find an index of max value from an array, using an UDF would be handy, I think.
CREATE TEMP FUNCTION index_of_max(probabilites ARRAY<FLOAT64>) AS ((
SELECT i FROM UNNEST(probabilites) p WITH OFFSET i
WHERE p = (SELECT MAX(p) FROM UNNEST(probabilites) p)
));
SELECT index_of_max(dense_1) index_of_max FROM UNNEST([
STRUCT([0.8611106872558594, 0.06648489832878113, 0.07240447402000427] AS dense_1),
STRUCT([0.6251607537269592, 0.2989124655723572, 0.07592668384313583]),
STRUCT([0.01427623350173235, 0.972910463809967, 0.01281337533146143])
]);
output:
[note] zero-based index
If applied to below example,
-
CodePudding user response:
consider also slightly refactored version of proposed already UDF
CREATE TEMP FUNCTION index_of_max(probabilites ARRAY<FLOAT64>) AS (( SELECT i FROM UNNEST(probabilites) p WITH OFFSET i ORDER BY p DESC LIMIT 1 ));
OR
CREATE TEMP FUNCTION index_of_max(probabilites ARRAY<FLOAT64>) AS (( SELECT i FROM UNNEST(probabilites) p WITH OFFSET i QUALIFY 1 = ROW_NUMBER() OVER(ORDER BY p DESC) ));