Matplotlib Stackplot of counts by groups

I am a complete beginner in Python and still quite confused how to store data for plotting. I am trying to create stacked area plot of counts over a group variable (continent) and an x variable (year_join). Here is example data to illustrate the structure

 df = pd.DataFrame({'id': ['1', '2', '3', '4', '5', '6', '7',
                          '8', '9', '10', '11', '12', '13', '14',
                          '15', '16', '17', '18', '19', '20', '21'],
'year_join': ['2015', '2016', '2017', '2015', '2016', '2017', '2015',
              '2015', '2016', '2017', '2015', '2016', '2017', '2015',
              '2015', '2016', '2017', '2015', '2016', '2017', '2015'],
'continent' : ['Europe', 'Asia', 'Europe', 'Africa', 'Asia', 'Europe', 'Africa',
               'Asia', 'Europe', 'Africa', 'Asia', 'Europe', 'Africa', 'Asia',
               'Africa', 'Africa', 'Asia', 'Europe', 'Africa', 'Asia', 'Europe']}) 

After fiddling around this code gives me a graph

# 1. Group data data by year_join and continent into new dataframe (maybe to complicated, found code on Stack)
grouped = (pd.DataFrame(df.groupby(['continent', 'year_join']).size().reset_index(name="count")).pivot(columns='continent', index='year_join', values='count'))

# 2. Bring value counts into dictionary
result = {}
for columnName in grouped:
    result[columnName] = [*grouped[columnName]]
    
# 3. Create lists    
year = grouped.index.values.tolist()
y = list(dict.values(result))

# 4. Create stackplot from lists
plt.stackplot(year, y)
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()

However, firstly the legend is not shown and, more generally, I doubt it makes sense to transfer data from one dataframe to another to a dictionary to a list before plotting it. Does anybody have tips on how to improve this?

Answer

Consider using Pandas’ plot API, DataFrame.plot, which integrates Pandas objects with Matplotlib objects. Additionally, groupby + pivot is essentially pivot_table which aggregates data while pivoting. Consequently, you can streamline your needs with simpler steps:

pvt_df = df.pivot_table(index='year_join', columns='continent', aggfunc = 'count')
pvt_df.columns = pvt_df.columns.get_level_values(1)   # FLATTENS HIERARCHICAL COLUMNS
pvt_df
# continent  Africa  Asia  Europe
# year_join                      
# 2015          2.0   NaN     1.0
# 2016          NaN   2.0     NaN
# 2017          NaN   NaN     2.0

pvt_df.plot(kind='bar', stacked=True, rot=0, title='Year Joined and Continent Count')

plt.legend(loc='upper left')
plt.tight_layout()
plt.show()
plt.clf()
plt.close()

Plot Output