本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下
简介
本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4
代码
''' 梯度下降算法 Batch Gradient Descent Stochastic Gradient Descent SGD ''' __author__ = 'epleone' import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import sys # 使用随机数种子, 让每次的随机数生成相同,方便调试 # np.random.seed(111111111) class GradientDescent(object): eps = 1.0e-8 max_iter = 1000000 # 暂时不需要 dim = 1 func_args = [2.1, 2.7] # [w_0, .., w_dim, b] def __init__(self, func_arg=None, N=1000): self.data_num = N if func_arg is not None: self.FuncArgs = func_arg self._getData() def _getData(self): x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5) b_1 = np.ones((self.data_num, 1), dtype=np.float) # x = np.concatenate((x, b_1), axis=1) self.x = np.concatenate((x, b_1), axis=1) def func(self, x): # noise太大的话, 梯度下降法失去作用 noise = 0.01 * np.random.randn(self.data_num) + 0 w = np.array(self.func_args) # y1 = w * self.x[0, ] # 直接相乘 y = np.dot(self.x, w) # 矩阵乘法 y += noise return y @property def FuncArgs(self): return self.func_args @FuncArgs.setter def FuncArgs(self, args): if not isinstance(args, list): raise Exception( 'args is not list, it should be like [w_0, ..., w_dim, b]') if len(args) == 0: raise Exception('args is empty list!!') if len(args) == 1: args.append(0.0) self.func_args = args self.dim = len(args) - 1 self._getData() @property def EPS(self): return self.eps @EPS.setter def EPS(self, value): if not isinstance(value, float) and not isinstance(value, int): raise Exception("The type of eps should be an float number") self.eps = value def plotFunc(self): # 一维画图 if self.dim == 1: # x = np.sort(self.x, axis=0) x = self.x y = self.func(x) fig, ax = plt.subplots() ax.plot(x, y, 'o') ax.set(xlabel='x ', ylabel='y', title='Loss Curve') ax.grid() plt.show() # 二维画图 if self.dim == 2: # x = np.sort(self.x, axis=0) x = self.x y = self.func(x) xs = x[:, 0] ys = x[:, 1] zs = y fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(xs, ys, zs, c='r', marker='o') ax.set_xlabel('X Label') ax.set_ylabel('Y Label') ax.set_zlabel('Z Label') plt.show() else: # plt.axis('off') plt.text( 0.5, 0.5, "The dimension(x.dim > 2) \n is too high to draw", size=17, rotation=0., ha="center", va="center", bbox=dict( boxstyle="round", ec=(1., 0.5, 0.5), fc=(1., 0.8, 0.8), )) plt.draw() plt.show() # print('The dimension(x.dim > 2) is too high to draw') # 梯度下降法只能求解凸函数 def _gradient_descent(self, bs, lr, epoch): x = self.x # shuffle数据集没有必要 # np.random.shuffle(x) y = self.func(x) w = np.ones((self.dim + 1, 1), dtype=float) for e in range(epoch): print('epoch:' + str(e), end=',') # 批量梯度下降,bs为1时 等价单样本梯度下降 for i in range(0, self.data_num, bs): y_ = np.dot(x[i:i + bs], w) loss = y_ - y[i:i + bs].reshape(-1, 1) d = loss * x[i:i + bs] d = d.sum(axis=0) / bs d = lr * d d.shape = (-1, 1) w = w - d y_ = np.dot(self.x, w) loss_ = abs((y_ - y).sum()) print('\tLoss = ' + str(loss_)) print('拟合的结果为:', end=',') print(sum(w.tolist(), [])) print() if loss_ < self.eps: print('The Gradient Descent algorithm has converged!!\n') break pass def __call__(self, bs=1, lr=0.1, epoch=10): if sys.version_info < (3, 4): raise RuntimeError('At least Python 3.4 is required') if not isinstance(bs, int) or not isinstance(epoch, int): raise Exception( "The type of BatchSize/Epoch should be an integer number") self._gradient_descent(bs, lr, epoch) pass pass if __name__ == "__main__": if sys.version_info < (3, 4): raise RuntimeError('At least Python 3.4 is required') gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1]) # gd = GradientDescent([1.2, 1.4, 2.1]) print("要拟合的参数结果是: ") print(gd.FuncArgs) print("===================\n\n") # gd.EPS = 0.0 gd.plotFunc() gd(10, 0.01) print("Finished!")
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。
标签:
python,梯度下降
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件!
如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
白云城资源网 Copyright www.dyhadc.com
暂无“python梯度下降算法的实现”评论...
《魔兽世界》大逃杀!60人新游玩模式《强袭风暴》3月21日上线
暴雪近日发布了《魔兽世界》10.2.6 更新内容,新游玩模式《强袭风暴》即将于3月21 日在亚服上线,届时玩家将前往阿拉希高地展开一场 60 人大逃杀对战。
艾泽拉斯的冒险者已经征服了艾泽拉斯的大地及遥远的彼岸。他们在对抗世界上最致命的敌人时展现出过人的手腕,并且成功阻止终结宇宙等级的威胁。当他们在为即将于《魔兽世界》资料片《地心之战》中来袭的萨拉塔斯势力做战斗准备时,他们还需要在熟悉的阿拉希高地面对一个全新的敌人──那就是彼此。在《巨龙崛起》10.2.6 更新的《强袭风暴》中,玩家将会进入一个全新的海盗主题大逃杀式限时活动,其中包含极高的风险和史诗级的奖励。
《强袭风暴》不是普通的战场,作为一个独立于主游戏之外的活动,玩家可以用大逃杀的风格来体验《魔兽世界》,不分职业、不分装备(除了你在赛局中捡到的),光是技巧和战略的强弱之分就能决定出谁才是能坚持到最后的赢家。本次活动将会开放单人和双人模式,玩家在加入海盗主题的预赛大厅区域前,可以从强袭风暴角色画面新增好友。游玩游戏将可以累计名望轨迹,《巨龙崛起》和《魔兽世界:巫妖王之怒 经典版》的玩家都可以获得奖励。
更新日志
2024年12月29日
2024年12月29日
- 小骆驼-《草原狼2(蓝光CD)》[原抓WAV+CUE]
- 群星《欢迎来到我身边 电影原声专辑》[320K/MP3][105.02MB]
- 群星《欢迎来到我身边 电影原声专辑》[FLAC/分轨][480.9MB]
- 雷婷《梦里蓝天HQⅡ》 2023头版限量编号低速原抓[WAV+CUE][463M]
- 群星《2024好听新歌42》AI调整音效【WAV分轨】
- 王思雨-《思念陪着鸿雁飞》WAV
- 王思雨《喜马拉雅HQ》头版限量编号[WAV+CUE]
- 李健《无时无刻》[WAV+CUE][590M]
- 陈奕迅《酝酿》[WAV分轨][502M]
- 卓依婷《化蝶》2CD[WAV+CUE][1.1G]
- 群星《吉他王(黑胶CD)》[WAV+CUE]
- 齐秦《穿乐(穿越)》[WAV+CUE]
- 发烧珍品《数位CD音响测试-动向效果(九)》【WAV+CUE】
- 邝美云《邝美云精装歌集》[DSF][1.6G]
- 吕方《爱一回伤一回》[WAV+CUE][454M]