Snapshots: allow snapshotting some user code (#4720)

This adds some basic ability to snapshot after executing user code. It is pretty
brittle right now:
1. It will crash if the user loads any binary extensions before taking the
snapshot
2. It doesn't track changes to the file system

Snapshots will probably have to be experimental for quite a while.

1. I think I have a pretty good solution for this, which I will work on in a
followup.

2. One possibility here is we could serialize the entire filesystem state into
the memory snapshot. This would be hard and make the snapshot big, but we
wouldn't have to load python_stdlib.zip when restoring from a snapshot so it
probably wouldn't increase the total download size by much...
This commit is contained in:
Hood Chatham 2024-05-10 18:15:17 -04:00 committed by GitHub
parent afe7215c06
commit a8021791a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 477 additions and 121 deletions

View File

@ -164,52 +164,53 @@ export MAIN_MODULE_LDFLAGS= $(LDFLAGS_BASE) \
-lsdl.js \
-sGL_WORKAROUND_SAFARI_GETCONTEXT_BUG=0
EXPORTS= _main\
\
,_free \
\
,_hiwire_new \
,_hiwire_intern \
,_hiwire_num_refs \
,_hiwire_get \
,_hiwire_incref \
,_hiwire_decref \
,_hiwire_pop \
,__hiwire_set \
,__hiwire_immortal_add \
,_jslib_init \
,_init_pyodide_proxy \
\
,_PyBuffer_Release \
,_Py_DecRef \
,_PyDict_New \
,_PyDict_SetItem \
,__PyErr_CheckSignals \
,_PyErr_CheckSignals \
,_PyErr_Clear \
,_PyErr_Occurred \
,_PyErr_Print \
,_PyErr_SetRaisedException \
,_PyErr_SetString \
,_PyEval_SaveThread \
,_PyEval_RestoreThread \
,_PyFloat_FromDouble \
,_PyGILState_Check \
,_Py_IncRef \
,_PyList_New \
,_PyList_SetItem \
,__PyLong_FromByteArray \
,_PyLong_FromDouble \
,_PyMem_Free \
,_PyObject_GetAIter \
,_PyObject_GetIter \
,_PyObject_Size \
,_PyRun_SimpleString \
,_PySet_Add \
,_PySet_New \
,__PyTraceback_Add \
,_PyUnicode_Data \
,_PyUnicode_New \
EXPORTS=_main \
,_free \
\
,_hiwire_new \
,_hiwire_intern \
,_hiwire_num_refs \
,_hiwire_get \
,_hiwire_incref \
,_hiwire_decref \
,_hiwire_pop \
,__hiwire_get \
,__hiwire_set \
,__hiwire_immortal_get \
,__hiwire_immortal_add \
,_jslib_init \
,_init_pyodide_proxy \
\
,_PyBuffer_Release \
,_Py_DecRef \
,_PyDict_New \
,_PyDict_SetItem \
,__PyErr_CheckSignals \
,_PyErr_CheckSignals \
,_PyErr_Clear \
,_PyErr_Occurred \
,_PyErr_Print \
,_PyErr_SetRaisedException \
,_PyErr_SetString \
,_PyEval_SaveThread \
,_PyEval_RestoreThread \
,_PyFloat_FromDouble \
,_PyGILState_Check \
,_Py_IncRef \
,_PyList_New \
,_PyList_SetItem \
,__PyLong_FromByteArray \
,_PyLong_FromDouble \
,_PyMem_Free \
,_PyObject_GetAIter \
,_PyObject_GetIter \
,_PyObject_Size \
,_PyRun_SimpleString \
,_PySet_Add \
,_PySet_New \
,__PyTraceback_Add \
,_PyUnicode_Data \
,_PyUnicode_New \
ifeq ($(DISABLE_DYLINK), 1)

View File

@ -65,6 +65,11 @@ EM_JS(void, set_pyodide_module, (JsVal mod), {
int
init_pyodide_proxy()
{
EM_ASM({
// sourmash needs open64 to mean the same thing as open.
// Emscripten 3.1.44 seems to have removed it??
wasmImports["open64"] = wasmImports["open"];
});
bool success = false;
// Enable JavaScript access to the _pyodide module.
PyObject* _pyodide = PyImport_ImportModule("_pyodide");
@ -83,12 +88,6 @@ EM_JS_DEPS(pyodide_core_deps, "stackAlloc,stackRestore,stackSave");
PyObject*
PyInit__pyodide_core(void)
{
EM_ASM({
// sourmash needs open64 to mean the same thing as open.
// Emscripten 3.1.44 seems to have removed it??
wasmImports["open64"] = wasmImports["open"];
});
bool success = false;
PyObject* _pyodide = NULL;
PyObject* core_module = NULL;

View File

@ -10,6 +10,12 @@ import { scheduleCallback } from "./scheduler";
import { TypedArray } from "./types";
import { IN_NODE, detectEnvironment } from "./environments";
import "./literal-map.js";
import {
makeGlobalsProxy,
SnapshotConfig,
syncUpSnapshotLoad1,
syncUpSnapshotLoad2,
} from "./snapshot";
// Exported for micropip
API.loadBinaryFile = loadBinaryFile;
@ -645,6 +651,15 @@ export class PyodideAPI {
API.debug_ffi = debug;
return orig;
}
static makeMemorySnapshot(): Uint8Array {
if (!API.config._makeSnapshot) {
throw new Error(
"Can only use pyodide.makeMemorySnapshot if the _makeSnapshot option is passed to loadPyodide",
);
}
return API.makeSnapshot();
}
}
/** @hidden */
@ -699,7 +714,7 @@ API.bootstrapFinalizedPromise = new Promise<void>(
(r) => (bootstrapFinalized = r),
);
function jsFinderHook(o: object) {
export function jsFinderHook(o: object) {
if ("__all__" in o) {
return;
}
@ -713,54 +728,6 @@ function jsFinderHook(o: object) {
});
}
/**
* Set up some of the JavaScript state that is normally set up by C initialization code. TODO:
* adjust C code to simplify.
*
* This is divided up into two parts: syncUpSnapshotLoad1 has to happen at the beginning of
* finalizeBootstrap before the public API is setup, syncUpSnapshotLoad2 happens near the end.
*
* This code is quite sensitive to the details of our setup, so it might break if we move stuff
* around far away in the code base. Ideally over time we can structure the code to make it less
* brittle.
*/
function syncUpSnapshotLoad1() {
// hiwire init puts a null at the beginning of both the mortal and immortal tables.
Module.__hiwire_set(0, null);
Module.__hiwire_immortal_add(null);
// Usually importing _pyodide_core would trigger jslib_init but we need to manually call it.
Module._jslib_init();
// Puts deduplication map into the immortal table.
// TODO: Add support for snapshots to hiwire and move this to a hiwire_snapshot_init function.
Module.__hiwire_immortal_add(new Map());
// An interned JS string.
// TODO: Better system for handling interned strings.
Module.__hiwire_immortal_add(
"This borrowed proxy was automatically destroyed at the end of a function call. Try using create_proxy or create_once_callable.",
);
// Set API._pyodide to a proxy of the _pyodide module.
// Normally called by import _pyodide.
Module._init_pyodide_proxy();
}
/**
* Fill in the JsRef table.
*/
function syncUpSnapshotLoad2() {
[
null,
jsFinderHook,
API.config.jsglobals,
API.public_api,
Module.API,
scheduleCallback,
Module.API,
{},
null,
null,
].forEach((v, idx) => Module.__hiwire_set(idx, v));
}
/**
* This function is called after the emscripten module is finished initializing,
* so eval_code is newly available.
@ -768,8 +735,10 @@ function syncUpSnapshotLoad2() {
* the core `pyodide` apis. (But package loading is not ready quite yet.)
* @private
*/
API.finalizeBootstrap = function (fromSnapshot?: boolean): PyodideInterface {
if (fromSnapshot) {
API.finalizeBootstrap = function (
snapshotConfig?: SnapshotConfig,
): PyodideInterface {
if (snapshotConfig) {
syncUpSnapshotLoad1();
}
let [err, captured_stderr] = API.rawRun("import _pyodide_core");
@ -802,11 +771,15 @@ API.finalizeBootstrap = function (fromSnapshot?: boolean): PyodideInterface {
// Set up key Javascript modules.
let importhook = API._pyodide._importhook;
let pyodide = makePublicAPI();
if (fromSnapshot) {
syncUpSnapshotLoad2();
if (API.config._makeSnapshot) {
API.config.jsglobals = makeGlobalsProxy(API.config.jsglobals);
}
const jsglobals = API.config.jsglobals;
if (snapshotConfig) {
syncUpSnapshotLoad2(jsglobals, snapshotConfig);
} else {
importhook.register_js_finder.callKwargs({ hook: jsFinderHook });
importhook.register_js_module("js", API.config.jsglobals);
importhook.register_js_module("js", jsglobals);
importhook.register_js_module("pyodide_js", pyodide);
}

View File

@ -16,6 +16,7 @@ import type { PyodideInterface } from "./api.js";
import type { TypedArray, Module } from "./types";
import type { EmscriptenSettings } from "./emscripten-settings";
import type { PackageData } from "./load-package";
import { SnapshotConfig } from "./snapshot";
export type { PyodideInterface, TypedArray };
export { version, type PackageData };
@ -42,6 +43,7 @@ export type ConfigType = {
_node_mounts: string[];
env: { [key: string]: string };
packages: string[];
_makeSnapshot: boolean;
};
/**
@ -213,11 +215,13 @@ export async function loadPyodide(
await loadScript(scriptSrc);
}
let snapshot;
let snapshot: Uint8Array | undefined = undefined;
if (options._loadSnapshot) {
snapshot = await options._loadSnapshot;
if (snapshot?.constructor?.name === "ArrayBuffer") {
snapshot = new Uint8Array(snapshot);
const snp = await options._loadSnapshot;
if (ArrayBuffer.isView(snp)) {
snapshot = snp;
} else {
snapshot = new Uint8Array(snp);
}
emscriptenSettings.noInitialRun = true;
// @ts-ignore
@ -248,17 +252,12 @@ If you updated the Pyodide version, make sure you also updated the 'indexURL' pa
throw new Error("Didn't expect to load any more file_packager files!");
};
let snapshotConfig: SnapshotConfig | undefined = undefined;
if (snapshot) {
// @ts-ignore
Module.HEAP8.set(snapshot);
snapshotConfig = API.restoreSnapshot(snapshot);
}
// runPython works starting after the call to finalizeBootstrap.
const pyodide = API.finalizeBootstrap(!!snapshot);
if (options._makeSnapshot) {
// @ts-ignore
pyodide._snapshot = Module.HEAP8.slice();
}
const pyodide = API.finalizeBootstrap(snapshotConfig);
API.sys.path.insert(0, API.config.env.HOME);
if (!pyodide.version.includes("dev")) {

243
src/js/snapshot.ts Normal file
View File

@ -0,0 +1,243 @@
import { jsFinderHook } from "./api";
import { scheduleCallback } from "./scheduler";
declare var Module: any;
export function getExpectedKeys() {
return [
null,
jsFinderHook,
API.config.jsglobals,
API.public_api,
API,
scheduleCallback,
API,
{},
];
}
const getAccessorList = Symbol("getAccessorList");
/**
* @private
*/
export function makeGlobalsProxy(
obj: any,
accessorList: (string | symbol)[] = [],
): any {
return new Proxy(obj, {
get(target, prop, receiver) {
if (prop === getAccessorList) {
return accessorList;
}
// @ts-ignore
const orig = Reflect.get(...arguments);
const descr = Reflect.getOwnPropertyDescriptor(target, prop);
// We're required to return the original value unmodified if it's an own
// property with a non-writable, non-configurable data descriptor
if (descr && descr.writable === false && !descr.configurable) {
return orig;
}
// Or an accessor descriptor with a setter but no getter
if (descr && descr.set && !descr.get) {
return orig;
}
if (!["object", "function"].includes(typeof orig)) {
return orig;
}
return makeGlobalsProxy(orig, [...accessorList, prop]);
},
getPrototypeOf() {
// @ts-ignore
return makeGlobalsProxy(Reflect.getPrototypeOf(...arguments), [
...accessorList,
"[getProtoTypeOf]",
]);
},
});
}
export type SnapshotConfig = {
hiwireKeys: (string[] | null)[];
immortalKeys: string[];
};
const SNAPSHOT_MAGIC = 0x706e7300; // "\x00snp"
// TODO: Make SNAPSHOT_BUILD_ID distinct for each build of pyodide.asm.js / pyodide.asm.wasm
const SNAPSHOT_BUILD_ID = 0;
const HEADER_SIZE = 4 * 4;
// The expected index of the deduplication map in the immortal externref table.
// We double check that this is still right in makeSnapshot (when creating the
// snapshot) and in syncUpSnapshotLoad1 (when using it).
const MAP_INDEX = 5;
API.makeSnapshot = function (): Uint8Array {
if (!API.config._makeSnapshot) {
throw new Error(
"makeSnapshot only works if you passed the makeSnapshot option to loadPyodide",
);
}
const hiwireKeys: (string[] | null)[] = [];
const expectedKeys = getExpectedKeys();
for (let i = 0; i < expectedKeys.length; i++) {
let value;
try {
value = Module.__hiwire_get(i);
} catch (e) {
throw new Error(`Failed to get value at index ${i}`);
}
let isOkay = false;
try {
isOkay =
value === expectedKeys[i] ||
JSON.stringify(value) === JSON.stringify(expectedKeys[i]);
} catch (e) {
// first comparison returned false and stringify raised
console.warn(e);
}
if (!isOkay) {
console.warn(expectedKeys[i], value);
throw new Error(`Unexpected hiwire entry at index ${i}`);
}
}
for (let i = expectedKeys.length; ; i++) {
let value;
try {
value = Module.__hiwire_get(i);
} catch (e) {
break;
}
if (!["object", "function"].includes(typeof value)) {
throw new Error(
`Unexpected object of type ${typeof value} at index ${i}`,
);
}
if (value === null) {
hiwireKeys.push(value);
continue;
}
const accessorList = value[getAccessorList];
if (!accessorList) {
throw new Error(`Can't serialize object at index ${i}`);
}
hiwireKeys.push(accessorList);
}
const immortalKeys = [];
const shouldBeAMap = Module.__hiwire_immortal_get(MAP_INDEX);
if (Object.prototype.toString.call(shouldBeAMap) !== "[object Map]") {
throw new Error(`Internal error: expected a map at index ${MAP_INDEX}`);
}
for (let i = MAP_INDEX + 1; ; i++) {
let v;
try {
v = Module.__hiwire_immortal_get(i);
} catch (e) {
break;
}
if (typeof v !== "string") {
throw new Error("Expected a string");
}
immortalKeys.push(v);
}
const snapshotConfig: SnapshotConfig = {
hiwireKeys,
immortalKeys,
};
const snapshotConfigString = JSON.stringify(snapshotConfig);
let snapshotOffset = HEADER_SIZE + 2 * snapshotConfigString.length;
// align to 8 bytes
snapshotOffset = Math.ceil(snapshotOffset / 8) * 8;
const snapshot = new Uint8Array(snapshotOffset + Module.HEAP8.length);
const encoder = new TextEncoder();
const { written: jsonLength } = encoder.encodeInto(
snapshotConfigString,
snapshot.subarray(HEADER_SIZE),
);
const uint32View = new Uint32Array(snapshot.buffer);
uint32View[0] = SNAPSHOT_MAGIC;
uint32View[1] = SNAPSHOT_BUILD_ID;
uint32View[2] = snapshotOffset;
uint32View[3] = jsonLength!;
snapshot.subarray(snapshotOffset).set(Module.HEAP8);
return snapshot;
};
API.restoreSnapshot = function (snapshot: Uint8Array): SnapshotConfig {
const uint32View = new Uint32Array(
snapshot.buffer,
snapshot.byteOffset,
snapshot.byteLength / 4,
);
if (uint32View[0] !== SNAPSHOT_MAGIC) {
throw new Error("Snapshot has invalid magic number");
}
if (uint32View[1] !== SNAPSHOT_BUILD_ID) {
throw new Error("Snapshot has invalid BUILD_ID");
}
const snpOffset = uint32View[2];
const jsonSize = uint32View[3];
const jsonBuf = snapshot.subarray(HEADER_SIZE, HEADER_SIZE + jsonSize);
snapshot = snapshot.subarray(snpOffset);
const jsonStr = new TextDecoder().decode(jsonBuf);
const snapshotConfig: SnapshotConfig = JSON.parse(jsonStr);
// @ts-ignore
Module.HEAP8.set(snapshot);
return snapshotConfig;
};
/**
* Set up some of the JavaScript state that is normally set up by C
* initialization code. TODO: adjust C code to simplify.
*
* This is divided up into two parts: syncUpSnapshotLoad1 has to happen at the
* beginning of finalizeBootstrap before the public API is setup,
* syncUpSnapshotLoad2 happens near the end so that API.public_api exists.
*
* This code is quite sensitive to the details of our setup, so it might break
* if we move stuff around far away in the code base. Ideally over time we can
* structure the code to make it less brittle.
*/
export function syncUpSnapshotLoad1() {
// hiwire init puts a null at the beginning of both the mortal and immortal tables.
Module.__hiwire_set(0, null);
Module.__hiwire_immortal_add(null);
// Usually importing _pyodide_core would trigger jslib_init but we need to manually call it.
Module._jslib_init();
// Puts deduplication map into the immortal table.
// TODO: Add support for snapshots to hiwire and move this to a hiwire_snapshot_init function?
let mapIndex = Module.__hiwire_immortal_add(new Map());
// We expect everything after this in the immortal table to be interned strings.
// We need to know where to start looking for the strings so that we serialized correctly.
if (mapIndex !== MAP_INDEX) {
throw new Error(
`Internal error: Expected mapIndex to be ${MAP_INDEX}, got ${mapIndex}`,
);
}
// Set API._pyodide to a proxy of the _pyodide module.
// Normally called by import _pyodide.
Module._init_pyodide_proxy();
}
function tableSet(idx: number, val: any): void {
if (Module.__hiwire_set(idx, val) < 0) {
throw new Error("table set failed");
}
}
/**
* Fill in the JsRef table.
*/
export function syncUpSnapshotLoad2(
jsglobals: any,
snapshotConfig: SnapshotConfig,
) {
const expectedKeys = getExpectedKeys();
expectedKeys.forEach((v, idx) => tableSet(idx, v));
snapshotConfig.hiwireKeys.forEach((e, idx) => {
const x = e?.reduce((x, y) => x[y], jsglobals) || null;
// @ts-ignore
tableSet(expectedKeys.length + idx, x);
});
snapshotConfig.immortalKeys.forEach((v) => Module.__hiwire_immortal_add(v));
}

View File

@ -13,6 +13,6 @@ describe("scheduleCallback", () => {
const start = Date.now();
scheduleCallback(() => {
chai.assert.isAtLeast(Date.now() - start, 10);
}, 10);
}, 11);
});
});

View File

@ -8,6 +8,7 @@ import {
type InternalPackageData,
type PackageLoadMetadata,
} from "./load-package";
import { SnapshotConfig } from "./snapshot";
export type TypedArray =
| Int8Array
@ -424,6 +425,10 @@ export interface API {
sys: PyProxy;
os: PyProxy;
finalizeBootstrap: (fromSnapshot?: boolean) => PyodideInterface;
restoreSnapshot(snapshot: Uint8Array): SnapshotConfig;
makeSnapshot(): Uint8Array;
saveSnapshot(): Uint8Array;
finalizeBootstrap: (fromSnapshot?: SnapshotConfig) => PyodideInterface;
syncUpSnapshotLoad3(conf: SnapshotConfig): void;
version: string;
}

View File

@ -1,5 +1,8 @@
import { loadPyodide } from "./pyodide.mjs";
import { writeFileSync } from "fs";
import { fileURLToPath } from "url";
import { dirname } from "path";
const __dirname = dirname(fileURLToPath(import.meta.url));
const py = await loadPyodide({ _makeSnapshot: true });
writeFileSync("snapshot.bin", py._snapshot);
writeFileSync(__dirname + "/snapshot.bin", py.makeMemorySnapshot());

133
src/tests/test_snapshots.py Normal file
View File

@ -0,0 +1,133 @@
import pytest
def test_make_snapshot_requires_arg(selenium):
match = "Can only use pyodide.makeMemorySnapshot if the _makeSnapshot option is passed to loadPyodide"
with pytest.raises(selenium.JavascriptException, match=match):
selenium.run_js(
"""
pyodide.makeMemorySnapshot();
"""
)
def test_snapshot_bad_magic(selenium_standalone_noload):
selenium = selenium_standalone_noload
match = "Snapshot has invalid magic number"
with pytest.raises(selenium.JavascriptException, match=match):
selenium.run_js(
"""
const pyodide = await loadPyodide({_loadSnapshot: new Uint8Array(20 * (1<<20))});
"""
)
def test_snapshot_simple(selenium_standalone_noload):
selenium = selenium_standalone_noload
selenium.run_js(
"""
const py1 = await loadPyodide({_makeSnapshot: true});
py1.runPython(`
from js import Headers, URL
canParse = URL.canParse
`);
const snapshot = py1.makeMemorySnapshot();
const py2 = await loadPyodide({_loadSnapshot: snapshot});
assert(() => py2.globals.get("Headers") === Headers);
assert(() => py2.globals.get("URL") === URL);
assert(() => py2.globals.get("canParse") === URL.canParse);
"""
)
def test_snapshot_cannot_serialize(selenium_standalone_noload):
selenium = selenium_standalone_noload
match = "Can't serialize object at index"
with pytest.raises(selenium.JavascriptException, match=match):
selenium.run_js(
"""
const py1 = await loadPyodide({_makeSnapshot: true});
py1.runPython(`
from js import Headers, URL
a = Headers.new()
`);
py1.makeMemorySnapshot();
"""
)
def test_snapshot_deleted_proxy(selenium_standalone_noload):
"""In previous test, we fail to make the snapshot because we have a proxy of
a Headers which we don't know how to serialize.
In this test, we delete the headers proxy and should be able to successfully
create the snapshot.
"""
selenium = selenium_standalone_noload
selenium.run_js(
"""
const py1 = await loadPyodide({_makeSnapshot: true});
py1.runPython(`
from js import Headers, URL
from pyodide.code import run_js
assert run_js("1+1") == 2
assert run_js("(x) => x.get('a')")({'a': 7}) == 7
a = Headers.new()
del a # delete non-serializable JsProxy
`);
const snapshot = py1.makeMemorySnapshot();
const py2 = await loadPyodide({_loadSnapshot: snapshot});
py2.runPython(`
assert run_js("1+1") == 2
assert run_js("(x) => x.get('a')")({'a': 7}) == 7
a = Headers.new()
`);
"""
)
def test_snapshot_stacked(selenium_standalone_noload):
selenium = selenium_standalone_noload
selenium.run_js(
"""
const py1 = await loadPyodide({_makeSnapshot: true});
py1.runPython(`
from js import Headers
from pyodide.code import run_js
assert run_js("1+1") == 2
assert run_js("(x) => x.get('a')")({'a': 7}) == 7
a = Headers.new()
del a
`);
const snapshot = py1.makeMemorySnapshot();
const py2 = await loadPyodide({_loadSnapshot: snapshot, _makeSnapshot: true});
py2.runPython(`
assert run_js("1+1") == 2
assert run_js("(x) => x.get('a')")({'a': 7}) == 7
from js import URL
t = URL.new("http://a.com/z?t=2").searchParams["t"]
assert t == "2"
a = Headers.new()
del a
`);
const snapshot2 = py2.makeMemorySnapshot();
const py3 = await loadPyodide({_loadSnapshot: snapshot2, _makeSnapshot: true});
py3.runPython(`
assert run_js("1+1") == 2
assert run_js("(x) => x.get('a')")({'a': 7}) == 7
t = URL.new("http://a.com/z?t=2").searchParams["t"]
assert t == "2"
a = Headers.new()
`);
"""
)