Skip to content

Commit

Permalink
extract the server-side mini-framework to its own repo
Browse files Browse the repository at this point in the history
  • Loading branch information
cztomsik committed Feb 3, 2024
1 parent 978007f commit 306cc51
Show file tree
Hide file tree
Showing 17 changed files with 132 additions and 462 deletions.
3 changes: 3 additions & 0 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ pub fn build(b: *std.Build) !void {
fn buildExe(b: *std.Build, exe: anytype) !void {
exe.addIncludePath(.{ .path = "llama.cpp" });

const tokamak = b.dependency("tokamak", .{});
exe.root_module.addImport("tokamak", tokamak.module("tokamak"));

const sqlite = b.dependency("ava-sqlite", .{ .bundle = exe.rootModuleTarget().os.tag != .macos });
exe.root_module.addImport("ava-sqlite", sqlite.module("ava-sqlite"));
if (@hasField(@TypeOf(exe.*), "sdk")) sqlite.module("ava-sqlite").addSystemIncludePath(.{ .path = b.fmt("{s}/usr/include", .{exe.sdk}) });
Expand Down
5 changes: 5 additions & 0 deletions build.zig.zon
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
.url = "https://github.com/cztomsik/ava-sqlite/archive/c48ac06.tar.gz",
.hash = "122049ceedb98046f5d3362145b5f67a93d52d33a7a5357885795971be2f5bffcd45",
},

.tokamak = .{
.url = "https://github.com/cztomsik/tokamak/archive/b173be5.tar.gz",
.hash = "1220eeb2ce65750c676b5a4a7c577c892de25c3aaa57daff777f97d1a1228625e3bb",
},
},
.paths = .{
"",
Expand Down
57 changes: 21 additions & 36 deletions src/api/chat.zig
Original file line number Diff line number Diff line change
@@ -1,83 +1,68 @@
const db = @import("../db.zig");
const server = @import("../server.zig");
const tk = @import("tokamak");

pub fn @"GET /chat"(ctx: *server.Context) !void {
pub fn @"GET /chat"(r: *tk.Responder) !void {
var stmt = try db.query(
\\SELECT id, name,
\\(SELECT content FROM ChatMessage WHERE chat_id = Chat.id ORDER BY id DESC LIMIT 1) as last_message
\\FROM Chat ORDER BY id DESC
, .{});
defer stmt.deinit();

return ctx.sendJson(stmt.iterator(struct { id: u32, name: []const u8, last_message: ?[]const u8 }));
return r.sendJson(stmt.iterator(struct { id: u32, name: []const u8, last_message: ?[]const u8 }));
}

pub fn @"POST /chat"(ctx: *server.Context) !void {
const data = try ctx.readJson(struct {
name: []const u8,
prompt: ?[]const u8,
});

pub fn @"POST /chat"(r: *tk.Responder, data: db.Chat) !void {
var stmt = try db.query("INSERT INTO Chat (name, prompt) VALUES (?, ?) RETURNING *", .{ data.name, data.prompt });
defer stmt.deinit();

try ctx.sendJson(try stmt.read(db.Chat));
try r.sendJson(try stmt.read(db.Chat));
}

pub fn @"GET /chat/:id"(ctx: *server.Context, id: u32) !void {
pub fn @"GET /chat/:id"(r: *tk.Responder, id: u32) !void {
var stmt = try db.query("SELECT * FROM Chat WHERE id = ?", .{id});
defer stmt.deinit();

return ctx.sendJson(try stmt.read(db.Chat));
return r.sendJson(try stmt.read(db.Chat));
}

pub fn @"PUT /chat/:id"(ctx: *server.Context, id: u32, data: db.Chat) !void {
pub fn @"PUT /chat/:id"(r: *tk.Responder, id: u32, data: db.Chat) !void {
try db.exec("UPDATE Chat SET name = ?, prompt = ? WHERE id = ?", .{ data.name, data.prompt, id });
return ctx.noContent();
return r.noContent();
}

pub fn @"GET /chat/:id/messages"(ctx: *server.Context, id: u32) !void {
pub fn @"GET /chat/:id/messages"(r: *tk.Responder, id: u32) !void {
var stmt = try db.query("SELECT * FROM ChatMessage WHERE chat_id = ? ORDER BY id", .{id});
defer stmt.deinit();

return ctx.sendJson(stmt.iterator(db.ChatMessage));
return r.sendJson(stmt.iterator(db.ChatMessage));
}

pub fn @"POST /chat/:id/messages"(ctx: *server.Context, id: u32) !void {
const data = try ctx.readJson(struct {
role: []const u8,
content: []const u8,
});

pub fn @"POST /chat/:id/messages"(r: *tk.Responder, id: u32, data: db.ChatMessage) !void {
var stmt = try db.query("INSERT INTO ChatMessage (chat_id, role, content) VALUES (?, ?, ?) RETURNING *", .{ id, data.role, data.content });
defer stmt.deinit();

try ctx.sendJson(try stmt.read(db.ChatMessage));
try r.sendJson(try stmt.read(db.ChatMessage));
}

pub fn @"GET /chat/:id/messages/:message_id"(ctx: *server.Context, id: u32, message_id: u32) !void {
pub fn @"GET /chat/:id/messages/:message_id"(r: *tk.Responder, id: u32, message_id: u32) !void {
var stmt = try db.query("SELECT * FROM ChatMessage WHERE id = ? AND chat_id = ?", .{ message_id, id });
defer stmt.deinit();

return ctx.sendJson(try stmt.read(db.ChatMessage));
return r.sendJson(try stmt.read(db.ChatMessage));
}

pub fn @"PUT /chat/:id/messages/:message_id"(ctx: *server.Context, id: u32, message_id: u32) !void {
const data = try ctx.readJson(struct {
role: []const u8,
content: []const u8,
});

pub fn @"PUT /chat/:id/messages/:message_id"(r: *tk.Responder, id: u32, message_id: u32, data: db.ChatMessage) !void {
try db.exec("UPDATE ChatMessage SET role = ?, content = ? WHERE id = ? AND chat_id = ?", .{ data.role, data.content, message_id, id });
return ctx.noContent();
return r.noContent();
}

pub fn @"DELETE /chat/:id/messages/:message_id"(ctx: *server.Context, id: u32, message_id: u32) !void {
pub fn @"DELETE /chat/:id/messages/:message_id"(r: *tk.Responder, id: u32, message_id: u32) !void {
try db.exec("DELETE FROM ChatMessage WHERE id = ? AND chat_id = ?", .{ message_id, id });
return ctx.noContent();
return r.noContent();
}

pub fn @"DELETE /chat/:id"(ctx: *server.Context, id: u32) !void {
pub fn @"DELETE /chat/:id"(r: *tk.Responder, id: u32) !void {
try db.exec("DELETE FROM Chat WHERE id = ?", .{id});
return ctx.noContent();
return r.noContent();
}
31 changes: 15 additions & 16 deletions src/api/download.zig
Original file line number Diff line number Diff line change
@@ -1,41 +1,40 @@
const builtin = @import("builtin");
const std = @import("std");
const server = @import("../server.zig");
const tk = @import("tokamak");
const util = @import("../util.zig");

pub fn @"POST /download"(ctx: *server.Context) !void {
const url = try ctx.readJson([]const u8);
pub fn @"POST /download"(allocator: std.mem.Allocator, sreq: *tk.Request, r: *tk.Responder, params: struct { url: []const u8 }) !void {
inline for (.{ "Content-Type", "Content-Length", "Host", "Referer", "Origin" }) |h| {
_ = ctx.res.request.headers.delete(h);
_ = sreq.headers.delete(h);
}

var client: std.http.Client = .{ .allocator = ctx.arena };
var client: std.http.Client = .{ .allocator = allocator };
defer client.deinit();

if (builtin.target.os.tag == .windows) {
try client.ca_bundle.rescan(ctx.arena);
try client.ca_bundle.rescan(allocator);
const start = client.ca_bundle.bytes.items.len;
try client.ca_bundle.bytes.appendSlice(ctx.arena, @embedFile("../windows/amazon1.cer"));
try client.ca_bundle.parseCert(ctx.arena, @intCast(start), std.time.timestamp());
try client.ca_bundle.bytes.appendSlice(allocator, @embedFile("../windows/amazon1.cer"));
try client.ca_bundle.parseCert(allocator, @intCast(start), std.time.timestamp());
}

var req = try client.open(.GET, try std.Uri.parse(url), ctx.res.request.headers, .{});
var req = try client.open(.GET, try std.Uri.parse(params.url), sreq.headers, .{});
defer req.deinit();

try req.send(.{});
try req.wait();

if (req.response.status != .ok) {
return ctx.sendJson(.{ .@"error" = try std.fmt.allocPrint(ctx.arena, "Invalid status code: `{d}`", .{req.response.status}) });
return r.sendJson(.{ .@"error" = try std.fmt.allocPrint(allocator, "Invalid status code: `{d}`", .{req.response.status}) });
}

const content_type = req.response.headers.getFirstValue("Content-Type") orelse "";
if (!std.mem.eql(u8, content_type, "binary/octet-stream")) {
return ctx.sendJson(.{ .@"error" = try std.fmt.allocPrint(ctx.arena, "Invalid content type: `{s}`", .{content_type}) });
return r.sendJson(.{ .@"error" = try std.fmt.allocPrint(allocator, "Invalid content type: `{s}`", .{content_type}) });
}

const path = try util.getWritableHomePath(ctx.arena, &.{ "models", std.fs.path.basename(url) });
const tmp_path = try std.fmt.allocPrint(ctx.arena, "{s}.part", .{path});
const path = try util.getWritableHomePath(allocator, &.{ "models", std.fs.path.basename(params.url) });
const tmp_path = try std.fmt.allocPrint(allocator, "{s}.part", .{path});
var file = try std.fs.createFileAbsolute(tmp_path, .{});
defer file.close();
errdefer std.fs.deleteFileAbsolute(tmp_path) catch {};
Expand All @@ -51,9 +50,9 @@ pub fn @"POST /download"(ctx: *server.Context) !void {
if (n < buf.len) break;

progress += n;
try ctx.sendJson(.{ .progress = progress });
} else |_| return ctx.sendJson(.{ .@"error" = "Failed to download the model" });
try r.sendJson(.{ .progress = progress });
} else |_| return r.sendJson(.{ .@"error" = "Failed to download the model" });

try std.fs.renameAbsolute(tmp_path, path);
try ctx.sendJson(.{ .path = path });
try r.sendJson(.{ .path = path });
}
16 changes: 7 additions & 9 deletions src/api/find-models.zig
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
const std = @import("std");
const server = @import("../server.zig");
const tk = @import("tokamak");

pub fn @"POST /find-models"(ctx: *server.Context) !void {
var models_found = std.ArrayList(struct { path: []const u8, size: ?u64 }).init(ctx.arena);
pub fn @"POST /find-models"(allocator: std.mem.Allocator, r: *tk.Responder, params: struct { path: []const u8 }) !void {
var models_found = std.ArrayList(struct { path: []const u8, size: ?u64 }).init(allocator);

const path = try ctx.readJson([]const u8);

var dir = try std.fs.openDirAbsolute(path, .{ .iterate = true });
var dir = try std.fs.openDirAbsolute(params.path, .{ .iterate = true });
defer dir.close();

var walker = try dir.walk(ctx.arena);
var walker = try dir.walk(allocator);
defer walker.deinit();

while (try walker.next()) |entry| switch (entry.kind) {
Expand All @@ -18,13 +16,13 @@ pub fn @"POST /find-models"(ctx: *server.Context) !void {
defer file.close();

try models_found.append(.{
.path = try std.fs.path.join(ctx.arena, &.{ path, entry.path }),
.path = try std.fs.path.join(allocator, &.{ params.path, entry.path }),
.size = (try file.stat()).size,
});
},
.directory => _ = if (walker.stack.items.len > 3) walker.stack.pop(),
else => {},
};

return ctx.sendJson(models_found.items);
return r.sendJson(models_found.items);
}
16 changes: 8 additions & 8 deletions src/api/generate.zig
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const std = @import("std");
const server = @import("../server.zig");
const tk = @import("tokamak");
const db = @import("../db.zig");
const llama = @import("../llama.zig");

Expand All @@ -11,27 +11,27 @@ const GenerateParams = struct {
sampling: llama.SamplingParams = .{},
};

pub fn @"POST /generate"(ctx: *server.Context, params: GenerateParams) !void {
try ctx.sendJson(.{ .status = "Waiting for the model..." });
const model_path = try db.getString(ctx.arena, "SELECT path FROM Model WHERE id = ?", .{params.model_id});
pub fn @"POST /generate"(allocator: std.mem.Allocator, r: *tk.Responder, params: GenerateParams) !void {
try r.sendJson(.{ .status = "Waiting for the model..." });
const model_path = try db.getString(allocator, "SELECT path FROM Model WHERE id = ?", .{params.model_id});
var cx = try llama.Pool.get(model_path, 60_000);
defer llama.Pool.release(cx);

try ctx.sendJson(.{ .status = "Reading the history..." });
try r.sendJson(.{ .status = "Reading the history..." });
try cx.prepare(params.prompt, &params.sampling);

while (cx.n_past < cx.tokens.items.len) {
try ctx.sendJson(.{ .status = try std.fmt.allocPrint(ctx.arena, "Reading the history... ({}/{})", .{ cx.n_past, cx.tokens.items.len }) });
try r.sendJson(.{ .status = try std.fmt.allocPrint(allocator, "Reading the history... ({}/{})", .{ cx.n_past, cx.tokens.items.len }) });
_ = try cx.evalOnce();
}

// TODO: send enums/unions
try ctx.sendJson(.{ .status = "" });
try r.sendJson(.{ .status = "" });

var tokens: u32 = 0;

while (try cx.generate(&params.sampling)) |content| {
try ctx.sendJson(.{
try r.sendJson(.{
.content = if (tokens == 0 and params.trim_first) std.mem.trimLeft(u8, content, " \t\n\r") else content,
});

Expand Down
6 changes: 2 additions & 4 deletions src/api/log.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
const server = @import("../server.zig");
const tk = @import("tokamak");
const util = @import("../util.zig");

pub fn @"GET /log"(ctx: *server.Context) !void {
try ctx.sendChunk(try util.Logger.dump(ctx.arena));
}
pub const @"GET /log" = util.Logger.dump;
34 changes: 14 additions & 20 deletions src/api/models.zig
Original file line number Diff line number Diff line change
@@ -1,49 +1,43 @@
const std = @import("std");
const db = @import("../db.zig");
const server = @import("../server.zig");
const tk = @import("tokamak");
const util = @import("../util.zig");

pub fn @"GET /models"(ctx: *server.Context) !void {
pub fn @"GET /models"(allocator: std.mem.Allocator, r: *tk.Responder) !void {
var stmt = try db.query("SELECT * FROM Model ORDER BY id", .{});
defer stmt.deinit();

var rows = std.ArrayList(struct { id: u32, name: []const u8, path: []const u8, imported: bool, size: ?u64 }).init(ctx.arena);
var rows = std.ArrayList(struct { id: u32, name: []const u8, path: []const u8, imported: bool, size: ?u64 }).init(allocator);
var it = stmt.iterator(db.Model);
while (try it.next()) |m| {
try rows.append(.{
.id = m.id,
.name = try ctx.arena.dupe(u8, m.name),
.path = try ctx.arena.dupe(u8, m.path),
.id = m.id.?,
.name = try allocator.dupe(u8, m.name),
.path = try allocator.dupe(u8, m.path),
.imported = m.imported,
.size = util.getFileSize(m.path) catch null,
});
}

return ctx.sendJson(rows.items);
return r.sendJson(rows.items);
}

pub fn @"POST /models"(ctx: *server.Context) !void {
const data = try ctx.readJson(struct {
name: []const u8,
path: []const u8,
imported: bool = false,
});

pub fn @"POST /models"(r: *tk.Responder, data: db.Model) !void {
var stmt = try db.query("INSERT INTO Model (name, path, imported) VALUES (?, ?, ?) RETURNING *", .{ data.name, data.path, data.imported });
defer stmt.deinit();

try ctx.sendJson(try stmt.read(db.Model));
try r.sendJson(try stmt.read(db.Model));
}

pub fn @"PUT /models/:id"(ctx: *server.Context, id: u32, data: db.Model) !void {
pub fn @"PUT /models/:id"(r: *tk.Responder, id: u32, data: db.Model) !void {
try db.exec("UPDATE Model SET name = ?, path = ? WHERE id = ?", .{ data.name, data.path, id });
return ctx.noContent();
return r.noContent();
}

pub fn @"DELETE /models/:id"(ctx: *server.Context, id: []const u8) !void {
const path = try db.getString(ctx.arena, "SELECT path FROM Model WHERE id = ?", .{id});
pub fn @"DELETE /models/:id"(allocator: std.mem.Allocator, r: *tk.Responder, id: []const u8) !void {
const path = try db.getString(allocator, "SELECT path FROM Model WHERE id = ?", .{id});
const imported = try db.get(bool, "SELECT imported FROM Model WHERE id = ?", .{id});
try db.exec("DELETE FROM Model WHERE id = ?", .{id});
if (!imported) std.fs.deleteFileAbsolute(path) catch {};
return ctx.noContent();
return r.noContent();
}
19 changes: 7 additions & 12 deletions src/api/prompts.zig
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
const db = @import("../db.zig");
const server = @import("../server.zig");
const tk = @import("tokamak");

pub fn @"GET /prompts"(ctx: *server.Context) !void {
pub fn @"GET /prompts"(r: *tk.Responder) !void {
var stmt = try db.query("SELECT * FROM Prompt ORDER BY id", .{});
defer stmt.deinit();

return ctx.sendJson(stmt.iterator(db.Prompt));
return r.sendJson(stmt.iterator(db.Prompt));
}

pub fn @"POST /prompts"(ctx: *server.Context) !void {
const data = try ctx.readJson(struct {
name: []const u8,
prompt: []const u8,
});

pub fn @"POST /prompts"(r: *tk.Responder, data: db.Prompt) !void {
var stmt = try db.query("INSERT INTO Prompt (name, prompt) VALUES (?, ?) RETURNING *", .{ data.name, data.prompt });
defer stmt.deinit();

try ctx.sendJson(try stmt.read(db.Prompt));
try r.sendJson(try stmt.read(db.Prompt));
}

pub fn @"DELETE /prompts/:id"(ctx: *server.Context, id: u32) !void {
pub fn @"DELETE /prompts/:id"(r: *tk.Responder, id: u32) !void {
try db.exec("DELETE FROM Prompt WHERE id = ?", .{id});
return ctx.noContent();
return r.noContent();
}
Loading

0 comments on commit 306cc51

Please sign in to comment.