ChatGPTのFunction callingで複数機能を実装してみた

スポンサードリンク



こんにちは。sinyです。

OpenAIのChatGPT APIに追加されたFunction calling機能で実際に動作する処理を繰み込んだコードを試作してみましたのでサンプルコードをご紹介します。

サンプルコードは以下リンクにありますので必要に応じてご利用ください。

 

以下のようなデモアプリも作成してみたので気になる方はチェックしてみてください。

実装するFunction機能

 

今回は以下の2つの機能をFunctionとして定義します。

Function機能
  •  DB問合せ機能
    SqliteDBのサンプルDB(アルバムの売上データ)に対して自然言語で命令を投げる機能
  • 天気機能
    Openweathermap のAPIをCALLして都道府県の天気情報(天気、最低気温、最高気温)を返す機能。

 

SqliteDBのサンプルデータベースはこちらのページからダウンロードできます。

このサンプルデータベースにはアルバムやアーティスト、アルバムの売上データなどが格納されています。

天気情報の取得には無料で使えるOpenWeatherMapを利用します。

OpenAIのAPIキーとOpenWeatherMapのAPIキーがそれぞれ必要ですので準備してください(手順は割愛)。

 

事前準備

jupyter notebook上で実装するのでまずnotebookをインストールします。

※Google CoraboなどでもOKです。

pip install jupyter notebook openai requests

 

 

SqliteDBのFunction実装

まず最初に実装するfunctionの定義を提示しておきます。

# データベーススキーマを関数に挿入していることに注意。
# これはモデルが知るべき重要な情報 
functions=[
    {
        "name": "ask_database",
        "description": "この関数を使用して、音楽に関するユーザーの質問に答えます。出力は完全に形成されたSQLクエリでなければなりません。",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                        ユーザーの質問に答えるための情報を抽出するSQLクエリ。
                        SQLは、以下のデータベーススキーマを使って書かなければならない:
                        データベーススキーマ:{database_schema_string}
                        クエリーはJSONではなく、プレーンテキストで返す必要があります。
                    """,
                },
            },
            "required": ["query"],
        },
    },

    {
        "name": "ask_weather",
        "description": "この関数を使用して特定の都市の天気情報の質問に答えます。都市名のプレーンテキストでなければなりません。",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                        ユーザーの質問に答えるため、知りたい天気情報の都市名を返さなければならない。
                        例:
                        【質問】:「東京都のお天気情報を教えてください。」
                        【回答】:東京都
                    """,
                },
            },
            "required": ["query"],
        },
    },

]

 

DB問合せ用のfunction名「ask_database」を以下の様に定義しています。

ask_databaseの定義
  • name:ask_database
  • description:「この関数を使用して、音楽に関するユーザーの質問に答えます。出力は完全に形成されたSQLクエリでなければなりません。」
  • properties:stringタイプ、description:「ユーザーの質問に答えるための情報を抽出するSQLクエリ。
    SQLは、以下のデータベーススキーマを使って書かなければならない:
    データベーススキーマ:{database_schema_string}
    クエリーはJSONではなく、プレーンテキストで返す必要があります。」

ポイントはpropertiesのdescriptionに「データベーススキーマ:{database_schema_string}」を与えることで、より正確なSQL文をChatGPTに生成させるようにしている点です。

DBスキーマ情報を取得する

 

まず、DBスキーマ情報を取得するコードを追加します。

import openai
import os
import requests
import sqlite3
import json

#SqliteDB接続
conn = sqlite3.connect("chinook.db") 
print("Opned database successfully")

# chinook.dbからテーブル情報を取得
## dbから情報を収集する関数
def get_table_names(conn):
    """テーブル名のリストを返す"""
    table_names = []
    tables = conn.execute("select name from sqlite_master where type='table';")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names

def get_column_names(conn, table_name):
    """テーブルのカラム名のリストを返す"""
    column_names = []
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    for col in columns:
        column_names.append(col[1])
    return column_names

def get_database_info(conn):
    """テーブル名とカラムの情報を辞書のリストで返す"""
    table_dicts = []
    for table_name in get_table_names(conn):
        columns_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": columns_names})
    return table_dicts

database_schema_dict = get_database_info(conn)

database_schema_string = "\n".join(
    f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
    for table in database_schema_dict
)

print(database_schema_string)

 

上記コードを実行するとdatabase_schema_stringに以下のようなDBスキーマの情報が格納されます。

