
  These are some notes regarding how we'll interact with Python and PyTorch.




    # Assumes that A is an acceptor but B may
    # have auxiliary symbols (i.e. may be a transducer).
    def TransducerCompose(a: FsaVec, a_weights: Tensor,
                          b: FsaVec, b_weights: Tensor,
                          b_aux_symbols = None):
        c, indexes_a, indexes_b = fsa.FsaVecCompose(a, b)

        c_weights = a_weights[indexes_a] + b_weights[indexes_b]
        if b_aux_symbols is None:
          return c, c_weights
        else:
          return c, c_weights, fsa.MapAuxSymbols(b_aux_symbols, indexes_b)


    def Compose(a: FsaVec, a_weights: Tensor,
                b: FsaVec, b_weights: Tensor):
        c, input_indexes1, input_indexes2 = fsa.FsaVecCompose(a, b)

        c_weights = a_weights[input_indexes1] + b_weights[input_indexes2]

        # Handle transducers:
        if a.aux_symbol != None:
          c.aux_symbol = a.aux_symbol[input_indexes1]
        if b.aux_symbol != None:
          c.aux_symbol = b.aux_symbol[input_indexes2]

        return c, c_weights

    def RmEpsilon(a: FsaVec, a_weights, a_aux_symbols = None):
        # At the C++ level, RmEpsilon outputs a `vector<vector<int> > indexes` that
        # says, for each arc in the output, what list of arcs in the input it corresponds
        # to.  For exposition purposes, imagine we're dealing with a single FSA, not
        # a vector of FSAs.   Suppose indexes ==  [ [ 1, 2 ], [ 6, 8, 9 ] ], we'd form
        # indexes1 = [ 1, 2, 6, 8, 9 ], and indexes2 = [ 0, 0, 1, 1, 1 ].

        b, indexes1, indexes2 = fsa.FsaVecRmEpsilon(a)

        # Note: the 1st dim of a_weights must equal a.num_arcs, but it is allowed
        # to have other dims (e.g. for non-scalar weights).
        # In the normal case, a_weights and b_weights will have just one axis.

        b_weights = torch.zeros( (b.num_arcs, *a_weights.shape[1:]) )
        b_weights._index_add(0, a_weights[indexes1], indexes2)

        # If we later need access to indexes1 and indexes2, we can
        # create a different version of this function or extend its interface.

        if a_aux_symbols is None:
           return b, b_weights
        else:
           return b, b_weights, fsa.MapAuxSymbols(a_aux_symbols, indexes1, indexes2)




    class TotalWeight(Function):
       """
       Returns the total weight of FSAs (i.e. the log-sum-exp across
       paths) as a Tensor with shape (num_fsas,)
       """

      @staticmethod
      def forward(ctx, a:FsaVec, a_weights: Tensor):
         ctx.a = a
         ctx.fb = fsa.FsaVecWithFbWeights(a, a_weights, fsa.kLogSumWeight)
         ans = fb.GetTotalWeights() # a Tensor
         return ans

     def backward(ctx, grad_out):
         # `indexes` is a Tensor that would contain, for each arc in a,
         # the index of the FSA in the FSAVec it belongs to.
         # It would be of the form  [ 0, 0, 0 .., 0, 1, 1, 1, .. 1, 2 ... ]
         indexes = fsa.GetFsaVecIndexes(ctx.a)

         # GetArcProbs would return the probability of traversing each arc of
         # each FSA, as a single Tensor.
         return ctx.fb.GetArcProbs() * grad_out[indexes]

         # TODO: handle transfers to/from GPU in case grad_out was on GPU.
         # Maybe mark this only once differentiable (it's twice differentiable,
         # I think, but this code doesn't currently support that).




 =====================
 From here is some ideas on how we'd use this in a program.


   # decoding_graph is an FsaVec, graph_weights is a float Tensor,
   # graph_word_syms is a LongTensor (both with 1 axis

   # decoding_graph will have the following extra attributes:
   # decoding_graph.weights,
   # decoding_graph.word_syms
   # both of shape (num_arcs,), of dtypes float and long respectively.
   decoding_graph = fsa.ReadDecodingGraph('a/b/c.fsa')


   # nnet_output's shape is (num_seq, num_frames, num_symbols).. these symbols
   # might be phones or letters or small word-pieces.
   nnet_output = model(input_feats)

   # Interpret each sequence numbered `n` in `nnet_output` as being a FSA with
   # `num_frames + 1` states numbered 0, 1, ... num_frames, and for each 0 <= i
   # < num_frames, arcs from state i to i+1 with each symbol `s` as the label
   # and the output given by `nnet_output[n,i,s]`.  This is a lightweight
   # operation (except for the fact that it transfers the matrix to CPU for
   # now).
   # NOTE: the above is a slight simplification because the cuts may not
   # span the entirety of `nnet_output`, they may be sub-sequences of frames
   # within there, and `cut_info` will describe that somehow.  So the ends
   # of the cuts may be "ragged".
   nnet_output_fsas = fsa.DenseFsaVec(nnet_output, cut_info)

   # nnet_output_fsas.input_indexes is a LongTensor with dimension (num_arcs_in_nnet_output_fsas, 3)
   # where the 3 is are indexes n,i,s into nnet_output.
   nnet_output_fsas.input_indexes = nnet_output_fsas.GetIndexes()

   # ... and implicitly nnet_output_fsas has a `weights` vector, which is a
   # reshape/copy of parts of `nnet_output`.  (It's not just a reshape of the
   # whole thing due to its ragged structure).


   # Composing with the FSAs representing the supervision; this is CTC aligment.
   # alignment_fsas will be an FsaVec.
   alignment_fsas = fsa.compose_pruned(nnet_output_fsas, supervision_fsas, beam=10.0)

   # objf_part1 is the CTC part of the objective; we'll later add it with the others.
   objf_part1 = fsa.GetTotalLogWeight(alignment_fsas).sum()


   # first_pass_fsas will be an FsaVec.
   # This differs from `alignment_fsas` because it allows all possible word sequences,
   # not just the supervision one(s).
   # For a while, now, the code path will be the same as we'll take in test time.
   #

   first_pass_fsas = fsa.compose_pruned(nnet_output_fsas, decoding_graph, beam=10.0)


   # `first_pass_fsas` will have the following attributes:
   #
   #   input0.{arc_indexes,weights,input_indexes}
   #   input1.{arc_indexes,weights,word_syms}
   #
   #  and the following derived/computed params:
   #    weights = input0.weights + input1.weights
   #    word_syms = input1.word_syms

   # We'll want to propagate `nnet_arc_indexes` forward, so we can
   # keep track of alignments.
   first_pass_fsas.nnet_arc_indexes = first_pass_fsas.input0.arc_indexes

   # the following computes arc-level log-posteriors
   log_posts = first_pass_fsas.ComputeLogPosteriors()

   # Also get the symbol-level posteriors:

   # The following will be the posteriors of a particular phone (or blank) at a particular
   # time-index of a particular sequence.   see fsautil.py.
   first_pass_fsas.phone_log_posts = fsautil.sum_by_index(
       log_posts.exp(),
       first_pass_fsas.input0.arc_indexes).log()

   # the following mean "the posterior of the current word".
   first_pass_fsas.word_log_posts = fsautil.sum_by_index(
       log_posts.exp(),
       torch.cat(first_pass_fsas.input0.arc_indexes,
                 first_pass_fsas.word_syms))


   #  aggregate all these weights ?
   #    first_pass_fsas.all_weights = torch.cat(
   #       first_pass_fsas.input0.weights.unsqueeze(0),  # acoustic weight
   #       first_pass_fsa.input1.weights.unsqueeze(0),   # graph weight
   #       first_pass_fsas.phone_log_posts.unsqueeze(0),
   #       first_pass_fsas.word_log_posts.unsqueeze(0))



   objf_part2 =


   # OK, now we want to do a pass of RNNLM rescoring.  Compose with a
   # virtual FSA (DeterministicOnDemandFsa) that has words as the
   # labels.  We can make this a bit more efficient by determinizing the
   # input first.


   # we need the following to propagate members weights,all_weights.
   # Invert the FST, swapping its symbols with word_syms.
   fsas_inverted = fsa.Invert(first_pass_fsas,
                              first_pass_fsas.word_syms,
                              other_label_name='phone_syms',
                              keep=[...])

   fsas_rmeps = fsa.RmEpsWeighted(first_pass_fsas)
   fsas_det = fsa.DeterminizeTropicalPruned(fsas_rmeps, beam=..)

   # note: fsas_det still has `phone_syms` and `nnet_arc_indexes`, now as sequences.


