Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1112,17 +1112,10 @@ val t1 = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f))
val t2 = Tensor1(Axis[B]).fromArray(Array(3.0f, 4.0f, 5.0f))
val wrong = t1 + t2 // Different labels AND different sizes
// error:
//
// A tuple of axis labels Tuple1[MdocApp11.this.A | MdocApp11.this.B] was given or inferred that does not have a valid Labels instance.
//
// Ensure that all of the types in the tuple have a 'derives Label' clause.
// .
// I found:
//
// dimwit.tensor.Labels.concat[head, tail](
// /* missing */summon[dimwit.tensor.Label[head]], ???)
//
// But no implicit values were found that match type dimwit.tensor.Label[head].
// Found: (MdocApp11.this.t2 :
// dimwit.tensor.Tensor1[MdocApp11.this.B, dimwit.tensor.DType.Float32])
// Required: dimwit.tensor.Tensor[Tuple1[MdocApp11.this.A],
// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)]
```

```scala
Expand Down
14 changes: 7 additions & 7 deletions core/src/main/scala/dimwit/tensor/TensorOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,21 @@ object TensorOps:
// IsNumber operations (IsFloat or IsInt)
// ---------------------------------------------------------

def add[T <: Tuple: Labels, T1 <: T, T2 <: T, V: IsNumber](t1: Tensor[T1, V], t2: Tensor[T2, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue))
def add[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue))
def addScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue))

def negate[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.negative(t.jaxValue))
def subtract[T <: Tuple: Labels, T1 <: T, T2 <: T, V: IsNumber](t1: Tensor[T1, V], t2: Tensor[T2, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t1.jaxValue, t2.jaxValue))
def subtract[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t1.jaxValue, t2.jaxValue))
def subtractScalar[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, t2.jaxValue))

def multiply[T <: Tuple: Labels, T1 <: T, T2 <: T, V: IsNumber](t1: Tensor[T1, V], t2: Tensor[T2, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue))
def multiply[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue))
def multiplyScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue))

extension [T <: Tuple: Labels, T1 <: T, T2 <: T, V: IsNumber](t: Tensor[T1, V])
extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])

def +(other: Tensor[T2, V]): Tensor[T, V] = add(t, other)
def -(other: Tensor[T2, V]): Tensor[T, V] = subtract(t, other)
def *(other: Tensor[T2, V]): Tensor[T, V] = multiply(t, other)
def +(other: Tensor[T, V]): Tensor[T, V] = add(t, other)
def -(other: Tensor[T, V]): Tensor[T, V] = subtract(t, other)
def *(other: Tensor[T, V]): Tensor[T, V] = multiply(t, other)

extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])

Expand Down
44 changes: 12 additions & 32 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,25 +207,12 @@ val tensor1 = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2)).fill(1.0f)
val tensor3 = Tensor(Shape(Axis[A] -> 3, Axis[C] -> 2)).fill(2.0f)
tensor1 + tensor3
// error:
//
// A tuple of axis labels MdocApp1.this.A *: (MdocApp1.this.B *: EmptyTuple | MdocApp1.this.C *:
// EmptyTuple) was given or inferred that does not have a valid Labels instance.
//
// Ensure that all of the types in the tuple have a 'derives Label' clause.
// .
// I found:
//
// dimwit.tensor.Labels.given_Labels_A_B[A², B²](
// dimwit.tensor.Labels.lift[MdocApp1.this.A](this.A.derived$Label),
// dimwit.tensor.Labels.lift[B²](/* missing */summon[dimwit.tensor.Label[B²]]))
//
// But no implicit values were found that match type dimwit.tensor.Label[B²]
//
// where: A is a trait in class MdocApp1
// A² is a type variable
// B is a trait in class MdocApp1
// B² is a type variable
// .
// Found: (MdocApp1.this.tensor3 :
// dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.C),
// dimwit.tensor.DType.Float32]
// )
// Required: dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B),
// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)]
// tensor1 + tensor3
// ^^^^^^^
// error:
Expand All @@ -249,19 +236,12 @@ trait C derives Label
val tensor3 = Tensor(Shape(Axis[A] -> 3)).fill(1.0f)
tensor1 + tensor3
// error:
//
// A tuple of axis labels MdocApp1.this.A *: (MdocApp1.this.B *: EmptyTuple | EmptyTuple) was given or inferred that does not have a valid Labels instance.
//
// Ensure that all of the types in the tuple have a 'derives Label' clause.
// .
// I found:
//
// dimwit.tensor.Labels.concat[head, tail](this.A.derived$Label,
// dimwit.tensor.Labels.lift[tail](/* missing */summon[dimwit.tensor.Label[tail]]
// )
// )
//
// But no implicit values were found that match type dimwit.tensor.Label[tail].
// Found: (MdocApp1.this.tensor3 :
// dimwit.tensor.Tensor[MdocApp1.this.A *: EmptyTuple,
// dimwit.tensor.DType.Float32]
// )
// Required: dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B),
// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)]
// tensor1 + tensor3
// ^^^^^^^
```
Expand Down
Loading