Skip to content

Commit e236a65

Browse files
committed
fix: refinement type assert cast bug
1 parent 1762588 commit e236a65

File tree

7 files changed

+72
-8
lines changed

7 files changed

+72
-8
lines changed

crates/erg_common/triple.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ impl<T> Triple<T, T> {
141141
Triple::Ok(a) | Triple::Err(a) => Some(a),
142142
}
143143
}
144+
145+
pub fn merge_or(self, default: T) -> T {
146+
match self {
147+
Triple::None => default,
148+
Triple::Ok(ok) => ok,
149+
Triple::Err(err) => err,
150+
}
151+
}
144152
}
145153

146154
impl<T, E: std::error::Error> Triple<T, E> {

crates/erg_compiler/context/compare.rs

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ impl Context {
119119
self.supertype_of(lhs, rhs) || self.subtype_of(lhs, rhs)
120120
}
121121

122+
pub(crate) fn _related_tp(&self, lhs: &TyParam, rhs: &TyParam) -> bool {
123+
self._subtype_of_tp(lhs, rhs, Variance::Covariant)
124+
|| self.supertype_of_tp(lhs, rhs, Variance::Covariant)
125+
}
126+
122127
/// lhs :> rhs ?
123128
pub(crate) fn supertype_of(&self, lhs: &Type, rhs: &Type) -> bool {
124129
let res = match Self::cheap_supertype_of(lhs, rhs) {
@@ -1118,6 +1123,10 @@ impl Context {
11181123
}
11191124
}
11201125

1126+
pub(crate) fn covariant_supertype_of_tp(&self, lp: &TyParam, rp: &TyParam) -> bool {
1127+
self.supertype_of_tp(lp, rp, Variance::Covariant)
1128+
}
1129+
11211130
/// lhs <: rhs?
11221131
pub(crate) fn structural_subtype_of(&self, lhs: &Type, rhs: &Type) -> bool {
11231132
self.structural_supertype_of(rhs, lhs)
@@ -1282,6 +1291,7 @@ impl Context {
12821291
/// union(Array(Int, 2), Array(Str, 3)) == Array(Int, 2) or Array(Int, 3)
12831292
/// union({ .a = Int }, { .a = Str }) == { .a = Int or Str }
12841293
/// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int }
1294+
/// union((A and B) or C) == (A or C) and (B or C)
12851295
/// ```
12861296
pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type {
12871297
if lhs == rhs {
@@ -1345,6 +1355,16 @@ impl Context {
13451355
_ => self.simple_union(lhs, rhs),
13461356
},
13471357
(other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other),
1358+
// (A and B) or C ==> (A or C) and (B or C)
1359+
(and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => {
1360+
let ands = and_t.ands();
1361+
let mut t = Type::Obj;
1362+
for branch in ands.iter() {
1363+
let union = self.union(branch, other);
1364+
t = and(t, union);
1365+
}
1366+
t
1367+
}
13481368
(t, Type::Never) | (Type::Never, t) => t.clone(),
13491369
// Array({1, 2}, 2), Array({3, 4}, 2) ==> Array({1, 2, 3, 4}, 2)
13501370
(
@@ -1497,12 +1517,6 @@ impl Context {
14971517
self.intersection(&fv.crack(), other)
14981518
}
14991519
(Refinement(l), Refinement(r)) => Type::Refinement(self.intersection_refinement(l, r)),
1500-
(other, Refinement(refine)) | (Refinement(refine), other) => {
1501-
let other = other.clone().into_refinement();
1502-
let intersec = self.intersection_refinement(&other, refine);
1503-
self.try_squash_refinement(intersec)
1504-
.unwrap_or_else(Type::Refinement)
1505-
}
15061520
(Structural(l), Structural(r)) => self.intersection(l, r).structuralize(),
15071521
(Guard(l), Guard(r)) => {
15081522
if l.namespace == r.namespace && l.target == r.target {
@@ -1527,6 +1541,25 @@ impl Context {
15271541
(other, and @ And(_, _)) | (and @ And(_, _), other) => {
15281542
self.intersection_add(and, other)
15291543
}
1544+
// (A or B) and C == (A and C) or (B and C)
1545+
(or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => {
1546+
let ors = or_t.ors();
1547+
if ors.iter().any(|t| t.has_unbound_var()) {
1548+
return self.simple_intersection(lhs, rhs);
1549+
}
1550+
let mut t = Type::Never;
1551+
for branch in ors.iter() {
1552+
let isec = self.intersection(branch, other);
1553+
t = self.union(&t, &isec);
1554+
}
1555+
t
1556+
}
1557+
(other, Refinement(refine)) | (Refinement(refine), other) => {
1558+
let other = other.clone().into_refinement();
1559+
let intersec = self.intersection_refinement(&other, refine);
1560+
self.try_squash_refinement(intersec)
1561+
.unwrap_or_else(Type::Refinement)
1562+
}
15301563
// overloading
15311564
(l, r) if l.is_subr() && r.is_subr() => and(lhs.clone(), rhs.clone()),
15321565
_ => self.simple_intersection(lhs, rhs),

crates/erg_compiler/context/eval.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ impl<'c> Substituter<'c> {
155155
/// e.g.
156156
/// ```erg
157157
/// qt: Array(T, N), st: Array(Int, 3)
158+
/// qt: T or NoneType, st: NoneType or Int (T == Int)
158159
/// ```
159160
/// invalid (no effect):
160161
/// ```erg
@@ -167,8 +168,15 @@ impl<'c> Substituter<'c> {
167168
st: &Type,
168169
) -> EvalResult<Option<Self>> {
169170
let qtps = qt.typarams();
170-
let stps = st.typarams();
171-
if qt.qual_name() != st.qual_name() || qtps.len() != stps.len() {
171+
let mut stps = st.typarams();
172+
// Or, And are commutative, choose fitting order
173+
if qt.qual_name() == st.qual_name() && (st.qual_name() == "Or" || st.qual_name() == "And") {
174+
if ctx.covariant_supertype_of_tp(&qtps[0], &stps[1])
175+
&& ctx.covariant_supertype_of_tp(&qtps[1], &stps[0])
176+
{
177+
stps.swap(0, 1);
178+
}
179+
} else if qt.qual_name() != st.qual_name() || qtps.len() != stps.len() {
172180
if let Some(inner) = st.ref_inner().or_else(|| st.ref_mut_inner()) {
173181
return Self::substitute_typarams(ctx, qt, &inner);
174182
} else if let Some(sub) = st.get_sub() {

crates/erg_compiler/context/generalize.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,9 @@ impl Context {
11801180
super_exists
11811181
}
11821182

1183+
/// Check if a trait implementation exists for a polymorphic class.
1184+
/// This is needed because the trait implementation spec can contain projection types.
1185+
/// e.g. `Tuple(Ts) <: Container(Ts.union())`
11831186
fn poly_class_trait_impl_exists(&self, class: &Type, trait_: &Type) -> bool {
11841187
let class_hash = get_hash(&class);
11851188
let trait_hash = get_hash(&trait_);

crates/erg_compiler/context/inquire.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3657,6 +3657,7 @@ impl Context {
36573657
/// ```erg
36583658
/// recover_typarams(Int, Nat) == Nat
36593659
/// recover_typarams(Array!(Int, _), Array(Nat, 2)) == Array!(Nat, 2)
3660+
/// recover_typarams(Str or NoneType, {"a", "b"}) == {"a", "b"}
36603661
/// ```
36613662
/// ```erg
36623663
/// # REVIEW: should be?

crates/erg_compiler/tests/infer.er

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ c_new x, y = C.new x, y
3030
C = Class Int
3131
C.
3232
new x, y = Self x + y
33+
34+
val!() =
35+
for! [{ "a": "b" }], (pkg as {Str: Str}) =>
36+
x = pkg.get("a", "c")
37+
assert x in {"b"}
38+
val!::return x
39+
"d"
40+
val = val!()

crates/erg_compiler/tests/test.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ fn _test_infer_types() -> Result<(), ()> {
8787
let c_new_t = func2(add_r, r, c.clone()).quantify();
8888
module.context.assert_var_type("c_new", &c_new_t)?;
8989
module.context.assert_attr_type(&c, "new", &c_new_t)?;
90+
module
91+
.context
92+
.assert_var_type("val", &v_enum(set! { "b".into(), "d".into() }))?;
9093
Ok(())
9194
}
9295

0 commit comments

Comments
 (0)