Home > Software engineering >  Count number of attributes in vector of features - Spark Scala
Count number of attributes in vector of features - Spark Scala

Time:11-23

I have a dataframe with a column of normalized features like this:

 -------------------- 
|        normFeatures|
 -------------------- 
|(17412,[0,1,2,5,1...|
|(17412,[0,1,2,5,9...|
|(17412,[0,1,2,5,1...|
|(17412,[0,1,2,5,9...|
|(17412,[0,1,2,5,1...|
|(17412,[0,1,2,5,1...|
 -------------------- 

These vectors have been obtained after applying StringIndexer, OneHotEncoder and VectorAssembler to the original columns of attributes.

I was wondering if it is possible to count the new number of attributes. I don't know if the new number is the size of a vector or if numbers inside nested [] and () count as attributes as well.

Dataframe before vectorizing:

 --------- --------- ------------- --------- ------- -------- -------- -------- ------- -------- ------- 
|DayOfWeek|  DepTime|UniqueCarrier|FlightNum|TailNum|ArrDelay|DepDelay|Distance|TaxiOut|    Date| Flight|
 --------- --------- ------------- --------- ------- -------- -------- -------- ------- -------- ------- 
| Thursday|Afternoon|           WN|      588| N240WN|      16|      18|     393|      9|2008/1/3|HOU-LIT|
| Thursday|  Morning|           WN|     1343| N523SW|       2|       5|     441|      8|2008/1/3|HOU-MAF|
| Thursday|    Night|           WN|     3841| N280WN|      -4|      -6|     441|     14|2008/1/3|HOU-MAF|
| Thursday|  Morning|           WN|        3| N308SA|      -2|       8|     848|      7|2008/1/3|HOU-MCO|
| Thursday|Afternoon|           WN|       25| N462WN|      16|      23|     848|     10|2008/1/3|HOU-MCO|
| Thursday|    Night|           WN|       51| N483WN|       0|       4|     848|      7|2008/1/3|HOU-MCO|
| Thursday|  Evening|           WN|      940| N493WN|       3|       8|     848|      7|2008/1/3|HOU-MCO|
| Thursday|  Morning|           WN|     2621| N266WN|       5|       2|     848|     19|2008/1/3|HOU-MCO|
| Thursday|  Evening|           WN|      389| N266WN|      -5|      -1|     937|     15|2008/1/3|HOU-MDW|
| Thursday|Afternoon|           WN|      519| N514SW|      26|      28|     937|     13|2008/1/3|HOU-MDW|
 --------- --------- ------------- --------- ------- -------- -------- -------- ------- -------- ------- 

Dataframe after vectorizing and normalizing:

 --------- --------- ------------- --------- ------- -------- -------- -------- ------- -------- ------- --------- ------------ ------------------ ----------- ------------ -------------- -------------- ------------- ------------------- ------------- ------------------- --------------- ---------------- ------------------- -------------------- -------------------- 
|DayOfWeek|  DepTime|UniqueCarrier|FlightNum|TailNum|ArrDelay|DepDelay|Distance|TaxiOut|    Date| Flight|DateIndex|DepTimeIndex|UniqueCarrierIndex|FlightIndex|TailNumIndex|FlightNumIndex|DayOfWeekIndex| DayOfWeekVec|       FlightNumVec|   DepTimeVec|          FlightVec|        DateVec|UniqueCarrierVec|         TailNumVec|            features|        normFeatures|
 --------- --------- ------------- --------- ------- -------- -------- -------- ------- -------- ------- --------- ------------ ------------------ ----------- ------------ -------------- -------------- ------------- ------------------- ------------- ------------------- --------------- ---------------- ------------------- -------------------- -------------------- 
| Thursday|Afternoon|           WN|      588| N240WN|      16|      18|     393|      9|2008/1/3|HOU-LIT|      9.0|         1.0|               0.0|     3631.0|       554.0|         399.0|           2.0|(6,[2],[1.0])| (7262,[399],[1.0])|(3,[1],[1.0])|(4974,[3631],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[554],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|  Morning|           WN|     1343| N523SW|       2|       5|     441|      8|2008/1/3|HOU-MAF|      9.0|         0.0|               0.0|     3060.0|      1256.0|        3961.0|           2.0|(6,[2],[1.0])|(7262,[3961],[1.0])|(3,[0],[1.0])|(4974,[3060],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])|(5025,[1256],[1.0])|(17412,[0,1,2,5,9...|(17412,[0,1,2,5,9...|
| Thursday|    Night|           WN|     3841| N280WN|      -4|      -6|     441|     14|2008/1/3|HOU-MAF|      9.0|         3.0|               0.0|     3060.0|       463.0|        1520.0|           2.0|(6,[2],[1.0])|(7262,[1520],[1.0])|    (3,[],[])|(4974,[3060],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[463],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|  Morning|           WN|        3| N308SA|      -2|       8|     848|      7|2008/1/3|HOU-MCO|      9.0|         0.0|               0.0|     1285.0|        93.0|          76.0|           2.0|(6,[2],[1.0])|  (7262,[76],[1.0])|(3,[0],[1.0])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])|  (5025,[93],[1.0])|(17412,[0,1,2,5,9...|(17412,[0,1,2,5,9...|
| Thursday|Afternoon|           WN|       25| N462WN|      16|      23|     848|     10|2008/1/3|HOU-MCO|      9.0|         1.0|               0.0|     1285.0|       497.0|         213.0|           2.0|(6,[2],[1.0])| (7262,[213],[1.0])|(3,[1],[1.0])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[497],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|    Night|           WN|       51| N483WN|       0|       4|     848|      7|2008/1/3|HOU-MCO|      9.0|         3.0|               0.0|     1285.0|       282.0|         204.0|           2.0|(6,[2],[1.0])| (7262,[204],[1.0])|    (3,[],[])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[282],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|  Evening|           WN|      940| N493WN|       3|       8|     848|      7|2008/1/3|HOU-MCO|      9.0|         2.0|               0.0|     1285.0|       342.0|        1455.0|           2.0|(6,[2],[1.0])|(7262,[1455],[1.0])|(3,[2],[1.0])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[342],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|  Morning|           WN|     2621| N266WN|       5|       2|     848|     19|2008/1/3|HOU-MCO|      9.0|         0.0|               0.0|     1285.0|       555.0|        2051.0|           2.0|(6,[2],[1.0])|(7262,[2051],[1.0])|(3,[0],[1.0])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[555],[1.0])|(17412,[0,1,2,5,9...|(17412,[0,1,2,5,9...|
| Thursday|  Evening|           WN|      389| N266WN|      -5|      -1|     937|     15|2008/1/3|HOU-MDW|      9.0|         2.0|               0.0|     1081.0|       555.0|        1016.0|           2.0|(6,[2],[1.0])|(7262,[1016],[1.0])|(3,[2],[1.0])|(4974,[1081],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[555],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|Afternoon|           WN|      519| N514SW|      26|      28|     937|     13|2008/1/3|HOU-MDW|      9.0|         1.0|               0.0|     1081.0|       133.0|         309.0|           2.0|(6,[2],[1.0])| (7262,[309],[1.0])|(3,[1],[1.0])|(4974,[1081],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[133],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
 --------- --------- ------------- --------- ------- -------- -------- -------- ------- -------- ------- --------- ------------ ------------------ ----------- ------------ -------------- -------------- ------------- ------------------- ------------- ------------------- --------------- ---------------- ------------------- -------------------- -------------------- 

Code:

            val numeric_columns = Array("ArrDelay","DepDelay","TaxiOut","Distance")
            val string_columns = df.columns.diff(numeric_columns)
            println("Getting vector of normalized features")
            val index_columns = string_columns.map(col => col   "Index")
    
            // StringIndexer
            val indexer = new StringIndexer()
            .setInputCols(string_columns)
            .setOutputCols(index_columns)
    
            val vec_columns = string_columns.map(col => col   "Vec")
    
            // OneHotEncoder
            val encoder = new OneHotEncoder()
            .setInputCols(index_columns)
            .setOutputCols(vec_columns)
    
            // VectorAssembler
            val num_vec_columns:Array[String] = (numeric_columns.filter(!_.contains("ArrDelay")))    vec_columns   
            val assembler = new VectorAssembler()
            .setInputCols(num_vec_columns)
            .setOutputCol("features")
    
            // Normalizer
            val normalizer = new Normalizer()
            .setInputCol("features")
            .setOutputCol("normFeatures")
            .setP(1.0)
    
        // All together in pipeline
        val pipeline = new Pipeline()
        .setStages(Array(indexer, encoder, assembler,normalizer))
        df = pipeline.fit(df).transform(df)
        df.printSchema()
        df.show(10)
        println("Done")
        println("-------------------")

Thanks in advance.

CodePudding user response:

Some considerations here:

What you have here is a SparseVector representation that comes for multiple sparse vector transformations. These are created when you use the OneHotEncoder transformation(it is already deprecated). So when you have something like:

(7262,[399],[1.0])

It is a descrption that indicates that you have a vector of 7262 positions with a 1.0 in the 399 position. Here the length is 7262. Although it is a sparse representation, not dense.

The vector assembler is concatenating the sparse representations and finally you have a final 17412 length sparse representation. If you print the dataframe without truncation you will see the positions and values of the normFeatures column.

If you want to extract the length of this field you could do something like:

val row = df2.select("normFeatures").head
val vector = row(0).asInstanceOf[SparseVector]
val size = vector.size

But, the sparse representation length is not fixed for the whole dataframe. You could have rows with different lengths, although having the same transformations that should not happen. Be careful if you perform an union operation with another dataframe if you can´t track the applied transformations.

  • Related