malbox_plugin_sdk/
context.rs1pub mod message;
8mod result_sink;
9
10pub use result_sink::ResultSink;
11
12use crate::error::{Result, SdkError};
13use message::{ResultFormat, ResultKind, TaskResultMessage};
14
15use malbox_plugin_transport::messages::events::Event;
16use malbox_plugin_transport::traits::TransportEmitter;
17use serde::Serialize;
18use std::collections::{HashMap, HashSet};
19use std::path::{Path, PathBuf};
20use std::sync::{Arc, Mutex};
21use tracing::{info, warn};
22
23#[cfg(feature = "guest")]
24use crate::stash::ResultStash;
25
26#[derive(Serialize)]
27struct ProgressPayload<'a> {
28 progress: f64,
29 message: &'a str,
30}
31
32pub type ResultSender = tokio::sync::mpsc::Sender<TaskResultMessage>;
38
39pub(super) struct ContextInner {
40 pub task_id: i32,
41 pub sample_path: PathBuf,
42 pub config: HashMap<String, String>,
43 pub emitter: Arc<dyn TransportEmitter + Send + Sync>,
44 pub result_tx: Mutex<Option<ResultSender>>,
45 #[cfg(feature = "guest")]
46 pub stash: Option<Arc<ResultStash>>,
47 pub claimed_paths: Mutex<HashSet<PathBuf>>,
48}
49
50fn lock_or_err<T>(mutex: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>> {
51 mutex
52 .lock()
53 .map_err(|_| SdkError::Channel("internal mutex poisoned".into()))
54}
55
56#[derive(Clone)]
62pub struct Context {
63 inner: Arc<ContextInner>,
64}
65
66impl Context {
67 pub(crate) fn new(
73 task_id: i32,
74 sample_path: PathBuf,
75 config: HashMap<String, String>,
76 emitter: Arc<dyn TransportEmitter + Send + Sync>,
77 result_tx: Option<ResultSender>,
78 #[cfg(feature = "guest")] stash: Option<Arc<ResultStash>>,
79 ) -> Self {
80 Self {
81 inner: Arc::new(ContextInner {
82 task_id,
83 sample_path,
84 config,
85 emitter,
86 result_tx: Mutex::new(result_tx),
87 #[cfg(feature = "guest")]
88 stash,
89 claimed_paths: Mutex::new(HashSet::new()),
90 }),
91 }
92 }
93
94 pub fn task(&self) -> TaskInfo<'_> {
96 TaskInfo { inner: &self.inner }
97 }
98
99 pub fn results(&self) -> ResultSink<'_> {
101 ResultSink { inner: &self.inner }
102 }
103
104 pub fn progress(&self, pct: f64, message: &str) -> Result<()> {
117 let clamped = pct.clamp(0.0, 1.0);
118 info!(kind = "progress", progress = clamped, %message, "Plugin progress");
119
120 let tx = {
121 let guard = lock_or_err(&self.inner.result_tx)?;
122 guard.clone()
123 };
124
125 if let Some(tx) = tx {
126 let progress_data = serde_json::to_vec(&ProgressPayload {
127 progress: clamped,
128 message,
129 })?;
130
131 let msg = TaskResultMessage {
132 task_id: self.inner.task_id,
133 result_name: String::new(),
134 data: progress_data,
135 format: ResultFormat::Json,
136 is_final: false,
137 kind: ResultKind::Progress,
138 stash_handle: String::new(),
139 stash_format: ResultFormat::Unspecified,
140 stash_size: 0,
141 };
142
143 tx.blocking_send(msg)
144 .map_err(|e| SdkError::Channel(format!("progress: {e}")))?;
145 }
146
147 Ok(())
148 }
149
150 pub fn emit_event(&self, event: Event) -> Result<()> {
155 self.inner.emitter.emit(event).map_err(SdkError::Transport)
156 }
157
158 pub fn warn(&self, message: &str) -> Result<()> {
164 warn!(kind = "warning", %message, "Plugin warning");
165 Ok(())
166 }
167
168 pub fn mark_collected(&self, path: impl AsRef<Path>) {
174 if let Ok(canonical) = std::fs::canonicalize(path.as_ref())
175 && let Ok(mut guard) = self.inner.claimed_paths.lock()
176 {
177 guard.insert(canonical);
178 }
179 }
180
181 #[allow(dead_code)]
185 pub(crate) fn close_result_channel(&self) {
186 if let Ok(mut guard) = self.inner.result_tx.lock() {
187 *guard = None;
188 }
189 }
190
191 #[cfg(feature = "guest")]
196 pub(crate) fn claimed_paths(&self) -> HashSet<PathBuf> {
197 self.inner
198 .claimed_paths
199 .lock()
200 .unwrap_or_else(|e| e.into_inner())
201 .clone()
202 }
203}
204
205pub struct TaskInfo<'a> {
208 inner: &'a ContextInner,
209}
210
211impl TaskInfo<'_> {
212 pub fn id(&self) -> i32 {
214 self.inner.task_id
215 }
216
217 pub fn sample_path(&self) -> &Path {
219 &self.inner.sample_path
220 }
221
222 pub fn sample_bytes(&self) -> Result<Vec<u8>> {
224 std::fs::read(&self.inner.sample_path).map_err(SdkError::Io)
225 }
226
227 pub fn config(&self) -> &HashMap<String, String> {
229 &self.inner.config
230 }
231
232 pub fn config_value(&self, key: &str) -> Option<&str> {
235 self.inner.config.get(key).map(|s| s.as_str())
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use std::sync::Arc;
243
244 fn noop_emitter() -> Arc<dyn TransportEmitter + Send + Sync> {
245 Arc::new(())
246 }
247
248 #[test]
249 fn context_is_clone() {
250 let ctx = Context::test_new(noop_emitter());
251 let ctx2 = ctx.clone();
252 assert_eq!(ctx.task().id(), ctx2.task().id());
253 }
254
255 #[test]
256 fn context_progress_clamps_values() {
257 let ctx = Context::test_new(noop_emitter());
258
259 assert!(ctx.progress(-0.5, "negative").is_ok());
260 assert!(ctx.progress(1.5, "over").is_ok());
261 assert!(ctx.progress(0.5, "normal").is_ok());
262 }
263
264 #[test]
265 fn context_warn_succeeds() {
266 let ctx = Context::test_new(noop_emitter());
267 assert!(ctx.warn("test warning").is_ok());
268 }
269
270 #[test]
271 fn context_emit_event_succeeds_with_noop() {
272 let ctx = Context::test_new(noop_emitter());
273 let result = ctx.emit_event(Event::PluginStarted { plugin_id: 1 });
274 assert!(result.is_ok());
275 }
276
277 #[test]
278 fn task_info_accessors_work() {
279 let mut config = HashMap::new();
280 config.insert("k".to_string(), "v".to_string());
281 let ctx = Context::new(
282 7,
283 PathBuf::from("/tmp/s.bin"),
284 config,
285 noop_emitter(),
286 None,
287 #[cfg(feature = "guest")]
288 None,
289 );
290 assert_eq!(ctx.task().id(), 7);
291 assert_eq!(ctx.task().sample_path(), Path::new("/tmp/s.bin"));
292 assert_eq!(ctx.task().config().get("k"), Some(&"v".to_string()));
293 assert_eq!(ctx.task().config_value("k"), Some("v"));
294 assert_eq!(ctx.task().config_value("missing"), None);
295 }
296}