スポンサードリンク



こんにちは。sinyです。

この記事では、djangoでkerasを使ったAIアプリを開発する際によく遭遇する問題点まとめてみました。

計算グラフのリセット忘れ

JupterNotebookでモデルをロード、学習、推論を行っているようなケースでは何も問題がないのに、Django等でWEBアプリケーション化した場合に、突如以下のようなエラーが発生するケースがよくあります。

 Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(18, 50), dtype=float32) is not an element of this graph.

よくあるのが、WEB画面から推論処理の初回実行時は問題がないのに、同じ処理を再実行した場合に上記エラーが発生するといったケースがあります。

これは、古いTFグラフ(計算グラフ)情報が残ってしまっていることで発生するエラーです。

keras.backend.clear_session()

現在のTFグラフを壊し,新たなものを作成します.
古いモデル/レイヤが散らかってしまうを避けるのに役立ちます.

Django等でWEBアプリ化する際は、必ずkeras.backend.clear_session()を記載するようにしましょう。

計算グラフが共有されていないケース

JupterNotebookのように1つのファイル内で処理が完結している場合は問題になりませんが、Djangoで複数モジュール間を行き来してModel情報をやり取りする場合に以下のエラーが発生するケースがよくあります。

 

Tensor Tensor("dense_3/Softmax:0", shape=(?, 18), dtype=float32) is not an element of this graph.

 

具体例で、解決方法を簡単に解説します。

【utils.py内に定義されているモデルロード関数と推論関数をview側で実行する場合】

■views.py

 

ss.load_model()でutils.pyで定義されたScreenShotクラスのload_modelメソッドをCALLしてKerasのモデルをロードしています。
さらに、ロードされたモデル(loaded_model)を使い、utils.py側に定義されたss.predict_modelメソッドを実行し推論結果を取得するような処理になっています。

■utils.py

 

上記のように複数のモジュールファイル(views.py⇔utils.py)間でモデル情報をやり取りする場合、このまま実行すると以下のようなエラーが発生します。

 

Tensor Tensor("xxxxxxxxx) is not an element of this graph.

 

理由は、「TFグラフ情報が共有されていないため」です。

TFグラフを共有するには、以下のような処理を加えてあげればOKです。

  • モデルをロードした後にグラフ情報を変数に格納して戻り値にTFグラフ情報を渡す。
     graph = tf.get_default_graph()
  • 推論時には共有されているTFグラフを使うようにする。
      with graph.as_default():
         model.predict(******)

 

先に記載したviews.pyとutils.pyでTFグラフを共有するようにした場合のコードを以下に記載します。

■views.py

 

■utils.py

 

上記のようにロードしたモデルの計算グラフ情報tf.get_default_graph() を変数graphに格納して、引数として受け渡ししてあげます。

計算グラフを利用する場合は、with graph.as_default(): のようにwith句を使ってモデルの予測を実行すればOKです。

以上、djangoでkerasを使ったAIアプリを開発する際によく遭遇する問題点のまとめでした。

追記情報があれば、随時アップデートしていきます。

 

おすすめの記事