I want to create a model, which is symmetric on two fields. Let's call the model Balance:
class Balance (models.Model):
payer = models.ForeignKey(auth.User, ...)
payee = models.ForeignKey(auth.User, ...)
amount = models.DecimalField(...)
It should have the following property:
balance_forward = Balance.objects.get(payer=USER_1, payee=USER_2)
balance_backward = Balance.objects.get(payer=USER_2, payee=USER_1)
balance_forward.amount == -1 * balance_backward.amount
What is the best way to implement this?
CodePudding user response:
So, I came up with the following solution. Feel free to suggest other solutions.
class SymmetricPlayersQuerySet (models.query.QuerySet):
def do_swap(self, obj):
obj.payer, obj.payee = obj.payee, obj.payer
obj.amount *= -1
def get(self, **kwargs):
swap = False
if "payer" in kwargs and "payee" in kwargs:
if kwargs["payer"].id > kwargs["payee"].id:
swap = True
kwargs["payer"], kwargs["payee"] = \
kwargs["payee"], kwargs["payer"]
obj = super().get(**kwargs)
if swap:
self.do_swap(obj)
return obj
def filter(self, *args, **kwargs):
if (
("payer" in kwargs and "payee" not in kwargs) or
("payee" in kwargs and "payer" not in kwargs)
):
if "payee" in kwargs:
key, other = "payee", "payer"
else:
key, other = "payer", "payee"
constraints = (
models.Q(payer=kwargs[key]) |
models.Q(payee=kwargs[key])
)
queryset = super().filter(constraints)
for obj in queryset:
if getattr(obj, other) == kwargs[key]:
self.do_swap(obj)
return queryset
return super().filter(*args, **kwargs)
class BalanceManager (models.Manager.from_queryset(SymmetricPlayersQuerySet)):
pass
class Balance (models.Model):
objects = BalanceManager()
payer = models.ForeignKey(
Player,
on_delete=models.CASCADE,
related_name='balance_payer',
)
payee = models.ForeignKey(
Player,
on_delete=models.CASCADE,
related_name='balance_payee',
)
amount = models.DecimalField(decimal_places=2, max_digits=1000, default=0)
def do_swap(self):
self.payer, self.payee = self.payee, self.payer
self.amount *= -1
def save(self, *args, **kwargs):
swap = False
if self.payer.id > self.payee.id:
swap = True
self.do_swap()
result = super().save(*args, **kwargs)
if swap:
self.do_swap()
return result
def refresh_from_db(self, *args, **kwargs):
swap = False
if self.payer.id > self.payee.id:
swap = True
super().refresh_from_db(*args, **kwargs)
if swap:
self.do_swap()
CodePudding user response:
You can aggregate on the Balance
objects with:
from django.db.models import Case, F, Sum, When
from django.conf import settings
class Balance(models.Model):
payer = models.ForeignKey(settings.AUTH_USER_MODEL, …)
payee = models.ForeignKey(settings.AUTH_USER_MODEL, …)
amount = models.DecimalField(…)
def get_balance(cls, payer, payee):
return cls.objects.filter(
Q(payer=payer, payee=payee) | Q(payer=payee, payee=payer)
).aggregate(
total=Sum(
Case(
When(payer=payer, then=F('amount')),
otherwise=-F('amount'),
output_field=DecimalField(…),
)
)
)['total']
This will look for all Balance
s between the payer
and the payee
, and subtract the ones in the opposite direction. The Balance.get_balance(payer=foo, payee=bar)
will thus determine the total flow from foo
to bar
.
Note: It is normally better to make use of the
settings.AUTH_USER_MODEL
[Django-doc] to refer to the user model, than to use theUser
model [Django-doc] directly. For more information you can see the referencing theUser
model section of the documentation.