はじめに
今回はロス関数とTensorに対する操作(detach, reshapeなど)を実装したいと思います。
その前に
実装でロス関数をBackward処理を行う際に
Tensor loss = LossFunction(y, target)
loss.Backward();
と呼び出したいのですが、そのままだとBackward処理が連鎖していかない(起点となる処理が必要)ので、Tensorクラスで現在Backward関数としている関数をBackwardChainに変更し新たにBackwardを追加します。
namespaceRein{[Serializable]publicpartialclassTensor{publicvoidBackwardChain(){// BackFunctionが存在しない時は終了if(this.BackFunction==null)return;this.UseCount--;// 他の関数に対しても出力している場合にはまだ勾配を計算しないif(this.UseCount!=0)return;this.BackFunction.Backward();}publicvoidBackward(){ // 一つの変数しか持たないことを確認するif(this.Size>1)thrownewInvalidSizeException($"expect size : 1, but actual : {this.Size}");this.Grad[0]=1.0;this.BackFunction.Backward();}}}Loss関数の実装
それではいくつか主となるLoss関数を実装していきます。基本的にはLossの関数をLambdaで計算した後SumやMeanを計算することになります
MSELoss(二乗誤差)
これは以下のような関数です
Loss_{MSE}=\frac{1}{n}\sum_{i=1}^{n}(y_i-t_i)^2
MSELossの実装
F.cs内部に直接LambdaFunctionとして実装していきます。
namespaceRein{publicstaticclassF{publicstaticTensorMSELoss(TensorIn){returnnewLambda("MSELoss",(x)=>x*x,(x)=>2*x).Forward(In)[0].Mean();}}}HuberLossの実装
HuberLossは以下のような計算を行います。
f_{huber}(x) = \left\{
\begin{array}{ll}
\frac{1}{2}x^2 & (-\delta \leq x \leq \delta) \\
\delta|x|-\frac{1}{2}\delta^2 & (x \lt -\delta \, or\, x \gt \delta)
\end{array}
\right.\\
L_{huber}=\frac{1}{n}\sum_{i=1}^{n}f_{huber}(y_i-t_i)
実装
こちらも同様にF.csに加えていきます。
namespaceRein{publicstaticclassF{publicstaticTensorHuberLoss(Tensorleft,Tensorright,Rdelta=1.0){RdeltaSquare=delta*delta/2;returnnewLambda("HuberLossFunction",newFunc<R,R>((x)=>x<-delta?-delta*x-deltaSquare:(x>delta?delta*x-deltaSquare:x*x/2)),newFunc<R,R>((x)=>x<-delta?-delta:(x>delta?delta:x))).Forward(left-right);}publicstaticTensorHuberLoss(Tensorleft,Tensorright,Rdelta=1.0){RdeltaSquare=delta*delta/2;returnnewLambda("HuberLossFunction",newFunc<R,R>((x)=>x<-delta?-delta*x-deltaSquare:(x>delta?delta*x-deltaSquare:x*x/2)),newFunc<R,R>((x)=>x<-delta?-delta:(x>delta?delta:x))).Forward(left-right);}}}Tensorの操作
次はいくつかTensorの構造に作用する関数を実装していきたいと思います。Tensorの操作を行うメソッドでは基本的にShapeに作用するためDataの中身を変えないため、入力したTensorと同じインスタンスが出力されることとなります。
Detach
これはTensorの依存関係を切り離し、勾配の伝播を止める操作です。要は学習はさせないがネットワークの出力だけ欲しいという時に使う関数です。これをTensorの関数として実装したいのですが、一つ問題があります。
例えば以下のような形式で使用するとします。
Tensor y = network(x).detach();
Tensor z = network(t);
Tensor loss = (y - z) * (y - z);
loss.Backward();
ここでTensor yは独立したBackFuncを持たないTensorとなるのですが、network内部ではxが入力された時に計算グラフが作られ保存されているので、これらの関係を解消するためには一々yからグラフを遡る必要が出てきます。
そのため、残念ながらTensorの操作としてのDetach操作は断念せざるを得ません。
そこで、代わりにBaseFunctionに「勾配情報を保存しないForward」を定義します。これをPredictとします。
実装(IFunction.csの追記)
まずIFunctionに対してPredictを追加します。
namespaceRein.Functions{publicinterfaceIFunction{publicTensor[]Forward(paramsTensor[]inputs);publicTensor[]Predict(paramsTensor[]inputs);publicvoidBackward();publicTensor[]Parameters{get;}}}実装(BaseFuncttion.csの追記)
IFunctionに追加した関数の詳細をBaseFunctionで定義します。
namespaceRein.Functions{publicabstractclassBaseFunction:IFunction{// ...publicvirtualTensor[]Predict(paramsTensor[]inputs){returnthis.FunctionForward(inputs);}// ...}}これを使用することで、学習時に勾配を計算させないようにすることができます。PytorchのようにDetachをTensorの操作として呼び出したいなら、計算グラフの実装方法を変える必要があるようです。
Squeeze・Unsqueeze
SqueezeはTensorのある軸方向のサイズが1の時にその軸を消し次元を減らす操作で、
Unsqueezeは逆に次元を増やす操作です。これらも同様に関数クラスとして実装しTensorから呼び出せるようにしておきます。
Squeezeの実装
namespaceRein.Functions.Process{publicclassSqueeze:UnaryFunction{privateList<int>InShape;privateintDim;publicSqueeze(intdim):base($"Squeeze-{dim}"){this.Dim=dim;}protectedoverrideTensorUnaryForward(Tensortensor){this.InShape=newList<int>(tensor.Shape);if(tensor.Shape[this.Dim]==1)tensor.Shape.RemoveAt(this.Dim);returntensor;}protectedoverridevoidUnaryBackward(){this.In.Shape=this.InShape;}}}Unsqueezeの実装
namespaceRein.Functions.Process{publicclassUnsqueeze:UnaryFunction{privateList<int>InShape;privateintDim;publicUnsqueeze(intdim):base($"Unsqueeze-{dim}"){this.Dim=dim;}protectedoverrideTensorUnaryForward(Tensortensor){this.InShape=newList<int>(tensor.Shape);tensor.Shape.Insert(this.Dim,1);returntensor;}protectedoverridevoidUnaryBackward(){this.In.Shape=this.InShape;}}}Reshape
ReshapeでもSqueezeと同様にTensorのデータは変えずにShapeのみを入れ替えることになります。
実装
namespaceRein.Functions.Process{publicclassReshape:UnaryFunction{privateList<int>OutShape;privateList<int>InShape;publicReshape(List<int>shape):base($"Reshape-({string.Join(",",shape)})"){this.OutShape=shape;}protectedoverrideTensorUnaryForward(Tensortensor){// サイズ確認if(this.OutShape.Aggregate((now,next)=>now*next)!=tensor.Size)thrownewInvalidShapeException($"Expected Output Shape : ({string.Join(",",this.OutShape)}) ,Input Shape :({string.Join(",",tensor.Shape)})");this.InShape=tensor.Shape;tensor.Shape=this.OutShape;returntensor;}protectedoverridevoidUnaryBackward(){this.In.Shape=this.InShape;}}}Tensorクラスへの追加
ここまで実装したクラスのForwardをTensorから実行できるようにしておきます。
namespaceRein{publicpartialclassTensor{publicTensorDetach(){returnnewDetach().Forward(this);}publicTensorSqueeze(intdim){returnnewSqueeze(dim).Forward(this);}publicTensorUnsqueeze(intdim=0){returnnewUnsqueeze(dim).Forward(this);}publicTensorReshape(List<int>shape){returnnewReshape(shape).Forward(this);}}}これでTensor側でいつでも操作できるようになりました。
終わりに
今回は、ロス関数とTensorの操作関数を定義しました。ロス関数は他にもクロスエントロピーとかがよく使うと思いますが、現時点では使わなさそうなので必要になったら実装しようと思います。
次はOptimizerの実装を行います。