一个基本的线性回归模型1
2
3
4
5
6
7
8
9
10class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression,self).__init__()
self.linear = nn.Linear(1,1)#一个输入一个输出,此外隐藏一个偏差,y=ax+b
def forward(self,x):
out = self.linear(x)
return out
model = LinearRegression()
至此一个最简单的pytorch模型就定义完了。