Python model 模块,train() 实例源码

我们从Python开源项目中,提取了以下12个代码示例,用于说明如何使用model.train()

项目:Saliency_Detection_Convolutional_Autoencoder    作者:arthurmeyer    | 项目源码 | 文件源码
def save_model(model, sess, log_path, step):
  """
  Save model using tensorflow checkpoint (also save hidden variables)

  Args:
    model            :         model to save variable from
    sess             :         tensorflow session
    log_path         :         where to save
    step             :         number of step at time of saving
  """

  path = log_path + '/' + model.name
  if tf.gfile.Exists(path):
    tf.gfile.DeleteRecursively(path)
  tf.gfile.MakeDirs(path)
  saver = tf.train.Saver()
  checkpoint_path = os.path.join(path, 'model.ckpt')
  saver.save(sess, checkpoint_path, global_step=step)
项目:Saliency_Detection_Convolutional_Autoencoder    作者:arthurmeyer    | 项目源码 | 文件源码
def restore_model(model, sess, log_path):
  """
  Restore model (including hidden variable)
  In practice use to resume the training of the same model

  Args
    model            :         model to restore variable to
    sess             :         tensorflow session
    log_path         :         where to save

  Returns:
    step_b           :         the step number at which training ended
  """

  path = log_path + '/' + model.name  
  saver = tf.train.Saver()
  ckpt = tf.train.get_checkpoint_state(path)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
    return ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  else:
    print('------------------------------------------------------')
    print('No checkpoint file found')
    print('------------------------------------------------------ \n')
    exit()
项目:Saliency_Detection_Convolutional_Autoencoder    作者:arthurmeyer    | 项目源码 | 文件源码
def save_weight_only(model, sess, log_path, step):
  """
  Save model but only weight (meaning no hidden variable)
  In practice use this to just transfer weights from one model to the other

  Args:
    model            :         model to save variable from
    sess             :         tensorflow session
    log_path         :         where to save
    step             :         number of step at time of saving
  """

  path = log_path + '/' + model.name + '_weight_only'
  if tf.gfile.Exists(path):
    tf.gfile.DeleteRecursively(path)
  tf.gfile.MakeDirs(path)

  variable_to_save = {}
  for i in range(30):
    name = 'conv_' + str(i)
    variable_to_save[name] = model.parameters_conv[i]
    if i in [2, 4] and model.concat:
      name = 'deconv_' + str(i)
      variable_to_save[name] = model.parameters_deconv[i][0]
      name = 'deconv_' + str(i) + '_bis'
      variable_to_save[name] = model.parameters_deconv[i][1]
    else:
      name = 'deconv_' + str(i)
      variable_to_save[name] = model.parameters_deconv[i]
    if i < 2:
      name = 'deconv_bis_' + str(i)
      variable_to_save[name] = model.deconv[i]
  saver = tf.train.Saver(variable_to_save)
  checkpoint_path = os.path.join(path, 'model.ckpt')
  saver.save(sess, checkpoint_path, global_step=step)
项目:Tree-LSTM-LM    作者:vgene    | 项目源码 | 文件源码
def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
        data, targets = get_batch(train_data, i)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        model.zero_grad()
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.data

        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss[0] / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // args.bptt, lr,
                elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

# Loop over epochs.
项目:facial-emotion-detection-dl    作者:dllatas    | 项目源码 | 文件源码
def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)
        image, label = input.get_input(LABEL_PATH, LABEL_FORMAT, IMAGE_PATH, IMAGE_FORMAT)
        logits = model.inference(image)
        loss = model.loss(logits, label)
        train_op = model.train(loss, global_step)
        saver = tf.train.Saver(tf.all_variables())
        summary_op = tf.merge_all_summaries()
        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=input.FLAGS.log_device_placement))
        sess.run(init)
        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)
        summary_writer = tf.train.SummaryWriter(input.FLAGS.train_dir, graph_def=sess.graph_def)
        for step in xrange(input.FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            if step % 1 == 0:
                num_examples_per_step = input.FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)')
                print (format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch))
            if step % 10 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)
            # Save the model checkpoint periodically.
            if step % 25 == 0:
                checkpoint_path = os.path.join(input.FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
项目:facial-emotion-detection-dl    作者:dllatas    | 项目源码 | 文件源码
def main(argv=None):
    train()
项目:examples    作者:pytorch    | 项目源码 | 文件源码
def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
        data, targets = get_batch(train_data, i)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        model.zero_grad()
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.data

        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss[0] / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // args.bptt, lr,
                elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

# Loop over epochs.
项目:IllustrationGAN    作者:tdrussell    | 项目源码 | 文件源码
def init_variables(sess):
    init_op = tf.initialize_all_variables()
    sess.run(init_op)

    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
    discrim_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
    variables_to_restore = gen_vars + discrim_vars
    saver = tf.train.Saver(variables_to_restore)
    if FLAGS.restore_model:
        ckpt = tf.train.get_checkpoint_state('./')
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise RuntimeError('No checkpoint file found.')
    return
项目:IllustrationGAN    作者:tdrussell    | 项目源码 | 文件源码
def main(argv=None):
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)

    train()
项目:Saliency_Detection_Convolutional_Autoencoder    作者:arthurmeyer    | 项目源码 | 文件源码
def restore_weight_from(model, name, sess, log_path, copy_concat = False):
  """
  Restore model (excluding hidden variable)
  In practice use to train a model with the weight from another model. 
  As long as both model have architecture from the original model.py, then it works 
  Compatible w or w/o direct connections

  Args
    model            :         model to restore variable to
    name             :         name of model to copy
    sess             :         tensorflow session
    log_path         :         where to restore
    copy_concat      :         specify if the model to copy from also had direct connections

  Returns:
    step_b           :         the step number at which training ended
  """

  path = log_path + '/' + name + '_weight_only'

  variable_to_save = {}
  for i in range(30):
    name = 'conv_' + str(i)
    variable_to_save[name] = model.parameters_conv[i]
    if i < 2:
      if copy_concat == model.concat:
        name = 'deconv_' + str(i)
        variable_to_save[name] = model.parameters_deconv[i]
        name = 'deconv_bis_' + str(i)
        variable_to_save[name] = model.deconv[i]
    else:
      if i in [2, 4] and model.concat:
        name = 'deconv_' + str(i)
        variable_to_save[name] = model.parameters_deconv[i][0]
        if copy_concat:
          name = 'deconv_' + str(i) + '_bis'
          variable_to_save[name] = model.parameters_deconv[i][1]
      elif i in [2, 4] and not model.concat:
        name = 'deconv_' + str(i)
        variable_to_save[name] = model.parameters_deconv[i]
      else:
        name = 'deconv_' + str(i)
        variable_to_save[name] = model.parameters_deconv[i]

  saver = tf.train.Saver(variable_to_save)
  ckpt = tf.train.get_checkpoint_state(path)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
    return ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  else:
    print('------------------------------------------------------')
    print('No checkpoint file found')
    print('------------------------------------------------------ \n')
    exit()
项目:awd-lstm-lm    作者:salesforce    | 项目源码 | 文件源码
def train():
    # Turn on training mode which enables dropout.
    if args.model == 'QRNN': model.reset()
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    batch, i = 0, 0
    while i < train_data.size(0) - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
        # Prevent excessively small or negative sequence lengths
        seq_len = max(5, int(np.random.normal(bptt, 5)))
        # There's a very small chance that it could select a very long sequence length resulting in OOM
        seq_len = min(seq_len, args.bptt + 10)

        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
        model.train()
        data, targets = get_batch(train_data, i, args, seq_len=seq_len)

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        optimizer.zero_grad()

        output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
        raw_loss = criterion(output.view(-1, ntokens), targets)

        loss = raw_loss
        # Activiation Regularization
        loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
        # Temporal Activation Regularization (slowness)
        loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        optimizer.step()

        total_loss += raw_loss.data
        optimizer.param_groups[0]['lr'] = lr2
        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss[0] / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'],
                elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        ###
        batch += 1
        i += seq_len


# Load the best saved model.
项目:awd-lstm-lm    作者:salesforce    | 项目源码 | 文件源码
def train():
    # Turn on training mode which enables dropout.
    if args.model == 'QRNN': model.reset()
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    batch, i = 0, 0
    while i < train_data.size(0) - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
        # Prevent excessively small or negative sequence lengths
        seq_len = max(5, int(np.random.normal(bptt, 5)))
        # There's a very small chance that it could select a very long sequence length resulting in OOM
        # seq_len = min(seq_len, args.bptt + 10)

        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
        model.train()
        data, targets = get_batch(train_data, i, args, seq_len=seq_len)

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        optimizer.zero_grad()

        output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
        raw_loss = criterion(output.view(-1, ntokens), targets)

        loss = raw_loss
        # Activiation Regularization
        loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
        # Temporal Activation Regularization (slowness)
        loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        optimizer.step()

        total_loss += raw_loss.data
        optimizer.param_groups[0]['lr'] = lr2
        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss[0] / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'],
                elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        ###
        batch += 1
        i += seq_len

# Loop over epochs.