lightning/docs/source-app/examples/file_server/app.py

244 lines
7.9 KiB
Python

import json
import os
import tarfile
import uuid
import zipfile
from pathlib import Path
import lightning as L
from lightning.app.storage import Drive
class FileServer(L.LightningWork):
def __init__(
self,
drive: Drive,
base_dir: str = "file_server",
chunk_size=10240,
**kwargs
):
"""This component uploads, downloads files to your application.
Arguments:
drive: The drive can share data inside your application.
base_dir: The local directory where the data will be stored.
chunk_size: The quantity of bytes to download/upload at once.
"""
super().__init__(
cloud_build_config=L.BuildConfig(["flask, flask-cors"]),
parallel=True,
**kwargs,
)
# 1: Attach the arguments to the state.
self.drive = drive
self.base_dir = base_dir
self.chunk_size = chunk_size
# 2: Create a folder to store the data.
os.makedirs(self.base_dir, exist_ok=True)
# 3: Keep a reference to the uploaded filenames.
self.uploaded_files = dict()
def get_filepath(self, path: str) -> str:
"""Returns file path stored on the file server."""
return os.path.join(self.base_dir, path)
def get_random_filename(self) -> str:
"""Returns a random hash for the file name."""
return uuid.uuid4().hex
def upload_file(self, file):
"""Upload a file while tracking its progress."""
# 1: Track metadata about the file
filename = file.filename
uploaded_file = self.get_random_filename()
meta_file = uploaded_file + ".meta"
self.uploaded_files[filename] = {
"progress": (0, None), "done": False
}
# 2: Create a stream and write bytes of
# the file to the disk under `uploaded_file` path.
with open(self.get_filepath(uploaded_file), "wb") as out_file:
content = file.read(self.chunk_size)
while content:
# 2.1 Write the file bytes
size = out_file.write(content)
# 2.2 Update the progress metadata
self.uploaded_files[filename]["progress"] = (
self.uploaded_files[filename]["progress"][0] + size,
None,
)
# 4: Read next chunk of data
content = file.read(self.chunk_size)
# 3: Update metadata that the file has been uploaded.
full_size = self.uploaded_files[filename]["progress"][0]
self.drive.put(self.get_filepath(uploaded_file))
self.uploaded_files[filename] = {
"progress": (full_size, full_size),
"done": True,
"uploaded_file": uploaded_file,
}
# 4: Write down the metadata about the file to the disk
meta = {
"original_path": filename,
"display_name": os.path.splitext(filename)[0],
"size": full_size,
"drive_path": uploaded_file,
}
with open(self.get_filepath(meta_file), "wt") as f:
json.dump(meta, f)
# 5: Put the file to the drive.
# It means other components can access get or list them.
self.drive.put(self.get_filepath(meta_file))
return meta
def list_files(self, file_path: str):
# 1: Get the local file path of the file server.
file_path = self.get_filepath(file_path)
# 2: If the file exists in the drive, transfer it locally.
if not os.path.exists(file_path):
self.drive.get(file_path)
if os.path.isdir(file_path):
result = set()
for _, _, f in os.walk(file_path):
for file in f:
if not file.endswith(".meta"):
for filename, meta in self.uploaded_files.items():
if meta["uploaded_file"] == file:
result.add(filename)
return {"asset_names": [v for v in result]}
# 3: If the filepath is a tar or zip file, list their contents
if zipfile.is_zipfile(file_path):
with zipfile.ZipFile(file_path, "r") as zf:
result = zf.namelist()
elif tarfile.is_tarfile(file_path):
with tarfile.TarFile(file_path, "r") as tf:
result = tf.getnames()
else:
raise ValueError("Cannot open archive file!")
# 4: Returns the matching files.
return {"asset_names": result}
def run(self):
# 1: Imports flask requirements.
from flask import Flask, request
from flask_cors import CORS
# 2: Create a flask app
flask_app = Flask(__name__)
CORS(flask_app)
# 3: Define the upload file endpoint
@flask_app.post("/upload_file/")
def upload_file():
"""Upload a file directly as form data."""
f = request.files["file"]
return self.upload_file(f)
@flask_app.get("/")
def list_files():
return self.list_files(str(Path(self.base_dir).resolve()))
# 5: Start the flask app while providing the `host` and `port`.
flask_app.run(host=self.host, port=self.port, load_dotenv=False)
def alive(self):
"""Hack: Returns whether the server is alive."""
return self.url != ""
import requests # noqa: E402
from lightning import LightningWork # noqa: E402
class TestFileServer(LightningWork):
def __init__(self, drive: Drive):
super().__init__(cache_calls=True)
self.drive = drive
def run(self, file_server_url: str, first=True):
if first:
with open("test.txt", "w") as f:
f.write("Some text.")
response = requests.post(
file_server_url + "/upload_file/",
files={'file': open("test.txt", 'rb')}
)
assert response.status_code == 200
else:
response = requests.get(file_server_url)
assert response.status_code == 200
assert response.json() == {"asset_names": ["test.txt"]}
from lightning import LightningApp, LightningFlow # noqa: E402
class Flow(LightningFlow):
def __init__(self):
super().__init__()
# 1: Create a drive to share data between works
self.drive = Drive("lit://file_server")
# 2: Create the filer server
self.file_server = FileServer(self.drive)
# 3: Create the file ser
self.test_file_server = TestFileServer(self.drive)
def run(self):
# 1: Start the file server.
self.file_server.run()
# 2: Trigger the test file server work when ready.
if self.file_server.alive():
# 3 Execute the test file server work.
self.test_file_server.run(self.file_server.url)
self.test_file_server.run(self.file_server.url, first=False)
# 4 When both execution are successful, exit the app.
if self.test_file_server.num_successes == 2:
self._exit()
def configure_layout(self):
# Expose the file_server component
# in the UI using its `/` endpoint.
return {"name": "File Server", "content": self.file_server}
from lightning.app.runners import MultiProcessRuntime # noqa: E402
def test_file_server():
app = LightningApp(Flow())
MultiProcessRuntime(app).dispatch()
from lightning.app.testing.testing import run_app_in_cloud # noqa: E402
def test_file_server_in_cloud():
# You need to provide the directory containing the app file.
app_dir = "docs/source-app/examples/file_server"
with run_app_in_cloud(app_dir) as (admin_page, view_page, get_logs_fn):
"""# 1. `admin_page` and `view_page` are playwright Page Objects.
# Check out https://playwright.dev/python/ doc to learn more.
# You can click the UI and trigger actions.
# 2. By calling logs = get_logs_fn(),
# you get all the logs currently on the admin page.
"""