diff --git a/__tests__/strict-mode.js b/__tests__/strict-mode.js new file mode 100644 index 00000000..7328a35e --- /dev/null +++ b/__tests__/strict-mode.js @@ -0,0 +1,132 @@ +"use strict" +import produce, { + setStrictMode, + unsafe, + immerable, + enableMapSet +} from "../src/immer" + +enableMapSet() + +describe("Strict Mode", () => { + class Foo {} + + describe("by default", () => { + it("should not throw an error when accessing a non-draftable class instance", () => { + expect.hasAssertions() + + produce({instance: new Foo()}, draft => { + expect(() => { + draft.instance.value = 5 + }).not.toThrow() + }) + }) + }) + + afterAll(() => { + setStrictMode(false) + }) + + describe("when disabled", () => { + beforeEach(() => { + setStrictMode(false) + }) + + it("should allow accessing a non-draftable class instance", () => { + expect.hasAssertions() + + produce({instance: new Foo()}, draft => { + expect(() => { + draft.instance.value = 5 + }).not.toThrow() + }) + }) + + it("should not throw errors when using the `unsafe` function", () => { + expect.hasAssertions() + + produce({instance: new Foo()}, draft => { + unsafe(() => { + expect(() => { + draft.instance.value = 5 + }).not.toThrow() + }) + }) + }) + }) + + describe("when enabled", () => { + beforeEach(() => { + setStrictMode(true) + }) + + it("should throw an error when accessing a non-draftable class instance", () => { + expect.hasAssertions() + + produce({instance: new Foo()}, draft => { + expect(() => draft.instance).toThrow() + }) + }) + + it("should allow accessing a non-draftable using the `unsafe` function", () => { + expect.hasAssertions() + + produce({instance: new Foo()}, draft => { + unsafe(() => { + expect(() => { + draft.instance.value = 5 + }).not.toThrow() + }) + }) + }) + + it("should require using unsafe for non-draftables in a different scope", () => { + expect.assertions(2) + + produce({instance: new Foo()}, () => { + unsafe(() => { + produce({nested: new Foo()}, nestedDraft => { + expect(() => nestedDraft.nested).toThrow() + + unsafe(() => { + expect(() => nestedDraft.nested).not.toThrow() + }) + }) + }) + }) + }) + + describe("with an immerable class", () => { + beforeAll(() => { + Foo[immerable] = true + }) + + afterAll(() => { + Foo[immerable] = true + }) + + it("should allow accessing the class instance", () => { + expect.hasAssertions() + + produce({instance: new Foo()}, draft => { + expect(() => { + draft.instance.value = 5 + }).not.toThrow() + }) + }) + }) + + it("should allow accessing draftable properties", () => { + expect(() => + produce({arr: [], obj: {}, map: new Map(), set: new Set()}, draft => { + draft.arr.push(1) + draft.arr[0] = 1 + draft.obj.foo = 5 + draft.obj.hasOwnProperty("abc") + draft.map.set("foo", 5) + draft.set.add("foo") + }) + ).not.toThrow() + }) + }) +}) diff --git a/src/core/immerClass.ts b/src/core/immerClass.ts index f98d3756..c80c7ba5 100644 --- a/src/core/immerClass.ts +++ b/src/core/immerClass.ts @@ -37,11 +37,19 @@ export class Immer implements ProducersFns { autoFreeze_: boolean = true - constructor(config?: {useProxies?: boolean; autoFreeze?: boolean}) { + strictModeEnabled_: boolean = false + + constructor(config?: { + useProxies?: boolean + autoFreeze?: boolean + strictMode?: boolean + }) { if (typeof config?.useProxies === "boolean") this.setUseProxies(config!.useProxies) if (typeof config?.autoFreeze === "boolean") this.setAutoFreeze(config!.autoFreeze) + if (typeof config?.strictMode === "boolean") + this.setStrictMode(config!.strictMode) this.produce = this.produce.bind(this) this.produceWithPatches = this.produceWithPatches.bind(this) } @@ -183,6 +191,23 @@ export class Immer implements ProducersFns { this.useProxies_ = value } + /** + * Pass true to throw errors when attempting to access a non-draftable reference. + * + * By default, strict mode is disabled. + */ + setStrictMode(value: boolean) { + this.strictModeEnabled_ = value + } + + unsafe(callback: () => void) { + const scope = getCurrentScope() + + scope.unsafeNonDraftabledAllowed_ = true + callback() + scope.unsafeNonDraftabledAllowed_ = false + } + applyPatches(base: Objectish, patches: Patch[]) { // If a patch replaces the entire state, take that replacement as base // before applying patches diff --git a/src/core/proxy.ts b/src/core/proxy.ts index 2469d73e..319af4fb 100644 --- a/src/core/proxy.ts +++ b/src/core/proxy.ts @@ -108,7 +108,19 @@ export const objectTraps: ProxyHandler = { return readPropFromProto(state, source, prop) } const value = source[prop] - if (state.finalized_ || !isDraftable(value)) { + if (state.finalized_) { + return value + } + if (!isDraftable(value)) { + if ( + state.scope_.immer_.strictModeEnabled_ && + !state.scope_.unsafeNonDraftabledAllowed_ && + typeof value === "object" && + value !== null + ) { + die(24) + } + return value } // Check for existing draft in modified state. diff --git a/src/core/scope.ts b/src/core/scope.ts index 4505ea63..90bbe893 100644 --- a/src/core/scope.ts +++ b/src/core/scope.ts @@ -22,6 +22,7 @@ export interface ImmerScope { patchListener_?: PatchListener immer_: Immer unfinalizedDrafts_: number + unsafeNonDraftabledAllowed_: boolean } let currentScope: ImmerScope | undefined @@ -42,7 +43,8 @@ function createScope( // Whenever the modified draft contains a draft from another scope, we // need to prevent auto-freezing so the unowned draft can be finalized. canAutoFreeze_: true, - unfinalizedDrafts_: 0 + unfinalizedDrafts_: 0, + unsafeNonDraftabledAllowed_: false } } diff --git a/src/immer.ts b/src/immer.ts index 53455494..e3e6e1ab 100644 --- a/src/immer.ts +++ b/src/immer.ts @@ -67,6 +67,18 @@ export const setAutoFreeze = immer.setAutoFreeze.bind(immer) */ export const setUseProxies = immer.setUseProxies.bind(immer) +/** + * Pass true to throw errors when attempting to access a non-draftable reference. + * + * By default, strict mode is disabled. + */ +export const setStrictMode = immer.setStrictMode.bind(immer) + +/** + * Allow accessing non-draftable references in strict mode inside the callback. + */ +export const unsafe = immer.unsafe.bind(immer) + /** * Apply an array of Immer patches to the first argument. * diff --git a/src/plugins/es5.ts b/src/plugins/es5.ts index d0c3e02d..45f6ec7f 100644 --- a/src/plugins/es5.ts +++ b/src/plugins/es5.ts @@ -31,7 +31,9 @@ export function enableES5() { ) { if (!isReplaced) { if (scope.patches_) { + scope.unsafeNonDraftabledAllowed_ = true markChangesRecursively(scope.drafts_![0]) + scope.unsafeNonDraftabledAllowed_ = false } // This is faster when we don't care about which attributes changed. markChangesSweep(scope.drafts_) diff --git a/src/types/index.js.flow b/src/types/index.js.flow index 1be017a2..c6e742f7 100644 --- a/src/types/index.js.flow +++ b/src/types/index.js.flow @@ -84,6 +84,18 @@ declare export function setAutoFreeze(autoFreeze: boolean): void */ declare export function setUseProxies(useProxies: boolean): void +/** + * Pass true to throw errors when attempting to access a non-draftable reference. + * + * By default, strict mode is disabled. + */ +declare export function setStrictMode(strictMode: boolean): void + +/** + * Allow accessing non-draftable references in strict mode inside the callback. + */ +declare export function unsafe(callback: () => void): void + declare export function applyPatches(state: S, patches: Patch[]): S declare export function original(value: S): S diff --git a/src/utils/errors.ts b/src/utils/errors.ts index b73e2aac..7d158bb0 100644 --- a/src/utils/errors.ts +++ b/src/utils/errors.ts @@ -38,7 +38,8 @@ const errors = { }, 23(thing: string) { return `'original' expects a draft, got: ${thing}` - } + }, + 24: "Cannot get a non-draftable reference in strict mode. Use the `unsafe` function, add the `immerable` symbol, or disable strict mode" } as const export function die(error: keyof typeof errors, ...args: any[]): never {