def PseudoTensor:
    """This is a class that behaves like a torch.Tensor but in fact only supports one kind of
       operation, namely indexing with another torch.Tensor"""
    def __init__(self, t, divisor):
       """ Constructor.
           Parameters:
           t:  torch.LongTensor
           divisor: int
        """
        self.t = t
        self.divisor = divisor
    def __getitem__(self, indexes):
        return self.t[indexes / divisor]

def DenseFsaVec:

    """Represents a vector of FSAs, but with a special regular
    structure.  Each FSA would normally correspond to one supervised
    segment within an acoustic sequence.  This wraps the data
    output from the neural net.  Each segment has T+2 states
    numbered 0, 1, .. T, T+1 (the T+1'th is the final-state and
    only has final-arcs entering it).  From state i to i+1
    there is an arc for each symbol, whose loglike comes from
    lookup in the neural-net output."""


    def __init__(self, loglikes, seq_indexes, start_times, end_times):
      """ Constructor.
      Params:
         loglikes (torch.Tensor):  The tensor of log-likelihoods
            output by the neural network.  Will be interpreted
            as having shape (num_seqs, num_frames, num_symbols).
            Here `num_symbols` includes epsilon.
          seq_indexes, start_times, end_times (torch.Tensor):
            These must all have the same shape, of the form
            (num_segments,).  Here, num_segments would normally
            be >= num_seqs (each sequence may have more than
            one supervised segment in it, and they may
            overlap).
              - seq_indexes says, for each segment, which sequence
                it is a part of
              - start_times says, for each segment, what the first
                frame-index in `loglikes` is
              - start_times says, for each segment, what the
                one-past-the-last frame-index in `loglikes` is.
        """
        # note: this will reate a csrc.CfsaVec object internal to this
        # object.
        # It will also create self.arc_loglikes containing the
        # loglikes, one per arc of the CfsaVec object.  This is
        # a repeat of `loglikes` but possibly in a different
        # order.


        pass

    @property
    def loglikes(self):
        return self.arc_loglikes


    @property
    def seg_frames(self):
        """Return something that 'acts' like a tensor, indexed by arc, of
           the frame-index relative to the segment start corresponding to that
           arc.  NOTE: self.frame_loglikes will actually be a sub-Tensor
           of the Tensor created at the C++ level as the DenseFsaVecMeta object.
        """
        return PseudoTensor(self.frame_loglikes, self.num_symbols)


    @property
    def seq_frames(self):
        """ as for seg_frames"""
        pass

    @property
    def seq_ids(self):
        """ as for seg_frames"""
        pass

    @property
    def seg_ids(self):
        """ as for seg_frames"""
        pass
