Home > OS >  Python: Sort every row and accumulate weights
Python: Sort every row and accumulate weights

Time:12-23

I have the following dataframe:

df1 = pd.DataFrame(
{
"A_price": [10, 12, 15],
"B_price": [20, 19, 29],
"C_price": [23, 21, 4],
"D_price": [45, 47, 44],
},
index = ['01-01-2020', '01-02-2020', '01-03-2020']
)

df2 = pd.DataFrame(
{
"A_weight": [0.1, 0.2, 0.4],
"B_weight": [0.2, 0.5, 0.1],
"C_weight": [0.3, 0.2, 0.1],
"D_weight": [0.4, 0.1, 0.4],
},
index = ['01-01-2020', '01-02-2020', '01-03-2020']
)

out = pd.merge(df1, df2, left_index=True, right_index=True)
out.columns = out.columns.str.split('_', expand=True)
out = out.sort_index(axis=1)
out:
            A               B               C               D
            price   weight  price   weight  price   weight  price   weight
01-01-2020  10      0.1     20      0.2     23      0.3     45      0.4
01-02-2020  12      0.2     19      0.5     21      0.2     47      0.1
01-03-2020  15      0.4     29      0.1     4       0.1     44      0.4

What I want to do is calculate the weighted median which is found by sorting the (weight, price) pairs by price and then accumulating the weights until the two prices that straddle the 50% cumulative weight point are found.

We then interpolate between those two (weight, price) pairs to find the price at 50% cumulative weight, and then put that price into a new DataFrame.

UPDATE: I changed my dataframe so it is more reflective of what I currently have.

The output I want would be the weighted median for each row. Meaning, for the row of index "01-01-2020" I would expect the median to be the interpolation of price = 23 since when you add the weights accross the row, we have 0.1 0.2 0.3 > 0.5. So I would get a dataframe of prices that looks like this:

df_prices:
             Price
01-01-2020   23
01-02-2020   19
01-03-2020   29

CodePudding user response:

IIUC:

def wmedian(sr):
    df = sr.unstack().sort_values('price')
    return df.loc[df['weight'].cumsum() > 0.5, 'price'].iloc[0]

out2 = out.apply(wmedian, axis=1)
print(out2)

# Output:
01-01-2020    23.0
01-02-2020    19.0
01-03-2020    29.0
dtype: float64

https://en.wikipedia.org/wiki/Weighted_median

  • Related