pytorch 踩坑笔记之w.grad.data.zero_()

  在使用pytorch实现多项线性回归中,在grad更新时,每一次运算后都须要将上一次的梯度记录清空,运用以下方法:spa

 w.grad.data.zero_() b.grad.data.zero_() 

   可是,运行程序就会报以下错误:code

  报错,grad没有data这个属性,blog

  缘由是,在系统将w的grad值初始化为none,第一次求梯度计算是在none值上进行报错,天然会没有data属性get

  修改方法:添加一个判断语句,从第二次循环开始执行求导运算class

for i in range(100): y_pred = multi_linear(x_train) loss = getloss(y_pred,y_train) if i != 0: w.grad.data.zero_() b.grad.data.zero_() loss.backward() w.data = w.data - 0.001 * w.grad.data b.data = b.data - 0.001 * b.grad.data
相关文章
相关标签/搜索