Home > Enterprise >  K-means Clustering in C Implementation Help and Improvement
K-means Clustering in C Implementation Help and Improvement

Time:09-07

I was given an array of colors and my task is to estimate the real 2 different colors out of the array. So I figured I'll try to do a k means algorithm to best estimate the two colors and index which input is which, but my algorithm doesn't find the centroids well. And by well I mean the centroids are completely off.

What I do wrong?

Here is my full code:

#include <iostream>
#include <sstream>
#include <array>
#include <algorithm>
#include <cmath>
#include <cassert>
#include <vector>

using namespace std;

struct color {
    float red;
    float green;
    float blue;
};

void kMeansClustering(vector <color>* stick, vector <color>* centroids, int k , vector <int>* clustID);
float DistColors(color centroid, color c);
void UpdateCentroids(vector <color>* stick, int k , vector <color>* centroids , vector <int>* clustID);
color ColorEstimator(vector <color>* stick, int k, int epochs, vector <int>* clustID);

int main()
{

    array<color, 50> input = { { {231.40743881294037, 237.44965622220946, 254.77121736965333}, {228.5534410910449, 235.464386448478, 242.19988902045463}, {230.30772169389454, 238.72320816537146, 250.49564182065964}, {227.15016686772202, 236.24691719267867, 252.90235409762704}, {230.41895618727227, 238.657597362693, 248.29153087604715}, {229.93684620291512, 234.01819943500396, 249.60623949753426}, {228.42175736330205, 238.34361201766305, 247.2204140886806}, {227.84022562034895, 236.5214355504492, 247.8538373236252}, {227.71544053823285, 240.16861817247062, 254.60425325391995}, {226.58842798166197, 237.68042876356262, 254.71957768401475}, {230.88840479065888, 236.7678042900178, 248.61150119469562}, {226.86912093334206, 235.8884424252573, 247.88671586964455}, {227.45517488635113, 240.3819316766157, 250.51481023872068}, {229.59861886575894, 235.83584959054974, 251.46889400577427}, {229.23367809030552, 237.24044401350025, 249.4279962431541}, {228.32192716816607, 237.05340803933723, 250.64537320750122}, {230.08628579311065, 235.0134454354286, 246.5735753630862}, {229.252340468657, 234.00910635835228, 249.34217504033197}, {225.22447511374702, 239.4709613428637, 246.4458975263566}, {225.61187564359977, 235.66097995340849, 249.02345892965656}, {229.79637417586468, 239.10337544176383, 246.92327185739796}, {227.58812892987137, 240.21670944169784, 250.60935853418786}, {228.41945997585944, 238.11672487783682, 253.91798978403156}, {228.96868361083293, 235.2705636041433, 251.18917992896004}, {229.28284008153832, 239.37264981015255, 250.3604105076603}, {231.5687396936171, 232.90968630866203, 243.5450028582501}, {229.40662599794854, 236.45161336682256, 248.05397175207972}, {228.70207834717965, 243.15024540608889, 250.5010381192849}, {229.95651741532535, 239.64157110660392, 248.80395810592591}, {228.9988967797162, 236.21108589663837, 246.07611157333568}, {122.69433611768231, 169.37473153784, 66.18694573878635}, {125.25289749728812, 169.7640666346183, 68.91076979782899}, {123.38953288543718, 168.98423240064878, 68.2114085083602}, {125.35729146507643, 169.51176403300119, 69.78421062439558}, {124.3884981048362, 168.94995025162697, 68.63663408419842}, {126.73444024584143, 171.7928274202038, 69.00474217104087}, {124.01034543440451, 167.87622211655415, 67.71426849138949}, {123.3329143640377, 170.10828953207476, 68.18084821738653}, {124.05385478199051, 169.44109340129506, 70.42607758635805}, {126.15353870421363, 170.69852488354542, 66.82441574127802}, {229.0877270736824, 236.21275756861388, 249.53741724579652}, {231.4238125605047, 238.34822768240278, 249.30582303433133}, {228.28016325966595, 232.5048867215949, 246.89611636914483}, {227.27386122243635, 241.11957952107088, 251.0571248642451}, {225.13067522766005, 235.50757722085427, 249.22992628681828}, {227.9106192249585, 233.08165289848347, 248.3099867450273}, {227.52684451385937, 238.51789816429803, 251.5834717176488}, {224.408858290033, 233.5616065285502, 247.7543836219995}, {226.51482613575666, 239.76326979626094, 247.35607749364993}, {228.60095253722585, 238.309990231983, 244.90373176067322} } };
    vector<color> inputVector(input.begin(), input.end());
    vector <int> clustId;
    int epochs = 5, k=2;
    ColorEstimator(&inputVector, k, epochs, &clustId);
    return 0;
}

