Skip to content
16 changes: 7 additions & 9 deletions monarch_extension/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use hyperactor::ActorId;
use monarch_hyperactor::ndslice::PySlice;
use monarch_hyperactor::proc::PyActorId;
use monarch_messages::controller::Seq;
use monarch_messages::wire_value::func_call_args_to_wire_values;
use monarch_messages::worker;
use monarch_messages::worker::ArgsKwargs;
use monarch_messages::worker::CallFunctionParams;
use monarch_messages::worker::Cloudpickle;
use monarch_messages::worker::Factory;
Expand Down Expand Up @@ -220,17 +220,16 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
});
m.insert(key("CallFunction"), |p| {
let function = p.parseFunction("function")?;
let args = p.parse("args")?;
let kwargs = p.parse("kwargs")?;
let args: Bound<'_, PyTuple> = p.parse("args")?;
let kwargs: Bound<'_, PyDict> = p.parse("kwargs")?;

let (args, kwargs) = func_call_args_to_wire_values(Some(&function), &args, &kwargs)?;
let args_kwargs = ArgsKwargs::from_python(args.into_any(), kwargs.into_any())?;
Ok(WorkerMessage::CallFunction(CallFunctionParams {
seq: p.parseSeq("ident")?,
results: p.parseFlatReferences("result")?,
mutates: p.parseRefList("mutates")?,
function,
args,
kwargs,
args_kwargs,
stream: p.parseStreamRef("stream")?,
remote_process_groups: p.parseRefList("remote_process_groups")?,
}))
Expand Down Expand Up @@ -340,14 +339,13 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
"SendValue with no function must have exactly one argument and no keyword arguments",
));
}
let (args, kwargs) = func_call_args_to_wire_values(function.as_ref(), &args, &kwargs)?;
let args_kwargs = ArgsKwargs::from_python(args.into_any(), kwargs.into_any())?;
Ok(WorkerMessage::SendValue {
seq: p.parseSeq("ident")?,
destination: p.parseOptionalRef("destination")?,
mutates: p.parseRefList("mutates")?,
function,
args,
kwargs,
args_kwargs,
stream: p.parseStreamRef("stream")?,
})
});
Expand Down
266 changes: 7 additions & 259 deletions monarch_messages/src/wire_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,22 @@
* LICENSE file in the root directory of this source tree.
*/

use std::collections::HashMap;

use derive_more::From;
use derive_more::TryInto;
use enum_as_inner::EnumAsInner;
use hyperactor::Named;
use monarch_types::PickledPyObject;
use monarch_types::TryIntoPyObjectUnsafe;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBool;
use pyo3::types::PyDict;
use pyo3::types::PyFloat;
use pyo3::types::PyList;
use pyo3::types::PyNone;
use pyo3::types::PyString;
use pyo3::types::PyTuple;
use serde::Deserialize;
use serde::Serialize;
use torch_sys::Device;
use torch_sys::Layout;
use torch_sys::MemoryFormat;
use torch_sys::OpaqueIValue;
use torch_sys::ScalarType;

use crate::worker::Ref;
use crate::worker::ResolvableFunction;

/// A value used as an input to CallFunction.
// TODO, this is basically the same as RValue, but with TensorIndices swapped
Expand All @@ -59,81 +47,20 @@ pub enum WireValue {
// empty enum variants.
None(()),
PyObject(PickledPyObject),
// It is ok to just have IValue without an alias tracking cell as we just use
// WireValue as a way to serialize and send args to workers. We dont mutate the
// IValue and use the opaque wrapper to make accessing the IValue directly
// an unsafe op.
IValue(torch_sys::OpaqueIValue),
}

