zhenyundeng
add files
e62781a
raw
history blame contribute delete
No virus
6.01 kB
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Model architecture/optimization options for DrQA document reader."""
import argparse
import logging
logger = logging.getLogger(__name__)
# Index of arguments concerning the core model architecture
MODEL_ARCHITECTURE = {
'model_type', 'embedding_dim', 'hidden_size', 'doc_layers',
'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge',
'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf'
}
# Index of arguments concerning the model optimizer/training
MODEL_OPTIMIZER = {
'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay',
'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb',
'max_len', 'grad_clipping', 'tune_partial'
}
def str2bool(v):
return v.lower() in ('yes', 'true', 't', '1', 'y')
def add_model_args(parser):
parser.register('type', 'bool', str2bool)
# Model architecture
model = parser.add_argument_group('DrQA Reader Model Architecture')
model.add_argument('--model-type', type=str, default='rnn',
help='Model architecture type')
model.add_argument('--embedding-dim', type=int, default=300,
help='Embedding size if embedding_file is not given')
model.add_argument('--hidden-size', type=int, default=128,
help='Hidden size of RNN units')
model.add_argument('--doc-layers', type=int, default=3,
help='Number of encoding layers for document')
model.add_argument('--question-layers', type=int, default=3,
help='Number of encoding layers for question')
model.add_argument('--rnn-type', type=str, default='lstm',
help='RNN type: LSTM, GRU, or RNN')
# Model specific details
detail = parser.add_argument_group('DrQA Reader Model Details')
detail.add_argument('--concat-rnn-layers', type='bool', default=True,
help='Combine hidden states from each encoding layer')
detail.add_argument('--question-merge', type=str, default='self_attn',
help='The way of computing the question representation')
detail.add_argument('--use-qemb', type='bool', default=True,
help='Whether to use weighted question embeddings')
detail.add_argument('--use-in-question', type='bool', default=True,
help='Whether to use in_question_* features')
detail.add_argument('--use-pos', type='bool', default=True,
help='Whether to use pos features')
detail.add_argument('--use-ner', type='bool', default=True,
help='Whether to use ner features')
detail.add_argument('--use-lemma', type='bool', default=True,
help='Whether to use lemma features')
detail.add_argument('--use-tf', type='bool', default=True,
help='Whether to use term frequency features')
# Optimization details
optim = parser.add_argument_group('DrQA Reader Optimization')
optim.add_argument('--dropout-emb', type=float, default=0.4,
help='Dropout rate for word embeddings')
optim.add_argument('--dropout-rnn', type=float, default=0.4,
help='Dropout rate for RNN states')
optim.add_argument('--dropout-rnn-output', type='bool', default=True,
help='Whether to dropout the RNN output')
optim.add_argument('--optimizer', type=str, default='adamax',
help='Optimizer: sgd or adamax')
optim.add_argument('--learning-rate', type=float, default=0.1,
help='Learning rate for SGD only')
optim.add_argument('--grad-clipping', type=float, default=10,
help='Gradient clipping')
optim.add_argument('--weight-decay', type=float, default=0,
help='Weight decay factor')
optim.add_argument('--momentum', type=float, default=0,
help='Momentum factor')
optim.add_argument('--fix-embeddings', type='bool', default=True,
help='Keep word embeddings fixed (use pretrained)')
optim.add_argument('--tune-partial', type=int, default=0,
help='Backprop through only the top N question words')
optim.add_argument('--rnn-padding', type='bool', default=False,
help='Explicitly account for padding in RNN encoding')
optim.add_argument('--max-len', type=int, default=15,
help='The max span allowed during decoding')
def get_model_args(args):
"""Filter args for model ones.
From a args Namespace, return a new Namespace with *only* the args specific
to the model architecture or optimization. (i.e. the ones defined here.)
"""
global MODEL_ARCHITECTURE, MODEL_OPTIMIZER
required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER
arg_values = {k: v for k, v in vars(args).items() if k in required_args}
return argparse.Namespace(**arg_values)
def override_model_args(old_args, new_args):
"""Set args to new parameters.
Decide which model args to keep and which to override when resolving a set
of saved args and new args.
We keep the new optimation, but leave the model architecture alone.
"""
global MODEL_OPTIMIZER
old_args, new_args = vars(old_args), vars(new_args)
for k in old_args.keys():
if k in new_args and old_args[k] != new_args[k]:
if k in MODEL_OPTIMIZER:
logger.info('Overriding saved %s: %s --> %s' %
(k, old_args[k], new_args[k]))
old_args[k] = new_args[k]
else:
logger.info('Keeping saved %s: %s' % (k, old_args[k]))
return argparse.Namespace(**old_args)