void kMeansClustering(vector <color>* stick, vector <color>* centroids,int k, vector <int>* clustID)
{
    vector <float> minDistance;
    color c;
    float dist;
    int colorIndex, clusterId ;

    for (vector<color>::iterator i = centroids -> begin();
        i != centroids -> end();   i) {
        // quick hack to get cluster index
        clusterId = i - centroids->begin();
        c = *i;
        //cout << c.red << " " << c.green << " " << c.blue << endl;
        for (vector<color>::iterator it = stick->begin();
            it != stick->end();   it) {

            c = *it;
            //cout << c.red << " " << c.green << " " << c.blue << endl;
            dist = DistColors(*i , c);
            // quick hack to get color index
            colorIndex = it - stick->begin();
            if (minDistance.size() < stick->size())
            {
                minDistance.push_back(dist);
                clustID->push_back(clusterId);
                continue;
            }
            //cout << minDistance.at(colorIndex);
            if (dist < minDistance.at(colorIndex))
            {
                clustID->at(colorIndex) = clusterId;
                minDistance.at(colorIndex) = dist;
            }
        }
    }
}

float DistColors(color centroid, color c)
{
    return sqrt(pow(centroid.red - c.red, 2)   pow(centroid.green - c.green, 2)   pow(centroid.blue - c.blue, 2));
}

void UpdateCentroids(vector <color>* stick, int k, vector <color>* centroids, vector <int>* clustID)
{
    vector<int> nPoints;
    vector<double> sumRed, sumGreen, sumBlue;
    int colorIndex, clusterId;

    // Initialise with zeroes
    for (int j = 0; j < k;   j) {
        nPoints.push_back(0);
        sumRed.push_back(0.0);
        sumGreen.push_back(0.0);
        sumBlue.push_back(0.0);
    }

    // Iterate over colors to append data to centroids
    for (vector<color>::iterator it = stick->begin();
        it != stick->end();   it) {
        colorIndex = it - stick->begin();
        clusterId = clustID -> at(colorIndex);
        nPoints[clusterId]  = 1;
        sumRed[clusterId]  = it->red;
        sumGreen[clusterId]  = it->green;
        sumBlue[clusterId]  = it->blue;
    }
    cout << nPoints.size() << endl;
    // Compute the new centroids
    for (vector<color>::iterator c = centroids ->begin();
        c != centroids->end();   c) {
        clusterId = c - centroids->begin();
        c->red = sumRed[clusterId] / nPoints[clusterId];
        c->green = sumGreen[clusterId] / nPoints[clusterId];
        c->blue = sumBlue[clusterId] / nPoints[clusterId];
        //centroids->at(clusterId) = c;
        //cout << c->red << " " << c->green << " " << c->blue << " " << nPoints[clusterId] << endl;
    }
}

color ColorEstimator(vector <color>* stick, int k, int epochs, vector <int>* clustID)
{
    vector <color> centroids;
    color c;
    int n;
    // Initialising the clusters
    n = stick->size();
    for (int i = 0; i < k;   i) {
        //c = {(float) (rand() % 256) , (float)(rand() % 256), (float)(rand() % 256) };
        c = stick->at(rand() % n);
        centroids.push_back(c);
    }
    for (int e = 0; e < epochs; e  )
    {
        //c = centroids.at(1);
        //cout << c.red << " " << c.green << " " << c.blue << endl;
        kMeansClustering(stick, &centroids, k, clustID);
        UpdateCentroids(stick, k, &centroids, clustID);
        //c = centroids.at(0);
        //cout << c.red << " " << c.green << " " << c.blue << endl;
    }
    return c;
}

CodePudding user response:

I found the bug to the code. The way I update my cluster indices vector for every color is wrong, instead of reassigning a value I added a value to the vector which made it really big and so it didn't work.

That was the main reason why the code I posted didn't work, I changed it to:

void kMeansClustering(vector <color>* stick, vector <color>* centroids,int k, vector <int>* clustID)
{
    vector <float> minDistance;
    color c;
    float dist;
    int colorIndex, clusterId ;

    for (vector<color>::iterator i = centroids -> begin();
        i != centroids -> end();   i) {
        // quick hack to get cluster index
        clusterId = i - centroids->begin();
        c = *i;
        for (vector<color>::iterator it = stick->begin();
            it != stick->end();   it) {

            c = *it;
            dist = DistColors(*i , c);
            // quick hack to get color index
            colorIndex = it - stick->begin();
            if (minDistance.size() < stick->size())
            {
                minDistance.push_back(dist);
                if(clustID->size() < stick->size())
                {
                    clustID->push_back(clusterId);
                }
                else {
                    clustID->at(colorIndex) = clusterId;
                }
                continue;
            }
            //cout << minDistance.at(colorIndex);
            if (dist < minDistance.at(colorIndex))
            {
                clustID->at(colorIndex) = clusterId;
                minDistance.at(colorIndex) = dist;
            }
        }
    }
}
  • Related