Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ members = ["codegen", "examples", "performance_measurement", "performance_measur

[package]
name = "worktable"
version = "0.9.0-alpha7"
version = "0.9.0-alpha8"
edition = "2024"
authors = ["Handy-caT"]
license = "MIT"
Expand All @@ -12,14 +12,12 @@ description = "WorkTable is in-memory storage"

[features]
perf_measurements = ["dep:performance_measurement", "dep:performance_measurement_codegen"]
s3-support = ["dep:rust-s3", "dep:aws-creds", "dep:aws-region", "dep:walkdir", "worktable_codegen/s3-support"]
s3-support = ["dep:rusty-s3", "dep:url", "dep:reqwest", "dep:walkdir", "worktable_codegen/s3-support"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1.89"
aws-creds = { version = "0.39", optional = true, default-features = false, features = ["rustls-tls"] }
aws-region = { version = "0.28", optional = true }
convert_case = "0.6.0"
data_bucket = "=0.3.12"
# data_bucket = { git = "https://github.com/pathscale/DataBucket", branch = "page_cdc_correction", version = "0.2.7" }
Expand All @@ -40,10 +38,12 @@ performance_measurement_codegen = { path = "performance_measurement/codegen", ve
prettytable-rs = "^0.10"
psc-nanoid = { version = "3.1.1", features = ["rkyv", "packed"] }
rkyv = { version = "0.8.9", features = ["uuid-1"] }
rust-s3 = { version = "0.37", optional = true, default-features = false, features = ["tokio-rustls-tls"] }
reqwest = { version = "0.12", optional = true, default-features = false, features = ["rustls-tls-webpki-roots", "charset", "http2"] }
rusty-s3 = { version = "0.9.0", optional = true }
smart-default = "0.7.1"
tokio = { version = "1", features = ["full"] }
tracing = "0.1"
url = { version = "2", optional = true }
uuid = { version = "1.10.0", features = ["v4", "v7"] }
walkdir = { version = "2", optional = true }
worktable_codegen = { path = "codegen", version = "=0.9.0-alpha4" }
Expand Down
121 changes: 74 additions & 47 deletions src/features/s3_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use std::fmt::Debug;
use std::hash::Hash;
use std::marker::PhantomData;
use std::path::Path;
use std::time::Duration;

use awscreds::Credentials;
use awsregion::Region;
use s3::Bucket;
use reqwest::Client;
use rusty_s3::{Bucket, Credentials, S3Action, UrlStyle};
use url::Url;
use walkdir::WalkDir;

use crate::persistence::operation::{BatchOperation, Operation};
Expand Down Expand Up @@ -64,7 +65,9 @@ where
PrimaryKeyGenState,
>,
config: S3DiskConfig,
bucket: Box<Bucket>,
bucket: Bucket,
credentials: Credentials,
client: Client,
phantom: PhantomData<(PrimaryKey, SecondaryIndexEvents, PrimaryKeyGenState, AvailableIndexes)>,
}

Expand Down Expand Up @@ -97,29 +100,17 @@ where
PrimaryKeyGenState: Clone + Debug + Send + Sync,
AvailableIndexes: Clone + Copy + Debug + Eq + Hash + Send + Sync,
{
fn create_bucket(config: &S3Config) -> eyre::Result<Box<Bucket>> {
let credentials = Credentials::new(
Some(&config.access_key),
Some(&config.secret_key),
None,
None,
None,
)?;

let region = if let Some(region) = &config.region {
Region::Custom {
region: region.clone(),
endpoint: config.endpoint.clone(),
}
} else {
Region::Custom {
region: "auto".to_string(),
endpoint: config.endpoint.clone(),
}
};
fn create_bucket(config: &S3Config) -> eyre::Result<(Bucket, Credentials, Client)> {
let credentials = Credentials::new(&config.access_key, &config.secret_key);
let endpoint: Url = config.endpoint.parse()?;
let region = config.region.clone().unwrap_or_else(|| "auto".to_string());
let bucket = Bucket::new(endpoint, UrlStyle::Path, config.bucket_name.clone(), region)?;

let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()?;

let bucket = Bucket::new(&config.bucket_name, region, credentials)?.with_path_style();
Ok(bucket)
Ok((bucket, credentials, client))
}

async fn sync_to_s3(&self) -> eyre::Result<()> {
Expand Down Expand Up @@ -147,7 +138,16 @@ where
tracing::debug!(local_path = %local_path.display(), s3_key = %s3_key, "Uploading file to S3");

let content = tokio::fs::read(local_path).await?;
self.bucket.put_object(&s3_key, &content).await?;

let action = self.bucket.put_object(Some(&self.credentials), &s3_key);
let url = action.sign(Duration::from_secs(3600));

self.client
.put(url)
.body(content)
.send()
.await?
.error_for_status()?;
}

tracing::debug!("S3 sync complete");
Expand All @@ -164,7 +164,14 @@ where
}
}

async fn sync_from_s3(bucket: &Bucket, config: &S3DiskConfig) -> eyre::Result<()> {
async fn sync_from_s3(
bucket: &Bucket,
credentials: &Credentials,
client: &Client,
config: &S3DiskConfig,
) -> eyre::Result<()> {
use rusty_s3::actions::ListObjectsV2;

let table_path = config.disk.table_path();
let table_path = Path::new(table_path);
let prefix = config.s3.prefix.as_deref().unwrap_or("");
Expand All @@ -175,35 +182,53 @@ where
.ok_or_else(|| eyre::eyre!("Invalid table path"))?;

let s3_path = Self::full_s3_path(prefix, "", table_name);
let results = bucket.list(s3_path.clone(), Some("/".to_string())).await?;

if results.is_empty() {
let mut action = bucket.list_objects_v2(Some(credentials));
action.with_prefix(&s3_path);
action.with_delimiter("/");
let url = action.sign(Duration::from_secs(3600));

let response = client
.get(url)
.send()
.await?
.error_for_status()?;

let text = response.text().await?;
let parsed = ListObjectsV2::parse_response(&text)?;

if parsed.contents.is_empty() {
tracing::debug!(s3_prefix = %s3_path, "No objects found in S3");
return Ok(());
}

tokio::fs::create_dir_all(table_path).await?;

for result in results {
for obj in result.contents {
let s3_key = &obj.key;
for obj in parsed.contents {
let s3_key = &obj.key;

let filename = s3_key.rsplit('/').next().unwrap_or(s3_key);
let filename = s3_key.rsplit('/').next().unwrap_or(s3_key);

if !filename.ends_with(WT_DATA_EXTENSION) && !filename.ends_with(WT_INDEX_EXTENSION)
{
tracing::debug!(s3_key = %s3_key, "Skipping non-table file");
continue;
}
if !filename.ends_with(WT_DATA_EXTENSION) && !filename.ends_with(WT_INDEX_EXTENSION) {
tracing::debug!(s3_key = %s3_key, "Skipping non-table file");
continue;
}

let local_path = table_path.join(filename);
let local_path = table_path.join(filename);

tracing::debug!(s3_key = %s3_key, local_path = %local_path.display(), "Downloading file from S3");
tracing::debug!(s3_key = %s3_key, local_path = %local_path.display(), "Downloading file from S3");

let content = bucket.get_object(s3_key).await?;
let action = bucket.get_object(Some(credentials), s3_key);
let url = action.sign(Duration::from_secs(3600));

tokio::fs::write(&local_path, content.bytes()).await?;
}
let response = client
.get(url)
.send()
.await?
.error_for_status()?;

let content = response.bytes().await?;
tokio::fs::write(&local_path, content).await?;
}

tracing::info!(table_name = %table_name, "S3 download sync complete");
Expand Down Expand Up @@ -246,9 +271,9 @@ where
where
Self: Sized,
{
let bucket = Self::create_bucket(&config.s3)?;
let (bucket, credentials, client) = Self::create_bucket(&config.s3)?;

if let Err(e) = Self::sync_from_s3(&bucket, &config).await {
if let Err(e) = Self::sync_from_s3(&bucket, &credentials, &client, &config).await {
tracing::warn!(error = %e, "Failed to sync from S3, continuing with local files");
}

Expand All @@ -258,6 +283,8 @@ where
inner,
config,
bucket,
credentials,
client,
phantom: PhantomData,
})
}
Expand Down Expand Up @@ -288,4 +315,4 @@ where
fn config(&self) -> &Self::Config {
&self.config
}
}
}
Loading