博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
反向传播&梯度下降 的直观理解程序(numpy)
阅读量:2135 次
发布时间:2019-04-30

本文共 1052 字,大约阅读时间需要 3 分钟。

 

 

import numpy as npimport mathimport matplotlib.pyplot as plt# Create random input and output datax = np.linspace(-math.pi, math.pi, 2000)y = np.sin(x)plt.scatter(x,y)plt.show()# Randomly initialize weightsa = np.random.randn()b = np.random.randn()c = np.random.randn()d = np.random.randn()learning_rate = 1e-6for t in range(4000):    # Forward pass: compute predicted y    # y = a + b x + c x^2 + d x^3    y_pred = a + b * x + c * x**2 + d * x**3    # Compute and print loss    loss = np.square(y_pred - y).sum()    if t % 100 == 99:        print(t, loss)    # Backprop to compute gradients of a, b, c, d with respect to loss    grad_y_pred = 2.0 * (y_pred - y)    grad_a = grad_y_pred.sum()    grad_b = (grad_y_pred * x).sum()    grad_c = (grad_y_pred * x**2).sum()    grad_d = (grad_y_pred * x**3).sum()    # Update weights    a -= learning_rate * grad_a    b -= learning_rate * grad_b    c -= learning_rate * grad_c    d -= learning_rate * grad_dprint(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')

 

首先记住,权重w更新是 减去 损失函数L 对权重w的求导,即αL/αw

这里a,b,c,d都是权重

 

转载地址:http://sufgf.baihongyu.com/

你可能感兴趣的文章
简单封装FMDB操作sqlite的模板
查看>>
iOS开发中Instruments的用法
查看>>
iOS常用宏定义
查看>>
什么是ActiveRecord
查看>>
有道词典for mac在Mac OS X 10.9不能取词
查看>>
关于“团队建设”的反思
查看>>
利用jekyll在github中搭建博客
查看>>
Windows7中IIS简单安装与配置(详细图解)
查看>>
linux基本命令
查看>>
BlockQueue 生产消费 不需要判断阻塞唤醒条件
查看>>
强引用 软引用 弱引用 虚引用
查看>>
数据类型 java转换
查看>>
"NetworkError: 400 Bad Request - http://172.16.47.117:8088/rhip/**/####t/approval?date=976
查看>>
mybatis 根据 数据库表 自动生成 实体
查看>>
win10将IE11兼容ie10
查看>>
checkbox设置字体颜色
查看>>
第一篇 HelloWorld.java重新学起
查看>>
ORACLE表空间扩张
查看>>
orcal 循环执行sql
查看>>
web.xml配置监听器,加载数据库信息配置文件ServletContextListener
查看>>