Description of feature
Once we merged #228, #235, #239, #240, we can start with the torch backend implementation. Therefore
- we will keep one
CellFlow class for both backends
- have a
src/backends/jax and a src/backends/torch directory
CellFlow.prepare_model will have an argument backend: Literal["jax", "torch"]
- for now, let's not implement GENOT for torch, let's just go with
OTFlowMatching, to keep things simple-
- the only problem which arises is that
prepare_model takes backend-specific arguments, namely match_fn, optimizer, and vf_act_fn.
For the last point, I see the following as the best solution
allow passing both jax and torch instances, setting per default None , describe the default in the docs, and eventually instantiating them later in the solver classes
Description of feature
Once we merged #228, #235, #239, #240, we can start with the torch backend implementation. Therefore
CellFlowclass for both backendssrc/backends/jaxand asrc/backends/torchdirectoryCellFlow.prepare_modelwill have an argumentbackend: Literal["jax", "torch"]OTFlowMatching, to keep things simple-prepare_modeltakes backend-specific arguments, namelymatch_fn,optimizer, andvf_act_fn.For the last point, I see the following as the best solution
allow passing both jax and torch instances, setting per default
None, describe the default in the docs, and eventually instantiating them later in the solver classes