Skip to content

Commit

Permalink
Merge pull request #194 from CaptainDario/feature-TF_String_support
Browse files Browse the repository at this point in the history
TensorFlow String support for model output
  • Loading branch information
PaulTR authored Feb 14, 2024
2 parents a91be05 + 7c29260 commit 1472c73
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
4 changes: 2 additions & 2 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ android {
}

defaultConfig {
minSdkVersion 26
minSdkVersion 19
}
}


dependencies {
def tflite_version = "2.12.0"
def tflite_version = "2.11.0"

implementation("org.tensorflow:tensorflow-lite:${tflite_version}")
implementation("org.tensorflow:tensorflow-lite-gpu:${tflite_version}")
Expand Down
33 changes: 33 additions & 0 deletions lib/src/util/byte_conversion_utils.dart
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

import 'dart:convert';
import 'dart:ffi';
import 'dart:typed_data';

import 'package:tflite_flutter/tflite_flutter.dart';
Expand Down Expand Up @@ -150,6 +152,34 @@ class ByteConversionUtils {
);
}

/// Decodes a TensorFlow string to a List<String>
static List<String> decodeTFStrings(Uint8List bytes) {
/// The decoded string
List<String> decodedStrings = [];

/// get the first 32bit int representing num of strings
int numStrings = ByteData.view(bytes.sublist(0, sizeOf<Int32>()).buffer)
.getInt32(0, Endian.little);

/// parse subsequent string position and sizes
for (int s = 0; s < numStrings; s++) {
// get current str index
int startIdx = ByteData.view(bytes
.sublist((1 + s) * sizeOf<Int32>(), (2 + s) * sizeOf<Int32>())
.buffer)
.getInt32(0, Endian.little);
// get next str index, or in last case the ending byte position
int endIdx = ByteData.view(bytes
.sublist((2 + s) * sizeOf<Int32>(), (3 + s) * sizeOf<Int32>())
.buffer)
.getInt32(0, Endian.little);

decodedStrings.add(utf8.decode(bytes.sublist(startIdx, endIdx)));
}

return decodedStrings;
}

static Object convertBytesToObject(
Uint8List bytes, TensorType tensorType, List<int> shape) {
// stores flattened data
Expand Down Expand Up @@ -191,6 +221,9 @@ class ByteConversionUtils {
list.add(ByteData.view(bytes.buffer).getInt64(i));
}
return list.reshape<int>(shape);
} else if (tensorType.value == TfLiteType.kTfLiteString) {
list.add(decodeTFStrings(bytes));
return list;
}
throw UnsupportedError("$tensorType is not Supported.");
}
Expand Down

0 comments on commit 1472c73

Please sign in to comment.