From 98f942952931f9d8d10619654a3d843ac83e1000 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 8 Apr 2024 14:24:44 +0800 Subject: [PATCH] Support --context-size=1 (#565) When --context-size is 1, there is no conv module at all. --- sherpa/csrc/online-lstm-transducer-model.cc | 4 +++- sherpa/csrc/online-zipformer-transducer-model.cc | 4 +++- sherpa/csrc/online-zipformer2-transducer-model.cc | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sherpa/csrc/online-lstm-transducer-model.cc b/sherpa/csrc/online-lstm-transducer-model.cc index 96fd0799f..b1ef4df8a 100644 --- a/sherpa/csrc/online-lstm-transducer-model.cc +++ b/sherpa/csrc/online-lstm-transducer-model.cc @@ -27,7 +27,9 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( joiner_.eval(); context_size_ = - decoder_.attr("conv").toModule().attr("weight").toTensor().size(2); + decoder_.hasattr("conv") + ? decoder_.attr("conv").toModule().attr("weight").toTensor().size(2) + : 1; // Use 5 here since the subsampling is ((len - 3) // 2 - 1) // 2. int32_t pad_length = 5; diff --git a/sherpa/csrc/online-zipformer-transducer-model.cc b/sherpa/csrc/online-zipformer-transducer-model.cc index 78f92a743..0eb7b2d22 100644 --- a/sherpa/csrc/online-zipformer-transducer-model.cc +++ b/sherpa/csrc/online-zipformer-transducer-model.cc @@ -27,7 +27,9 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( joiner_.eval(); context_size_ = - decoder_.attr("conv").toModule().attr("weight").toTensor().size(2); + decoder_.hasattr("conv") + ? decoder_.attr("conv").toModule().attr("weight").toTensor().size(2) + : 1; // Use 7 here since the subsampling is ((len - 7) // 2 + 1) // 2. int32_t pad_length = 7; diff --git a/sherpa/csrc/online-zipformer2-transducer-model.cc b/sherpa/csrc/online-zipformer2-transducer-model.cc index ed5331ead..43937b1a2 100644 --- a/sherpa/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa/csrc/online-zipformer2-transducer-model.cc @@ -24,7 +24,9 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( joiner_ = model_.attr("joiner").toModule(); context_size_ = - decoder_.attr("conv").toModule().attr("weight").toTensor().size(2); + decoder_.hasattr("conv") + ? decoder_.attr("conv").toModule().attr("weight").toTensor().size(2) + : 1; int32_t pad_length = encoder_.attr("pad_length").toInt();