Home > Blockchain >  Plot Venn diagram from a boolean set notation in Python
Plot Venn diagram from a boolean set notation in Python

Time:03-08

The following code plots a Venn diagram by interpreting the set input in boolean form and shading the respective region of interest:

from matplotlib import pyplot as plt
from matplotlib_venn import venn3, venn3_circles

def list_regions(formula):
    from itertools import product
    for x in product([False, True], repeat=3):
        if eval(formula, {}, dict(zip('ABC', x))):
            yield ''.join(str(int(x_i)) for x_i in x)

def plot_diagram(formula):
    plt.figure(figsize=(6,6))
    v = venn3(subsets=[1]*7,
              set_colors=['white']*3,
              subset_label_formatter=lambda x: '')
    c = venn3_circles(subsets=[1]*7, linestyle='solid')
    for region in list_regions(formula):
        v.get_patch_by_id(region).set_color('black')
    plt.show()

This code works fine for most of the set formulas:

plot_diagram('A')
plot_diagram('B')
plot_diagram('C')
plot_diagram('A or B')
plot_diagram('A and B and C')
plot_diagram('(A or B) and C')

All of the above diagrams are plotted properly. Moreover, the code also works for complements:

plot_diagram('A and (not B)')
plot_diagram('(A or B) and (not C)')
plot_diagram('(A and B) and (not C)')

However, I encountered some issues. The code does not work for all complements:

plot_diagram('not B')
plot_diagram('A or (not B)')

Trying to plot the complement of a single set, or the union of one set with the complement of another results in the same error for both:

...
/matplotlib_venn/_common.py", line 31, in get_patch_by_id
    return self.patches[self.id2idx[id]]
KeyError: '000'

Googling the exception did not show any relevant results, hence:

Does anyone know if there is another (more correct) way to yield the corresponding id instead of using the product class from the itertools module in order to shade the corresponding region of interest?

The code generally works but is flawed when it comes to the complements mentioned before.

Thanks in advance for any help on that matter.

CodePudding user response:

From the error (below) returned by your code when calling plot_diagram('not B'):

/usr/local/lib/python3.7/dist-packages/matplotlib_venn/_common.py in get_patch_by_id(self, id)
     29            A region id is a string '10', '01' or '11' for 2-circle diagram or a
     30            string like '001', '010', etc, for 3-circle diagram.'''
---> 31         return self.patches[self.id2idx[id]]
     32 
     33     def get_label_by_id(self, id):

KeyError: '000

It looks like the issue comes from trying to draw a region with a string id '000' (in other words not drawing anything at all). Setting up a string condition if region!='000': solves the issue. See code below:

from matplotlib import pyplot as plt
from matplotlib_venn import venn3, venn3_circles
from itertools import product

def list_regions(formula):
    
    for x in product([False, True], repeat=3):
        if eval(formula, {}, dict(zip('ABC', x))):
            yield ''.join(str(int(x_i)) for x_i in x)

def plot_diagram(formula):
    plt.figure(figsize=(6,6))
    v = venn3(subsets=[1]*7,
              set_colors=['white']*3,
              subset_label_formatter=lambda x: '')
    c = venn3_circles(subsets=[1]*7, linestyle='solid')
    for region in list_regions(formula):
      if region!='000':
        v.get_patch_by_id(region).set_color('black')
    plt.show()

And the output on the two last cases:

plot_diagram('not B')

enter image description here

plot_diagram('A or (not B)')

enter image description here

CodePudding user response:

The key 000 actually corresponds to the universal sample space, therefore:

from matplotlib import pyplot as plt
from matplotlib_venn import venn3, venn3_circles

def list_regions(formula):
    from itertools import product    
    for x in product([False, True], repeat=3):
        if eval(formula, {}, dict(zip('ABC', x))):
            yield ''.join(str(int(x_i)) for x_i in x)

def plot_diagram(formula):
    f = plt.figure(figsize=(6,6))
    v = venn3(alpha=1.0,
              subsets=[1]*7,
              set_colors=['white']*3,
              subset_label_formatter=lambda x: '')
    c = venn3_circles(subsets=[1]*7, linestyle='solid')
    for region in list_regions(formula):
      if region == '000':
        f.gca().set_facecolor('grey')
      else:
        v.get_patch_by_id(region).set_color('grey')
    plt.annotate('U', (-0.657, 0.557)); plt.axis('on')
    plt.show()
  • Related