11 繪圖實例(3) Drawing example(3)(代碼下載)
本文主要講述seaborn官網相關函數繪圖實例。具體內容有:
- Plotting a diagonal correlation matrix(heatmap)
- Scatterplot with marginal ticks(JointGrid)
- Multiple bivariate KDE plots(kdeplot)
- Multiple linear regression(lmplot)
- Paired density and scatterplot matrix(PairGrid)
- Paired categorical plots(PairGrid)
- Dot plot with several variables(PairGrid)
- Plotting a three-way ANOVA(catplot)
- Linear regression with marginal distributions(jointplot)
- Plotting model residuals(residplot)
# import packages
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
1. Plotting a diagonal correlation matrix(heatmap)
# 讀取字母表
from string import ascii_letters
# Generate a large random dataset 生成數據集
rs = np.random.RandomState(33)
d = pd.DataFrame(data=rs.normal(size=(100, 26)),
columns=list(ascii_letters[26:]))
# Compute the correlation matrix 計算相關係數
corr = d.corr()
# Generate a mask for the upper triangle 生成掩模
mask = np.zeros_like(corr, dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
d# Set up the matplotlib figure 設置圖大小
f, ax = plt.subplots(figsize=(11, 9))
# Generate a custom diverging colormap 設置顏色
cmap = sns.diverging_palette(220, 10, as_cmap=True)
# Draw the heatmap with the mask and correct aspect ratio
# square表都是正方形
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
square=True, linewidths=.5, cbar_kws={"shrink": .5});
2. Scatterplot with marginal ticks(JointGrid)
sns.set(style="white", color_codes=True)
# Generate a random bivariate dataset
rs = np.random.RandomState(9)
mean = [0, 0]
cov = [(1, 0), (0, 2)]
x, y = rs.multivariate_normal(mean, cov, 100).T
# Use JointGrid directly to draw a custom plot
# 創建一個繪圖表格區域,設置好x,y對應數據
grid = sns.JointGrid(x, y, space=0, height=6, ratio=50)
# 在聯合分佈上畫出散點圖
grid.plot_joint(plt.scatter, color="g")
# 在邊緣分佈上再作圖
grid.plot_marginals(sns.rugplot, height=1, color="g");
3. Multiple bivariate KDE plots(kdeplot)
sns.set(style="darkgrid")
iris = sns.load_dataset("iris")
# Subset the iris dataset by species
# 單獨篩選對應類的數據
setosa = iris.query("species == 'setosa'")
virginica = iris.query("species == 'virginica'")
# Set up the figure
f, ax = plt.subplots(figsize=(8, 8))
# 設置軸的縮放比例,equal表示x,y軸同等縮放比例
ax.set_aspect("equal")
# Draw the two density plots
# 畫核密度圖
# shade表示添加陰影,shade_lowest表示兩個核密度圖相疊時,核密度小的部分不畫出來
ax = sns.kdeplot(setosa.sepal_width, setosa.sepal_length,
cmap="Reds", shade=True, shade_lowest=False)
ax = sns.kdeplot(virginica.sepal_width, virginica.sepal_length,
cmap="Blues", shade=True, shade_lowest=False)
# Add labels to the plot
# 添加顏色
red = sns.color_palette("Reds")[-2]
blue = sns.color_palette("Blues")[-2]
ax.text(2.5, 8.2, "virginica", size=16, color=blue)
ax.text(3.8, 4.5, "setosa", size=16, color=red);
C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
4. Multiple linear regression(lmplot)
# Load the iris dataset 讀取數據
iris = sns.load_dataset("iris")
# Plot sepal with as a function of sepal_length across days
# 畫散點圖,lmplot默認參數,以hue設定不同種類
# truncate爲true表示現植迴歸擬合曲線繪圖時只畫出有數據的部分
g = sns.lmplot(x="sepal_length", y="sepal_width", hue="species",
truncate=True, height=5, data=iris);
# Use more informative axis labels than are provided by default
# 設置橫豎座標軸label
g.set_axis_labels("Sepal length (mm)", "Sepal width (mm)");
5. Paired density and scatterplot matrix(PairGrid)
sns.set(style="white")
df = sns.load_dataset("iris")
# 製作散點圖矩陣
# diag_sharey是否共享y軸
g = sns.PairGrid(df, diag_sharey=False)
# 下三角繪多變量核密度圖
g.map_lower(sns.kdeplot)
# 上三角繪散點圖
g.map_upper(sns.scatterplot)
# 對角線繪單變量核密度圖,lw表示線條粗細
g.map_diag(sns.kdeplot, lw=3);
6. Paired categorical plots(PairGrid)
sns.set(style="whitegrid")
# Load the example Titanic dataset
titanic = sns.load_dataset("titanic")
# Set up a grid to plot survival probability against several variables
# 製作散點圖矩陣
# y軸爲survived值,x_vars設定x軸
g = sns.PairGrid(titanic, y_vars="survived",
x_vars=["class", "sex", "who", "alone"],
height=5, aspect=.5)
# Draw a seaborn pointplot onto each Axes
# 製作折線圖, errwidth表示上下標準注的長度,其中各點代表平均值
g.map(sns.pointplot, scale=1.3, errwidth=4, color="xkcd:plum")
g.set(ylim=(0, 1))
sns.despine(fig=g.fig, left=True);
7. Dot plot with several variables(PairGrid)
sns.set(style="whitegrid")
# Load the dataset
crashes = sns.load_dataset("car_crashes")
# Make the PairGrid
# 按crash排序的值繪圖,x_vars,y_vars表示x軸或者y軸
g = sns.PairGrid(crashes.sort_values("total", ascending=False),
x_vars=crashes.columns[:-3], y_vars=["abbrev"],
height=10, aspect=.25)
# Draw a dot plot using the stripplot function
g.map(sns.stripplot, size=10, orient="h",
palette="ch:s=1,r=-.1,h=1_r", linewidth=1, edgecolor="w")
# Use the same x axis limits on all columns and add better labels
# 設置x軸,x標籤
g.set(xlim=(0, 25), xlabel="Crashes", ylabel="")
# Use semantically meaningful titles for the columns
titles = ["Total crashes", "Speeding crashes", "Alcohol crashes",
"Not distracted crashes", "No previous crashes"]
#去除軸線
for ax, title in zip(g.axes.flat, titles):
# Set a different title for each axes
ax.set(title=title)
# Make the grid horizontal instead of vertical
ax.xaxis.grid(False)
ax.yaxis.grid(True)
sns.despine(left=True, bottom=True);
8. Plotting a three-way ANOVA(catplot)
# Load the example exercise dataset
df = sns.load_dataset("exercise")
# Draw a pointplot to show pulse as a function of three categorical factors
# 分類型數據作座標軸畫圖catplot,
# col表示用什麼變量對圖像在橫座標方向分列
# hue表示在單個維度上用某個變量區分;
# capsize表示延伸線的長度
g = sns.catplot(x="time", y="pulse", hue="kind", col="diet",
capsize=0.2, palette="YlGnBu_d", height=6, aspect=.75,
kind="point", data=df)
g.despine(left=True);
9. Linear regression with marginal distributions(jointplot)
sns.set(style="darkgrid")
tips = sns.load_dataset("tips")
# 設置聯合圖像,類型是"reg"迴歸圖
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg",
xlim=(0, 60), ylim=(0, 12), color="m", height=7)
10. Plotting model residuals(residplot)
sns.set(style="whitegrid")
# Make an example dataset with y ~ x
rs = np.random.RandomState(7)
x = rs.normal(2, 1, 75)
y = 2 + 1.5 * x + rs.normal(0, 2, 75)
# Plot the residuals after fitting a linear model 殘差圖
# 中間曲線爲殘差曲線((對比一階擬合直線的殘差)),lowess曲線平滑
sns.residplot(x, y, lowess=True, color="g");