Home > Net >  Multiple grouped charts with altair
Multiple grouped charts with altair

Time:12-31

enter image description here

My data has 4 attributes: dataset (D1/D2), model (M1/M2), layer (L1/L2), scene (S1/S2). I can make a chart grouped by scenes and then merge plots horizontally and vertically (pic above). However, I would like to have 'double grouping' by scene and dataset, like merging the D1 and D2 plots by placing blue/orange bars from next to each other but with different opacity or pattern/hatch.

Basically something like this (pretend that the black traits are a hatch pattern).

enter image description here

Here is the code to reproduce the first plot

import numpy as np
import itertools
import argparse
import pandas as pd
import matplotlib.pyplot as plt
import os
import altair as alt
alt.renderers.enable('altair_viewer')

np.random.seed(0)

################################################################################

model_keys = ['M1', 'M2']
data_keys = ['D1', 'D2']
scene_keys = ['S1', 'S2']
layer_keys = ['L1', 'L2']

ys = []
models = []
dataset = []
layers = []
scenes = []

for sc in scene_keys:
    for m in model_keys:
        for d in data_keys:
            for l in layer_keys:
                for s in range(10):
                    data_y = list(np.random.rand(10) / 10)
                    ys  = data_y
                    scenes  = [sc] * len(data_y)
                    models  = [m] * len(data_y)
                    dataset  = [d] * len(data_y)
                    layers  = [l] * len(data_y)


# ------------------------------------------------------------------------------


df = pd.DataFrame({'Y': ys,
                   'Model': models,
                   'Dataset': dataset,
                   'Layer': layers,
                   'Scenes': scenes})

bars = alt.Chart(df, width=100, height=90).mark_bar().encode(
    # field to group columns on
    x=alt.X('Scenes:N',
        title=None,
        axis=alt.Axis(
            grid=False,
            title=None,
            labels=False,
        ),
    ),
    # field to use as Y values and how to calculate
    y=alt.Y('Y:Q',
        aggregate='mean',
        axis=alt.Axis(
            grid=True,
            title='Y',
            titleFontWeight='normal',
        ),
    ),
    # field to use for sorting
    order=alt.Order('Scenes',
        sort='ascending',
    ),
    # field to use for color segmentation
    color=alt.Color('Scenes',
        legend=alt.Legend(orient='bottom',
            padding=-10,
        ),
        title=None,
    ),
)

error_bars = alt.Chart(df).mark_errorbar(extent='ci').encode(
    x=alt.X('Scenes:N'),
    y=alt.Y('Y:Q'),
)

text = alt.Chart(df).mark_text(align='center',
    baseline='line-bottom',
    color='black',
    dy=-5 # y-shift
).encode(
    x=alt.X('Scenes:N'),
    y=alt.Y('mean(Y):Q'),
    text=alt.Text('mean(Y):Q', format='.1f'),
)

chart_base = bars   error_bars   text

chart_base = chart_base.facet(
    # field to use to use as the set of columns to be represented in each group
    column=alt.Column('Layer:N',
        # header=alt.Header(
            # labelFontStyle='bold',
        # ),
        title=None,
        sort=list(set(models)), # get unique indices
    ),
    spacing={"row": 0, "column": 15},
)

def unique(sequence):
    seen = set()
    return [x for x in sequence if not (x in seen or seen.add(x))]

for i, m in enumerate(unique(models)):
    chart_imnet = chart_base.transform_filter(
        alt.FieldEqualPredicate(field='Dataset', equal='D1'),
    ).transform_filter(
        alt.FieldEqualPredicate(field='Model', equal=m)
    )
    chart_places = chart_base.transform_filter(
        alt.FieldEqualPredicate(field='Dataset', equal='D2')
    ).transform_filter(
        alt.FieldEqualPredicate(field='Model', equal=m)
    )

    if i == 0:
        title_params = dict({'align': 'center', 'anchor': 'middle', 'dy': -10})
        chart_imnet = chart_imnet.properties(title=alt.TitleParams('D1', **title_params))
        chart_places = chart_places.properties(title=alt.TitleParams('D2', **title_params))

    chart_places = alt.concat(chart_places,
        title=alt.TitleParams(
            m,
            baseline='middle',
            orient='right',
            anchor='middle',
            angle=90,
            # dy=10,
            dx=30 if i == 0 else 0,
        ),
    )

    if i == 0:
        chart = (chart_imnet | chart_places).resolve_scale(x='shared')
    else:
        chart = (chart & (chart_imnet | chart_places).resolve_scale(x='shared'))

chart.save('test.html')

CodePudding user response:

For now, I don't know a good answer, but once enter image description here

See enter image description here

  • Related