Compare commits
26 Commits
beccc7b50b
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
0f0f2062a3
|
|||
|
b38e0c825b
|
|||
|
b370d50e8b
|
|||
|
1b84fd3128
|
|||
|
97e96739ac
|
|||
|
8a190496ca
|
|||
|
e5eb48057f
|
|||
|
b0bcb3460b
|
|||
|
9eca21de69
|
|||
|
43c1b09142
|
|||
|
e89ca23a42
|
|||
|
35cda247f2
|
|||
|
bd6c0bf211
|
|||
|
8568b147bb
|
|||
|
9c24bdca1f
|
|||
|
de106c61fe
|
|||
|
26b00f8566
|
|||
|
c1c8a3d6aa
|
|||
|
ea7e4f980b
|
|||
|
d059dc0032
|
|||
|
4be223d1b7
|
|||
|
9dcbd54387
|
|||
|
40768f0c3a
|
|||
|
994cf0935e
|
|||
|
a06ebf722f
|
|||
|
37bc07f667
|
@@ -0,0 +1,3 @@
|
||||
server/.env
|
||||
listener/target
|
||||
listener/config.toml
|
||||
+2
-1
@@ -1,2 +1,3 @@
|
||||
.devcontainer/
|
||||
scratch/
|
||||
scratch/
|
||||
.env
|
||||
@@ -0,0 +1,36 @@
|
||||
################################################################################
|
||||
# App builder stage
|
||||
FROM rust:1.96.0-trixie@sha256:fb328f0f58becb23ba1719940a2c94ece8b0b48afa837d05b79ef64bc1e18f6e AS app-builder
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y --no-install-recommends libclang-dev cmake protobuf-compiler
|
||||
|
||||
COPY listener/ /opt/listener
|
||||
COPY server/proto /opt/server/proto
|
||||
WORKDIR /opt/listener
|
||||
|
||||
RUN cargo build --release
|
||||
|
||||
################################################################################
|
||||
# FFmpeg builder stage
|
||||
FROM debian:13-slim@sha256:b6e2a152f22a40ff69d92cb397223c906017e1391a73c952b588e51af8883bf8 AS ffmpeg-builder
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends ffmpeg
|
||||
RUN mkdir -p /rootfs/bin && \
|
||||
cp /usr/bin/ffmpeg /usr/bin/ffprobe /rootfs/bin/ && \
|
||||
{ ldd /usr/bin/ffmpeg; ldd /yt-dlp; } 2>/dev/null | \
|
||||
grep -o '/[^ ]*' | sort -u | \
|
||||
xargs -I '{}' cp --parents '{}' /rootfs && \
|
||||
LIBDIR=$(dirname "$(find /rootfs -name 'libc.so.6' | head -1)") && \
|
||||
for stub in libutil.so.1 libdl.so.2 libpthread.so.0 librt.so.1; do \
|
||||
[ -f "${LIBDIR}/${stub}" ] || ln -sf libc.so.6 "${LIBDIR}/${stub}"; \
|
||||
done
|
||||
|
||||
################################################################################
|
||||
# Final squashed image
|
||||
FROM gcr.io/distroless/cc-debian13 AS final
|
||||
|
||||
# Copy listener library
|
||||
COPY --from=ffmpeg-builder /rootfs/ /
|
||||
COPY --from=app-builder /opt/listener/target/release/scannerbot-listener /bin/
|
||||
ENTRYPOINT ["/bin/scannerbot-listener"]
|
||||
@@ -0,0 +1,16 @@
|
||||
################################################################################
|
||||
# App builder stage
|
||||
FROM golang:1.26.3-trixie@sha256:0f6b034c99663ea8957e7dae99124e37374cbe7fcb5b5646f19b185f8f976279 AS app-builder
|
||||
|
||||
COPY server/ /opt/server
|
||||
WORKDIR /opt/server
|
||||
|
||||
RUN go get && CGO_ENABLED=0 go build -o out/scannerbot-server
|
||||
|
||||
################################################################################
|
||||
# Final squashed image
|
||||
FROM scratch AS final
|
||||
|
||||
# Copy server binary
|
||||
COPY --from=app-builder /opt/server/out/scannerbot-server /bin/
|
||||
ENTRYPOINT ["/bin/scannerbot-server"]
|
||||
@@ -0,0 +1,7 @@
|
||||
Copyright 2026 William T. Peebles
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
+7
-1
@@ -27,4 +27,10 @@ rustc-ice-*.txt
|
||||
models/
|
||||
|
||||
# Transcription text file
|
||||
transcription.txt
|
||||
transcription.txt
|
||||
|
||||
# Clips dir
|
||||
audio_clips/
|
||||
|
||||
# Config file
|
||||
config.toml
|
||||
Generated
+1102
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,19 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.8"
|
||||
bytes = "1"
|
||||
chrono = "0.4.44"
|
||||
ctrlc = "3.5.2"
|
||||
hound = "3.5.1"
|
||||
prost = "0.14.3"
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
toml = "1.1.2"
|
||||
tonic = "0.14.6"
|
||||
tonic-prost = "0.14.6"
|
||||
whisper-rs = "0.16.0"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-prost-build = "0.14.6"
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tonic_prost_build::configure()
|
||||
.build_server(false)
|
||||
.compile_protos(&["../server/proto/message.proto"], &["../server/proto"])?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
# Path to the Whisper GGML model file
|
||||
model = "models/ggml-base.en.bin"
|
||||
|
||||
# Audio source: a file path, or "pulse:<source>" for live PulseAudio capture
|
||||
# Examples:
|
||||
# audio = "recording.wav"
|
||||
# audio = "pulse:default"
|
||||
audio = "pulse:default"
|
||||
|
||||
# gRPC server address
|
||||
server_addr = "http://localhost:3001"
|
||||
|
||||
# Channel UUID to post transcriptions to (required by the server)
|
||||
channel_id = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
# Directory for temporary WAV files (deleted after upload); defaults to /tmp
|
||||
# temp_dir = "/tmp"
|
||||
|
||||
# File-mode only: how many seconds of audio to transcribe per chunk
|
||||
chunk_secs = 30
|
||||
|
||||
# VAD settings (pulse mode only)
|
||||
silence_threshold = 0.02 # RMS energy cutoff; raise for noisy environments
|
||||
pre_buffer_secs = 1.0 # seconds of audio kept before speech onset
|
||||
hangover_secs = 2.0 # trailing silence before a segment is closed
|
||||
max_segment_secs = 120.0 # hard cap per segment
|
||||
|
||||
# HTTP audio stream port; GET /audio streams a live WAV
|
||||
# http_port = 8080
|
||||
+90
-10
@@ -1,5 +1,19 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::io::Read;
|
||||
use std::process::{Child, ChildStdout, Command, Stdio};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use bytes::Bytes;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
const FRAME_SAMPLES: usize = 1600; // 100 ms at 16 kHz
|
||||
|
||||
pub struct VadConfig {
|
||||
pub silence_threshold: f32,
|
||||
pub pre_buffer_frames: usize,
|
||||
pub hangover_frames: usize,
|
||||
pub max_segment_samples: usize,
|
||||
}
|
||||
|
||||
pub fn decode(input: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
|
||||
let output = Command::new("ffmpeg")
|
||||
@@ -18,10 +32,11 @@ pub fn decode(input: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
|
||||
pub struct LiveStream {
|
||||
child: Child,
|
||||
stdout: ChildStdout,
|
||||
broadcast_tx: Option<broadcast::Sender<Bytes>>,
|
||||
}
|
||||
|
||||
impl LiveStream {
|
||||
pub fn open(source: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
pub fn open(source: &str, broadcast_tx: Option<broadcast::Sender<Bytes>>) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut child = Command::new("ffmpeg")
|
||||
.args(["-f", "pulse", "-i", source, "-ar", "16000", "-ac", "1", "-f", "f32le", "pipe:1"])
|
||||
.env("PULSE_PROP_application.name", "scannerbot-listener")
|
||||
@@ -30,20 +45,19 @@ impl LiveStream {
|
||||
.spawn()?;
|
||||
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
Ok(LiveStream { child, stdout })
|
||||
Ok(LiveStream { child, stdout, broadcast_tx })
|
||||
}
|
||||
|
||||
// Reads exactly `secs` seconds of audio. Returns None when the stream ends.
|
||||
pub fn next_chunk(&mut self, secs: u32) -> Result<Option<Vec<f32>>, Box<dyn std::error::Error>> {
|
||||
let num_bytes = secs as usize * 16000 * 4;
|
||||
let mut buf = vec![0u8; num_bytes];
|
||||
|
||||
fn next_frame(&mut self) -> Result<Option<Vec<f32>>, Box<dyn std::error::Error>> {
|
||||
let mut buf = vec![0u8; FRAME_SAMPLES * 4];
|
||||
match self.stdout.read_exact(&mut buf) {
|
||||
Ok(()) => {
|
||||
let samples = buf.chunks_exact(4)
|
||||
if let Some(tx) = &self.broadcast_tx {
|
||||
let _ = tx.send(Bytes::copy_from_slice(&buf));
|
||||
}
|
||||
Ok(Some(buf.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
|
||||
.collect();
|
||||
Ok(Some(samples))
|
||||
.collect()))
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
|
||||
Err(e) => Err(e.into()),
|
||||
@@ -56,3 +70,69 @@ impl Drop for LiveStream {
|
||||
let _ = self.child.kill();
|
||||
}
|
||||
}
|
||||
|
||||
fn rms(samples: &[f32]) -> f32 {
|
||||
(samples.iter().map(|&s| s * s).sum::<f32>() / samples.len() as f32).sqrt()
|
||||
}
|
||||
|
||||
pub struct VadStream {
|
||||
inner: LiveStream,
|
||||
pre_buffer: VecDeque<Vec<f32>>,
|
||||
running: Arc<AtomicBool>,
|
||||
cfg: VadConfig,
|
||||
}
|
||||
|
||||
impl VadStream {
|
||||
pub fn open(source: &str, running: Arc<AtomicBool>, cfg: VadConfig, broadcast_tx: Option<broadcast::Sender<Bytes>>) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
Ok(VadStream {
|
||||
inner: LiveStream::open(source, broadcast_tx)?,
|
||||
pre_buffer: VecDeque::with_capacity(cfg.pre_buffer_frames + 1),
|
||||
running,
|
||||
cfg,
|
||||
})
|
||||
}
|
||||
|
||||
/// Blocks until a complete speech segment is captured, then returns it.
|
||||
/// Returns None when the underlying stream ends or running is set to false.
|
||||
pub fn next_segment(&mut self) -> Result<Option<Vec<f32>>, Box<dyn std::error::Error>> {
|
||||
let mut speech: Vec<f32> = Vec::new();
|
||||
let mut speech_active = false;
|
||||
let mut hangover = 0usize;
|
||||
|
||||
loop {
|
||||
if !self.running.load(Ordering::SeqCst) {
|
||||
return if speech.is_empty() { Ok(None) } else { Ok(Some(speech)) };
|
||||
}
|
||||
|
||||
let frame = match self.inner.next_frame()? {
|
||||
Some(f) => f,
|
||||
None => return if speech.is_empty() { Ok(None) } else { Ok(Some(speech)) },
|
||||
};
|
||||
|
||||
let energy = rms(&frame);
|
||||
|
||||
if energy > self.cfg.silence_threshold {
|
||||
if !speech_active {
|
||||
speech_active = true;
|
||||
eprintln!(" [recording]");
|
||||
for pre in self.pre_buffer.drain(..) {
|
||||
speech.extend(pre);
|
||||
}
|
||||
}
|
||||
speech.extend(&frame);
|
||||
hangover = 0;
|
||||
} else if speech_active {
|
||||
speech.extend(&frame);
|
||||
hangover += 1;
|
||||
if hangover >= self.cfg.hangover_frames || speech.len() >= self.cfg.max_segment_samples {
|
||||
return Ok(Some(speech));
|
||||
}
|
||||
} else {
|
||||
if self.pre_buffer.len() >= self.cfg.pre_buffer_frames {
|
||||
self.pre_buffer.pop_front();
|
||||
}
|
||||
self.pre_buffer.push_back(frame);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
use axum::{Router, extract::State, response::Response, routing::get};
|
||||
use axum::body::Body;
|
||||
use bytes::Bytes;
|
||||
use std::convert::Infallible;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
audio_tx: broadcast::Sender<Bytes>,
|
||||
}
|
||||
|
||||
fn wav_header() -> Vec<u8> {
|
||||
let mut h = Vec::with_capacity(44);
|
||||
h.extend_from_slice(b"RIFF");
|
||||
h.extend_from_slice(&0xFFFF_FFFFu32.to_le_bytes()); // unknown size
|
||||
h.extend_from_slice(b"WAVE");
|
||||
h.extend_from_slice(b"fmt ");
|
||||
h.extend_from_slice(&16u32.to_le_bytes());
|
||||
h.extend_from_slice(&3u16.to_le_bytes()); // IEEE_FLOAT
|
||||
h.extend_from_slice(&1u16.to_le_bytes()); // mono
|
||||
h.extend_from_slice(&16000u32.to_le_bytes()); // sample rate
|
||||
h.extend_from_slice(&64000u32.to_le_bytes()); // byte rate = 16000 * 4
|
||||
h.extend_from_slice(&4u16.to_le_bytes()); // block align
|
||||
h.extend_from_slice(&32u16.to_le_bytes()); // bits per sample
|
||||
h.extend_from_slice(b"data");
|
||||
h.extend_from_slice(&0xFFFF_FFFFu32.to_le_bytes()); // unknown size
|
||||
h
|
||||
}
|
||||
|
||||
async fn audio_stream(State(state): State<AppState>) -> Response {
|
||||
let rx = state.audio_tx.subscribe();
|
||||
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Result<Bytes, Infallible>>(64);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if body_tx.send(Ok(Bytes::from(wav_header()))).await.is_err() {
|
||||
return;
|
||||
}
|
||||
let mut stream = BroadcastStream::new(rx);
|
||||
loop {
|
||||
match stream.next().await {
|
||||
Some(Ok(frame)) => {
|
||||
if body_tx.send(Ok(frame)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Err(_)) => {} // lagged receiver, skip
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Response::builder()
|
||||
.header("Content-Type", "audio/wav")
|
||||
.header("Cache-Control", "no-cache")
|
||||
.body(Body::from_stream(ReceiverStream::new(body_rx)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn run(port: u16, audio_tx: broadcast::Sender<Bytes>) {
|
||||
let state = AppState { audio_tx };
|
||||
let app = Router::new()
|
||||
.route("/audio", get(audio_stream))
|
||||
.with_state(state);
|
||||
|
||||
let addr = format!("0.0.0.0:{port}");
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await
|
||||
.unwrap_or_else(|e| panic!("failed to bind HTTP server to {addr}: {e}"));
|
||||
eprintln!("HTTP audio stream: http://{addr}/audio");
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
+173
-25
@@ -1,40 +1,139 @@
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::fs;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{mpsc, Arc};
|
||||
use chrono::{DateTime, Local};
|
||||
use serde::Deserialize;
|
||||
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
||||
|
||||
mod audio;
|
||||
mod http;
|
||||
|
||||
const CHUNK_SECS: u32 = 30;
|
||||
const CHUNK_SAMPLES: usize = 16000 * CHUNK_SECS as usize;
|
||||
pub mod scannerbot {
|
||||
tonic::include_proto!("scannerbot");
|
||||
}
|
||||
|
||||
use scannerbot::message_service_client::MessageServiceClient;
|
||||
use scannerbot::SendMessageRequest;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
model: String,
|
||||
audio: String,
|
||||
server_addr: String,
|
||||
channel_id: String,
|
||||
#[serde(default = "default_chunk_secs")]
|
||||
chunk_secs: u32,
|
||||
#[serde(default = "default_silence_threshold")]
|
||||
silence_threshold: f32,
|
||||
#[serde(default = "default_pre_buffer_secs")]
|
||||
pre_buffer_secs: f32,
|
||||
#[serde(default = "default_hangover_secs")]
|
||||
hangover_secs: f32,
|
||||
#[serde(default = "default_max_segment_secs")]
|
||||
max_segment_secs: f32,
|
||||
#[serde(default = "default_temp_dir")]
|
||||
temp_dir: String,
|
||||
#[serde(default = "default_http_port")]
|
||||
http_port: u16,
|
||||
}
|
||||
|
||||
fn default_chunk_secs() -> u32 { 30 }
|
||||
fn default_silence_threshold() -> f32 { 0.02 }
|
||||
fn default_pre_buffer_secs() -> f32 { 1.0 }
|
||||
fn default_hangover_secs() -> f32 { 2.0 }
|
||||
fn default_max_segment_secs() -> f32 { 120.0 }
|
||||
fn default_temp_dir() -> String { "/tmp".into() }
|
||||
fn default_http_port() -> u16 { 8080 }
|
||||
|
||||
fn save_clip(samples: &[f32], path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let spec = hound::WavSpec {
|
||||
channels: 1,
|
||||
sample_rate: 16000,
|
||||
bits_per_sample: 16,
|
||||
sample_format: hound::SampleFormat::Int,
|
||||
};
|
||||
let mut writer = hound::WavWriter::create(path, spec)?;
|
||||
for &s in samples {
|
||||
writer.write_sample((s.clamp(-1.0, 1.0) * i16::MAX as f32) as i16)?;
|
||||
}
|
||||
writer.finalize()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn transcribe_chunk(
|
||||
state: &mut whisper_rs::WhisperState,
|
||||
chunk: &[f32],
|
||||
out: &mut impl Write,
|
||||
counter: &mut u32,
|
||||
channel_id: &str,
|
||||
server_addr: &str,
|
||||
temp_dir: &str,
|
||||
rt: &tokio::runtime::Runtime,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let time: DateTime<Local> = Local::now();
|
||||
let id = format!("{}_{:04}", time.format("%Y%m%d_%H%M%S"), counter);
|
||||
*counter += 1;
|
||||
|
||||
let clip_path = std::path::Path::new(temp_dir).join(format!("scannerbot_{}.wav", id));
|
||||
save_clip(chunk, clip_path.to_str().unwrap())?;
|
||||
|
||||
let params = FullParams::new(SamplingStrategy::BeamSearch {
|
||||
beam_size: 5,
|
||||
patience: -1.0,
|
||||
});
|
||||
let time: DateTime<Local> = Local::now();
|
||||
out.write_all(format!("[{}]: ", time.to_string()).as_bytes())?;
|
||||
state.full(params, chunk)?;
|
||||
|
||||
let mut content = String::new();
|
||||
for segment in state.as_iter() {
|
||||
let line = format!("{}\n", segment);
|
||||
print!("{}", line);
|
||||
out.write_all(line.as_bytes())?;
|
||||
content.push_str(&line);
|
||||
}
|
||||
out.flush()?;
|
||||
|
||||
let audio_bytes = fs::read(&clip_path)?;
|
||||
let audio_filename = format!("scannerbot_{}.wav", id);
|
||||
let _ = fs::remove_file(&clip_path);
|
||||
|
||||
let channel_id = channel_id.to_string();
|
||||
let server_addr = server_addr.to_string();
|
||||
let result = rt.block_on(async move {
|
||||
let mut client = MessageServiceClient::connect(server_addr).await?;
|
||||
let request = tonic::Request::new(SendMessageRequest {
|
||||
channel_id,
|
||||
content,
|
||||
audio_filename,
|
||||
audio: audio_bytes,
|
||||
});
|
||||
let response = client.send_message(request).await?;
|
||||
eprintln!(" [message sent: {}]", response.into_inner().id);
|
||||
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
|
||||
});
|
||||
if let Err(e) = result {
|
||||
return Err(e.to_string().into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let model_path = std::env::args().nth(1).expect("usage: listener <model> <file.wav|pulse:SOURCE> [output.txt]");
|
||||
let audio_arg = std::env::args().nth(2).expect("usage: listener <model> <file.wav|pulse:SOURCE> [output.txt]");
|
||||
let output_path = std::env::args().nth(3).unwrap_or_else(|| "transcription.txt".to_string());
|
||||
let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string());
|
||||
let config_str = fs::read_to_string(&config_path)
|
||||
.unwrap_or_else(|_| panic!("failed to read config file: {config_path}"));
|
||||
let cfg: Config = toml::from_str(&config_str)
|
||||
.unwrap_or_else(|e| panic!("invalid config: {e}"));
|
||||
|
||||
let model_path = cfg.model.clone();
|
||||
let audio_arg = cfg.audio.clone();
|
||||
let channel_id = cfg.channel_id.clone();
|
||||
let server_addr = cfg.server_addr.clone();
|
||||
let temp_dir = cfg.temp_dir.clone();
|
||||
let chunk_samples: usize = 16000 * cfg.chunk_secs as usize;
|
||||
|
||||
let vad_cfg = audio::VadConfig {
|
||||
silence_threshold: cfg.silence_threshold,
|
||||
pre_buffer_frames: (cfg.pre_buffer_secs * 10.0) as usize,
|
||||
hangover_frames: (cfg.hangover_secs * 10.0) as usize,
|
||||
max_segment_samples: (cfg.max_segment_secs * 16000.0) as usize,
|
||||
};
|
||||
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
@@ -42,30 +141,79 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
r.store(false, Ordering::SeqCst);
|
||||
})?;
|
||||
|
||||
let ctx = WhisperContext::new_with_params(&model_path, WhisperContextParameters::default())
|
||||
.expect("failed to load model");
|
||||
let mut state = ctx.create_state().expect("failed to create state");
|
||||
eprintln!("Transcribing {} → {} (Ctrl+C to stop)", audio_arg, server_addr);
|
||||
|
||||
let mut out = OpenOptions::new().create(true).append(true).open(&output_path)?;
|
||||
|
||||
eprintln!("Transcribing {} → {} (Ctrl+C to stop)", audio_arg, output_path);
|
||||
let (audio_tx, _) = tokio::sync::broadcast::channel::<bytes::Bytes>(128);
|
||||
let http_audio_tx = audio_tx.clone();
|
||||
let http_port = cfg.http_port;
|
||||
std::thread::spawn(move || {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("failed to build HTTP runtime");
|
||||
rt.block_on(http::run(http_port, http_audio_tx));
|
||||
});
|
||||
|
||||
if let Some(source) = audio_arg.strip_prefix("pulse:") {
|
||||
let mut stream = audio::LiveStream::open(source)?;
|
||||
while running.load(Ordering::SeqCst) {
|
||||
match stream.next_chunk(CHUNK_SECS)? {
|
||||
Some(chunk) => transcribe_chunk(&mut state, &chunk, &mut out)?,
|
||||
eprintln!("Listening for speech (silence threshold: {:.3} RMS) …", cfg.silence_threshold);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<Vec<f32>>();
|
||||
|
||||
let model_path_t = model_path.clone();
|
||||
let channel_id_t = channel_id.clone();
|
||||
let server_addr_t = server_addr.clone();
|
||||
let temp_dir_t = temp_dir.clone();
|
||||
let transcription_thread = std::thread::spawn(move || {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("failed to build tokio runtime");
|
||||
let ctx = WhisperContext::new_with_params(&model_path_t, WhisperContextParameters::default())
|
||||
.expect("failed to load model");
|
||||
let mut state = ctx.create_state().expect("failed to create state");
|
||||
let mut counter: u32 = 0;
|
||||
for segment in rx {
|
||||
let secs = segment.len() as f32 / 16000.0;
|
||||
eprintln!(" [transcribing {:.1}s segment…]", secs);
|
||||
if let Err(e) = transcribe_chunk(&mut state, &segment, &mut counter, &channel_id_t, &server_addr_t, &temp_dir_t, &rt) {
|
||||
eprintln!("Transcription error: {e}");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// capture loop never pauses for transcription
|
||||
let mut stream = audio::VadStream::open(source, running.clone(), vad_cfg, Some(audio_tx))?;
|
||||
loop {
|
||||
match stream.next_segment()? {
|
||||
Some(segment) => {
|
||||
let secs = segment.len() as f32 / 16000.0;
|
||||
eprintln!(" [captured {:.1}s, queued for transcription]", secs);
|
||||
if tx.send(segment).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
drop(tx);
|
||||
transcription_thread.join().expect("transcription thread panicked");
|
||||
} else {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let ctx = WhisperContext::new_with_params(&model_path, WhisperContextParameters::default())
|
||||
.expect("failed to load model");
|
||||
let mut state = ctx.create_state().expect("failed to create state");
|
||||
let mut counter: u32 = 0;
|
||||
|
||||
while running.load(Ordering::SeqCst) {
|
||||
let audio = audio::decode(&audio_arg)?;
|
||||
for chunk in audio.chunks(CHUNK_SAMPLES) {
|
||||
for chunk in audio.chunks(chunk_samples) {
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
transcribe_chunk(&mut state, chunk, &mut out)?;
|
||||
transcribe_chunk(&mut state, chunk, &mut counter, &channel_id, &server_addr, &temp_dir, &rt)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"git.dubyatp.xyz/dubyatp/scannerbot/server/db"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func Start() {
|
||||
ctx := context.Background()
|
||||
db.InitPostgres(ctx)
|
||||
defer db.ClosePostgres()
|
||||
|
||||
Store = initFileStore()
|
||||
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hello world"))
|
||||
})
|
||||
|
||||
r.Route("/whoami", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
r.Use(LoginCtx)
|
||||
r.Get("/", Whoami)
|
||||
})
|
||||
|
||||
r.Route("/channels", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
|
||||
r.Get("/", ListChannels)
|
||||
r.Post("/", NewChannel)
|
||||
r.Route("/{channelID}", func(r chi.Router) {
|
||||
r.Get("/", GetChannel)
|
||||
r.Delete("/", DeleteChannel)
|
||||
|
||||
r.Route("/messages", func(r chi.Router) {
|
||||
r.Get("/", ListMessages)
|
||||
r.Route("/{messageID}", func(r chi.Router) {
|
||||
r.Get("/", GetMessage)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
|
||||
r.Get("/", ListUsers)
|
||||
r.Route("/{userID}", func(r chi.Router) {
|
||||
r.Get("/", GetUser)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
|
||||
r.Route("/{fileID}", func(r chi.Router) {
|
||||
r.Get("/", ServeFile)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/login", func(r chi.Router) {
|
||||
r.Post("/", Login)
|
||||
})
|
||||
|
||||
r.Route("/logout", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
|
||||
r.Post("/", Logout)
|
||||
})
|
||||
|
||||
r.Route("/register", func(r chi.Router) {
|
||||
r.Post("/", NewUser)
|
||||
})
|
||||
|
||||
http.ListenAndServe(":3000", r)
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var jwtSecret = []byte(os.Getenv("JWT_SECRET"))
|
||||
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func Login(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseMultipartForm(64 << 10)
|
||||
if err != nil {
|
||||
http.Error(w, "Unable to parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
if username == "" || password == "" {
|
||||
http.Error(w, "Username and password cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := dbGetUserByName(username)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePassword(user.Password, password); err != nil {
|
||||
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := CreateSession(user.ID)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
})
|
||||
|
||||
slog.Info("auth: login successful", "userid", user.ID, "username", user.Name)
|
||||
w.Write([]byte("Login successful"))
|
||||
}
|
||||
|
||||
func Logout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Error(w, "No session cookie found. You are already logged out", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := cookie.Value
|
||||
userID, valid := ValidateSession(sessionToken)
|
||||
if !valid {
|
||||
http.Error(w, "Session cookie could not be validated. You are already logged out", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := dbGetUser(userID.String())
|
||||
if err != nil {
|
||||
http.Error(w, "Session cookie validated but user could not be found", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
DeleteSession(sessionToken)
|
||||
|
||||
cookie.Expires = time.Now()
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
slog.Debug("auth: logout successful", "user ID", user.ID, "username", user.Name)
|
||||
w.Write([]byte(fmt.Sprintf("%v has been logged out", user.Name)))
|
||||
}
|
||||
|
||||
func ValidateSession(sessionToken string) (uuid.UUID, bool) {
|
||||
token, err := jwt.Parse(sessionToken, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
slog.Debug("auth: session token invalid, rejecting")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
slog.Debug("auth: could not map claims from JWT")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
userIDStr, ok := claims["userid"].(string)
|
||||
if !ok {
|
||||
slog.Debug("auth: userID claim is not a string")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
slog.Debug("auth: failed to parse userID as uuid", "error", err)
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
hashedToken := hashToken(sessionToken)
|
||||
|
||||
session, err := dbGetSession(hashedToken)
|
||||
if err != nil {
|
||||
slog.Debug("auth: failed to retrieve session from db", "error", err)
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
slog.Debug("auth: session validated", "userID", session.UserID)
|
||||
return userID, true
|
||||
}
|
||||
|
||||
func DeleteSession(sessionToken string) bool {
|
||||
hashedToken := hashToken(sessionToken)
|
||||
|
||||
err := dbDeleteSession(hashedToken)
|
||||
if err != nil {
|
||||
slog.Error("auth: failed to delete session", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
slog.Debug("auth: session deleted", "token", hashedToken)
|
||||
return true
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const userIDKey contextKey = "userID"
|
||||
|
||||
func SessionAuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := cookie.Value
|
||||
userID, valid := ValidateSession(sessionToken)
|
||||
if !valid {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Add username to request context
|
||||
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Token string
|
||||
UserID uuid.UUID
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
func CreateSession(userID uuid.UUID) string {
|
||||
expiry := time.Now().Add(7 * 24 * time.Hour)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"userid": userID.String(),
|
||||
"exp": expiry.Unix(), // 7 day token
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
slog.Error("auth: failed to create JWT", "error", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
hashedToken := hashToken(tokenString)
|
||||
session := Session{
|
||||
Token: hashedToken,
|
||||
UserID: userID,
|
||||
Expiry: expiry,
|
||||
}
|
||||
dbAddSession(&session)
|
||||
|
||||
slog.Debug("auth: new session created", "userid", session.UserID)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func hashPassword(password string) (string, error) {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password),
|
||||
bcrypt.DefaultCost)
|
||||
return string(hashedPassword), err
|
||||
}
|
||||
|
||||
func validatePassword(hashedPassword, password string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func GetChannel(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("channel: entering GetChannel handler")
|
||||
|
||||
channelID := chi.URLParam(r, "channelID")
|
||||
parsed, err := uuid.Parse(channelID)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := dbGetChannel(parsed.String())
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrChannelNotFound) {
|
||||
render.Render(w, r, ErrNotFound)
|
||||
} else {
|
||||
slog.Error("channel: failed to fetch channel", "channelid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("channel: rendering channel", "channelid", channel.ID, "channelname", channel.Name)
|
||||
if err := render.Render(w, r, NewChannelPayloadResponse(channel)); err != nil {
|
||||
slog.Error("channel: failed to render channel", "channelid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
}
|
||||
|
||||
func ListChannels(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("channel: entering ListChannels handler")
|
||||
dbChannels, err := dbGetAllChannels()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrChannelNotFound) {
|
||||
render.Render(w, r, ErrNotFound)
|
||||
} else {
|
||||
slog.Error("channel: failed to fetch channels", "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("channel: successfully fetched channels", "count", len(dbChannels))
|
||||
if err := render.RenderList(w, r, NewChannelListResponse(dbChannels)); err != nil {
|
||||
slog.Error("channel: failed to render channel list response", "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func newChannelID() uuid.UUID {
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
func NewChannel(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("channel: entering NewChannel handler")
|
||||
err := r.ParseMultipartForm(64 << 10)
|
||||
if err != nil {
|
||||
slog.Error("channel: failed to parse multipartform", "error", err)
|
||||
http.Error(w, "Unable to parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newChannelName := r.FormValue("name")
|
||||
newChannelType := r.FormValue("type")
|
||||
newChannelLocation := r.FormValue("location")
|
||||
newChannelNotes := r.FormValue("notes")
|
||||
if newChannelName == "" {
|
||||
slog.Error("channel: channelname is empty")
|
||||
http.Error(w, "Channel name cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newChannel := Channel{
|
||||
ID: newChannelID(),
|
||||
Name: newChannelName,
|
||||
Created: time.Now(),
|
||||
Type: newChannelType,
|
||||
Location: newChannelLocation,
|
||||
Notes: newChannelNotes,
|
||||
}
|
||||
|
||||
slog.Debug("channel: adding new channel to database", "channelid", newChannel.ID, "channelname", newChannel.Name)
|
||||
err = dbAddChannel(&newChannel)
|
||||
if err != nil {
|
||||
slog.Error("channel: failed to add new channel", "channelid", newChannel.ID, "channelname", newChannel.Name)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("channel: successfully added new channel", "channelid", newChannel.ID, "channelname", newChannel.Name)
|
||||
render.Render(w, r, NewChannelPayloadResponse(&newChannel))
|
||||
|
||||
}
|
||||
|
||||
func DeleteChannel(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("channel: entering DeleteChannel handler")
|
||||
|
||||
channelID := chi.URLParam(r, "channelID")
|
||||
parsed, err := uuid.Parse(channelID)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = dbDeleteChannel(parsed.String())
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrChannelNotFound) {
|
||||
render.Render(w, r, ErrNotFound)
|
||||
} else {
|
||||
slog.Error("channel: failed to delete channel", "channelid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("channel: deleted channel", "channelid", parsed.String())
|
||||
w.Write([]byte("Channel deleted successfully"))
|
||||
}
|
||||
|
||||
type Channel struct {
|
||||
ID uuid.UUID
|
||||
Name string
|
||||
Created time.Time
|
||||
Type string
|
||||
Location string
|
||||
Notes string
|
||||
}
|
||||
|
||||
type channelKey struct{}
|
||||
|
||||
type ChannelPayload struct {
|
||||
*Channel
|
||||
}
|
||||
@@ -0,0 +1,346 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.dubyatp.xyz/dubyatp/scannerbot/server/db"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
var ErrUserNotFound = errors.New("db: user not found")
|
||||
var ErrSessionNotFound = errors.New("db: session not found")
|
||||
var ErrChannelNotFound = errors.New("db: channel not found")
|
||||
var ErrFileNotFound = errors.New("db: file not found")
|
||||
var ErrMessageNotFound = errors.New("db: message not found")
|
||||
|
||||
func dbGetUser(id string) (*User, error) {
|
||||
query := `SELECT id, name, password FROM users WHERE id = $1`
|
||||
var user User
|
||||
err := db.Pool.QueryRow(context.Background(), query, id).Scan(&user.ID, &user.Name, &user.Password)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Debug("db: user not found", "userid", id)
|
||||
return nil, ErrUserNotFound
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query user", "error", err)
|
||||
return nil, fmt.Errorf("failed to query user")
|
||||
}
|
||||
|
||||
slog.Debug("db: user found", "userid", user.ID, "username", user.Name)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func dbGetUserByName(username string) (*User, error) {
|
||||
query := `SELECT id, name, password FROM users WHERE name = $1`
|
||||
var user User
|
||||
err := db.Pool.QueryRow(context.Background(), query, username).Scan(&user.ID, &user.Name, &user.Password)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Debug("db: user not found", "username", username)
|
||||
return nil, ErrUserNotFound
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query user", "error", err)
|
||||
return nil, fmt.Errorf("failed to query user")
|
||||
}
|
||||
|
||||
slog.Debug("db: user found", "userid", user.ID, "username", user.Name)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func dbGetAllUsers() ([]*User, error) {
|
||||
query := `SELECT id, name, password FROM users`
|
||||
rows, err := db.Pool.Query(context.Background(), query)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to query users", "error", err)
|
||||
return nil, fmt.Errorf("failed to query users")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*User
|
||||
for rows.Next() {
|
||||
user := &User{}
|
||||
if err := rows.Scan(&user.ID, &user.Name, &user.Password); err != nil {
|
||||
slog.Error("db: failed to scan user", "error", err)
|
||||
return nil, fmt.Errorf("failed to scan user")
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
slog.Error("db: row iteration error", "error", err)
|
||||
return nil, fmt.Errorf("failed to iterate users")
|
||||
}
|
||||
if len(users) == 0 {
|
||||
slog.Debug("db: no users found")
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
slog.Debug("db: user list returned")
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func dbAddUser(user *User) error {
|
||||
query := `INSERT INTO users (id, name, password) VALUES ($1, $2, $3)`
|
||||
_, err := db.Pool.Exec(context.Background(), query, user.ID, user.Name, user.Password)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to add user", "error", err, "userid", user.ID, "username", user.Name)
|
||||
return fmt.Errorf("failed to add user")
|
||||
}
|
||||
|
||||
slog.Debug("db: user added", "userid", user.ID, "username", user.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbAddSession(session *Session) error {
|
||||
query := `INSERT INTO sessions (jwttoken, userid, expiry) VALUES ($1, $2, $3)`
|
||||
_, err := db.Pool.Exec(context.Background(), query, session.Token, session.UserID, session.Expiry)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to add session", "error", err)
|
||||
return fmt.Errorf("failed to add session")
|
||||
}
|
||||
|
||||
slog.Debug("db: session added", "userid", session.UserID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbGetSession(jwtToken string) (*Session, error) {
|
||||
query := `SELECT jwttoken, userid, expiry FROM sessions WHERE jwttoken = $1`
|
||||
|
||||
var session Session
|
||||
err := db.Pool.QueryRow(context.Background(), query, jwtToken).Scan(&session.Token, &session.UserID, &session.Expiry)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Debug("db: session not found")
|
||||
return nil, ErrSessionNotFound
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query session", "error", err)
|
||||
return nil, fmt.Errorf("failed to query session")
|
||||
}
|
||||
|
||||
slog.Debug("db: session found", "userid", session.UserID)
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func dbDeleteSession(jwtToken string) error {
|
||||
query := `DELETE FROM sessions WHERE jwttoken = $1`
|
||||
tag, err := db.Pool.Exec(context.Background(), query, jwtToken)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to delete session", "error", err)
|
||||
return fmt.Errorf("failed to delete session")
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return ErrSessionNotFound
|
||||
}
|
||||
|
||||
slog.Debug("db: session deleted")
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbAddChannel(channel *Channel) error {
|
||||
query := `INSERT INTO channels (id, name, created, type, location, notes) VALUES ($1, $2, $3, $4, $5, $6)`
|
||||
_, err := db.Pool.Exec(context.Background(), query, channel.ID, channel.Name, channel.Created, channel.Type, channel.Location, channel.Notes)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to add channel", "channel", err, "channelid", channel.ID, "channelname", channel.Name)
|
||||
return fmt.Errorf("failed to add channel")
|
||||
}
|
||||
|
||||
slog.Debug("db: channel added", "channelid", channel.ID, "channelname", channel.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbGetChannel(id string) (*Channel, error) {
|
||||
query := `SELECT id, name, created, type, location, notes FROM channels WHERE id = $1`
|
||||
var channel Channel
|
||||
err := db.Pool.QueryRow(context.Background(), query, id).Scan(&channel.ID, &channel.Name, &channel.Created, &channel.Type, &channel.Location, &channel.Notes)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Debug("db: channel not found", "channelid", id)
|
||||
return nil, ErrChannelNotFound
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query channel", "error", err)
|
||||
return nil, fmt.Errorf("failed to query channel")
|
||||
}
|
||||
|
||||
slog.Debug("db: channel found", "channelid", channel.ID, "channelname", channel.Name)
|
||||
return &channel, nil
|
||||
}
|
||||
|
||||
func dbGetAllChannels() ([]*Channel, error) {
|
||||
query := `SELECT id, name, created, type, location, notes FROM channels`
|
||||
rows, err := db.Pool.Query(context.Background(), query)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to query channels", "error", err)
|
||||
return nil, fmt.Errorf("failed to query channels")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var channels []*Channel
|
||||
for rows.Next() {
|
||||
channel := &Channel{}
|
||||
if err := rows.Scan(&channel.ID, &channel.Name, &channel.Created, &channel.Type, &channel.Location, &channel.Notes); err != nil {
|
||||
slog.Error("db: failed to scan channel", "error", err)
|
||||
return nil, fmt.Errorf("failed to scan channel")
|
||||
}
|
||||
channels = append(channels, channel)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
slog.Error("db: row iteration error", "error", err)
|
||||
return nil, fmt.Errorf("failed to iterate channels")
|
||||
}
|
||||
if len(channels) == 0 {
|
||||
slog.Debug("db: no channels found")
|
||||
return nil, ErrChannelNotFound
|
||||
}
|
||||
|
||||
slog.Debug("db: channel list returned")
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func dbDeleteChannel(id string) error {
|
||||
query := `DELETE FROM channels WHERE id = $1`
|
||||
tag, err := db.Pool.Exec(context.Background(), query, id)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to delete channel", "error", err)
|
||||
return fmt.Errorf("failed to delete channel")
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return ErrChannelNotFound
|
||||
}
|
||||
|
||||
slog.Debug("db: channel deleted")
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbAddFile(file *File) error {
|
||||
query := `INSERT INTO files (id, name, created, backend, path) VALUES ($1, $2, $3, $4, $5)`
|
||||
_, err := db.Pool.Exec(context.Background(), query, file.ID, file.Name, file.Created, file.Backend, file.Path)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to add file", "error", err, "fileid", file.ID)
|
||||
return fmt.Errorf("failed to add file")
|
||||
}
|
||||
slog.Debug("db: file added", "fileid", file.ID, "filename", file.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbGetFile(id string) (*File, error) {
|
||||
query := `SELECT id, name, created, backend, path FROM files WHERE id = $1`
|
||||
var file File
|
||||
err := db.Pool.QueryRow(context.Background(), query, id).Scan(&file.ID, &file.Name, &file.Created, &file.Backend, &file.Path)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Debug("db: file not found", "fileid", id)
|
||||
return nil, ErrFileNotFound
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query file", "error", err)
|
||||
return nil, fmt.Errorf("failed to query file")
|
||||
}
|
||||
slog.Debug("db: file found", "fileid", file.ID, "filename", file.Name)
|
||||
return &file, nil
|
||||
}
|
||||
|
||||
func DBGetChannel(id string) (*Channel, error) { return dbGetChannel(id) }
|
||||
func DBAddFile(file *File) error { return dbAddFile(file) }
|
||||
|
||||
func DBAddMessage(msg *Message) error {
|
||||
query := `INSERT INTO messages (id, channel, created, content, audio) VALUES ($1, $2, $3, $4, $5)`
|
||||
_, err := db.Pool.Exec(context.Background(), query, msg.ID, msg.Channel.ID, msg.Created, msg.Content, msg.Audio.ID)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to add message", "error", err, "messageid", msg.ID)
|
||||
return fmt.Errorf("failed to add message")
|
||||
}
|
||||
slog.Debug("db: message added", "messageid", msg.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbGetMessage(id string) (*Message, error) {
|
||||
query := `SELECT id, channel, created, content, audio FROM messages WHERE id = $1`
|
||||
var channelID, audioID string
|
||||
var msg Message
|
||||
err := db.Pool.QueryRow(context.Background(), query, id).Scan(&msg.ID, &channelID, &msg.Created, &msg.Content, &audioID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Debug("db: message not found", "messageid", id)
|
||||
return nil, ErrMessageNotFound
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query message", "error", err)
|
||||
return nil, fmt.Errorf("failed to query message")
|
||||
}
|
||||
|
||||
channel, err := dbGetChannel(channelID)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to fetch channel for message", "messageid", id, "channelid", channelID, "error", err)
|
||||
return nil, fmt.Errorf("failed to fetch channel for message")
|
||||
}
|
||||
audio, err := dbGetFile(audioID)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to fetch audio for message", "messageid", id, "audioid", audioID, "error", err)
|
||||
return nil, fmt.Errorf("failed to fetch audio for message")
|
||||
}
|
||||
|
||||
msg.Channel = *channel
|
||||
msg.Audio = *audio
|
||||
slog.Debug("db: message found", "messageid", msg.ID)
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func dbGetMessagesByChannel(channelID string, from, to *time.Time) ([]*Message, error) {
|
||||
query := `SELECT id, channel, created, content, audio FROM messages WHERE channel = $1`
|
||||
args := []any{channelID}
|
||||
if from != nil {
|
||||
args = append(args, *from)
|
||||
query += fmt.Sprintf(" AND created >= $%d", len(args))
|
||||
}
|
||||
if to != nil {
|
||||
args = append(args, *to)
|
||||
query += fmt.Sprintf(" AND created <= $%d", len(args))
|
||||
}
|
||||
query += " ORDER BY created DESC"
|
||||
rows, err := db.Pool.Query(context.Background(), query, args...)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to query messages", "error", err)
|
||||
return nil, fmt.Errorf("failed to query messages")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type messageRow struct {
|
||||
msg Message
|
||||
channelID string
|
||||
audioID string
|
||||
}
|
||||
|
||||
var rows_ []messageRow
|
||||
for rows.Next() {
|
||||
var mr messageRow
|
||||
if err := rows.Scan(&mr.msg.ID, &mr.channelID, &mr.msg.Created, &mr.msg.Content, &mr.audioID); err != nil {
|
||||
slog.Error("db: failed to scan message", "error", err)
|
||||
return nil, fmt.Errorf("failed to scan message")
|
||||
}
|
||||
rows_ = append(rows_, mr)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
slog.Error("db: row iteration error", "error", err)
|
||||
return nil, fmt.Errorf("failed to iterate messages")
|
||||
}
|
||||
if len(rows_) == 0 {
|
||||
slog.Debug("db: no messages found", "channelid", channelID)
|
||||
return nil, ErrMessageNotFound
|
||||
}
|
||||
|
||||
channel, err := dbGetChannel(channelID)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to fetch channel for messages", "channelid", channelID, "error", err)
|
||||
return nil, fmt.Errorf("failed to fetch channel for messages")
|
||||
}
|
||||
|
||||
var messages []*Message
|
||||
for _, mr := range rows_ {
|
||||
audio, err := dbGetFile(mr.audioID)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to fetch audio for message", "messageid", mr.msg.ID, "audioid", mr.audioID, "error", err)
|
||||
return nil, fmt.Errorf("failed to fetch audio for message")
|
||||
}
|
||||
mr.msg.Channel = *channel
|
||||
mr.msg.Audio = *audio
|
||||
messages = append(messages, &mr.msg)
|
||||
}
|
||||
|
||||
slog.Debug("db: message list returned", "channelid", channelID, "count", len(messages))
|
||||
return messages, nil
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
type ErrResponse struct {
|
||||
Err error `json:"-"`
|
||||
HTTPStatusCode int `json:"-"`
|
||||
|
||||
StatusText string `json:"status"`
|
||||
AppCode int64 `json:"code,omitempty"`
|
||||
ErrorText string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
render.Status(r, e.HTTPStatusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ErrInvalidRequest(err error) render.Renderer {
|
||||
return &ErrResponse{
|
||||
Err: err,
|
||||
HTTPStatusCode: 400,
|
||||
StatusText: "Invalid request.",
|
||||
ErrorText: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
func ErrRender(err error) render.Renderer {
|
||||
return &ErrResponse{
|
||||
Err: err,
|
||||
HTTPStatusCode: 422,
|
||||
StatusText: "Error rendering response.",
|
||||
ErrorText: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
func ErrInternal(err error) render.Renderer {
|
||||
return &ErrResponse{
|
||||
Err: err,
|
||||
HTTPStatusCode: 500,
|
||||
StatusText: "Internal server error.",
|
||||
}
|
||||
}
|
||||
|
||||
var ErrNotFound = &ErrResponse{HTTPStatusCode: 404, StatusText: "Resource not found."}
|
||||
@@ -0,0 +1,156 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func initFileStore() FileStore {
|
||||
val, ok := os.LookupEnv("FILE_BACKEND")
|
||||
if !ok {
|
||||
slog.Error("FILE_BACKEND environment variable not set")
|
||||
os.Exit(1)
|
||||
}
|
||||
switch FileBackend(val) {
|
||||
case FileBackendLocal:
|
||||
localFilePath, ok := os.LookupEnv("LOCAL_FILEPATH")
|
||||
if !ok {
|
||||
slog.Error("LOCAL_FILEPATH environment variable not set")
|
||||
os.Exit(1)
|
||||
}
|
||||
return &LocalFileStore{BaseDir: localFilePath}
|
||||
}
|
||||
slog.Error("unsupported FILE_BACKEND", "value", val)
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
type File struct {
|
||||
ID uuid.UUID
|
||||
Name string
|
||||
Created time.Time
|
||||
Backend FileBackend
|
||||
Path string
|
||||
}
|
||||
|
||||
type FileBackend string
|
||||
|
||||
const (
|
||||
FileBackendLocal FileBackend = "local"
|
||||
FileBackendS3 FileBackend = "s3"
|
||||
)
|
||||
|
||||
var Store FileStore
|
||||
|
||||
type FileStore interface {
|
||||
Save(name string, r io.Reader) (*File, error)
|
||||
URL(file *File) (string, error)
|
||||
}
|
||||
|
||||
type LocalFileStore struct {
|
||||
BaseDir string
|
||||
}
|
||||
|
||||
func (s *LocalFileStore) Save(name string, r io.Reader) (*File, error) {
|
||||
id := uuid.New()
|
||||
path := filepath.Join(s.BaseDir, id.String())
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("file(local): failed to create file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(f, r); err != nil {
|
||||
os.Remove(path)
|
||||
return nil, fmt.Errorf("file(local): failed to write file: %w", err)
|
||||
}
|
||||
|
||||
return &File{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Created: time.Now(),
|
||||
Backend: FileBackendLocal,
|
||||
Path: path,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *LocalFileStore) URL(file *File) (string, error) {
|
||||
return "/files/" + file.ID.String(), nil
|
||||
}
|
||||
|
||||
func ServeFile(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("file: entering ServeFile handler")
|
||||
|
||||
fileID := chi.URLParam(r, "fileID")
|
||||
parsed, err := uuid.Parse(fileID)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
file, err := dbGetFile(parsed.String())
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrFileNotFound) {
|
||||
render.Render(w, r, ErrNotFound)
|
||||
} else {
|
||||
slog.Error("file: failed to fetch file", "fileid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
f, err := os.Open(file.Path)
|
||||
if err != nil {
|
||||
slog.Error("file: failed to open file", "fileid", file.ID, "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
http.ServeContent(w, r, file.Name, file.Created, f)
|
||||
}
|
||||
|
||||
// UploadFile is a temporary handler for testing file uploads.
|
||||
/*
|
||||
func UploadFile(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("file: entering UploadFile handler")
|
||||
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
f, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
file, err := Store.Save(header.Filename, f)
|
||||
if err != nil {
|
||||
slog.Error("file: failed to save file", "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := dbAddFile(file); err != nil {
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("file: uploaded file", "fileid", file.ID, "filename", file.Name)
|
||||
render.Render(w, r, NewFilePayloadResponse(file))
|
||||
}
|
||||
*/
|
||||
@@ -0,0 +1,99 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uuid.UUID
|
||||
Channel Channel
|
||||
Created time.Time
|
||||
Content string
|
||||
Audio File
|
||||
}
|
||||
|
||||
type MessagePayload struct {
|
||||
*Message
|
||||
}
|
||||
|
||||
func GetMessage(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("message: entering GetMessage handler")
|
||||
|
||||
messageID := chi.URLParam(r, "messageID")
|
||||
parsed, err := uuid.Parse(messageID)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := dbGetMessage(parsed.String())
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrMessageNotFound) {
|
||||
render.Render(w, r, ErrNotFound)
|
||||
} else {
|
||||
slog.Error("message: failed to fetch message", "messageid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: rendering message", "messageid", msg.ID)
|
||||
if err := render.Render(w, r, NewMessagePayloadResponse(msg)); err != nil {
|
||||
slog.Error("message: failed to render message", "messageid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
}
|
||||
|
||||
func ListMessages(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("message: entering ListMessages handler")
|
||||
|
||||
channelID := chi.URLParam(r, "channelID")
|
||||
parsed, err := uuid.Parse(channelID)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
var from, to *time.Time
|
||||
if v := r.URL.Query().Get("from"); v != "" {
|
||||
t, err := time.Parse(time.RFC3339, v)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(fmt.Errorf("invalid 'from' timestamp: %w", err)))
|
||||
return
|
||||
}
|
||||
from = &t
|
||||
}
|
||||
if v := r.URL.Query().Get("to"); v != "" {
|
||||
t, err := time.Parse(time.RFC3339, v)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(fmt.Errorf("invalid 'to' timestamp: %w", err)))
|
||||
return
|
||||
}
|
||||
to = &t
|
||||
}
|
||||
|
||||
messages, err := dbGetMessagesByChannel(parsed.String(), from, to)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrMessageNotFound) {
|
||||
render.Render(w, r, ErrNotFound)
|
||||
} else {
|
||||
slog.Error("message: failed to fetch messages", "channelid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: successfully fetched messages", "channelid", parsed.String(), "count", len(messages))
|
||||
if err := render.RenderList(w, r, NewMessageListResponse(messages)); err != nil {
|
||||
slog.Error("message: failed to render message list", "channelid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
func NewUserPayloadResponse(user *User) *UserPayload {
|
||||
return &UserPayload{User: user}
|
||||
}
|
||||
|
||||
func NewUserListResponse(users []*User) []render.Renderer {
|
||||
list := []render.Renderer{}
|
||||
for _, user := range users {
|
||||
list = append(list, NewUserPayloadResponse(user))
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func (u *UserPayload) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewChannelPayloadResponse(channel *Channel) *ChannelPayload {
|
||||
return &ChannelPayload{Channel: channel}
|
||||
}
|
||||
|
||||
func NewChannelListResponse(channels []*Channel) []render.Renderer {
|
||||
list := []render.Renderer{}
|
||||
for _, channel := range channels {
|
||||
list = append(list, NewChannelPayloadResponse(channel))
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func (c *ChannelPayload) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type FilePayload struct {
|
||||
*File
|
||||
}
|
||||
|
||||
func NewFilePayloadResponse(file *File) *FilePayload {
|
||||
return &FilePayload{File: file}
|
||||
}
|
||||
|
||||
func (f *FilePayload) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewMessagePayloadResponse(msg *Message) *MessagePayload {
|
||||
return &MessagePayload{Message: msg}
|
||||
}
|
||||
|
||||
func NewMessageListResponse(messages []*Message) []render.Renderer {
|
||||
list := []render.Renderer{}
|
||||
for _, msg := range messages {
|
||||
list = append(list, NewMessagePayloadResponse(msg))
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func (m *MessagePayload) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func Whoami(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering Whoami handler")
|
||||
user, ok := r.Context().Value(userKey{}).(*User)
|
||||
if !ok || user == nil {
|
||||
slog.Debug("user: anonymous user")
|
||||
w.Write([]byte("anonymous"))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: returning username", "userid", user.ID, "username", user.Name)
|
||||
w.Write([]byte(user.Name))
|
||||
}
|
||||
|
||||
func LoginCtx(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering LoginCtx middleware")
|
||||
userID, ok := r.Context().Value(userIDKey).(uuid.UUID)
|
||||
if !ok || userID == uuid.Nil {
|
||||
slog.Debug("user: no user ID provided, assuming anonymous user")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: fetching user by ID", "user ID", userID)
|
||||
user, err := dbGetUser(userID.String())
|
||||
if err != nil {
|
||||
slog.Error("user: failed to fetch user by ID", "user ID", userID, "error", err)
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: successfully fetched user", "user ID", user.ID, "username", user.Name)
|
||||
ctx := context.WithValue(r.Context(), userKey{}, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func ListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering ListUsers handler")
|
||||
dbUsers, err := dbGetAllUsers()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
render.Render(w, r, ErrNotFound)
|
||||
} else {
|
||||
slog.Error("user: failed to fetch users", "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: successfully fetched users", "count", len(dbUsers))
|
||||
if err := render.RenderList(w, r, NewUserListResponse(dbUsers)); err != nil {
|
||||
slog.Error("user: failed to render user list response", "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func GetUser(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering GetUser handler")
|
||||
|
||||
userID := chi.URLParam(r, "userID")
|
||||
parsed, err := uuid.Parse(userID)
|
||||
if err != nil {
|
||||
render.Render(w, r, ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
user, err := dbGetUser(parsed.String())
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
render.Render(w, r, ErrNotFound)
|
||||
} else {
|
||||
slog.Error("user: failed to fetch user", "userid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: rendering user", "userid", user.ID, "username", user.Name)
|
||||
if err := render.Render(w, r, NewUserPayloadResponse(user)); err != nil {
|
||||
slog.Error("user: failed to render user", "userid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
}
|
||||
|
||||
func newUserID() uuid.UUID {
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
func NewUser(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering NewUser handler")
|
||||
err := r.ParseMultipartForm(64 << 10)
|
||||
if err != nil {
|
||||
slog.Error("user: failed to parse multipartform", "error", err)
|
||||
http.Error(w, "Unable to parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newUserName := r.FormValue("name")
|
||||
password := r.FormValue("password")
|
||||
if newUserName == "" || password == "" {
|
||||
slog.Error("user: username or password is empty")
|
||||
http.Error(w, "Username and password cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: hashing password for new user", "userName", newUserName)
|
||||
hashedPassword, err := hashPassword(password)
|
||||
if err != nil {
|
||||
slog.Error("user: failed to hash password", "error", err)
|
||||
http.Error(w, "Unable to hash password", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
newUser := User{
|
||||
ID: newUserID(),
|
||||
Name: newUserName,
|
||||
Password: hashedPassword,
|
||||
}
|
||||
|
||||
slog.Debug("user: adding new user to database", "userID", newUser.ID, "userName", newUser.Name)
|
||||
err = dbAddUser(&newUser)
|
||||
if err != nil {
|
||||
slog.Error("user: failed to add new user", "userID", newUser.ID, "userName", newUser.Name, "error", err)
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: successfully added new user", "userID", newUser.ID, "userName", newUser.Name)
|
||||
render.Render(w, r, NewUserPayloadResponse(&newUser))
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"-"`
|
||||
}
|
||||
|
||||
type userKey struct{}
|
||||
|
||||
type UserPayload struct {
|
||||
*User
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var Pool *pgxpool.Pool
|
||||
|
||||
func InitPostgres(ctx context.Context) {
|
||||
pool, err := pgxpool.New(ctx, os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
slog.Error("Failed to connect to Postgres", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
Pool = pool
|
||||
err = pool.Ping(ctx)
|
||||
if err != nil {
|
||||
slog.Error("Failed to connect to Postgres", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("Connected to Postgres")
|
||||
}
|
||||
|
||||
func ClosePostgres() {
|
||||
if Pool != nil {
|
||||
Pool.Close()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
module git.dubyatp.xyz/dubyatp/scannerbot/server
|
||||
|
||||
go 1.26.3
|
||||
|
||||
require (
|
||||
github.com/go-chi/chi/v5 v5.2.5
|
||||
github.com/go-chi/render v1.0.3
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.9.2
|
||||
github.com/joho/godotenv v1.5.1
|
||||
golang.org/x/crypto v0.51.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.44.0 // indirect
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
|
||||
google.golang.org/grpc v1.81.1 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||
github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4=
|
||||
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
|
||||
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
||||
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
||||
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ=
|
||||
google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,72 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.dubyatp.xyz/dubyatp/scannerbot/server/api"
|
||||
pb "git.dubyatp.xyz/dubyatp/scannerbot/server/proto"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type MessageServer struct {
|
||||
pb.UnimplementedMessageServiceServer
|
||||
}
|
||||
|
||||
func (s *MessageServer) SendMessage(ctx context.Context, req *pb.SendMessageRequest) (*pb.SendMessageResponse, error) {
|
||||
slog.Debug("grpc: entering SendMessage handler")
|
||||
|
||||
channelID, err := uuid.Parse(req.ChannelId)
|
||||
if err != nil {
|
||||
slog.Error("grpc: invalid channel_id", "error", err)
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid channel_id: %v", err)
|
||||
}
|
||||
|
||||
channel, err := api.DBGetChannel(channelID.String())
|
||||
if err != nil {
|
||||
slog.Error("grpc: channel not found", "channelid", channelID, "error", err)
|
||||
return nil, status.Errorf(codes.NotFound, "channel not found")
|
||||
}
|
||||
|
||||
filename := req.AudioFilename
|
||||
if filename == "" {
|
||||
filename = channelID.String() + ".wav"
|
||||
}
|
||||
|
||||
audio, err := api.Store.Save(filename, bytes.NewReader(req.Audio))
|
||||
if err != nil {
|
||||
slog.Error("grpc: failed to save audio file", "error", err)
|
||||
return nil, status.Errorf(codes.Internal, "failed to save audio file")
|
||||
}
|
||||
|
||||
if err := api.DBAddFile(audio); err != nil {
|
||||
slog.Error("grpc: failed to persist audio file record", "error", err)
|
||||
return nil, status.Errorf(codes.Internal, "failed to persist audio file")
|
||||
}
|
||||
|
||||
msg := &api.Message{
|
||||
ID: uuid.New(),
|
||||
Channel: *channel,
|
||||
Created: time.Now(),
|
||||
Content: req.Content,
|
||||
Audio: *audio,
|
||||
}
|
||||
|
||||
if err := api.DBAddMessage(msg); err != nil {
|
||||
slog.Error("grpc: failed to persist message", "error", err)
|
||||
return nil, status.Errorf(codes.Internal, "failed to persist message")
|
||||
}
|
||||
|
||||
slog.Debug("grpc: message saved", "messageid", msg.ID, "channelid", channelID)
|
||||
return &pb.SendMessageResponse{
|
||||
Id: msg.ID.String(),
|
||||
ChannelId: channel.ID.String(),
|
||||
Created: msg.Created.Format(time.RFC3339),
|
||||
Content: msg.Content,
|
||||
AudioId: audio.ID.String(),
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
|
||||
pb "git.dubyatp.xyz/dubyatp/scannerbot/server/proto"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func Start() {
|
||||
lis, err := net.Listen("tcp", ":3001")
|
||||
if err != nil {
|
||||
slog.Error("grpc: failed to listen", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
s := grpc.NewServer()
|
||||
pb.RegisterMessageServiceServer(s, &MessageServer{})
|
||||
|
||||
slog.Info("Starting the gRPC server...", "addr", lis.Addr())
|
||||
if err := s.Serve(lis); err != nil {
|
||||
slog.Error("grpc: server failed", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"git.dubyatp.xyz/dubyatp/scannerbot/server/api"
|
||||
grpcserver "git.dubyatp.xyz/dubyatp/scannerbot/server/grpc"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
var REQUIRED_ENVS = [...]string{
|
||||
"DATABASE_URL", "JWT_SECRET", "FILE_BACKEND",
|
||||
}
|
||||
|
||||
func checkEnvVars(keys []string) (bool, []string) {
|
||||
var missing []string
|
||||
for _, key := range keys {
|
||||
if _, ok := os.LookupEnv(key); !ok {
|
||||
missing = append(missing, key)
|
||||
}
|
||||
}
|
||||
return len(missing) == 0, missing
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
slog.Debug("No .env file loaded, will try OS environment variables")
|
||||
}
|
||||
|
||||
exists, missingVars := checkEnvVars(REQUIRED_ENVS[:])
|
||||
if !exists {
|
||||
slog.Error("Missing environment variables", "missing", missingVars)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go grpcserver.Start()
|
||||
|
||||
slog.Info("Starting the API server...")
|
||||
api.Start()
|
||||
}
|
||||
@@ -0,0 +1,239 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.11
|
||||
// protoc v3.21.12
|
||||
// source: proto/message.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type SendMessageRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ChannelId string `protobuf:"bytes,1,opt,name=channel_id,json=channelId,proto3" json:"channel_id,omitempty"`
|
||||
Content string `protobuf:"bytes,2,opt,name=content,proto3" json:"content,omitempty"`
|
||||
AudioFilename string `protobuf:"bytes,3,opt,name=audio_filename,json=audioFilename,proto3" json:"audio_filename,omitempty"`
|
||||
Audio []byte `protobuf:"bytes,4,opt,name=audio,proto3" json:"audio,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *SendMessageRequest) Reset() {
|
||||
*x = SendMessageRequest{}
|
||||
mi := &file_proto_message_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *SendMessageRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SendMessageRequest) ProtoMessage() {}
|
||||
|
||||
func (x *SendMessageRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_proto_message_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SendMessageRequest.ProtoReflect.Descriptor instead.
|
||||
func (*SendMessageRequest) Descriptor() ([]byte, []int) {
|
||||
return file_proto_message_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *SendMessageRequest) GetChannelId() string {
|
||||
if x != nil {
|
||||
return x.ChannelId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *SendMessageRequest) GetContent() string {
|
||||
if x != nil {
|
||||
return x.Content
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *SendMessageRequest) GetAudioFilename() string {
|
||||
if x != nil {
|
||||
return x.AudioFilename
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *SendMessageRequest) GetAudio() []byte {
|
||||
if x != nil {
|
||||
return x.Audio
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SendMessageResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
|
||||
ChannelId string `protobuf:"bytes,2,opt,name=channel_id,json=channelId,proto3" json:"channel_id,omitempty"`
|
||||
Created string `protobuf:"bytes,3,opt,name=created,proto3" json:"created,omitempty"`
|
||||
Content string `protobuf:"bytes,4,opt,name=content,proto3" json:"content,omitempty"`
|
||||
AudioId string `protobuf:"bytes,5,opt,name=audio_id,json=audioId,proto3" json:"audio_id,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *SendMessageResponse) Reset() {
|
||||
*x = SendMessageResponse{}
|
||||
mi := &file_proto_message_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *SendMessageResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SendMessageResponse) ProtoMessage() {}
|
||||
|
||||
func (x *SendMessageResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_proto_message_proto_msgTypes[1]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SendMessageResponse.ProtoReflect.Descriptor instead.
|
||||
func (*SendMessageResponse) Descriptor() ([]byte, []int) {
|
||||
return file_proto_message_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *SendMessageResponse) GetId() string {
|
||||
if x != nil {
|
||||
return x.Id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *SendMessageResponse) GetChannelId() string {
|
||||
if x != nil {
|
||||
return x.ChannelId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *SendMessageResponse) GetCreated() string {
|
||||
if x != nil {
|
||||
return x.Created
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *SendMessageResponse) GetContent() string {
|
||||
if x != nil {
|
||||
return x.Content
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *SendMessageResponse) GetAudioId() string {
|
||||
if x != nil {
|
||||
return x.AudioId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_proto_message_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_proto_message_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x13proto/message.proto\x12\n" +
|
||||
"scannerbot\"\x8a\x01\n" +
|
||||
"\x12SendMessageRequest\x12\x1d\n" +
|
||||
"\n" +
|
||||
"channel_id\x18\x01 \x01(\tR\tchannelId\x12\x18\n" +
|
||||
"\acontent\x18\x02 \x01(\tR\acontent\x12%\n" +
|
||||
"\x0eaudio_filename\x18\x03 \x01(\tR\raudioFilename\x12\x14\n" +
|
||||
"\x05audio\x18\x04 \x01(\fR\x05audio\"\x93\x01\n" +
|
||||
"\x13SendMessageResponse\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n" +
|
||||
"\n" +
|
||||
"channel_id\x18\x02 \x01(\tR\tchannelId\x12\x18\n" +
|
||||
"\acreated\x18\x03 \x01(\tR\acreated\x12\x18\n" +
|
||||
"\acontent\x18\x04 \x01(\tR\acontent\x12\x19\n" +
|
||||
"\baudio_id\x18\x05 \x01(\tR\aaudioId2`\n" +
|
||||
"\x0eMessageService\x12N\n" +
|
||||
"\vSendMessage\x12\x1e.scannerbot.SendMessageRequest\x1a\x1f.scannerbot.SendMessageResponseB1Z/git.dubyatp.xyz/dubyatp/scannerbot/server/protob\x06proto3"
|
||||
|
||||
var (
|
||||
file_proto_message_proto_rawDescOnce sync.Once
|
||||
file_proto_message_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_proto_message_proto_rawDescGZIP() []byte {
|
||||
file_proto_message_proto_rawDescOnce.Do(func() {
|
||||
file_proto_message_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proto_message_proto_rawDesc), len(file_proto_message_proto_rawDesc)))
|
||||
})
|
||||
return file_proto_message_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_proto_message_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_proto_message_proto_goTypes = []any{
|
||||
(*SendMessageRequest)(nil), // 0: scannerbot.SendMessageRequest
|
||||
(*SendMessageResponse)(nil), // 1: scannerbot.SendMessageResponse
|
||||
}
|
||||
var file_proto_message_proto_depIdxs = []int32{
|
||||
0, // 0: scannerbot.MessageService.SendMessage:input_type -> scannerbot.SendMessageRequest
|
||||
1, // 1: scannerbot.MessageService.SendMessage:output_type -> scannerbot.SendMessageResponse
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_proto_message_proto_init() }
|
||||
func file_proto_message_proto_init() {
|
||||
if File_proto_message_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_proto_message_proto_rawDesc), len(file_proto_message_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_proto_message_proto_goTypes,
|
||||
DependencyIndexes: file_proto_message_proto_depIdxs,
|
||||
MessageInfos: file_proto_message_proto_msgTypes,
|
||||
}.Build()
|
||||
File_proto_message_proto = out.File
|
||||
file_proto_message_proto_goTypes = nil
|
||||
file_proto_message_proto_depIdxs = nil
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package scannerbot;
|
||||
|
||||
option go_package = "git.dubyatp.xyz/dubyatp/scannerbot/server/proto";
|
||||
|
||||
service MessageService {
|
||||
rpc SendMessage(SendMessageRequest) returns (SendMessageResponse);
|
||||
}
|
||||
|
||||
message SendMessageRequest {
|
||||
string channel_id = 1;
|
||||
string content = 2;
|
||||
string audio_filename = 3;
|
||||
bytes audio = 4;
|
||||
}
|
||||
|
||||
message SendMessageResponse {
|
||||
string id = 1;
|
||||
string channel_id = 2;
|
||||
string created = 3;
|
||||
string content = 4;
|
||||
string audio_id = 5;
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.2
|
||||
// - protoc v3.21.12
|
||||
// source: proto/message.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
MessageService_SendMessage_FullMethodName = "/scannerbot.MessageService/SendMessage"
|
||||
)
|
||||
|
||||
// MessageServiceClient is the client API for MessageService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type MessageServiceClient interface {
|
||||
SendMessage(ctx context.Context, in *SendMessageRequest, opts ...grpc.CallOption) (*SendMessageResponse, error)
|
||||
}
|
||||
|
||||
type messageServiceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewMessageServiceClient(cc grpc.ClientConnInterface) MessageServiceClient {
|
||||
return &messageServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *messageServiceClient) SendMessage(ctx context.Context, in *SendMessageRequest, opts ...grpc.CallOption) (*SendMessageResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(SendMessageResponse)
|
||||
err := c.cc.Invoke(ctx, MessageService_SendMessage_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// MessageServiceServer is the server API for MessageService service.
|
||||
// All implementations must embed UnimplementedMessageServiceServer
|
||||
// for forward compatibility.
|
||||
type MessageServiceServer interface {
|
||||
SendMessage(context.Context, *SendMessageRequest) (*SendMessageResponse, error)
|
||||
mustEmbedUnimplementedMessageServiceServer()
|
||||
}
|
||||
|
||||
// UnimplementedMessageServiceServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedMessageServiceServer struct{}
|
||||
|
||||
func (UnimplementedMessageServiceServer) SendMessage(context.Context, *SendMessageRequest) (*SendMessageResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method SendMessage not implemented")
|
||||
}
|
||||
func (UnimplementedMessageServiceServer) mustEmbedUnimplementedMessageServiceServer() {}
|
||||
func (UnimplementedMessageServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeMessageServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to MessageServiceServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeMessageServiceServer interface {
|
||||
mustEmbedUnimplementedMessageServiceServer()
|
||||
}
|
||||
|
||||
func RegisterMessageServiceServer(s grpc.ServiceRegistrar, srv MessageServiceServer) {
|
||||
// If the following call panics, it indicates UnimplementedMessageServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&MessageService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _MessageService_SendMessage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SendMessageRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(MessageServiceServer).SendMessage(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: MessageService_SendMessage_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(MessageServiceServer).SendMessage(ctx, req.(*SendMessageRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// MessageService_ServiceDesc is the grpc.ServiceDesc for MessageService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var MessageService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "scannerbot.MessageService",
|
||||
HandlerType: (*MessageServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "SendMessage",
|
||||
Handler: _MessageService_SendMessage_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "proto/message.proto",
|
||||
}
|
||||
Reference in New Issue
Block a user