diff --git a/clap.zig b/clap.zig index 6054b24..8040852 100644 --- a/clap.zig +++ b/clap.zig @@ -13,6 +13,7 @@ const testing = std.testing; pub const args = @import("clap/args.zig"); pub const parsers = @import("clap/parsers.zig"); pub const streaming = @import("clap/streaming.zig"); +pub const ccw = @import("clap/codepoint_counting_writer.zig"); test "clap" { testing.refAllDecls(@This()); @@ -766,7 +767,7 @@ pub fn parseEx( // fields to slices and return that. var result_args = Arguments(Id, params, value_parsers, .slice){}; inline for (meta.fields(@TypeOf(arguments))) |field| { - if (@typeInfo(field.type) == .Struct and + if (@typeInfo(field.type) == .@"struct" and @hasDecl(field.type, "toOwnedSlice")) { const slice = try @field(arguments, field.name).toOwnedSlice(allocator); @@ -883,7 +884,7 @@ fn deinitArgs( // If the multi value field is a struct, we know it is a list and should be deinited. // Otherwise, it is a slice that should be freed. switch (@typeInfo(@TypeOf(field))) { - .Struct => @field(arguments, longest.name).deinit(allocator), + .@"struct" => @field(arguments, longest.name).deinit(allocator), else => allocator.free(@field(arguments, longest.name)), } } @@ -936,7 +937,7 @@ fn Arguments( i += 1; } - return @Type(.{ .Struct = .{ + return @Type(.{ .@"struct" = .{ .layout = .auto, .fields = &fields, .decls = &.{}, @@ -1153,10 +1154,10 @@ pub fn help( const max_spacing = blk: { var res: usize = 0; for (params) |param| { - var cs = io.countingWriter(io.null_writer); + var cs = ccw.codepointCountingWriter(io.null_writer); try printParam(cs.writer(), Id, param); - if (res < cs.bytes_written) - res = @intCast(cs.bytes_written); + if (res < cs.codepoints_written) + res = @intCast(cs.codepoints_written); } break :blk res; @@ -1166,22 +1167,22 @@ pub fn help( opt.description_indent + max_spacing * @intFromBool(!opt.description_on_new_line); - var first_paramter: bool = true; + var first_parameter: bool = true; for (params) |param| { - if (!first_paramter) + if (!first_parameter) try writer.writeByteNTimes('\n', opt.spacing_between_parameters); - first_paramter = false; + first_parameter = false; try writer.writeByteNTimes(' ', opt.indent); - var cw = io.countingWriter(writer); + var cw = ccw.codepointCountingWriter(writer); try printParam(cw.writer(), Id, param); const Writer = DescriptionWriter(@TypeOf(writer)); var description_writer = Writer{ .underlying_writer = writer, .indentation = description_indentation, - .printed_chars = @intCast(cw.bytes_written), + .printed_chars = @intCast(cw.codepoints_written), .max_width = opt.max_width, }; @@ -1260,8 +1261,7 @@ pub fn help( } else { // For none markdown like format, we just respect the newlines in the input // string and output them as is. - var i: usize = 0; - while (i < non_emitted_newlines) : (i += 1) + for (0..non_emitted_newlines) |_| try description_writer.newline(); } @@ -1292,7 +1292,7 @@ fn DescriptionWriter(comptime UnderlyingWriter: type) type { debug.assert(word.len != 0); var first_word = writer.printed_chars <= writer.indentation; - const chars_to_write = word.len + @intFromBool(!first_word); + const chars_to_write = try std.unicode.utf8CountCodepoints(word) + @intFromBool(!first_word); if (chars_to_write + writer.printed_chars > writer.max_width) { // If the word does not fit on this line, then we insert a new line and print // it on that line. The only exception to this is if this was the first word. @@ -1744,6 +1744,50 @@ test "clap.help" { \\-d, --dd ... Both repeated option. \\ ); + + // Test with multibyte characters. + try testHelp(.{ + .indent = 0, + .max_width = 46, + .description_on_new_line = false, + .description_indent = 4, + .spacing_between_parameters = 2, + }, + \\-a Shört flåg. + \\ + \\ + \\-b Shört öptiön. + \\ + \\ + \\ --aa Löng fläg. + \\ + \\ + \\ --bb Löng öptiön. + \\ + \\ + \\-c, --cc Bóth fläg. + \\ + \\ + \\ --complicate Fläg wíth ä cömplǐcätéd + \\ änd vërý löng dèscrıptıön + \\ thät späns mültíplë + \\ lınēs. + \\ + \\ Pärägräph number 2: + \\ * Bullet pöint + \\ * Bullet pöint + \\ + \\ Exämple: + \\ sömething sömething + \\ sömething + \\ + \\ + \\-d, --dd Böth öptiön. + \\ + \\ + \\-d, --dd ... Böth repeäted öptiön. + \\ + ); } /// Will print a usage message in the following format: @@ -1752,18 +1796,18 @@ test "clap.help" { /// First all none value taking parameters, which have a short name are printed, then non /// positional parameters and finally the positional. pub fn usage(stream: anytype, comptime Id: type, params: []const Param(Id)) !void { - var cos = io.countingWriter(stream); + var cos = ccw.codepointCountingWriter(stream); const cs = cos.writer(); for (params) |param| { const name = param.names.short orelse continue; if (param.takes_value != .none) continue; - if (cos.bytes_written == 0) + if (cos.codepoints_written == 0) try stream.writeAll("[-"); try cs.writeByte(name); } - if (cos.bytes_written != 0) + if (cos.codepoints_written != 0) try cs.writeAll("]"); var has_positionals: bool = false; @@ -1782,7 +1826,7 @@ pub fn usage(stream: anytype, comptime Id: type, params: []const Param(Id)) !voi continue; }; - if (cos.bytes_written != 0) + if (cos.codepoints_written != 0) try cs.writeAll(" "); try cs.writeAll("["); @@ -1806,7 +1850,7 @@ pub fn usage(stream: anytype, comptime Id: type, params: []const Param(Id)) !voi if (param.names.short != null or param.names.long != null) continue; - if (cos.bytes_written != 0) + if (cos.codepoints_written != 0) try cs.writeAll(" "); try cs.writeAll("<"); diff --git a/clap/codepoint_counting_writer.zig b/clap/codepoint_counting_writer.zig new file mode 100644 index 0000000..e6b9d1c --- /dev/null +++ b/clap/codepoint_counting_writer.zig @@ -0,0 +1,102 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); + +/// A Writer that counts how many codepoints has been written to it. +/// Expects valid UTF-8 input, and does not validate the input. +pub fn CodepointCountingWriter(comptime WriterType: type) type { + return struct { + codepoints_written: u64, + child_stream: WriterType, + + pub const Error = WriterType.Error || error{Utf8InvalidStartByte}; + pub const Writer = std.io.Writer(*Self, Error, write); + + const Self = @This(); + + pub fn write(self: *Self, bytes: []const u8) Error!usize { + const bytes_and_codepoints = try utf8CountCodepointsAllowTruncate(bytes); + // Might not be the full input, so the leftover bytes are written on the next call. + const bytes_to_write = bytes[0..bytes_and_codepoints.bytes]; + const amt = try self.child_stream.write(bytes_to_write); + const bytes_written = bytes_to_write[0..amt]; + self.codepoints_written += (try utf8CountCodepointsAllowTruncate(bytes_written)).codepoints; + return amt; + } + + pub fn writer(self: *Self) Writer { + return .{ .context = self }; + } + }; +} + +// Like `std.unicode.utf8CountCodepoints`, but on truncated input, it returns +// the number of codepoints up to that point. +// Does not validate UTF-8 beyond checking the start byte. +fn utf8CountCodepointsAllowTruncate(s: []const u8) !struct { bytes: usize, codepoints: usize } { + var len: usize = 0; + + const N = @sizeOf(usize); + const MASK = 0x80 * (std.math.maxInt(usize) / 0xff); + + var i: usize = 0; + while (i < s.len) { + // Fast path for ASCII sequences + while (i + N <= s.len) : (i += N) { + const v = std.mem.readInt(usize, s[i..][0..N], native_endian); + if (v & MASK != 0) break; + len += N; + } + + if (i < s.len) { + const n = try std.unicode.utf8ByteSequenceLength(s[i]); + // Truncated input; return the current counts. + if (i + n > s.len) return .{ .bytes = i, .codepoints = len }; + + i += n; + len += 1; + } + } + + return .{ .bytes = i, .codepoints = len }; +} + +pub fn codepointCountingWriter(child_stream: anytype) CodepointCountingWriter(@TypeOf(child_stream)) { + return .{ .codepoints_written = 0, .child_stream = child_stream }; +} + +const testing = std.testing; + +test CodepointCountingWriter { + var counting_stream = codepointCountingWriter(std.io.null_writer); + const stream = counting_stream.writer(); + + const utf8_text = "blåhaj" ** 100; + stream.writeAll(utf8_text) catch unreachable; + const expected_count = try std.unicode.utf8CountCodepoints(utf8_text); + try testing.expectEqual(expected_count, counting_stream.codepoints_written); +} + +test "handles partial UTF-8 writes" { + var buf: [100]u8 = undefined; + var fbs = std.io.fixedBufferStream(&buf); + var counting_stream = codepointCountingWriter(fbs.writer()); + const stream = counting_stream.writer(); + + const utf8_text = "ååå"; + // `å` is represented as `\xC5\xA5`, write 1.5 `å`s. + var wc = try stream.write(utf8_text[0..3]); + // One should have been written fully. + try testing.expectEqual("å".len, wc); + try testing.expectEqual(1, counting_stream.codepoints_written); + + // Write the rest, continuing from the reported number of bytes written. + wc = try stream.write(utf8_text[wc..]); + try testing.expectEqual(4, wc); + try testing.expectEqual(3, counting_stream.codepoints_written); + + const expected_count = try std.unicode.utf8CountCodepoints(utf8_text); + try testing.expectEqual(expected_count, counting_stream.codepoints_written); + + try testing.expectEqualSlices(u8, utf8_text, fbs.getWritten()); +} diff --git a/clap/parsers.zig b/clap/parsers.zig index 874c23b..8abdf57 100644 --- a/clap/parsers.zig +++ b/clap/parsers.zig @@ -92,9 +92,9 @@ test "enumeration" { } fn ReturnType(comptime P: type) type { - return @typeInfo(P).Fn.return_type.?; + return @typeInfo(P).@"fn".return_type.?; } pub fn Result(comptime P: type) type { - return @typeInfo(ReturnType(P)).ErrorUnion.payload; + return @typeInfo(ReturnType(P)).error_union.payload; }