Table: albums
Columns: AlbumId, Title, ArtistId
Table: sqlite_sequence
Columns: name, seq
Table: artists
Columns: ArtistId, Name
Table: customers
Columns: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId
Table: employees
Columns: EmployeeId, LastName, FirstName, Title, ReportsTo, BirthDate, HireDate, Address, City, State, Country, PostalCode, Phone, Fax, Email
Table: genres
Columns: GenreId, Name
Table: invoices
Columns: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity, BillingState, BillingCountry, BillingPostalCode, Total
Table: invoice_items
Columns: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity
Table: media_types
Columns: MediaTypeId, Name
Table: playlists
Columns: PlaylistId, Name
Table: playlist_track
Columns: PlaylistId, TrackId
Table: tracks
Columns: TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, UnitPrice
Table: sqlite_stat1
Columns: tbl, idx, stat

 

会話情報を格納するクラスの定義

会話情報を格納する簡易的なクラスを以下の通り定義しておきます。

class Conversation:
    
    # 会話履歴を空のリストで初期化
    def __init__(self):
        self.conversation_history = []
        #self.add_message("system", agent_system_message)
        
    # 会話履歴にメッセージを追加するメソッド    
    def add_message(self, role, content):
        message = {"role": role, "content":content}
        self.conversation_history.append(message)

 

ChatGPTリクエスト関数の定義

次にChatGPT APIをCallする関数を定義します。
モデルには「gpt-3.5-turbo-0613」を指定しています。

 

