はじめに
前回の記事の続きです。今回は作成した3Dオブジェクトを動かすスクリプトを書いていきます。
エージェントのスクリプトを記述する
今回はカートポールのカートに力を加えてポールを制御していきます。つまり、カートが学習させるエージェントになります。
まず初めにScriptフォルダにCartPoleAgentというC# Scriptを作成してください。
エージェントのスクリプトではAgentクラスを継承し、以下の関数を記述します。
- Initialize() : 初期値を取得
- CollectObservations() : 観測データをセンサーに送る
- OnActionReceived() : 行動する
- OnEpisodeBegin() : ステップ開始の条件を決定
- Heuristic() : キーボードから操作する
- SetResetParameters() : パラメータをリセット
CartPoleAgent.csの全体
はじめにスクリプトの全体のコードを載せます。
usingUnityEngine;usingUnity.MLAgents;usingUnity.MLAgents.Sensors;usingSystem.Collections;usingSystem.Collections.Generic;publicclassCartPoleAgent:Agent{//学習用publicGameObjectpole;RigidbodypoleRB;RigidbodycartRB;EnvironmentParametersm_ResetParams;//初期値publicoverridevoidInitialize(){//学習の初期化poleRB=pole.GetComponent<Rigidbody>();cartRB=gameObject.GetComponent<Rigidbody>();m_ResetParams=Academy.Instance.EnvironmentParameters;SetResetParameters();}//センサーにデータを送るpublicoverridevoidCollectObservations(VectorSensorsensor){sensor.AddObservation(gameObject.transform.localPosition.z);sensor.AddObservation(cartRB.velocity.z);sensor.AddObservation(pole.transform.localRotation.eulerAngles.x);sensor.AddObservation(poleRB.angularVelocity.x);}//各ステップでの行動publicoverridevoidOnActionReceived(float[]verctorAction){//カートに力を加えるvaractionZ=200f*Mathf.Clamp(verctorAction[0],-1f,1f);cartRB.AddForce(newVector3(0.0f,0.0f,actionZ),ForceMode.Force);//カートの位置、ポールの角度と角速度floatcart_z=this.gameObject.transform.localPosition.z;floatangle_x=pole.transform.localRotation.eulerAngles.x;//angle_zを-180~180に変換if(180f<angle_x&&angle_x<360f){angle_x=angle_x-360f;}//カートが+-45度いないなら報酬+0.1 それ以外は -1if((-180f<angle_x&&angle_x<-45f)||(45f<angle_x&&angle_x<180f)){SetReward(-1.0f);EndEpisode();}else{SetReward(0.1f);}//カートの位置が-10~10の範囲を超えたら報酬-1if(cart_z<-10f||10f<cart_z){SetReward(-1.0f);EndEpisode();}}//ステップ開始の初期条件を決定publicoverridevoidOnEpisodeBegin(){//エージェントの状態をリセットgameObject.transform.localPosition=newVector3(0f,0f,0f);pole.transform.localPosition=newVector3(0f,2.5f,0f);pole.transform.localRotation=Quaternion.Euler(0f,0f,0f);poleRB.angularVelocity=newVector3(0f,0f,0f);poleRB.velocity=newVector3(0f,0f,0f);//ポールにランダムな傾きを与えるpoleRB.angularVelocity=newVector3(Random.Range(-0.1f,0.1f),0f,0f);SetResetParameters();}//キーボードから操作する場合publicoverridevoidHeuristic(float[]actionsOut){actionsOut[0]=Input.GetAxis("Horizontal");}//ポールの条件をリセットpublicvoidSetPole(){poleRB.mass=m_ResetParams.GetWithDefault("mass",1.0f);pole.transform.localScale=newVector3(0.4f,2f,0.4f);}//パラメータをリセットする関数publicvoidSetResetParameters(){SetPole();}}publicclassCartPoleAgent:AgentAgentクラスを継承します。
publicGameObjectpole;RigidbodypoleRB;RigidbodycartRB;EnviromentParametersm_ResetParams;で PoleオブジェクトやCart, Pole の rigidbody をいれる変数を定義します。Poleオブジェクトはスクリプト外から取得するので public にします。
Initialize()
Initialize()では学習の初期値を取得します。今回はCart と Pole のrigidbody と環境のパラメータです。また最後の行で各パラメータをリセットしています。
publicoverridevoidInitialize(){//学習の初期化poleRB=pole.GetComponent<Rigidbody>();cartRB=gameObject.GetComponent<Rigidbody>();m_ResetParams=Academy.Instance.EnvironmentParameters;SetResetParameters();}CollectObservations()
CollectionObservations()ではエージェントが得た観測情報をセンサーに追加します。
publicoverridevoidCollectObservations(VectorSensorsensor){sensor.AddObservation(gameObject.transform.localPosition.z);sensor.AddObservation(cartRB.velocity.z);sensor.AddObservation(pole.transform.localRotation.eulerAngles.x);sensor.AddObservation(poleRB.angularVelocity.x);}今回は上から、「Cartの位置」、「Cartの速度」、「Poleの角度」、「Poleの角速度」を与えています。
この値の選択は学習モデルによって変わります。この値が適切なのかわからないので、いろいろ試してみてください。
OnActionReceived()
OnActionReceived()では各ステップでのエージェントの動きについて記述します。今回はカートに移動方向の力を加えます。
また、報酬の設定などもここで行います。
//各ステップでの行動publicoverridevoidOnActionReceived(float[]verctorAction){//カートに力を加えるvaractionZ=200f*Mathf.Clamp(verctorAction[0],-1f,1f);cartRB.AddForce(newVector3(0.0f,0.0f,actionZ),ForceMode.Force);//カートの位置、ポールの角度と角速度floatcart_z=this.gameObject.transform.localPosition.z;floatangle_x=pole.transform.localRotation.eulerAngles.x;//angle_zを-180~180に変換if(180f<angle_x&&angle_x<360f){angle_x=angle_x-360f;}//カートが+-45度いないなら報酬+0.1 それ以外は -1if((-180f<angle_x&&angle_x<-45f)||(45f<angle_x&&angle_x<180f)){SetReward(-1.0f);EndEpisode();}else{SetReward(0.1f);}//カートの位置が-10~10の範囲を超えたら報酬-1if(cart_z<-10f||10f<cart_z){SetReward(-1.0f);EndEpisode();}}カートを動かすコードです。入力値がvectorActionになっています。この値を -200~200 に変換し、加える力の大きさにします。この値の幅は適当です。CartのrigidbodyのAddForceをいじることでカートに力を加えることができます。今回はZ方向に力を加えます。
//カートに力を加えるvaractionZ=200f*Mathf.Clamp(verctorAction[0],-1f,1f);cartRB.AddForce(newVector3(0.0f,0.0f,actionZ),ForceMode.Force);エージェント(カート)の報酬に関するコードです。今回はポールの角度が -45°~45° 以内だと報酬として +0.1, それ以外だと報酬を -1.0 与えゲームを終了させます。また、カートが横に動きすぎないようにカートの位置が -10~10 の範囲から出ると -1.0 の報酬を与えゲームを終了させます。
//カートが+-45度いないなら報酬+0.1 それ以外は -1if((-180f<angle_x&&angle_x<-45f)||(45f<angle_x&&angle_x<180f)){SetReward(-1.0f);EndEpisode();}else{SetReward(0.1f);}//カートの位置が-10~10の範囲を超えたら報酬-1if(cart_z<-10f||10f<cart_z){SetReward(-1.0f);EndEpisode();}OnEpisodeBegin()
OnEpisodeBegin()ではゲームの初期条件を決めます。今回はカートは初期位置に戻し、ポールにランダムな傾きを与えます。
//ステップ開始の初期条件を決定publicoverridevoidOnEpisodeBegin(){//エージェントの状態をリセットgameObject.transform.localPosition=newVector3(0f,0f,0f);pole.transform.localPosition=newVector3(0f,2.5f,0f);pole.transform.localRotation=Quaternion.Euler(0f,0f,0f);poleRB.angularVelocity=newVector3(0f,0f,0f);poleRB.velocity=newVector3(0f,0f,0f);//ポールにランダムな傾きを与えるpoleRB.angularVelocity=newVector3(Random.Range(-0.1f,0.1f),0f,0f);SetResetParameters();}Heuristic()
Heuristic()ではキーボード入力でモデルを動かすときに使用します。今回は十字キーの左右の入力に対応しています。
//キーボードから操作する場合publicoverridevoidHeuristic(float[]actionsOut){actionsOut[0]=Input.GetAxis("Horizontal");}Behavior Parametersの設定、実際に動かしてみる
このスクリプトをCartオブジェクトに追加してください。このときPoleにPoleオブジェクトを追加してください。
また、Add Component から、Behavior Parameters と Decision Requester を追加してください。これは強化学習時に必要になります。
Behavior Parameters / Behavior Type でモデルのモードを変更することができます。今回はキーボード入力を使用したいので Heuristic Only にしてください。学習時は Defaults, 学習済みモデルを使用するときは Inference Only にします。(Inference Onlyを使用する場合は学習済みモデルをModelに追加する必要があります。)
以下の画像のようにパラメータを設定してください。
この状態でモデルを実行するとキーボードで操作することができると思います。しっかり、ポールの角度でリセットされていればOKです。
ここまでで学習に必要な準備は完了しました。あとはサンプルと同じように学習されれば完成になります。
学習させる
実際にカートポールを学習させます。学習させるため、Behavior Parameters / Behavior Type を Defaults にしてください。
また学習の効率化のためカートポールを複数台にします。今回は10台にしました。コピペで簡単に増やすことができます。
学習パラメータを設定するYAMLファイルをconfigフォルダに作成してください。3DBallのパラメータを参考にしました。
モデルの学習方法の詳しい解説はこの記事を参考にしてください。
25万回ぐらいで学習が完了します。学習が完了すると CartPole.nn というファイルが作成されています。このファイルを TFModels に追加してください。

TrainingArea(1)~(9)を非アクティブにして学習済みモデルを実行した動画です。ポールを倒さないように制御できていることが確認できます。
まとめ
カートポールのモデルを作成し、学習させませた。ML-Agentsでは基本的にエージェントのスクリプトを書くだけで強化学習モデルを簡単に実装することができます。まだわからないことが多いので知識が増えたらまた記事を書きたいと思います。