From 2085d14fdebac762f0e3b16e6ca06963eea99d35 Mon Sep 17 00:00:00 2001 From: Kamil Tomsik Date: Sat, 20 Apr 2024 09:21:43 +0200 Subject: [PATCH] WIP: server-side templating, quick-tools api --- src/api/quick-tools.zig | 26 ++++ src/template.zig | 266 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 src/template.zig diff --git a/src/api/quick-tools.zig b/src/api/quick-tools.zig index 65b3558..1c2c2fb 100644 --- a/src/api/quick-tools.zig +++ b/src/api/quick-tools.zig @@ -1,6 +1,17 @@ const std = @import("std"); +const tk = @import("tokamak"); const fr = @import("fridge"); const schema = @import("../schema.zig"); +const llama = @import("../llama.zig"); +const template = @import("../template.zig"); +const generate = @import("generate.zig"); + +const GenerateParams = struct { + model: []const u8, + data: std.json.Value, + max_tokens: u32 = 2048, + sampling: llama.SamplingParams = .{}, +}; pub fn @"GET /quick-tools"(db: *fr.Session) ![]const schema.QuickTool { return db.findAll(fr.query(schema.QuickTool).orderBy(.id, .asc)); @@ -20,6 +31,21 @@ pub fn @"PUT /quick-tools/:id"(db: *fr.Session, id: u32, data: schema.QuickTool) return try db.find(schema.QuickTool, id) orelse error.NotFound; } +pub fn @"POST /quick-tools/:id/generate"(ctx: *tk.Context, db: *fr.Session, id: u32, params: GenerateParams) !void { + const tool = try db.find(schema.QuickTool, id) orelse return error.NotFound; + const tpl = try template.Template.parse(ctx.allocator, tool.prompt); + const prompt = try tpl.renderAlloc(ctx.allocator, params.data); + + std.log.debug("Generating with prompt: {s}", .{prompt}); + + return ctx.injector.call(generate.@"POST /generate", .{.{ + .model = params.model, + .prompt = prompt, + .max_tokens = params.max_tokens, + .sampling = params.sampling, + }}); +} + pub fn @"DELETE /quick-tools/:id"(db: *fr.Session, id: u32) !void { try db.delete(schema.QuickTool, id); } diff --git a/src/template.zig b/src/template.zig new file mode 100644 index 0000000..f8ff2bf --- /dev/null +++ b/src/template.zig @@ -0,0 +1,266 @@ +// TODO: consider switching to jinja/swig/xxx syntax because we could use this also for chat templates + +const std = @import("std"); + +pub const Template = struct { + tokens: std.ArrayList(Token), + + pub fn parse(allocator: std.mem.Allocator, input: []const u8) !Template { + var tokens = std.ArrayList(Token).init(allocator); + errdefer tokens.deinit(); + + var stack = std.ArrayList(struct { usize, []const u8 }).init(allocator); + defer stack.deinit(); + + var tokenizer = Tokenizer{ .input = input }; + + while (tokenizer.next()) |tok| { + switch (tok) { + .section_open => |section| try stack.append(.{ tokens.items.len, section.name }), + + .section_close => |name| { + const index, const open = stack.popOrNull() orelse return error.MissingSectionOpen; + + if (!std.mem.eql(u8, open, name)) { + return error.MismatchedSection; + } + + tokens.items[index].section_open.skip = tokens.items.len - index; + }, + + else => {}, + } + + try tokens.append(tok); + } + + return .{ + .tokens = tokens, + }; + } + + pub fn deinit(self: *Template) void { + self.tokens.deinit(); + } + + // TODO: accept anytype, but expect json.Value union(enum) shape + pub fn render(self: *const Template, data: std.json.Value, writer: anytype) !void { + try renderPart(self.tokens.items, data, writer); + } + + pub fn renderAlloc(self: *const Template, allocator: std.mem.Allocator, data: std.json.Value) ![]const u8 { + var buf = std.ArrayList(u8).init(allocator); + errdefer buf.deinit(); + + try self.render(data, buf.writer()); + + return buf.toOwnedSlice(); + } + + fn renderPart(tokens: []const Token, data: std.json.Value, writer: anytype) !void { + var i: usize = 0; + + while (i < tokens.len) : (i += 1) { + switch (tokens[i]) { + .text => |text| try writer.writeAll(text), + + .variable => |name| { + const val = if (name.len == 1 and name[0] == '.') data else (data.object.get(name) orelse continue); + + switch (val) { + .null => {}, + .bool => |v| try writer.writeAll(if (v) "true" else "false"), + .string, .number_string => |v| try writer.writeAll(v), + inline else => |v| try writer.print("{}", .{v}), + } + }, + + .section_open => |sec| { + defer i += sec.skip; + const next = tokens[i + 1 ..]; + const val = data.object.get(sec.name) orelse .null; + + const truthy = switch (val) { + .null => false, + .bool => |v| v, + .integer => |v| v != 0, + .float => |v| v != 0.0, + .string, .number_string => |v| v.len > 0, + .array => |arr| arr.items.len > 0, + .object => |_| true, + }; + + if (sec.inverted == truthy) { + continue; + } + + if (sec.inverted) { + try renderPart(next, data, writer); + continue; + } + + switch (val) { + .object => try renderPart(next, val, writer), + .array => |arr| for (arr.items) |it| try renderPart(next, it, writer), + else => try renderPart(next, data, writer), + } + }, + + .section_close => |_| return, + } + } + } +}; + +fn expectRender(tpl: Template, data: anytype, expected: []const u8) !void { + const json = try std.json.stringifyAlloc(std.testing.allocator, data, .{}); + defer std.testing.allocator.free(json); + + const parsed = try std.json.parseFromSlice(std.json.Value, std.testing.allocator, json, .{}); + defer parsed.deinit(); + + const res = try tpl.renderAlloc(std.testing.allocator, parsed.value); + defer std.testing.allocator.free(res); + + try std.testing.expectEqualStrings(expected, res); +} + +test "Template" { + var tpl = try Template.parse(std.testing.allocator, "Hello {{#name}}{{name}}{{/name}}{{^name}}World{{/name}}"); + defer tpl.deinit(); + + try expectRender(tpl, .{ .name = "Alice" }, "Hello Alice"); + try expectRender(tpl, .{ .name = null }, "Hello World"); + try expectRender(tpl, .{ .name = "" }, "Hello World"); + try expectRender(tpl, .{ .name = [_]u32{} }, "Hello World"); + try expectRender(tpl, .{ .name = struct {}{} }, "Hello "); + + var tpl2 = try Template.parse(std.testing.allocator, "{{#names}}- {{.}}\n{{/names}}{{^names}}No names{{/names}}"); + defer tpl2.deinit(); + + try expectRender(tpl2, .{ .names = .{} }, "No names"); + try expectRender(tpl2, .{ .names = .{ "first", "second" } }, "- first\n- second\n"); +} + +/// A token in a mustache template +const Token = union(enum) { + text: []const u8, + variable: []const u8, + section_open: struct { + name: []const u8, + inverted: bool = false, + skip: usize = 0, + }, + section_close: []const u8, +}; + +const Tag = std.meta.Tag(Token); + +const Tokenizer = struct { + input: []const u8, + pos: usize = 0, + + fn next(self: *Tokenizer) ?Token { + const start = self.pos; + + while (self.pos < self.input.len) : (self.pos += 1) { + if (self.startsWith("{{")) { + if (self.pos > start) { + return .{ .text = self.input[start..self.pos] }; + } else { + self.pos += 2; + } + + if (self.consumeIdent()) |name| { + if (self.consumeSeq("}}")) { + return .{ .variable = name }; + } + } + + if (self.consumeSeq("#")) { + if (self.consumeIdent()) |name| { + if (self.consumeSeq("}}")) { + return .{ .section_open = .{ .name = name } }; + } + } + } + + if (self.consumeSeq("^")) { + if (self.consumeIdent()) |name| { + if (self.consumeSeq("}}")) { + return .{ .section_open = .{ .name = name, .inverted = true } }; + } + } + } + + if (self.consumeSeq("/")) { + if (self.consumeIdent()) |name| { + if (self.consumeSeq("}}")) { + return .{ .section_close = name }; + } + } + } + + return .{ .text = self.input[start..self.pos] }; + } + } + + if (start < self.pos) { + return .{ .text = self.input[start..self.pos] }; + } + + return null; + } + + fn startsWith(self: *Tokenizer, seq: []const u8) bool { + return std.mem.startsWith(u8, self.input[self.pos..], seq); + } + + fn consumeSeq(self: *Tokenizer, seq: []const u8) bool { + if (self.startsWith(seq)) { + self.pos += seq.len; + return true; + } + + return false; + } + + fn consumeIdent(self: *Tokenizer) ?[]const u8 { + const start = self.pos; + + while (self.pos < self.input.len) { + switch (self.input[self.pos]) { + 'a'...'z', 'A'...'Z', '0'...'9', '_', '.' => self.pos += 1, + else => break, + } + } + + return if (self.pos > start) self.input[start..self.pos] else null; + } +}; + +fn expectTokens(tpl: []const u8, tokens: []const Tag) !void { + var tokenizer = Tokenizer{ .input = tpl }; + + for (tokens) |tag| { + const tok: Tag = tokenizer.next() orelse return error.Eof; + errdefer std.debug.print("rest: {s}\n", .{tokenizer.input[tokenizer.pos..]}); + + try std.testing.expectEqual(tag, tok); + } + + try std.testing.expectEqual(tpl.len, tokenizer.pos); +} + +test "Tokenizer" { + try expectTokens("", &.{}); + + try expectTokens("Hello", &.{.text}); + try expectTokens("{{name}}", &.{.variable}); + try expectTokens("{{#name}}", &.{.section_open}); + try expectTokens("{{^name}}", &.{.section_open}); + try expectTokens("{{/name}}", &.{.section_close}); + + try expectTokens("Hello {{name}}", &.{ .text, .variable }); + try expectTokens("{{#name}}{{/name}}", &.{ .section_open, .section_close }); +}