Skip to content

Commit

Permalink
allow decoding tf strings
Browse files Browse the repository at this point in the history
  • Loading branch information
CaptainDario committed Feb 11, 2024
1 parent c3a2f05 commit 06d2fdc
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions lib/src/util/byte_conversion_utils.dart
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

import 'dart:convert';
import 'dart:ffi';
import 'dart:typed_data';

import 'package:tflite_flutter/tflite_flutter.dart';
Expand Down Expand Up @@ -165,6 +167,37 @@ class ByteConversionUtils {
);
}

/// Decodes a TensorFlow string to a List<String>
static List<String> decodeTFStrings(Uint8List bytes){
/// The decoded string
List<String> decodedStrings = [];
/// get the first 32bit int representing num of strings
int numStrings = ByteData.view(bytes.sublist(0,sizeOf<Int32>()).buffer).getInt32(0, Endian.little);

/// parse subsequent string position and sizes
for(int s = 0; s < numStrings; s++){

// get current str index
int startIdx = ByteData.view(
bytes.sublist(
(1+s)*sizeOf<Int32>(),
(2+s)*sizeOf<Int32>()
)
.buffer).getInt32(0, Endian.little);
// get next str index, or in last case the ending byte position
int endIdx = ByteData.view(
bytes.sublist(
(2+s)*sizeOf<Int32>(),
(3+s)*sizeOf<Int32>()
)
.buffer).getInt32(0, Endian.little);

decodedStrings.add(utf8.decode(bytes.sublist(startIdx,endIdx)));
}

return decodedStrings;
}

static Object convertBytesToObject(
Uint8List bytes, TensorType tensorType, List<int> shape) {
// stores flattened data
Expand Down Expand Up @@ -209,6 +242,9 @@ class ByteConversionUtils {
list.add(ByteData.view(bytes.buffer).getInt64(i));
}
return list.reshape<int>(shape);
} else if (tensorType.value == TfLiteType.kTfLiteString) {
list.add(decodeTFStrings(bytes));
return list;
}
throw UnsupportedError("$tensorType is not Supported.");
}
Expand Down

0 comments on commit 06d2fdc

Please sign in to comment.