def chat_completion_request(messages, functions, model="gpt-3.5-turbo-0613"):
    openai.api_key = "自分のKeyを設定"
    try:
        response = openai.ChatCompletion.create(
            model=model,
            messages=messages,
            functions= functions,
            function_call="auto",
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response1")
        print(f"Exception:{e}")
        return e
    
    
def ask_database(conn, query):
    """
    指定されたSQLクエリでSQliteデータベースに問い合わせを行う関数。
    
    Return:
    conn(sqlite3.Connection):SQLiteデータベースへの接続オブジェクト。
    query(str):データベースに対して実行するSQLクエリを含む文字列。
    
    Return:
    list:  クエリの結果を含むタプルのリスト。
    
    Raises:
    Exception:SQLクエリーの実行に問題があった場合。
    """
    try:
        # 与えられたSQLiteデータベース接続オブジェクトに対してSQLクエリを実行し、すべての結果を取得
        results = conn.execute(query).fetchall()
        return results
    except Exception as e:
        # SQLクエリの実行に失敗した場合、エラーメッセージとともに例外を発生
        raise Exception(f"SQL error: {e}")



def chat_completion_with_function_execution(messages, functions):
   """
   この関数はChatCompletion APIコールを行い、関数コールが要求された場合、その関数を実行する。
   
   Parameters:
   messages(list): 会話履歴を表す文字列のリスト
   functions(list): モデルから呼び出すことができる関数を表す辞書のリスト(オプション)
   
   Return:
   dict: ChatCompletion API呼び出しによるレスポンス、または関数呼び出しの結果を含む辞書。
   """ 
   
   try:
       response = chat_completion_request(messages, functions)
       #full_message = response.json()["choices"][0]
       full_message = response["choices"][0]
       print("full_message=", full_message)
       if full_message["finish_reason"] == "function_call":
           print(f"Function generation requested, calling function")
           return call_function(messages, full_message,functions)
       else:
           print(f"function not required, responding to user")
           #return response.json()
           return response
    
   except Exception as e:
       print("Unable to generate ChatCompletion response2")
       print(f"Exception:{e}")
       return e



def call_function(messages, full_message,functions):
    """
    excecutes function calls using model generated function arguments.
    """
    if full_message["message"]["function_call"]["name"] == "ask_database":
        query = eval(full_message["message"]["function_call"]["arguments"])
        print(f"Prepped query is {query}")
        print("query['query']=", query["query"])
        try:
            conn = sqlite3.connect("../db/chinook.db") 
            results = ask_database(conn, query["query"])
        except Exception as e:
            print(e)
            
            # if there is an error in the query, try to fix it with a subsequent call
            messages.append(
                {
                    "role": "system",
                    "content": f"""Query: {query['query']}
                    The previous query received the error {e}.
                    Please return a fixed SQL query in plain text.
                    Your response should consist of ONLY the SQL query with the separator sql_start at the beginning and sql_end at the end""",
                }
            )
            response = chat_completion_request(messages, functions)
            try:
                cleaned_query = response["choices"][0]["message"]["content"].split("sql_start")[1]
                cleaned_query = cleaned_query.split("sql_end")[0]
                print(cleaned_query)
                results = ask_database(conn, cleaned_query)
                print(results)
                print("Got on second try")
            except Exception as e:
                print("Second failure, exiting")
                print(f"Function execution failed")
                print(f"Error message: {e}")
                
        messages.append(
            {"role": "function", "name": "ask_database", "content": str(results)}
        )
        try:
            response = chat_completion_request(messages, functions)
            return response
        except Exception as e:
            print(type(e))
            print(e)
            raise Exception("Function chat request failed")

    else:
        raise Exception("Function does not exist and cannot be called")

 

functionsは以下の様に定義しておきます。

# データベーススキーマを関数に挿入していることに注意。
# これはモデルが知るべき重要な情報 
functions=[
    {
        "name": "ask_database",
        "description": "この関数を使用して、音楽に関するユーザーの質問に答えます。出力は完全に形成されたSQLクエリでなければなりません。",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                        ユーザーの質問に答えるための情報を抽出するSQLクエリ。
                        SQLは、以下のデータベーススキーマを使って書かなければならない:
                        データベーススキーマ:{database_schema_string}
                        クエリーはJSONではなく、プレーンテキストで返す必要があります。
                    """,
                },
            },
            "required": ["query"],
        },
    }

]

 

クエリの実行(自然言語でDB問合せ)

試しに「2009年のInvoiceデータは何件ありますか?」というQueryを投げてみます。

処理フローとしては以下の様になります。

処理フロー
  1. 自然言語で問合せ(2009年のInvoiceデータは何件ありますか?)
  2. LLMモデル(gpt-3.5-turbo-0613)から「ask_database関数の実行が必要だよ」というレスポンスが返ってくる。
    ※この際にfunction_callのargumentsにask_database関数に渡す引数(実行するSQL文)が返ってきます。
  3. ask_database関数の引数に2のargmentsに返された値(SQL文)を渡しSQliteDBに対してクエリを実行します。
  4. 3の結果を再度ChatGPTモデルに投げて最終的な回答を生成する。

 

実行結果は以下の様になります。

sql_conversation.add_message("user", "2009年のInvoiceデータは何件ありますか?") 
chat_response = chat_completion_with_function_execution(sql_conversation.conversation_history, functions)  
try:
    assistant_message = chat_response["choices"][0]["message"]["content"]
    print(assistant_message)
except Exception as e:
    print(e)
    print(chat_response)

 

2009年のInvoiceデータは83件あります。

 

chat_responseには以下のようなレスポンスが返ってきます。

full_message= {
  "index": 0,
  "message": {
    "role": "assistant",
    "content": null,
    "function_call": {
      "name": "ask_database",
      "arguments": "{\n  \"query\": \"SELECT COUNT(*) FROM invoices WHERE InvoiceDate LIKE '2009%'\"\n}"
    }
  },
  "finish_reason": "function_call"
}
Function generation requested, calling function
Prepped query is {'query': "SELECT COUNT(*) FROM invoices WHERE InvoiceDate LIKE '2009%'"}
query['query']= SELECT COUNT(*) FROM invoices WHERE InvoiceDate LIKE '2009%'

 

以下のSQL文が自動生成されて実行結果が返されていることがわかります。

SELECT COUNT(*) FROM invoices WHERE InvoiceDate LIKE '2009%'

 

もうちょっと複雑なSQL文が生成されるような質問をしてみます。

sql_conversation = Conversation()
sql_conversation.add_message("user", "売上が一番多いTOP3の商品情報を教えてください。") 
chat_response = chat_completion_with_function_execution(sql_conversation.conversation_history, functions)
try:
    assistant_message = chat_response["choices"][0]["message"]["content"]
    print(assistant_message)
except Exception as e:
    print(e)
    print(chat_response)

 

実行結果は以下の様になりました。

full_message= {
  "index": 0,
  "message": {
    "role": "assistant",
    "content": null,
    "function_call": {
      "name": "ask_database",
      "arguments": "{\n  \"query\": \"SELECT tracks.Name AS TrackName, albums.Title AS AlbumTitle, artists.Name AS ArtistName, tracks.UnitPrice, SUM(invoice_items.Quantity) AS TotalQuantity, SUM(invoice_items.Quantity * invoice_items.UnitPrice) AS TotalRevenue FROM tracks JOIN albums ON tracks.AlbumId = albums.AlbumId JOIN artists ON albums.ArtistId = artists.ArtistId JOIN invoice_items ON tracks.TrackId = invoice_items.TrackId GROUP BY tracks.TrackId ORDER BY TotalRevenue DESC LIMIT 3;\"\n}"
    }
  },
  "finish_reason": "function_call"
}
Function generation requested, calling function
Prepped query is {'query': 'SELECT tracks.Name AS TrackName, albums.Title AS AlbumTitle, artists.Name AS ArtistName, tracks.UnitPrice, SUM(invoice_items.Quantity) AS TotalQuantity, SUM(invoice_items.Quantity * invoice_items.UnitPrice) AS TotalRevenue FROM tracks JOIN albums ON tracks.AlbumId = albums.AlbumId JOIN artists ON albums.ArtistId = artists.ArtistId JOIN invoice_items ON tracks.TrackId = invoice_items.TrackId GROUP BY tracks.TrackId ORDER BY TotalRevenue DESC LIMIT 3;'}
query['query']= SELECT tracks.Name AS TrackName, albums.Title AS AlbumTitle, artists.Name AS ArtistName, tracks.UnitPrice, SUM(invoice_items.Quantity) AS TotalQuantity, SUM(invoice_items.Quantity * invoice_items.UnitPrice) AS TotalRevenue FROM tracks JOIN albums ON tracks.AlbumId = albums.AlbumId JOIN artists ON albums.ArtistId = artists.ArtistId JOIN invoice_items ON tracks.TrackId = invoice_items.TrackId GROUP BY tracks.TrackId ORDER BY TotalRevenue DESC LIMIT 3;
売上が一番多いTOP3の商品情報は以下の通りです:

1. 商品名: The Woman King
   アルバム: Battlestar Galactica, Season 3
   アーティスト: Battlestar Galactica
   単価: 1.99
   数量: 2
   売上: 3.98

2. 商品名: The Fix
   アルバム: Heroes, Season 1
   アーティスト: Heroes
   単価: 1.99
   数量: 2
   売上: 3.98

3. 商品名: Walkabout
   アルバム: Lost, Season 1
   アーティスト: Lost
   単価: 1.99
   数量: 2
   売上: 3.98

以上が売上が一番多いTOP3の商品情報です。

 

functionsで定義するpropertiesdescriptionを以下の様に定義しておくことでChatGPTがSQLを生成する際にデータベーススキーマの情報を文脈として与えることができ、精度が高いSQL文の自動生成が可能になるようですね。

「ユーザーの質問に答えるための情報を抽出するSQLクエリ。
SQLは、以下のデータベーススキーマを使って書かなければならない
データベーススキーマ:{database_schema_string}」

もちろん、不適切なSQLが生成される可能性は0ではないため、自動的にSQLを修正させるような処理を追加しています。
それでもダメな場合は処理をエラーにするといった制御は必要だと思います。

また、実際に業務活用することを想定するならレコード数の上限や実行するクエリの種類の制御なども必要になると思います。

お天気情報Functionの実装

次は、Openweathermap のAPIをCALLして都道府県の天気情報(天気、最低気温、最高気温)を返すfunctionを追加します。

functionsの追加

まず、functionsask_weatherを以下の通り追加します。

# データベーススキーマを関数に挿入していることに注意。
# これはモデルが知るべき重要な情報 
functions=[
    {
        "name": "ask_database",
        "description": "この関数を使用して、音楽に関するユーザーの質問に答えます。出力は完全に形成されたSQLクエリでなければなりません。",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                        ユーザーの質問に答えるための情報を抽出するSQLクエリ。
                        SQLは、以下のデータベーススキーマを使って書かなければならない:
                        データベーススキーマ:{database_schema_string}
                        クエリーはJSONではなく、プレーンテキストで返す必要があります。
                    """,
                },
            },
            "required": ["query"],
        },
    },

    {
        "name": "ask_weather",
        "description": "この関数を使用して特定の都市の天気情報の質問に答えます。都市名のプレーンテキストでなければなりません。",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                        ユーザーの質問に答えるため、知りたい天気情報の都市名を返さなければならない。
                        例:
                        【質問】:「東京都のお天気情報を教えてください。」
                        【回答】:東京都
                    """,
                },
            },
            "required": ["query"],
        },
    },

]

 

APIコール関数の追加

OpenWeatherMapのAPIをコールするask_weather関数を追加します。
ask_weather関数の引数にはChatGPTから返されたargumentsの値を渡すようにします。

def ask_weather(query):
    try:
        response = requests.get("https://api.openweathermap.org/data/2.5/weather",
        params={
            "q": query,

            "appid": "自分のトークン情報をここに記載",
            "units": "metric",
            "lang": "ja",
               }
             )
        response = json.loads(response.text)
        description = response['weather'][0]['description']
        temp_min = response['main']['temp_min']
        temp_max = response['main']['temp_max']
        return "{}の天気は{}です。最低気温は{}℃で、最大気温は{}℃です。".format(query,description,temp_min,temp_max)
      
    except Exception as e:
        # API実行に失敗した場合、エラーメッセージとともに例外を発生
        raise Exception(f"Wheather API error: {e}")

 

call_function関数の修正

call_function関数内でfunction_callnameask_weatherだった場合の分岐処理を追加します。

 

def call_function(messages, full_message,functions):
    """
    excecutes function calls using model generated function arguments.
    """
    if full_message["message"]["function_call"]["name"] == "ask_database":
        query = eval(full_message["message"]["function_call"]["arguments"])
        print(f"Prepped query is {query}")
        print("query['query']=", query["query"])
        try:
            conn = sqlite3.connect("../db/chinook.db") 
            results = ask_database(conn, query["query"])
        except Exception as e:
            print(e)
            
            # if there is an error in the query, try to fix it with a subsequent call
            messages.append(
                {
                    "role": "system",
                    "content": f"""Query: {query['query']}
                    The previous query received the error {e}.
                    Please return a fixed SQL query in plain text.
                    Your response should consist of ONLY the SQL query with the separator sql_start at the beginning and sql_end at the end""",
                }
            )
            response = chat_completion_request(messages, functions)
            try:
                cleaned_query = response["choices"][0]["message"]["content"].split("sql_start")[1]
                cleaned_query = cleaned_query.split("sql_end")[0]
                print(cleaned_query)
                results = ask_database(conn, cleaned_query)
                print(results)
                print("Got on second try")
            except Exception as e:
                print("Second failure, exiting")
                print(f"Function execution failed")
                print(f"Error message: {e}")
                
        messages.append(
            {"role": "function", "name": "ask_database", "content": str(results)}
        )
        try:
            response = chat_completion_request(messages, functions)
            return response
        except Exception as e:
            print(type(e))
            print(e)
            raise Exception("Function chat request failed")

    #ここから下を追加
    elif full_message["message"]["function_call"]["name"] == "ask_weather":

        query = eval(full_message["message"]["function_call"]["arguments"])

        print("query['query']=", query["query"])
        try:
            results = ask_weather(query["query"])
        except Exception as e:
            print(e)
            
        messages.append(
            {"role": "function", "name": "ask_weather", "content": str(results)}
        )
        try:
            response = chat_completion_request(messages, functions)
            return response
        except Exception as e:
            print(type(e))
            print(e)
            raise Exception("Function chat request failed")        

    #ここまでを追加

    else:
        raise Exception("Function does not exist and cannot be called")

 

以上で実装は完了です。

実際にお天気情報について質問してみましょう。

sql_conversation = Conversation() 
sql_conversation.add_message("user", "東京都の今日の天気を教えてください。") 
chat_response = chat_completion_with_function_execution(sql_conversation.conversation_history, functions)
try:
    assistant_message = chat_response["choices"][0]["message"]["content"]
    print(assistant_message)
except Exception as e:
    print(e)
    print(chat_response)

 

実行結果は以下の様になります。

full_message= {
  "index": 0,
  "message": {
    "role": "assistant",
    "content": null,
    "function_call": {
      "name": "ask_weather",
      "arguments": "{\n  \"query\": \"\u6771\u4eac\u90fd\"\n}"
    }
  },
  "finish_reason": "function_call"
}
Function generation requested, calling function
query['query']= 東京都
東京都の今日の天気は雲です。最低気温は25.79℃で、最高気温は29.88℃です。

 

OpenweathermapのAPIをコールしてちゃんと結果が返されてますね。

おまけ

今回のfunction機能を活用したデモアプリも作成してみました。
以下にリポジトリがありますので気になる方はぜひ試してみてください。

 

 

 

 

おすすめの記事