I'm trying to create an std::unordered_map
that takes a std::pair
as key, and returns a size_t
as value. The tricky part for me is that I want custom hash function for my map to disregard the order of the members of the key std::pair
. I.e:
std::pair<int,int> p1 = std::make_pair<3,4>;
std::pair<int,int> p2 = std::make_pair<4,3>;
std::unordered_map<std::pair<int,int>, int> m;
m[p1] = 3;
// m[p2] should now also return 3!
This is not a clear cut MWE but it's a cut out of what I'm trying to do in my program:
#include <vector>
#include <string>
#include <iostream>
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <functional>
class Point
{
public:
static size_t id_counter;
size_t id;
Point()=default;
~Point()=default;
bool operator==(const Point& rhs)
{
return id == rhs.id;
}
friend std::ostream& operator<<(std::ostream& os, Point& P);
};
size_t Point::id_counter = 0;
class Hasch_point_pair
{
public:
size_t operator()(const std::pair<Point*, Point*>* p) const
{
// XOR hash. We don't care about collision we're FREAKS
auto h1 = std::hash<size_t>()(p->first->id);
auto h2 = std::hash<size_t>()(p->second->id);
return h1^h2;
}
};
int main(int argc, char const *argv[])
{
auto p1 = std::make_unique<Point>();
auto p2 = std::make_unique<Point>();
auto p3 = std::make_unique<Point>();
auto p4 = std::make_unique<Point>();
std::unordered_map<std::pair<Point*, Point*>*, size_t*, Hasch_point_pair> m;
auto p = std::make_unique<std::pair<Point*, Point*>>(p1.get(),p2.get());
auto p_hmm = std::make_unique<std::pair<Point*, Point*>>(p2.get(),p1.get());
size_t value = 3;
m[p.get()] = &value;
std::cout << "m[p] = " << m.at(p.get()) << std::endl;
std::cout << "m[p_hmm] = " << m.at(p_hmm.get()) << std::endl;
}
One thought I had was to compare the id's of each Point and always use the Point with the largest id member variable as the first hash, but I haven't gotten it to work. Does it make sense?
class Hasch_point_pair
{
public:
size_t operator()(const std::pair<Point*, Point*>* p) const
{
if (p->first->id > p->second->id)
{
auto h1 = std::hash<size_t>()(p->first->id);
auto h2 = std::hash<size_t>()(p->second->id);
return h1^h2;
}
else
{
// Note switched order of hash1 and hash2!
auto h2 = std::hash<size_t>()(p->first->id);
auto h1 = std::hash<size_t>()(p->second->id);
return h1^h2;
}
}
};
CodePudding user response:
@Drew put his finger on the problem when he told you that your issue is with operator==
and not std::hash
. So, how to fix it?
Well, the obvious solution is to define your own type (which can inherit from std::pair
) which defines operator==
to work how you need it to. Something like this maybe:
template <typename T> struct my_pair : std::pair <T, T>
{
using std::pair<T, T>::pair;
bool operator== (const my_pair &other)
{
using std::swap;
std::pair p1 = *this;
std::pair p2 = other;
if (p1.first < p1.second)
swap (p1.first, p1.second);
if (p2.first < p2.second)
swap (p2.first, p2.second);
return p1 == p2;
}
};
Note that this code assumes that the same type is used for both first
and second
(because that seems to me to be a necessity). It could also be improved to do less work but I wanted to keep things simple (see Deduplicator's comment for a better version).
CodePudding user response:
Using a custom class for equality testing:
class Equal_point_pair
{
public:
bool operator(
const std::pair<Point *, Point *> p1,
const std::pair<Point *, Point *> p2) const
{
// Verify if both pair are in the same order
const bool p1Asc = p1->first-> id < p1->second-> id;
const bool p2Asc = p2->first-> id < p2->second-> id;
// If both point are in same order, compare same members
// Otherwise, compare swaped members...
return p1Asc == p2Asc ?
*p1->first == *p2->first && *p1->second == *p2->second :
*p1->first == *p2->second && *p1->second == *p2->first;
}
};
Note that the above code does what I think you want to do... Also I haven't tested the code.
Then your map would be declared like that :
using PointMap = std::unordered_map<
std::pair<Point*, Point*>*,
size_t*,
Hasch_point_pair,
Equal_pointPair>;
PointMap m;
By the way, not sure why you are using (nested) pointers...