Home > front end >  pyspark when otherwise statement returning incorrect output
pyspark when otherwise statement returning incorrect output

Time:10-15

I have pasted my code below. I am expecting that when col2 = 7, it should return 1, but it is returning 1 at times and 2 - at other times. I am not doing any operations on col2 once it is set. Has anyone ever experienced this odd behavior? Or is the problem due to the fact that the limits for each condition are overlapping?

 df = df.withColumn('col1', F.when(F.col('col2').between(1,7), 1)
                             .when(F.col('col2').between(7,14), 2)
                             .when(F.col('col2').between(14,21), 3)
                             .when(F.col('col2').between(21,28), 4)
                             .otherwise(5))

CodePudding user response:

I'd say this is something unexpected, because the case-when will be converted to a sequence of ifs by CodeGen. Hence you should always see 'col2' being 1.

You can review the actual code that Spark generated using QueryExecution.debug.codegen, something like this:

>>> df = spark.range(1000)
>>> from pyspark.sql.functions import *
>>> dff = df.withColumn('col1',when(col('id').between(1,7),1).when(col('id').between(7,14),2).otherwise(3))

>>> dff._jdf.queryExecution().debug().codegen()

Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 ==
*(1) Project [id#4L, CASE WHEN ((id#4L >= 1) && (id#4L <= 7)) THEN 1 WHEN ((id#4L >= 7) && (id#4L <= 14)) THEN 2 ELSE 3 END AS col1#6]
 - *(1) Range (0, 1000, step=1, splits=2)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean range_initRange_0;
/* 010 */   private long range_number_0;
/* 011 */   private TaskContext range_taskContext_0;
/* 012 */   private InputMetrics range_inputMetrics_0;
/* 013 */   private long range_batchEnd_0;
/* 014 */   private long range_numElementsTodo_0;
/* 015 */   private int project_project_value_2_0;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */
/* 026 */     range_taskContext_0 = TaskContext.get();
/* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 031 */
/* 032 */   }
/* 033 */
/* 034 */   private void project_doConsume_0(long project_expr_0_0) throws java.io.IOException {
/* 035 */     byte project_caseWhenResultState_0 = -1;
/* 036 */     do {
/* 037 */       boolean project_value_4 = false;
/* 038 */       project_value_4 = project_expr_0_0 >= 1L;
/* 039 */       boolean project_value_3 = false;
/* 040 */
/* 041 */       if (project_value_4) {
/* 042 */         boolean project_value_7 = false;
/* 043 */         project_value_7 = project_expr_0_0 <= 7L;
/* 044 */         project_value_3 = project_value_7;
/* 045 */       }
/* 046 */       if (!false && project_value_3) {
/* 047 */         project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
/* 048 */         project_project_value_2_0 = 1;
/* 049 */         continue;
/* 050 */       }
/* 051 */
/* 052 */       boolean project_value_12 = false;
/* 053 */       project_value_12 = project_expr_0_0 >= 7L;
/* 054 */       boolean project_value_11 = false;
/* 055 */
/* 056 */       if (project_value_12) {
/* 057 */         boolean project_value_15 = false;
/* 058 */         project_value_15 = project_expr_0_0 <= 14L;
/* 059 */         project_value_11 = project_value_15;
/* 060 */       }
/* 061 */       if (!false && project_value_11) {
/* 062 */         project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
/* 063 */         project_project_value_2_0 = 2;
/* 064 */         continue;
/* 065 */       }
/* 066 */
/* 067 */       project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
/* 068 */       project_project_value_2_0 = 3;
/* 069 */
/* 070 */     } while (false);
/* 071 */     // TRUE if any condition is met and the result is null, or no any condition is met.
/* 072 */     final boolean project_isNull_2 = (project_caseWhenResultState_0 != 0);
/* 073 */     range_mutableStateArray_0[2].reset();
/* 074 */
/* 075 */     range_mutableStateArray_0[2].zeroOutNullBytes();
/* 076 */
/* 077 */     range_mutableStateArray_0[2].write(0, project_expr_0_0);
/* 078 */
/* 079 */     range_mutableStateArray_0[2].write(1, project_project_value_2_0);
/* 080 */     append((range_mutableStateArray_0[2].getRow()));
/* 081 */
/* 082 */   }
/* 083 */
...

We're interested in method private void project_doConsume_0(... (starting from line 34).

CodePudding user response:

First point: between is inclusive and you have some overlap in you interval (7 can be both True in the first and the second interval as they both contain 7)

So this should improve:

 df = df.withColumn('col1', F.when(F.col('col2').between(1,7), 1)
                             .when(F.col('col2').between(8,14), 2)
                             .when(F.col('col2').between(15,21), 3)
                             .when(F.col('col2').between(22,28), 4)
                             .otherwise(5))

But also when working with multiple F.when() I have less trouble by nesting them inside an .otherwise(F.when()) like the following:

 df = df.withColumn('col1', F.when(F.col('col2').between(1,7), 1)
                             .otherwise(F.when(F.col('col2').between(8,14), 2)
                             .otherwise(F.when(F.col('col2').between(15,21), 3)
                             .otherwise(F.when(F.col('col2').between(22,28), 4)
                             .otherwise(5)))))
  • Related