diff --git a/AGENTS.md b/AGENTS.md index 6bd69686..01369f68 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1112,17 +1112,10 @@ val t1 = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) val t2 = Tensor1(Axis[B]).fromArray(Array(3.0f, 4.0f, 5.0f)) val wrong = t1 + t2 // Different labels AND different sizes // error: -// -// A tuple of axis labels Tuple1[MdocApp11.this.A | MdocApp11.this.B] was given or inferred that does not have a valid Labels instance. -// -// Ensure that all of the types in the tuple have a 'derives Label' clause. -// . -// I found: -// -// dimwit.tensor.Labels.concat[head, tail]( -// /* missing */summon[dimwit.tensor.Label[head]], ???) -// -// But no implicit values were found that match type dimwit.tensor.Label[head]. +// Found: (MdocApp11.this.t2 : +// dimwit.tensor.Tensor1[MdocApp11.this.B, dimwit.tensor.DType.Float32]) +// Required: dimwit.tensor.Tensor[Tuple1[MdocApp11.this.A], +// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)] ``` ```scala diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index 5d09abe1..0ce086d9 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -111,21 +111,21 @@ object TensorOps: // IsNumber operations (IsFloat or IsInt) // --------------------------------------------------------- - def add[T <: Tuple: Labels, T1 <: T, T2 <: T, V: IsNumber](t1: Tensor[T1, V], t2: Tensor[T2, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue)) + def add[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue)) def addScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue)) def negate[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.negative(t.jaxValue)) - def subtract[T <: Tuple: Labels, T1 <: T, T2 <: T, V: IsNumber](t1: Tensor[T1, V], t2: Tensor[T2, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t1.jaxValue, t2.jaxValue)) + def subtract[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t1.jaxValue, t2.jaxValue)) def subtractScalar[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, t2.jaxValue)) - def multiply[T <: Tuple: Labels, T1 <: T, T2 <: T, V: IsNumber](t1: Tensor[T1, V], t2: Tensor[T2, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) + def multiply[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) def multiplyScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) - extension [T <: Tuple: Labels, T1 <: T, T2 <: T, V: IsNumber](t: Tensor[T1, V]) + extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) - def +(other: Tensor[T2, V]): Tensor[T, V] = add(t, other) - def -(other: Tensor[T2, V]): Tensor[T, V] = subtract(t, other) - def *(other: Tensor[T2, V]): Tensor[T, V] = multiply(t, other) + def +(other: Tensor[T, V]): Tensor[T, V] = add(t, other) + def -(other: Tensor[T, V]): Tensor[T, V] = subtract(t, other) + def *(other: Tensor[T, V]): Tensor[T, V] = multiply(t, other) extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) diff --git a/docs/quickstart.md b/docs/quickstart.md index ba901c2e..c75886c4 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -207,25 +207,12 @@ val tensor1 = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2)).fill(1.0f) val tensor3 = Tensor(Shape(Axis[A] -> 3, Axis[C] -> 2)).fill(2.0f) tensor1 + tensor3 // error: -// -// A tuple of axis labels MdocApp1.this.A *: (MdocApp1.this.B *: EmptyTuple | MdocApp1.this.C *: -// EmptyTuple) was given or inferred that does not have a valid Labels instance. -// -// Ensure that all of the types in the tuple have a 'derives Label' clause. -// . -// I found: -// -// dimwit.tensor.Labels.given_Labels_A_B[A², B²]( -// dimwit.tensor.Labels.lift[MdocApp1.this.A](this.A.derived$Label), -// dimwit.tensor.Labels.lift[B²](/* missing */summon[dimwit.tensor.Label[B²]])) -// -// But no implicit values were found that match type dimwit.tensor.Label[B²] -// -// where: A is a trait in class MdocApp1 -// A² is a type variable -// B is a trait in class MdocApp1 -// B² is a type variable -// . +// Found: (MdocApp1.this.tensor3 : +// dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.C), +// dimwit.tensor.DType.Float32] +// ) +// Required: dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B), +// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)] // tensor1 + tensor3 // ^^^^^^^ // error: @@ -249,19 +236,12 @@ trait C derives Label val tensor3 = Tensor(Shape(Axis[A] -> 3)).fill(1.0f) tensor1 + tensor3 // error: -// -// A tuple of axis labels MdocApp1.this.A *: (MdocApp1.this.B *: EmptyTuple | EmptyTuple) was given or inferred that does not have a valid Labels instance. -// -// Ensure that all of the types in the tuple have a 'derives Label' clause. -// . -// I found: -// -// dimwit.tensor.Labels.concat[head, tail](this.A.derived$Label, -// dimwit.tensor.Labels.lift[tail](/* missing */summon[dimwit.tensor.Label[tail]] -// ) -// ) -// -// But no implicit values were found that match type dimwit.tensor.Label[tail]. +// Found: (MdocApp1.this.tensor3 : +// dimwit.tensor.Tensor[MdocApp1.this.A *: EmptyTuple, +// dimwit.tensor.DType.Float32] +// ) +// Required: dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B), +// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)] // tensor1 + tensor3 // ^^^^^^^ ```