目次
こんにちは。sinyです。
この記事ではPytorchで自薦言語処理を行う場合にとても便利なtorchtextの使い方について学習を兼ねて情報をまとめています。
継続してアップデートしていきます。
torchtextとは?
torchtextとはPytorchでテキストデータを扱うためのパッケージです。
torchtextと使うとテキストデータの前処理として行う単語、インデックス辞書の作成や単語語録等を少ないコーディングで非常に簡単に行うことができるので大変便利です。
テキストデータの読み込み
torchtextを使ってテキストデータを読み込む方法を簡単にまとめました。
- torchtextではCSV、TSV、JSON形式のデータを読み込むことができます。
- 読み込むデータのカラムをtorchtext.data.Field()を利用して定義します。
- 読み込んだデータに対して前処理したい場合はtorchtext.data.Field()の引数で各種設定をします。
→tokenizeに個別に定義した前処理の関数を指定すれば個別の前処理を組み込むことができる。 - torchtext.data.TabularDataset.splitsを利用して学習用、訓練用(テスト用)のDatasetを生成することができます。
torchtext.data.Field()の引数には以下のパラメータがあります。
- sequential: データの長さが可変か?(True or False)
- tokenize: 文章を読み込んだときに、前処理や単語分割をするための関数を定義
- use_vocab:単語をボキャブラリーに追加するかどうか
- lower:アルファベットがあったときに小文字に変換するかどうか
- include_length:文章の単語数のデータを保持するか
- batch_first:ミニバッチの次元を先頭に用意するかどうか
- fix_length:全部の文章を指定した長さと同じになるようにpaddingする
sample_train.tsvに以下のようなタブ区切りのテキスト情報が格納されているデータを読み込む手順を示します。
ディープラーニングの基礎について学ぶ。 0
異常検知の手法としてオートエンコーダがよく利用されます。 1
自然言語処理にtorchtextを利用すると便利です。 2
画像認識の手法の1つとしてGANがあります。 3
tsvを読み込む場合の例
※今回はsample_train.tsvとsample_val.tsvは同じ情報が格納されているものとします。
#文章読み込みテスト import torchtext TEXT = torchtext.data.Field() LABEL = torchtext.data.Field() train_ds, test_ds = torchtext.data.TabularDataset.splits( path='./data/', train='sample_train.tsv', validation='sample_val.tsv', format='tsv', fields=[('Text', TEXT), ('Label', LABEL)])
一番簡素な例として、以下のようにtorchtext.data.Field()を使って読み込むデータのカラムを定義します。
TEXT = torchtext.data.Field() LABEL = torchtext.data.Field()
※各種オプションを指定して前処理を追加したい場合はFIled()内にオプションを指定するだけです。
今回は、文章の部分「ディープラーニングの基礎について学ぶ。」をTEXT、タブ区切りの2つ目の番号部分「0」をLABELとして定義しています。
次にtorchtext.data.TabularDataset.splitsを利用して、学習用と訓練用データセットを生成しています。
torchtext.data.TabularDataset.splitsにもいくつか指定できるパラメータがあります。
- path:読み込むデータファイルのルートパスを指定
- train:学習データファイルを指定
- validation:検証データファイルを指定
- test:テストファイルを指定
- fields:読み込むデータのフィールドをタプルのリスト形式で指定
今回は、学習データと検証データの2つだけ指定しています。
また、読み込むデータはテキスト部分とラベル情報(数字)の2つのフィールドを持っているため、fieldsオプションで以下のように設定します。
fields=[('Text', TEXT), ('Label', LABEL)])
読み込んだデータを表示させてみます。
print('訓練データの数', len(train_ds)) print('1つ目の訓練データ', vars(train_ds[0])) print(vars(train_ds[0])['Text']) print(vars(train_ds[0])['Label'])
出力結果
訓練データの数 4
1つ目の訓練データ {'Text': ['ディープラーニングの基礎について学ぶ。'], 'Label': ['0']}
['ディープラーニングの基礎について学ぶ。']
['0']
※出力結果を表示する場合はvarsメソッドを使い引数の辞書を返してあげます。
csvを読み込む場合の例
別の例としてcsvファイルを読み込む例です。
今度は以下のような3つのフィールドが存在するデータファイルとします。
ディープラーニングの基礎について学ぶ。,0,基礎
異常検知の手法としてオートエンコーダがよく利用されます,1,異常検知
自然言語処理にtorchtextを利用すると便利です。,2,NLP
画像認識の手法の1つとしてGANがあります,3,画像分類
csvを読み込む場合はformat='csvと指定するだけです。
#文章読み込みテスト import torchtext TEXT = torchtext.data.Field() LABEL = torchtext.data.Field() LABEL2 = torchtext.data.Field() train_ds, test_ds = torchtext.data.TabularDataset.splits( path='./data/', train='sample_train.csv', validation='sample_val.csv', format='csv', fields=[('Text', TEXT), ('Label', LABEL), ('Label2', LABEL2)])
今度はフィールドが3つあるので、torchtext.data.Field()を使って3つのフィールド(TEXT,LABEL,LABEL2)を定義します。
あとはtorchtext.data.TabularDataset.splitsを使って学習用、検証用データとしてデータセットを生成します。
今回はフィールドが3つあるので以下の通りfiledsオプションで先ほど定義した3つのフィールドを指定します。
fields=[('Text', TEXT), ('Label', LABEL), ('Label2', LABEL2)])
出力結果を確認してみます。
print('訓練データの数', len(train_ds)) print('1つ目の訓練データ', vars(train_ds[0])) print(vars(train_ds[0])['Text']) print(vars(train_ds[0])['Label']) print(vars(train_ds[0])['Label2'])
出力結果
訓練データの数 4
1つ目の訓練データ {'Text': ['ディープラーニングの基礎について学ぶ。'], 'Label': ['0'], 'Label2': ['基礎']}
['ディープラーニングの基礎について学ぶ。']
['0'] ['基礎']
csvデータをutf-8形式で保存していないと読み込み時に以下のエラーが発生するので注意しましょう。
UnicodeDecodeError: 'utf8' codec can't decode byte
torchtextにはほかにも便利機能がありますので、本記事は、継続して更新していきます。
torchtext.build_vocab
テキストデータを元に単語辞書データを生成するメソッドが用意されています。
- build_vocab(data, min_freq=1)
※dataをソースとして単語辞書データを生成してくれる。
※min_freqに指定した数以上の頻度の単語をターゲットとする。
torchtext.vocab
torchtext.vocabには以下の3つの機能があります。
- freqs –Vocab内の単語と頻度をcollections.Counterオブジェクトとして生成してくれます。
Counter({'王': 1, 'と': 5, '王子': 1, '女王': 1, '姫': 1, '男性': 1, '女性': 1})
- stoi –トークン文字列をインデックス(数値)にマッピングするcollections.defaultdictインスタンスを生成してくれる。
※unk(未知語)とpad(パディング)は自動的に0,1として割り振られる。
defaultdict(<bound method Vocab._default_unk_index of <torchtext.vocab.Vocab object at 0x7fc91f94de80>>, {'<unk>': 0, '<pad>': 1, 'と': 2, '。': 3, 'な': 4, 'の': 5, '文章': 6, '、': 7, 'が': 8)
- itos –数値識別子でインデックス付けされたトークン文字列のリスト。
※stoiの逆で数値を単語にマッピングしてくれる。