
The plan for training
=====================

Suppose we're training from python.

Raw training data is a list of tuples
   (feats, FSA)
 where `feats` is a torch.tensor and FSA is a k2._C.fsa which is a C++ object
 wrapped as Python.  (I will use the ._C submodule to indicate when something is
 actually a C++ object, because module k2 will also contain some Python objects).

 We combine the features into `all_feats` which will contain many sequences.  Possibly
 there might be some mechanism of packing short sequences into longer ones, and using a mask
 to restart any recurrences, as in Lingvo.

 We `forward` the network to get `nnet_output`, which is of dim
   (seqs_per_minibatch, num_frames, num_phones)

 We combine the FSAs for the sequences into `supervisions` which is k2.fsavec;
 this represents a vector of FSA's.  Note, its size may be greater than
 seqs_per_minibatch if we have packed multiple utterance into a single sequence.

 We also create a data structure `nnet_output_ranges` of the same dimension as
 `supervisions` which tells us, for each FSA in `supervisions`, which range of
 elements of `nnet_output` it needs to use as its input.

 From `nnet_output_ranges` and the data in `nnet_output`, which in the initial
 implementation we'll have to transfer to CPU, we create `nnet_output_fsas`
 which is a data structure of type k2.dwfsavec (the "d" is for dense, the "w" is
 for weighted).  Conceptually this represents a vector of dense, weighted FSAs,
 one for each element of `nnet_output_ranges`.  This is a very lightweight
 object; the bulk of the data lives in `nnet_output_cpu` which is the CPU version
 of `nnet_output`.

 k2.dwfsavec actually contains two things: a torch.tensor where the weights are
 owned, and a wrapped std::vector<Dwfsa>, which contains a pointer to the
 relevant part of the weights in the torch.tensor (having no ownership over it,
 although we can make it memory-safe later on).


 For now, suppose we just want the CTC part of the objective function.  (There's lots
 more we can do with this type of framework, but here my goal is just to explain
 how it works from a programming level, and the CTC part is enough).

 We do something like:

   ctc_lattices =  k2.PrunedIntersection(nnet_output_fsas, supervisions, prune_beam)

 Here, ctc_lattices would be of type k2.wfsavec, which is a Python object containing:
   A Torch.tensor where the weights in the lattices live, and a a wrapped
   std::vector<Wfsa>, where Wfsa is basically { Fsa, float* }, where the float*
   vector contains a cost per arc and is a pointer to memory owned by the Tensor.


 At the C++ level the pruned-intersection algorithm outputs an Fsa and also two
 arrays of integer indexes, which show, for each arc in the output Wfsa, which
 two arcs in the input Wfsa it corresponds to.  These indexes, suitably
 concatenated and turned into a Tensor, will enable us to (a) construct a Tensor
 containing the weights of the output Wfsa, using PyTorch indexing; and
 (b) do the backprop when the time comes.

 The k2.PrunedIntersection() function call is just a function call, but after
 constructing the output Fsa's and getting the indexes it will construct
 the tensor of weights using a torch.Function whose .ctx  will retain
 the integer indexes mentioned above, which make it possible to
 automatically backprop derivatives w.r.t. the output k2.wsfsavec
 back to the source weights.

 Then we compute the per-sequence objective functions from the ctc lattices via:

  ctc_objfs = k2.ForwardProb(ctc_lattices)

 objf = ctc_objfs.sum()
 objf.backward() ...

 Note: the ctc_logprobs FSA is sparse, so it makes it easier to do things with
 (say) language models.  For instance: composing ctc_logprobs with some
 deterministic FSA object representing a language model or an RNNLM, using that
 as a numerator term, and for the denominator term, take `nnet_output_fsas`,
 prune it, and compose with the same LM or RNNLM.



