{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train model using declarative and imperative API\n", " \n", "CNTK gives the user several ways how her model can be trained:\n", "* High level declarative style API using [Function.train](https://www.cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function.train) method (or training_session). Given a criterion function, the user can simply call the train method, providing configuration parameters for different aspects of the training, such as data sources, checkpointing, cross validation and progress printing. The corresponding [test](https://www.cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function.test) method can be used for evaluation. This API simplifies implementation of routine training tasks and eliminates boilerplate code.\n", "* Using low level [Trainer.train_minibatch](https://www.cntk.ai/pythondocs/cntk.train.trainer.html#cntk.train.trainer.Trainer.train_minibatch) or [test](https://www.cntk.ai/pythondocs/cntk.train.trainer.html#cntk.train.trainer.Trainer.test_minibatch) methods. In this case the user writes the minibatch loop explicitly and has full control of all aspects. It is more flexible than the first option but is more error prone and requires deeper understanding of concepts especially in a distribution environment. \n", "\n", "This document is organized as follows: firstly, we will give an example how a typical imperative loop looks like and what are its caveats in a distributed environment. Then we present how declarative API looks like and how it simplifies development and eliminates potential errors. If you are interested only in using the high level API (in generarl you should always use it instead of explicit minibatch loop) please jump directly to [the corresponding Function.train section](## Using Function.train/(training session)).\n", "\n", "We start with some imports we need for the rest of this manual: " ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from __future__ import print_function\n", "import os\n", "import cntk\n", "import cntk.ops\n", "import cntk.io\n", "import cntk.train\n", "import cntk.tests.test_utils\n", "from cntk.layers import Dense, Sequential\n", "from cntk.io import StreamDef, StreamDefs, MinibatchSource, CTFDeserializer\n", "from cntk.logging import ProgressPrinter\n", "\n", "cntk.tests.test_utils.set_device_from_pytest_env() # (only needed for our build system)\n", "cntk.cntk_py.set_fixed_random_seed(1) # fix the random seed so that LR examples are repeatable" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Example script with an explicit loop\n", "\n", "Many scripts in CNTK have a very similar structure:\n", " - they create a network\n", " - instantiate a trainer and a learner with appropriate hyper-parameters\n", " - load training and testing data with minibatch sources\n", " - then run the main training loop fetching the data from the train minibatch source and feeding it to the trainer for N samples/sweeps\n", " - at the end they perform the eval loop using data from the test minibatch\n", "\n", "As an example for such a script we will take a toy task of learning XOR operation with a simple feed forward network.\n", "We will try to learn the following function:\n", "\n", "|x|y|result|\n", "|:-|:-|:------|\n", "|0|0| 0 |\n", "|0|1| 1 |\n", "|1|0| 1 |\n", "|0|0| 0 |\n", "\n", "The network will have two dense layers, we use [tanh](https://www.cntk.ai/pythondocs/cntk.ops.html?highlight=tanh#cntk.ops.tanh) as an activation for the first layer and no activation for the second. The sample script is presented below:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Learning rate per sample: 0.1\n", " Minibatch[ 1- 10]: loss = 0.352425 * 40, metric = 35.24% * 40;\n", " Minibatch[ 11- 20]: loss = 0.207848 * 40, metric = 20.78% * 40;\n", " Minibatch[ 21- 30]: loss = 0.191173 * 40, metric = 19.12% * 40;\n", " Minibatch[ 31- 40]: loss = 0.176530 * 40, metric = 17.65% * 40;\n", " Minibatch[ 41- 50]: loss = 0.161325 * 40, metric = 16.13% * 40;\n", " Minibatch[ 51- 60]: loss = 0.143685 * 40, metric = 14.37% * 40;\n", " Minibatch[ 61- 70]: loss = 0.118660 * 40, metric = 11.87% * 40;\n", " Minibatch[ 71- 80]: loss = 0.082769 * 40, metric = 8.28% * 40;\n", " Minibatch[ 81- 90]: loss = 0.046990 * 40, metric = 4.70% * 40;\n", " Minibatch[ 91- 100]: loss = 0.048029 * 40, metric = 4.80% * 40;\n", " Minibatch[ 101- 110]: loss = 0.518075 * 40, metric = 51.81% * 40;\n", " Minibatch[ 111- 120]: loss = 0.022979 * 40, metric = 2.30% * 40;\n", " Minibatch[ 121- 130]: loss = 0.018714 * 40, metric = 1.87% * 40;\n", " Minibatch[ 131- 140]: loss = 0.193923 * 40, metric = 19.39% * 40;\n", " Minibatch[ 141- 150]: loss = 0.030105 * 40, metric = 3.01% * 40;\n", " Minibatch[ 151- 160]: loss = 0.007559 * 40, metric = 0.76% * 40;\n", " Minibatch[ 161- 170]: loss = 0.006812 * 40, metric = 0.68% * 40;\n", " Minibatch[ 171- 180]: loss = 0.009554 * 40, metric = 0.96% * 40;\n", " Minibatch[ 181- 190]: loss = 0.012426 * 40, metric = 1.24% * 40;\n", " Minibatch[ 191- 200]: loss = 0.012157 * 40, metric = 1.22% * 40;\n", "Metric= 0.010537\n" ] } ], "source": [ "# Let's prepare data in the CTF format. It exactly matches\n", "# the table above\n", "INPUT_DATA = r'''|features 0 0\t|label 0\n", "|features 1 0\t|label 1\n", "|features 0 1\t|label 1\n", "|features 1 1\t|label 0\n", "'''\n", "\n", "# Write the data to a temporary file\n", "input_file = 'input.ctf.tmp'\n", "with open(input_file, 'w') as f:\n", " f.write(INPUT_DATA)\n", "\n", "# Create a network\n", "features = cntk.input_variable(2)\n", "label = cntk.input_variable(1)\n", "\n", "# Define our input data streams\n", "streams = StreamDefs(\n", " features = StreamDef(field='features', shape=2),\n", " label = StreamDef(field='label', shape=1))\n", "\n", "model = Sequential([\n", " Dense(2, activation=cntk.ops.tanh),\n", " Dense(1)])\n", "\n", "z = model(features)\n", "loss = cntk.squared_error(z, label)\n", "\n", "# Create a learner and a trainer and a progress writer to \n", "# output current progress\n", "learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n", "trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n", "\n", "# Now let's create a minibatch source for our input file\n", "mb_source = MinibatchSource(CTFDeserializer(input_file, streams))\n", "input_map = { features : mb_source['features'], label : mb_source['label'] }\n", "\n", "# Run a manual training minibatch loop\n", "minibatch_size = 4\n", "max_samples = 800\n", "train = True\n", "while train and trainer.total_number_of_samples_seen < max_samples:\n", " data = mb_source.next_minibatch(minibatch_size, input_map)\n", " train = trainer.train_minibatch(data)\n", "\n", "# Run a manual evaluation loop ussing the same data file for evaluation\n", "test_mb_source = MinibatchSource(CTFDeserializer(input_file, streams), randomize=False, max_samples=100)\n", "test_input_map = { features : test_mb_source['features'], label : test_mb_source['label'] }\n", "total_samples = 0\n", "error = 0.\n", "data = test_mb_source.next_minibatch(32, input_map)\n", "while data:\n", " total_samples += data[label].number_of_samples \n", " error += trainer.test_minibatch(data) * data[label].number_of_samples\n", " data = test_mb_source.next_minibatch(32, test_input_map)\n", "\n", "print(\"Metric= %f\" % (error / total_samples))" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "As it can be seen above, the actual model is specified in just two lines, the rest is a boilerplate code to iterate over the data and feed it manually for training and evaluation. With a manual loop, the user has the complete flexibility how to feed the data, but she also has to take several not so obvious things into account.\n", "\n", "\n", "### 1.1 Failover and recovery\n", "\n", "For the small sample above the recovery is not important, but in case the training spans several weeks or days it is not safe to assume that the machine stays online all the time and there are no hardware or software glitches. If the machine reboots, goes down or the script has a bug the user will have to rerun the same experiment from the beginning. That is highly undesirable. To avoid that CNTK allows the user to perform checkpoints and restore from them in the event of failure.\n", "\n", "One of the means to save the model state in CNTK is by using [save method](https://cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function.save) on the [Function class](https://cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function).\n", "It is worth mentioning that this function only saves the model state, but there are other stateful entities in the script, including:\n", " * minibatch sources\n", " * trainer\n", " * learners\n", " \n", "In order to save the complete state of the script, the user has to manually save the current state of the minibatch source and the trainer. The minibatch source provides [get_checkpoint_state](https://www.cntk.ai/pythondocs/cntk.io.html#cntk.io.MinibatchSource.get_checkpoint_state) method, the result can be passed to the trainer [save_checkpoint](https://www.cntk.ai/pythondocs/cntk.train.trainer.html#cntk.train.trainer.Trainer.save_checkpoint) method, that takes care of saving the state to disk or exchanging the state in case of distributed training. There are also the corresponding [restore_from_checkpoint](https://www.cntk.ai/pythondocs/cntk.train.trainer.html#cntk.train.trainer.Trainer.restore_from_checkpoint) methods on the trainer and the minibatch source that can be used for restore. To recover from error, on start up the user has to restore a state using the trainer and set the current position of the minibatch source.\n", "\n", "With the above in mind, let's rewrite our loop as follows:" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No restore file found\n", "Learning rate per sample: 0.1\n", " Minibatch[ 1- 10]: loss = 0.352425 * 40, metric = 35.24% * 40;\n", " Minibatch[ 11- 20]: loss = 0.207848 * 40, metric = 20.78% * 40;\n", " Minibatch[ 21- 30]: loss = 0.191173 * 40, metric = 19.12% * 40;\n", " Minibatch[ 31- 40]: loss = 0.176530 * 40, metric = 17.65% * 40;\n", " Minibatch[ 41- 50]: loss = 0.161325 * 40, metric = 16.13% * 40;\n", " Minibatch[ 51- 60]: loss = 0.143685 * 40, metric = 14.37% * 40;\n", " Minibatch[ 61- 70]: loss = 0.118660 * 40, metric = 11.87% * 40;\n", " Minibatch[ 71- 80]: loss = 0.082769 * 40, metric = 8.28% * 40;\n", " Minibatch[ 81- 90]: loss = 0.046990 * 40, metric = 4.70% * 40;\n", " Minibatch[ 91- 100]: loss = 0.048029 * 40, metric = 4.80% * 40;\n", " Minibatch[ 101- 110]: loss = 0.518075 * 40, metric = 51.81% * 40;\n", " Minibatch[ 111- 120]: loss = 0.022979 * 40, metric = 2.30% * 40;\n", " Minibatch[ 121- 130]: loss = 0.018714 * 40, metric = 1.87% * 40;\n", " Minibatch[ 131- 140]: loss = 0.193923 * 40, metric = 19.39% * 40;\n", " Minibatch[ 141- 150]: loss = 0.030105 * 40, metric = 3.01% * 40;\n", " Minibatch[ 151- 160]: loss = 0.007559 * 40, metric = 0.76% * 40;\n", " Minibatch[ 161- 170]: loss = 0.006812 * 40, metric = 0.68% * 40;\n", " Minibatch[ 171- 180]: loss = 0.009554 * 40, metric = 0.96% * 40;\n", " Minibatch[ 181- 190]: loss = 0.012426 * 40, metric = 1.24% * 40;\n", " Minibatch[ 191- 200]: loss = 0.012157 * 40, metric = 1.22% * 40;\n" ] } ], "source": [ "# Run a manual training minibatch loop with checkpointing\n", "\n", "# Same as before\n", "mb_source = MinibatchSource(CTFDeserializer(input_file, streams))\n", "input_map = { features : mb_source['features'], label : mb_source['label'] }\n", "\n", "model = Sequential([\n", " Dense(2, activation=cntk.ops.tanh),\n", " Dense(1)])\n", "\n", "z = model(features)\n", "loss = cntk.squared_error(z, label)\n", "\n", "learner = cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1))\n", "trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))\n", "\n", "# Try to restore if the checkpoint exists\n", "checkpoint = 'manual_loop_checkpointed.tmp'\n", "#Please comment the line below if you want to restore from the checkpoint\n", "if os.path.exists(checkpoint):\n", " os.remove(checkpoint)\n", "\n", "if os.path.exists(checkpoint):\n", " print(\"Trying to restore from checkpoint\")\n", " mb_source_state = trainer.restore_from_checkpoint(checkpoint)\n", " mb_source.restore_from_checkpoint(mb_source_state)\n", " print(\"Restore has finished successfully\")\n", "else:\n", " print(\"No restore file found\")\n", " \n", "checkpoint_frequency = 100\n", "last_checkpoint = 0\n", "train = True\n", "while train and trainer.total_number_of_samples_seen < max_samples:\n", " data = mb_source.next_minibatch(minibatch_size, input_map)\n", " train = trainer.train_minibatch(data)\n", " if trainer.total_number_of_samples_seen / checkpoint_frequency != last_checkpoint:\n", " mb_source_state = mb_source.get_checkpoint_state()\n", " trainer.save_checkpoint(checkpoint, mb_source_state)\n", " last_checkpoint = trainer.total_number_of_samples_seen / checkpoint_frequency\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "At the beginning we check if the checkpoint file exists and we can restore from it. After that we start the training. Our loop is based on the total number of samples the trainer has seen. This information is included in the checkpoint, so in \n", "case of failure the training will resume at the saved position (this will become even more important for distributed training).\n", "\n", "Depending on the checkpointing frequency the above script retrieves the current state of the minibatch source and creates a checkpoint using the trainer. If the script iterates over the same data many times, saving the state of the minibatch source is not that important, but for huge workloads you probably do not want to start seeing the same data from the beginning.\n", "\n", "At some point the user will want to parallelize the script to decrease the training time. Let's look how this can be done in the next section.\n", "\n", "### 1.2 Distributed manual loop\n", "\n", "In order to make training distributed CNTK provides a set of distributed learner that encapsulate a set of algorithms (1BitSGD, BlockMomentum, data parallel SGD) that uses MPI to exchage the state. From the script perspective, almost everything stays the same. The only difference is that the user needs to wrap the learner into the corresponding distributed learner and make sure she picks up the data from the minibatch source based on the current worker rank (also the script should be run with `mpiexec`):" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No restore file found\n", " Minibatch[ 1- 10]: loss = 0.352425 * 40, metric = 35.24% * 40;\n", " Minibatch[ 11- 20]: loss = 0.207848 * 40, metric = 20.78% * 40;\n", " Minibatch[ 21- 30]: loss = 0.191173 * 40, metric = 19.12% * 40;\n", " Minibatch[ 31- 40]: loss = 0.176530 * 40, metric = 17.65% * 40;\n", " Minibatch[ 41- 50]: loss = 0.161325 * 40, metric = 16.13% * 40;\n", " Minibatch[ 51- 60]: loss = 0.143685 * 40, metric = 14.37% * 40;\n", " Minibatch[ 61- 70]: loss = 0.118660 * 40, metric = 11.87% * 40;\n", " Minibatch[ 71- 80]: loss = 0.082769 * 40, metric = 8.28% * 40;\n", " Minibatch[ 81- 90]: loss = 0.046990 * 40, metric = 4.70% * 40;\n", " Minibatch[ 91- 100]: loss = 0.048029 * 40, metric = 4.80% * 40;\n", " Minibatch[ 101- 110]: loss = 0.518075 * 40, metric = 51.81% * 40;\n", " Minibatch[ 111- 120]: loss = 0.022979 * 40, metric = 2.30% * 40;\n", " Minibatch[ 121- 130]: loss = 0.018714 * 40, metric = 1.87% * 40;\n", " Minibatch[ 131- 140]: loss = 0.193923 * 40, metric = 19.39% * 40;\n", " Minibatch[ 141- 150]: loss = 0.030105 * 40, metric = 3.01% * 40;\n", " Minibatch[ 151- 160]: loss = 0.007559 * 40, metric = 0.76% * 40;\n", " Minibatch[ 161- 170]: loss = 0.006812 * 40, metric = 0.68% * 40;\n", " Minibatch[ 171- 180]: loss = 0.009554 * 40, metric = 0.96% * 40;\n", " Minibatch[ 181- 190]: loss = 0.012426 * 40, metric = 1.24% * 40;\n", " Minibatch[ 191- 200]: loss = 0.012157 * 40, metric = 1.22% * 40;\n" ] } ], "source": [ "# Run a manual training minibatch loop with distributed learner\n", "checkpoint = 'manual_loop_distributed.tmp'\n", "#Please comment the line below if you want to restore the checkpoint\n", "if os.path.exists(checkpoint):\n", " os.remove(checkpoint)\n", "\n", "model = Sequential([\n", " Dense(2, activation=cntk.ops.tanh),\n", " Dense(1)])\n", "\n", "z = model(features)\n", "loss = cntk.squared_error(z, label)\n", "\n", "mb_source = MinibatchSource(CTFDeserializer(input_file, streams))\n", "input_map = { features : mb_source['features'], label : mb_source['label'] }\n", "\n", "# Make sure the learner is distributed\n", "distributed_learner = cntk.distributed.data_parallel_distributed_learner(\n", " cntk.sgd(model.parameters, cntk.learning_parameter_schedule_per_sample(0.1)))\n", "trainer = cntk.train.Trainer(z, (loss, loss), distributed_learner, ProgressPrinter(freq=10))\n", "\n", "if os.path.exists(checkpoint):\n", " print(\"Trying to restore from checkpoint\")\n", " mb_source_state = trainer.restore_from_checkpoint(checkpoint)\n", " mb_source.restore_from_checkpoint(mb_source_state)\n", " print(\"Restore has finished successfully\")\n", "else:\n", " print(\"No restore file found\")\n", "\n", "last_checkpoint = 0\n", "train = True\n", "partition = cntk.distributed.Communicator.rank()\n", "num_partitions = cntk.distributed.Communicator.num_workers()\n", "while train and trainer.total_number_of_samples_seen < max_samples:\n", " # Make sure each worker gets its own data only\n", " data = mb_source.next_minibatch(minibatch_size_in_samples = minibatch_size,\n", " input_map = input_map, device = cntk.use_default_device(), \n", " num_data_partitions=num_partitions, partition_index=partition)\n", " train = trainer.train_minibatch(data)\n", " if trainer.total_number_of_samples_seen / checkpoint_frequency != last_checkpoint:\n", " mb_source_state = mb_source.get_checkpoint_state()\n", " trainer.save_checkpoint(checkpoint, mb_source_state)\n", " last_checkpoint = trainer.total_number_of_samples_seen / checkpoint_frequency\n", "\n", "# When you use distributed learners, please call finalize MPI at the end of your script, \n", "# see the next cell.\n", "# cntk.distributed.Communicator.finalize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order for distribution to work properly, the minibatch loop should exit on all workers at the same time. Some of the workers can have more data then the others, so the exit condition of the loop should be based on the return value of the trainer (if no more work should be done by a particular worker this can be communicated by passing an empty minibatch to `train_minibatch`).\n", "\n", "As has been noted before, the decisions inside the loop are based on the [Trainer.total_number_of_samples_seen](https://www.cntk.ai/pythondocs/cntk.train.trainer.html#cntk.train.trainer.Trainer.total_number_of_samples_seen). Some of the operations (i.e. `train_minibatch`, checkpoint, cross validation, if done in a distributed fashion) require synchronization and to match among all the workers they use a global state - the global number of samples seen by the trainer.\n", "\n", "Even though writing manual training loops brings all the flexibility to the user, it can also be error prone and require a lot of boilerplate code to make everything work. When this flexibility if not required, it is better to use a higher abstraction.\n", "\n", "## 2. Using Function.train\n", "\n", "Instead of writing the training loop manually and taking care of checkpointing and distribution herself, the user can delegate this aspects to the training session exposed through [Function.train/test](https://www.cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function.train) methods. It automatically takes care of the following things:\n", " 1. checkpointing\n", " 2. validation\n", " 3. testing/evaluation\n", "\n", "All that is needed from the user is to provide the corresponding configuration parameters. In addition to the higher abstraction the training session is also implemented in C++, so it is generally faster than writing a loop in Python:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Minibatch[ 1- 10]: loss = 0.352425 * 40, metric = 35.24% * 40;\n", " Minibatch[ 11- 20]: loss = 0.207848 * 40, metric = 20.78% * 40;\n", " Minibatch[ 21- 30]: loss = 0.191173 * 40, metric = 19.12% * 40;\n", " Minibatch[ 31- 40]: loss = 0.176530 * 40, metric = 17.65% * 40;\n", " Minibatch[ 41- 50]: loss = 0.161325 * 40, metric = 16.13% * 40;\n", " Minibatch[ 51- 60]: loss = 0.143685 * 40, metric = 14.37% * 40;\n", " Minibatch[ 61- 70]: loss = 0.118660 * 40, metric = 11.87% * 40;\n", " Minibatch[ 71- 80]: loss = 0.082769 * 40, metric = 8.28% * 40;\n", " Minibatch[ 81- 90]: loss = 0.046990 * 40, metric = 4.70% * 40;\n", " Minibatch[ 91- 100]: loss = 0.048029 * 40, metric = 4.80% * 40;\n", " Minibatch[ 101- 110]: loss = 0.518075 * 40, metric = 51.81% * 40;\n", " Minibatch[ 111- 120]: loss = 0.022979 * 40, metric = 2.30% * 40;\n", " Minibatch[ 121- 130]: loss = 0.018714 * 40, metric = 1.87% * 40;\n", " Minibatch[ 131- 140]: loss = 0.193923 * 40, metric = 19.39% * 40;\n", " Minibatch[ 141- 150]: loss = 0.030105 * 40, metric = 3.01% * 40;\n", " Minibatch[ 151- 160]: loss = 0.007559 * 40, metric = 0.76% * 40;\n", " Minibatch[ 161- 170]: loss = 0.006812 * 40, metric = 0.68% * 40;\n", " Minibatch[ 171- 180]: loss = 0.009554 * 40, metric = 0.96% * 40;\n", " Minibatch[ 181- 190]: loss = 0.012426 * 40, metric = 1.24% * 40;\n", " Minibatch[ 191- 200]: loss = 0.012157 * 40, metric = 1.22% * 40;\n", "Finished Epoch[1]: loss = 0.118087 * 800, metric = 11.81% * 800 1.091s (733.3 samples/s);\n", "Finished Evaluation [1]: Minibatch[1-4]: metric = 1.05% * 100;\n" ] } ], "source": [ "checkpoint = 'training_session.tmp'\n", "#Please comment the line below if you want to restore from the checkpoint\n", "if os.path.exists(checkpoint):\n", " os.remove(checkpoint)\n", "\n", "# As before\n", "mb_source = MinibatchSource(CTFDeserializer(input_file, streams))\n", "test_mb_source = MinibatchSource(CTFDeserializer(input_file, streams), randomize=False, max_samples=100)\n", "\n", "model_factory = Sequential([\n", " Dense(2, activation=cntk.ops.tanh),\n", " Dense(1)])\n", "model = model_factory(features)\n", "\n", "# Criterion function\n", "@cntk.Function\n", "def criterion_factory(f, l):\n", " z = model_factory(f)\n", " loss = cntk.squared_error(z, l)\n", " return (loss, loss)\n", "\n", "criterion = criterion_factory(features, label)\n", "learner = cntk.distributed.data_parallel_distributed_learner(cntk.sgd(model.parameters, \n", " cntk.learning_parameter_schedule_per_sample(0.1)))\n", "progress_writer = cntk.logging.ProgressPrinter(freq=10)\n", "checkpoint_config = cntk.CheckpointConfig(filename=checkpoint, frequency=checkpoint_frequency)\n", "test_config = cntk.TestConfig(test_mb_source)\n", "\n", "# Actual training\n", "progress = criterion.train(mb_source, minibatch_size=minibatch_size,\n", " model_inputs_to_streams={ features : mb_source['features'], label : mb_source['label'] },\n", " max_samples=max_samples, parameter_learners=[learner], \n", " callbacks=[progress_writer, checkpoint_config, test_config])\n", "\n", "# When you use distributed learners, please call finalize MPI at the end of your script\n", "# cntk.distributed.Communicator.finalize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how to configure different aspects of the train method.\n", "\n", "### Progress tracking\n", "In order to report progress, please provide an instance of the [ProgressWriter](https://www.cntk.ai/pythondocs/cntk.logging.progress_print.html#module-cntk.logging.progress_print). It has its own set of parameters to control how often to print the loss value. If you need to have a custom logic for retrieving current status, please consider implementing your own ProgressWriter.\n", "\n", "### Checkpointing\n", "[Checkpoint configuration](https://www.cntk.ai/pythondocs/cntk.train.training_session.html#cntk.train.training_session.CheckpointConfig) specifies how often to save a checkpoint to the given file. The checkpointing frequency is specified in samples. When given, the method takes care of saving/restoring the state across the trainer/learners/minibatch source and propagating this information among distributed workers. If you need to preserve all checkpoints that were taken during training, please set `preserveAll` to true. \n", "\n", "### Validation\n", "When [cross validation](https://www.cntk.ai/pythondocs/cntk.train.training_session.html#cntk.train.training_session.CrossValidationConfig) config is given, the training session runs the validation on the specified minibatch source with the specified frequency and reports average metric error. The user can also provide a cross validation callback, that will be called with the specified frequency. It is up to the user to perform validation in the callback and return back `True` if the training should be continued, or `False` otherwise. \n", "\n", "### Testing\n", "If the test configuration is given, after completion of training, the train method runs evaluation on the specified minibatch source. If you need to run only evaluation without training, consider using [Function.test](https://www.cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function.test) method instead.\n", "\n", "For more advanced scenarios of use Function.train please see [Tutorial 200](https://github.com/Microsoft/CNTK/blob/release/2.6/Tutorials/CNTK_200_GuidedTour.ipynb)" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }