1. はじめに
前回はTensorクラスの一部とTensorを加工するIFunctionインターフェースの実装を行いました。今回は、四則演算を行う関数を定義し、それを用いてTensorの四則演算を実装していきたいと思います。
2. 関数クラスについて
関数クラスはこのライブラリにおいて、Tensorを加工する重要なクラスとなります。関数クラスは全てIFunctionから派生したBaseFunctionを継承するようにしてその下に入力と出力の関係から分類した抽象クラスを実装して最後に具体的な関数のクラスを実装します。下の図のような構成です。
UnaryFunctionは一つの入力に対して一つの出力が得られるような関数に対して定義しようと考えています。
SetFunctionは集合に対する関数(minやmax, sum, meanなど)の基底クラスとなる予定です。これらの他にもTensorのShapeを操作したりTensorを二つに分けたりする関数なども考えています。
3. 関数クラスの実装1(BaseFunction)
ここではIFunctionでForwardやBackwardを呼び出された際の具体的な動作を実装していきます。これによって子クラスでは勾配の伝播のための処理を省くことができます。
usingSystem;namespaceRein.Functions{publicabstractclassBaseFunction:IFunction{protectedTensor[]Inputs,Outputs;protectedintUseCount=0;protectedFunc<Tensor[],Tensor[]>FunctionForward;protectedActionFunctionBackward;protectedstringName;publicBaseFunction(stringname){this.Name=name;}publicTensor[]Forward(paramsTensor[]inputs){this.Inputs=inputs;this.Outputs=this.FunctionForward(inputs);this.UseCount=this.Outputs.Length;returnthis.Outputs;}publicvoidBackward(){this.UseCount--;if(this.UseCount!=0)return;this.FunctionBackward();}}}
4. 関数クラスの実装2(BinaryFunction)
ここでは、さらに入力を二つ出力が一つであるような関数の処理を求めています。BaseFunctionで定義したFunctionForwardとFunctionBackwardに対して、二つの入力と一つの出力を行うBinaryForwardとBinaryBackwardを追加し、継承したクラスで実装しやすいようLeft(演算子の左), Right(演算子の右), Outの三つを定義しています。
namespaceRein.Functions{publicabstractclassBinaryFunction:BaseFunction{protectedabstractTensorBinaryForward(Tensortensor1,Tensortensor2);protectedabstractvoidBinaryBackward();protectedTensorLeft{get{returnthis.Inputs[0];}set{this.Inputs[0]=value;}}protectedTensorRight{get{returnthis.Inputs[1];}set{this.Inputs[1]=value;}}protectedTensorOut{get{returnthis.Outputs[0];}set{this.Outputs[0]=value;}}publicBinaryFunction(stringname):base(name){this.FunctionForward=(tensor)=>newTensor[1]{this.BinaryForward(tensor[0],tensor[1])};this.FunctionBackward=this.BinaryBackward;}}}
5. 四則演算クラスの実装
正直四則演算のクラス同士にそこまで差異はありません。一つ実装できたら他も簡単に実装できると思います。二項演算においては二つの入力を持つ関数を考えます。
O = f(L, R)
最終的な出力を$E$とすると$E.Backward()$を行った際にはそれぞれのTensorに対する$E$の勾配が計算されていくので
\frac{\partial E}{\partial L}=\frac{\partial E}{\partial O}\frac{\partial O}{\partial L}
しかし、実際は$L$を入力とするのは$O$だけとは限らないので、$O$に対して添字$i=(1, 2, 3, ..., n)$がつき、
\frac{\partial E}{\partial L}=\sum_{i=1}^{n}\frac{\partial E}{\partial O_i}\frac{\partial O_i}{\partial L}
よって二項演算において実際に行う計算は
L.Backward\leftarrow L.Backward + O.Backward \times \frac{\partial O}{\partial L}
5.1 Add(和)
足し算の場合
O = f(L, R) = L + R
実際にはTensorの配列に対して行うので、
O_i = L_i + R_i
さらに微分は
\frac{\partial O_i}{\partial L_i}=1\\
\frac{\partial O_i}{\partial R_i}=1
したがって実装は
usingSystem.Linq;usingR=System.Double;namespaceRein.Functions.Arithmetic{publicclassAdd:BinaryFunction{publicAdd():base("Add"){}protectedoverrideTensorBinaryForward(Tensorleft,Tensorright){R[]data=newR[left.Size];for(inti=0;i<left.Size;i++){data[i]=left.Data[i]+right.Data[i];}returnnewTensor(data,left.Shape);}protectedoverridevoidBinaryBackward(){for(inti=0;i<this.Left.Size;i++){this.Left.Grad[i]+=this.Out.Grad[i];this.Right.Grad[i]+=this.Out.Grad[i];}}}}
5.2 Sub(差)
引き算の場合
O_i = L_i - R_i
微分は
\frac{\partial O_i}{\partial L_i}=1\\
\frac{\partial O_i}{\partial R_i}=-1
実装は
usingSystem.Linq;usingR=System.Double;namespaceRein.Functions.Arithmetic{publicclassSub:BinaryFunction{publicSub():base("Sub"){}protectedoverrideTensorBinaryForward(Tensorleft,Tensorright){R[]data=newR[left.Size];for(inti=0;i<left.Size;i++){data[i]=left.Data[i]-right.Data[i];}returnnewTensor(data,left.Shape);}protectedoverridevoidBinaryBackward(){for(inti=0;i<this.Left.Size;i++){this.Left.Grad[i]+=this.Out.Grad[i];this.Right.Grad[i]-=this.Out.Grad[i];}}}}usingSystem.Linq;usingR=System.Double;namespaceRein.Functions.Arithmetic{publicclassSub:BinaryFunction{publicSub():base("Sub"){}protectedoverrideTensorBinaryForward(Tensorleft,Tensorright){R[]data=newR[left.Size];for(inti=0;i<left.Size;i++){data[i]=left.Data[i]-right.Data[i];}returnnewTensor(data,left.Shape);}protectedoverridevoidBinaryBackward(){for(inti=0;i<this.Left.Size;i++){this.Left.Grad[i]+=this.Out.Grad[i];this.Right.Grad[i]-=this.Out.Grad[i];}}}}
5.3 Mul(積)
掛け算の場合
O_i = L_i * R_i
微分は
\frac{\partial O_i}{\partial L_i}=R_i\\
\frac{\partial O_i}{\partial R_i}=L_i
よって実装は
usingSystem.Linq;usingR=System.Double;namespaceRein.Functions.Arithmetic{publicclassMul:BinaryFunction{publicMul():base("Mul"){}protectedoverrideTensorBinaryForward(Tensorleft,Tensorright){R[]data=newR[left.Size];for(inti=0;i<left.Size;i++){data[i]=left.Data[i]*right.Data[i];}returnnewTensor(data,left.Shape);}protectedoverridevoidBinaryBackward(){for(inti=0;i<this.Left.Size;i++){this.Left.Grad[i]+=this.Out.Grad[i]*this.Right.Data[i];this.Right.Grad[i]+=this.Out.Grad[i]*this.Left.Data[i];}}}}
5.4 Div(商)
割り算の場合
O_i = L_i / R_i
微分は
\frac{\partial O_i}{\partial L_i}=\frac{1}{R_i}\\
\frac{\partial O_i}{\partial R_i}=-\frac{L_i}{R_i^2}=-\frac{O_i}{R_i}
$O_i$については$Forward$で計算しているので上のようにすることで計算回数を抑えられる。
これを用いて実装を行うと
usingSystem.Linq;usingR=System.Double;namespaceRein.Functions.Arithmetic{publicclassDiv:BinaryFunction{publicDiv():base("Div"){}protectedoverrideTensorBinaryForward(Tensorleft,Tensorright){R[]data=newR[left.Size];for(inti=0;i<left.Size;i++){data[i]=left.Data[i]/right.Data[i];}returnnewTensor(data,left.Shape);}protectedoverridevoidBinaryBackward(){for(inti=0;i<this.Left.Size;i++){this.Left.Grad[i]+=this.Out.Grad[i]/this.Right.Data[i];this.Right.Grad[i]-=this.Out.Grad[i]*this.Out.Data[i]/this.Right.Data[i];}}}}
6. Tensorクラスの実装(四則演算)
Tensorクラスは割と処理の内容が多くなってくると思うので前回使用したファイルとは別のファイルに四則演算を実装します。(そのために前回partialでTensorクラスを実装しました)
演算子の実装では、それぞれの関数のコンストラクタを呼び出すことで計算グラフを構築しながら演算を行うことができるようになりました。
usingRein.Functions.Arithmetic;usingRein.Utils.Exceptions;namespaceRein{publicpartialclassTensor{// 演算子のオーバーロードpublicstaticTensoroperator+(Tensortensor1,Tensortensor2){returnnewAdd().Forward(tensor1,tensor2);}publicstaticTensoroperator-(Tensortensor1,Tensortensor2){returnnewSub().Forward(tensor1,tensor2);}publicstaticTensoroperator-(Tensortensor){returnnull;}publicstaticTensoroperator*(Tensortensor1,Tensortensor2){returnnewMul().Forward(tensor1,tensor2);}publicstaticTensoroperator/(Tensortensor1,Tensortensor2){returnnewDiv().Forward(tensor1,tensor2);}publicstaticimplicitoperatorTensor(Tensor[]tensor1){if(!(tensor1.Length==1))thrownewInvalidLengthException();returntensor1[0];}}}
最後のTensor[]$\rightarrow$Tensorへの変換は実装するか悩みました。しかしこれが無いと出力一つの関数(minやmaxなど)を使用するたびにインデックス0を指定しないといけなくなるので(IFunctionではTensor[]で数値のやりとりを行うため)利便性のためにこれを追加することにしました。
8. コード
ここまでのコードはhttps://github.com/aokyut/Rein/tree/v0.0.2で公開して居ます。現時点では使い物になりませんが、続きを実装してみたい方やこれまでの実装を確認したい方は見てみてください。
7. 終わりに
ということで今回は関数の基底クラスの定義とTensorクラスの四則演算の定義を行いました。機械学習の実用に使えるようなものではありませんがTensor同士で計算できるようにはなりました。次は多次元配列の演算として多用されるドット演算や単項演算あたりを定義したいと思います。