244 lines
7.9 KiB
Python
244 lines
7.9 KiB
Python
import json
|
|
import os
|
|
import tarfile
|
|
import uuid
|
|
import zipfile
|
|
from pathlib import Path
|
|
|
|
import lightning as L
|
|
from lightning.app.storage import Drive
|
|
|
|
|
|
class FileServer(L.LightningWork):
|
|
def __init__(
|
|
self,
|
|
drive: Drive,
|
|
base_dir: str = "file_server",
|
|
chunk_size=10240,
|
|
**kwargs
|
|
):
|
|
"""This component uploads, downloads files to your application.
|
|
|
|
Arguments:
|
|
drive: The drive can share data inside your application.
|
|
base_dir: The local directory where the data will be stored.
|
|
chunk_size: The quantity of bytes to download/upload at once.
|
|
"""
|
|
super().__init__(
|
|
cloud_build_config=L.BuildConfig(["flask, flask-cors"]),
|
|
parallel=True,
|
|
**kwargs,
|
|
)
|
|
# 1: Attach the arguments to the state.
|
|
self.drive = drive
|
|
self.base_dir = base_dir
|
|
self.chunk_size = chunk_size
|
|
|
|
# 2: Create a folder to store the data.
|
|
os.makedirs(self.base_dir, exist_ok=True)
|
|
|
|
# 3: Keep a reference to the uploaded filenames.
|
|
self.uploaded_files = dict()
|
|
|
|
def get_filepath(self, path: str) -> str:
|
|
"""Returns file path stored on the file server."""
|
|
return os.path.join(self.base_dir, path)
|
|
|
|
def get_random_filename(self) -> str:
|
|
"""Returns a random hash for the file name."""
|
|
return uuid.uuid4().hex
|
|
|
|
def upload_file(self, file):
|
|
"""Upload a file while tracking its progress."""
|
|
# 1: Track metadata about the file
|
|
filename = file.filename
|
|
uploaded_file = self.get_random_filename()
|
|
meta_file = uploaded_file + ".meta"
|
|
self.uploaded_files[filename] = {
|
|
"progress": (0, None), "done": False
|
|
}
|
|
|
|
# 2: Create a stream and write bytes of
|
|
# the file to the disk under `uploaded_file` path.
|
|
with open(self.get_filepath(uploaded_file), "wb") as out_file:
|
|
content = file.read(self.chunk_size)
|
|
while content:
|
|
# 2.1 Write the file bytes
|
|
size = out_file.write(content)
|
|
|
|
# 2.2 Update the progress metadata
|
|
self.uploaded_files[filename]["progress"] = (
|
|
self.uploaded_files[filename]["progress"][0] + size,
|
|
None,
|
|
)
|
|
# 4: Read next chunk of data
|
|
content = file.read(self.chunk_size)
|
|
|
|
# 3: Update metadata that the file has been uploaded.
|
|
full_size = self.uploaded_files[filename]["progress"][0]
|
|
self.drive.put(self.get_filepath(uploaded_file))
|
|
self.uploaded_files[filename] = {
|
|
"progress": (full_size, full_size),
|
|
"done": True,
|
|
"uploaded_file": uploaded_file,
|
|
}
|
|
|
|
# 4: Write down the metadata about the file to the disk
|
|
meta = {
|
|
"original_path": filename,
|
|
"display_name": os.path.splitext(filename)[0],
|
|
"size": full_size,
|
|
"drive_path": uploaded_file,
|
|
}
|
|
with open(self.get_filepath(meta_file), "wt") as f:
|
|
json.dump(meta, f)
|
|
|
|
# 5: Put the file to the drive.
|
|
# It means other components can access get or list them.
|
|
self.drive.put(self.get_filepath(meta_file))
|
|
return meta
|
|
|
|
def list_files(self, file_path: str):
|
|
# 1: Get the local file path of the file server.
|
|
file_path = self.get_filepath(file_path)
|
|
|
|
# 2: If the file exists in the drive, transfer it locally.
|
|
if not os.path.exists(file_path):
|
|
self.drive.get(file_path)
|
|
|
|
if os.path.isdir(file_path):
|
|
result = set()
|
|
for _, _, f in os.walk(file_path):
|
|
for file in f:
|
|
if not file.endswith(".meta"):
|
|
for filename, meta in self.uploaded_files.items():
|
|
if meta["uploaded_file"] == file:
|
|
result.add(filename)
|
|
return {"asset_names": [v for v in result]}
|
|
|
|
# 3: If the filepath is a tar or zip file, list their contents
|
|
if zipfile.is_zipfile(file_path):
|
|
with zipfile.ZipFile(file_path, "r") as zf:
|
|
result = zf.namelist()
|
|
elif tarfile.is_tarfile(file_path):
|
|
with tarfile.TarFile(file_path, "r") as tf:
|
|
result = tf.getnames()
|
|
else:
|
|
raise ValueError("Cannot open archive file!")
|
|
|
|
# 4: Returns the matching files.
|
|
return {"asset_names": result}
|
|
|
|
def run(self):
|
|
# 1: Imports flask requirements.
|
|
from flask import Flask, request
|
|
from flask_cors import CORS
|
|
|
|
# 2: Create a flask app
|
|
flask_app = Flask(__name__)
|
|
CORS(flask_app)
|
|
|
|
# 3: Define the upload file endpoint
|
|
@flask_app.post("/upload_file/")
|
|
def upload_file():
|
|
"""Upload a file directly as form data."""
|
|
f = request.files["file"]
|
|
return self.upload_file(f)
|
|
|
|
@flask_app.get("/")
|
|
def list_files():
|
|
return self.list_files(str(Path(self.base_dir).resolve()))
|
|
|
|
# 5: Start the flask app while providing the `host` and `port`.
|
|
flask_app.run(host=self.host, port=self.port, load_dotenv=False)
|
|
|
|
def alive(self):
|
|
"""Hack: Returns whether the server is alive."""
|
|
return self.url != ""
|
|
|
|
|
|
import requests # noqa: E402
|
|
|
|
from lightning import LightningWork # noqa: E402
|
|
|
|
|
|
class TestFileServer(LightningWork):
|
|
|
|
def __init__(self, drive: Drive):
|
|
super().__init__(cache_calls=True)
|
|
self.drive = drive
|
|
|
|
def run(self, file_server_url: str, first=True):
|
|
if first:
|
|
with open("test.txt", "w") as f:
|
|
f.write("Some text.")
|
|
|
|
response = requests.post(
|
|
file_server_url + "/upload_file/",
|
|
files={'file': open("test.txt", 'rb')}
|
|
)
|
|
assert response.status_code == 200
|
|
else:
|
|
response = requests.get(file_server_url)
|
|
assert response.status_code == 200
|
|
assert response.json() == {"asset_names": ["test.txt"]}
|
|
|
|
|
|
from lightning import LightningApp, LightningFlow # noqa: E402
|
|
|
|
|
|
class Flow(LightningFlow):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
# 1: Create a drive to share data between works
|
|
self.drive = Drive("lit://file_server")
|
|
# 2: Create the filer server
|
|
self.file_server = FileServer(self.drive)
|
|
# 3: Create the file ser
|
|
self.test_file_server = TestFileServer(self.drive)
|
|
|
|
def run(self):
|
|
# 1: Start the file server.
|
|
self.file_server.run()
|
|
|
|
# 2: Trigger the test file server work when ready.
|
|
if self.file_server.alive():
|
|
# 3 Execute the test file server work.
|
|
self.test_file_server.run(self.file_server.url)
|
|
self.test_file_server.run(self.file_server.url, first=False)
|
|
|
|
# 4 When both execution are successful, exit the app.
|
|
if self.test_file_server.num_successes == 2:
|
|
self._exit()
|
|
|
|
def configure_layout(self):
|
|
# Expose the file_server component
|
|
# in the UI using its `/` endpoint.
|
|
return {"name": "File Server", "content": self.file_server}
|
|
|
|
|
|
from lightning.app.runners import MultiProcessRuntime # noqa: E402
|
|
|
|
|
|
def test_file_server():
|
|
app = LightningApp(Flow())
|
|
MultiProcessRuntime(app).dispatch()
|
|
|
|
|
|
from lightning.app.testing.testing import run_app_in_cloud # noqa: E402
|
|
|
|
|
|
def test_file_server_in_cloud():
|
|
# You need to provide the directory containing the app file.
|
|
app_dir = "docs/source-app/examples/file_server"
|
|
with run_app_in_cloud(app_dir) as (admin_page, view_page, get_logs_fn):
|
|
"""# 1. `admin_page` and `view_page` are playwright Page Objects.
|
|
|
|
# Check out https://playwright.dev/python/ doc to learn more.
|
|
# You can click the UI and trigger actions.
|
|
|
|
# 2. By calling logs = get_logs_fn(),
|
|
# you get all the logs currently on the admin page.
|
|
"""
|