impl FromPyObject<'_> for WireValue {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(ref_) = Ref::from_py_object(obj) {
Ok(WireValue::Ref(ref_))
} else if let Ok(list) = obj.downcast::<PyList>() {
let len = list.len();
if len == 0 {
// TODO: This is done for now as this seems to be the most common case for empty lists
// in torch ops but we should use the op schema to do this correctly.
return Ok(WireValue::IntList(vec![]));
}

// SAFETY: We know it is within bounds
let item = unsafe { list.get_item_unchecked(0) };
let len = list.len();
if let Ok(int) = item.extract::<i64>() {
let mut int_list = Vec::with_capacity(len);
int_list.push(int);
for item in list.iter().skip(1) {
int_list.push(item.extract::<i64>().map_err(|_| {
PyValueError::new_err(format!(
"Expected homogeneous list of ints got: {:?}",
list
))
})?);
}
return Ok(WireValue::IntList(int_list));
}
if let Ok(ref_) = Ref::from_py_object(&item) {
let mut ref_list = Vec::with_capacity(len);
ref_list.push(ref_);
for item in list.iter().skip(1) {
ref_list.push(Ref::from_py_object(&item).map_err(|_| {
PyValueError::new_err(format!(
"Expected homogeneous list of ints got: {:?}",
list
))
})?);
}
return Ok(WireValue::RefList(ref_list));
}
Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
} else if obj.is_none() {
Ok(WireValue::None(()))
} else if let Ok(bool_) = obj.downcast::<PyBool>() {
Ok(WireValue::Bool(bool_.is_true()))
} else if let Ok(int) = obj.extract::<i64>() {
Ok(WireValue::Int(int))
} else if let Ok(double) = obj.downcast::<PyFloat>() {
Ok(WireValue::Double(double.value()))
} else if let Ok(string) = obj.downcast::<PyString>() {
Ok(WireValue::String(string.to_str()?.to_string()))
} else if let Ok(device) = obj.extract::<Device>() {
Ok(WireValue::Device(device))
} else if let Ok(layout) = obj.extract::<Layout>() {
Ok(WireValue::Layout(layout))
} else if let Ok(scalar_type) = obj.extract::<ScalarType>() {
Ok(WireValue::ScalarType(scalar_type))
} else if let Ok(memory_format) = obj.extract::<MemoryFormat>() {
Ok(WireValue::MemoryFormat(memory_format))
} else {
Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
}
Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
}
}

impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue {
unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
impl<'py> IntoPyObject<'py> for WireValue {
type Target = PyAny;
type Output = Bound<'py, PyAny>;
type Error = PyErr;

fn into_pyobject(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
match self {
WireValue::Ref(ref_) => ref_.into_bound_py_any(py),
WireValue::RefList(ref_list) => ref_list.clone().into_bound_py_any(py),
Expand All @@ -148,191 +75,12 @@ impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue {
WireValue::MemoryFormat(val) => val.into_bound_py_any(py),
WireValue::None(()) => PyNone::get(py).into_bound_py_any(py),
WireValue::PyObject(val) => val.unpickle(py),
// SAFETY: WireValue is only used for serde between client and worker.
// This function is used to access the args / kwargs of a function call
// on the client side only.
WireValue::IValue(val) => unsafe { val.try_to_object_unsafe(py) },
}
}
}

impl<'py> IntoPyObject<'py> for WireValue {
type Target = PyAny;
type Output = Bound<'py, PyAny>;
type Error = PyErr;

fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
// SAFETY: We are going to remove this safe/unsafe distinction
unsafe { self.try_to_object_unsafe(py) }
}
}

impl From<PyObject> for WireValue {
fn from(obj: PyObject) -> Self {
Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap()))
}
}

pub fn func_call_args_to_wire_values(
_func: Option<&ResolvableFunction>,
args: &Bound<'_, PyTuple>,
kwargs: &Bound<'_, PyDict>,
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
python_func_args_to_wire_value(args, kwargs)
}

fn python_func_args_to_wire_value(
args: &Bound<'_, PyTuple>,
kwargs: &Bound<'_, PyDict>,
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
let args = args
.iter()
.map(|arg| Ok(WireValue::PyObject(PickledPyObject::pickle(&arg)?)))
.collect::<PyResult<_>>()?;
let kwargs = kwargs
.iter()
.map(|(k, v)| {
Ok((
k.extract::<String>()?,
WireValue::PyObject(PickledPyObject::pickle(&v)?),
))
})
.collect::<Result<HashMap<_, _>, PyErr>>()?;
Ok((args, kwargs))
}

