# 导入数据
import pandas as pd
url = "https://raw.githubusercontent.com/Resgression/Salary_Data.csv"
data = pd.read_csv(url)
data
# y = w*x + b
x = data["YearsExperience"]
y = data["Salary"]
x
0 0.3
1 0.6
2 0.8
3 1.1
4 1.3
5 1.5
6 2.0
7 2.2
8 2.9
9 3.0
10 3.2
11 3.2
12 3.7
13 3.9
14 4.0
15 4.0
16 4.1
17 4.5
18 4.9
19 5.1
20 5.3
21 5.9
22 6.0
23 6.8
24 7.1
25 7.9
26 8.2
27 8.7
28 9.0
29 9.5
30 9.6
31 10.3
32 10.5
Name: YearsExperience, dtype: float64
# 预测值的线性公式:y_pred = w*x + b
# cost function 成本函数公式:cost=(真实值-预测值)的平方,找出最适合的直线
w = 10
b = 0
y_pred = w*x + b
cost = (y - y_pred)**2
# 定义cost function函数
def compute(x,y,w,b):
y_pred = w*x + b
cost = (y - y_pred)**2
cost = cost.sum() / len(x)
return cost
# 示例:当b为0,w为[-100,100]之间,查看cost的值
costs = []
for w in range(-100,101):
cost = compute(x,y,w,0)
costs.append(cost)
costs
153122.113030303,
148678.02909090905,
144299.49606060603,
139986.51393939395,
135739.0827272727,
131557.20242424242,
127440.87303030302,
123390.09454545454,
119404.86696969694,
115485.19030303031,
111631.06454545453,
107842.4896969697,
104119.46575757574,
100461.99272727272,
96870.0706060606,
93343.69939393939,
89882.87909090909,
86487.60969696968,
83157.8912121212,
79893.72363636363,
76695.10696969698,
73562.04121212121,
70494.52636363637,
67492.56242424241,
64556.149393939384,
61685.28727272727,
58879.97606060606,
56140.21575757576,
53466.00636363636,
50857.34787878788,
48314.2403030303,
45836.68363636363,
43424.67787878788,
41078.22303030303,
38797.31909090909,
36581.96606060606,
34432.16393939394,
32347.91272727273,
30329.21242424242,
28376.063030303034,
26488.464545454543,
24666.41696969697,
22909.920303030303,
21218.97454545455,
19593.5796969697,
18033.73575757576,
16539.44272727273,
15110.700606060605,
13747.509393939394,
12449.869090909091,
11217.779696969697,
10051.241212121213,
8950.253636363635,
7914.81696969697,
6944.931212121212,
6040.596363636363,
5201.812424242424,
4428.579393939393,
3720.8972727272726,
3078.7660606060604,
2502.1857575757576,
1991.1563636363637,
1545.6778787878789,
1165.750303030303,
851.3736363636364,
602.547878787879,
419.27303030303034,
301.54909090909086,
249.37606060606066,
262.75393939393945,
341.6827272727273,
486.1624242424242,
696.1930303030302,
971.7745454545454,
1312.9069696969698,
1719.590303030303,
2191.8245454545454,
2729.609696969697,
3332.9457575757574,
4001.8327272727265,
4736.270606060605,
5536.259393939394,
6401.7990909090895,
7332.889696969696,
8329.531212121212,
9391.723636363637,
10519.466969696969,
11712.761212121211,
12971.606363636363,
14296.002424242422,
15685.949393939394,
17141.44727272727,
18662.49606060606,
20249.095757575757,
21901.24636363636,
23618.94787878788,
25402.200303030302,
27251.003636363635,
29165.357878787876,
31145.26303030303,
33190.71909090909,
35301.72606060606,
37478.283939393936,
39720.39272727272,
42028.05242424242,
44401.26303030303,
46840.024545454544,
49344.33696969697,
51914.20030303031,
54549.61454545455,
57250.5796969697,
60017.09575757576,
62849.16272727272,
65746.7806060606,
68709.9493939394,
71738.6690909091,
74832.9396969697,
77992.76121212122,
81218.13363636362,
84509.05696969696,
87865.53121212122,
91287.55636363637,
94775.13242424243,
98328.25939393938,
101946.93727272727,
105631.16606060608,
109380.94575757576,
113196.27636363637,
117077.15787878788,
121023.5903030303,
125035.57363636365,
129113.1078787879,
133256.19303030302,
137464.82909090907,
141739.01606060608,
146078.75393939397,
150484.04272727275,
154954.88242424242,
159491.273030303,
164093.21454545457,
168760.70696969697,
173493.7503030303,
178292.34454545454,
183156.48969696971,
188086.18575757576,
193081.4327272727,
198142.23060606062,
203268.57939393944,
208460.4790909091,
213717.9296969697,
219040.9312121212,
224429.48363636364,
229883.586969697,
235403.24121212118,
240988.44636363635,
246639.20242424245]
# 画出cost function直线图
import matplotlib.pyplot as plt
# plt.scatter(range(-100,101),costs)
# plt.show()
plt.plot(range(-100,101),costs)
plt.title("cost function b=0,w=[-100,100]")
plt.xlabel("w")
plt.ylabel("cost")
plt.show()
# 计算出 w,b均在[-100,100]之间的3维cost function值
import numpy as np
ws = np.arange(-100,101)
bs = np.arange(-100,101)
costs = np.zeros((201,201))
i = 0
for w in ws :
j = 0
for b in bs :
cost = compute(x,y,w,b)
costs[i,j]=cost
j = j+1
i = i+1
costs
array([[543097.74787879, 541777.28121212, 540458.81454545, ...,
320651.34787879, 319726.88121212, 318804.41454545],
[534727.50939394, 533416.80636364, 532108.10333333, ...,
314214.30939394, 313299.60636364, 312386.90333333],
[526422.82181818, 525121.88242424, 523822.9430303 , ...,
307842.82181818, 306937.88242424, 306034.9430303 ],
...,
[164229.90787879, 164842.64121212, 165457.37454545, ...,
324557.10787879, 325565.84121212, 326576.57454545],
[168838.74939394, 169461.24636364, 170085.74333333, ...,
331099.14939394, 332117.64636364, 333138.14333333],
[173513.14181818, 174145.40242424, 174779.6630303 , ...,
337706.74181818, 338735.00242424, 339765.2630303 ]])
# 导入中文包
!pip install wget
import wget
wget.download("https://github.com/GrandmaCan/ML/raw/main/Resgression/ChineseFont.ttf")
# 画出 w,b均在[-100,100]之间的3维cost function图
import matplotlib as mlp
# 引入上面下载得到中文包
from matplotlib.font_manager import fontManager
fontManager.addfont("ChineseFont.ttf")
mlp.rc('font',family="ChineseFont")
plt.figure(figsize=(7,7)) # 设定3维图大小
ax = plt.axes(projection="3d") # 3D图
ax.view_init(45,-120) # 图像旋转
ax.xaxis.set_pane_color((1,1,1)) # X轴color
ax.yaxis.set_pane_color((1,1,1)) # Y轴color
ax.zaxis.set_pane_color((1,1,1)) # Z轴color
b_grid,w_grid = np.meshgrid(bs,ws) # 二维网格
# 3D图展示
ax.plot_surface(w_grid,b_grid,costs,cmap="Spectral_r",alpha=0.7) #绘制曲面图
ax.plot_wireframe(w_grid,b_grid,costs,color="white",alpha=0.1)
ax.set_title("w、b 对应的 cost")
ax.set_xlabel("w")
ax.set_ylabel("b")
ax.set_zlabel("cost")
# 寻找cost最小值
w_index,b_index = np.where(costs==np.min(costs))
ax.scatter(ws[w_index],bs[b_index],costs[w_index,b_index],color="black",s=40)
plt.show()
print(f"当w={ws[w_index]},b={bs[b_index]},有最小cost值{costs[w_index,b_index]}")
matplotlib.pyplot
简介matplotlib.pyplot
是 Python 编程语言中 matplotlib
库的一个模块,它提供了一个类似于 MATLAB 的绘图框架。matplotlib
是 Python 中最受欢迎和广泛使用的绘图库之一,它能够生成高质量的二维图形。使用 pyplot
接口,用户可以轻松地制作线图、散点图、柱状图、饼图、直方图、误差图、箱型图等多种图表。
简单易用:
pyplot
提供了一组类似于 MATLAB 的命令式函数,使得用户能够方便地创建图表和修改图表元素。高度可定制:
多种图表类型:
交互式环境:
pyplot
可以生成交互式图表,用户可以放大、缩小或平移图表。集成其他库:
matplotlib.pyplot
与 Pandas、NumPy 等数据处理库良好集成,可以直接处理这些库的数据结构。保存和导出:
下面是一个简单的使用 matplotlib.pyplot
绘制线图的示例:
import matplotlib.pyplot as plt
# 准备数据
x = [1, 2, 3, 4, 5]
y = [1, 4, 9, 16, 25]
# 创建图表
plt.plot(x, y)
# 添加标题和标签
plt.title("Simple Plot")
plt.xlabel("x")
plt.ylabel("y")
# 显示图表
plt.show()
这段代码会创建一个简单的线图,显示 x
和 y
数值的关系。
matplotlib.pyplot
是 Python 中用于数据可视化的强大工具,适用于科学计算、数据分析、机器学习等多个领域。它的灵活性和易用性使其成为数据可视化的首选工具之一。通过 pyplot
,即使是初学者也能快速地创建漂亮且有表现力的图表。
在 matplotlib
中,您可以使用 mpl_toolkits.mplot3d
模块来创建三维图形。matplotlib.pyplot
与此模块结合,使得绘制三维图形变得简单。以下是绘制三维图形的一些常用语法和技巧:
首先,您需要导入 matplotlib.pyplot
和 mplot3d
模块:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
在绘图之前,您需要创建一个带有3D坐标轴的图形:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
add_subplot(111, projection='3d')
创建了一个3D坐标轴的实例。
3D线图(Line plot):
使用 plot
方法绘制三维线图。
x = [1, 2, 3, 4, 5]
y = [5, 6, 2, 3, 13]
z = [2, 3, 3, 3, 5]
ax.plot(x, y, z)
3D散点图(Scatter plot):
使用 scatter
方法绘制三维散点图。
ax.scatter(x, y, z)
3D曲面图(Surface plot):
首先需要生成网格数据,然后使用 plot_surface
方法。
import numpy as np
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))
ax.plot_surface(X, Y, Z, cmap='viridis')
这里 cmap='viridis'
是设置曲面的颜色映射。
为了更好地理解图形,您可以设置坐标轴的标签:
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
最后,使用 plt.show()
方法来显示图形:
plt.show()
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
# 创建图形和3D坐标轴
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# 准备数据
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))
# 绘制曲面图
ax.plot_surface(X, Y, Z, cmap='viridis')
# 设置坐标轴标签
ax.set_xlabel('X Axis')
ax.set_ylabel('Y Axis')
ax.set_zlabel('Z Axis')
# 显示图形
plt.show()
这个示例将创建一个三维曲面图,其中曲面表示的是 Z = sin(sqrt(X^2 + Y^2))
函数。
numpy
库中的 np.meshgrid
函数是一个非常有用的工具,它用于生成二维网格,这在数据可视化和数学建模中非常常见,特别是在绘制三维图形时。下面是对 np.meshgrid
的详细介绍:
np.meshgrid
函数接受两个一维数组,并生成两个二维矩阵。这两个矩阵对应于两个数组中所有的 (x, y)
点对。换句话说,它可以帮助你创建一个“网格”,在这个网格上,你可以评估两个变量的函数。
np.meshgrid
的主要参数是两个一维数组。例如,np.meshgrid(bs, ws)
中的 bs
和 ws
就是这样的数组。
bs
:第一个输入数组。ws
:第二个输入数组。bs
的复制,其中每一行都是 bs
的副本。ws
的复制,其中每一列都是 ws
的副本。假设我们有两个一维数组 bs = [1, 2, 3]
和 ws = [4, 5, 6, 7]
,并且我们想要创建一个网格来评估某个函数 f(b, w)
在这个网格上的值。
import numpy as np
bs = [1, 2, 3]
ws = [4, 5, 6, 7]
B, W = np.meshgrid(bs, ws)
print("B =", B)
print("W =", W)
输出将是:
B = [[1 2 3]
[1 2 3]
[1 2 3]
[1 2 3]]
W = [[4 4 4]
[5 5 5]
[6 6 6]
[7 7 7]]
在这个例子中,B
和 W
就形成了一个网格,你可以在每个 (b, w)
点上计算函数 f
的值。
在绘制三维图形时,np.meshgrid
尤其有用。例如,如果你想绘制一个由 bs
和 ws
变量定义的表面,你可以使用 np.meshgrid
来生成这个表面上每一点的坐标。然后,你可以使用这些坐标来计算函数值,并使用例如 matplotlib
的 plot_surface
函数来绘制这个表面。
总的来说,np.meshgrid
是一个在多维空间进行工作时不可或缺的工具,它让你能够方便地处理和可视化高维数据。
如下是一个使用 Python 和 Matplotlib 库,通过 np.meshgrid
创建三维曲面图的完整示例代码。这个示例展示了如何用二元高斯函数生成三维曲面:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 定义一维数组
bs = np.linspace(-5, 5, 30)
ws = np.linspace(-5, 5, 30)
# 使用 np.meshgrid 生成网格
B, W = np.meshgrid(bs, ws)
# 定义一个二元高斯函数
def gaussian(b, w):
return np.exp(-(b**2 + w**2))
# 计算函数值
Z = gaussian(B, W)
# 创建图形和3D坐标轴
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# 绘制曲面图
ax.plot_surface(B, W, Z, cmap='viridis')
# 设置坐标轴标签
ax.set_xlabel('B Axis')
ax.set_ylabel('W Axis')
ax.set_zlabel('Z Axis')
# 显示图形
plt.show()
这段代码首先定义了两个一维数组 bs
和 ws
,它们分别表示两个变量的范围。然后,使用 np.meshgrid
函数生成一个网格,这个网格上的每个点对应于 (b, w)
坐标。接下来,我们定义了一个高斯函数,并在网格上的每个点计算这个函数的值。最后,使用 Matplotlib 的 plot_surface
方法绘制了这个函数的三维曲面图。可以运行这段代码以可视化二元高斯函数的曲面。