#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;

use anyhow::Result;
use anyhow::bail;
use paste::paste;
use pyo3::Python;
use pyo3::ffi::c_str;
use pyo3::types::PyDict;
use torch_sys::DeviceType;
use torch_sys::ScalarType;

use super::*;
use crate::worker::Ref;

const MOCK_REFERNCABLE_MODULE: &std::ffi::CStr = c_str!(
r#"
class Referencable:
def __init__(self, ref: int):
self.ref = ref

def __monarch_ref__(self):
return self.ref
"#
);

fn setup() -> Result<()> {
pyo3::prepare_freethreaded_python();
// We need to load torch to initialize some internal structures used by
// the FFI funcs we use to convert ivalues to/from py objects.
Python::with_gil(|py| py.run(c_str!("import torch"), None, None))?;
Ok(())
}

fn create_py_object() -> PyObject {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("foo", "bar").unwrap();
dict.into_any().clone().unbind()
})
}

macro_rules! generate_wire_value_from_py_tests {
($($kind:ident, $input:expr);* $(;)?) => {
paste! {
$(
#[test]
fn [<test_wire_value_from_py_$kind:snake:lower>]() -> Result<()> {
setup()?;
Python::with_gil(|py| {
let actual = $input.into_pyobject(py)?.extract::<WireValue>()?;
assert_matches!(actual, WireValue::$kind(_));
anyhow::Ok(())
})
}
)*

#[test]
fn test_wire_value_from_py_none() -> Result<()> {
setup()?;
Python::with_gil(|py| {
let obj = PyNone::get(py).into_pyobject(py)?;
let actual = obj.extract::<WireValue>()?;
assert_matches!(actual, WireValue::None(_));
anyhow::Ok(())
})
}

#[test]
fn test_wire_value_from_py_empty_list() -> Result<()> {
setup()?;
Python::with_gil(|py| {
let obj: PyObject = PyList::empty(py).into_any().unbind();
let actual = obj.extract::<WireValue>(py)?;
match actual {
WireValue::IntList(list) if list.len() == 0 => (),
_ => bail!("Expected empty list to be converted to empty int list"),
}
anyhow::Ok(())
})
}

#[test]
fn test_wire_value_from_py_referencable_class() -> Result<()> {
setup()?;
Python::with_gil(|py| {
let referencable = PyModule::from_code(
py,
MOCK_REFERNCABLE_MODULE,
c_str!("referencable.py"),
c_str!("referencable"),
)?;
let ref_ = referencable.getattr("Referencable")?.call1((1,))?.unbind();
let actual = ref_.extract::<WireValue>(py)?;
assert_matches!(actual, WireValue::Ref(Ref { id: 1 }));
anyhow::Ok(())
})
}

#[test]
fn test_wire_value_from_py_roundtrip_was_exhaustive() {
let val = WireValue::Int(0);
match val {
$(WireValue::$kind(_) => (),)*
WireValue::None(_) => (),
// Can't test from py here as PyObject behaves as catch all for conversion from PY.
// We will manually convert torch ops args to IValue respecting the schema so its
// not super important to have this.
WireValue::IValue(_) => (),
}
}
}
}
}

// Generate exhaustive roundtrip tests for all IValue kind.
// If you got a "non-exhaustive patterns" error here, you need to add a new
// test entry for your IValue kind!
generate_wire_value_from_py_tests! {
Bool, false;
Double, 1.23f64;
Int, 123i64;
IntList, vec![1i64];
Ref, Ref::from(1);
RefList, vec![Ref::from(1), Ref::from(2)];
String, "foobar".to_owned();
Device, Device::new(DeviceType::CPU);
Layout, Layout(2);
ScalarType, ScalarType(3);
MemoryFormat, MemoryFormat(1);
PyObject, create_py_object();
}
}
Loading