diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index 6ac5b56f3752..0a771404694d 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -45,6 +45,7 @@ enum ConstantType : int { kShapeTuple = 2, kString = 3, kInt = 4, + kFloat = 5, }; #define STREAM_CHECK(val, section) \ @@ -312,6 +313,9 @@ void Executable::SaveConstantSection(dmlc::Stream* strm) { } else if (it.type_code() == kDLInt) { strm->Write(ConstantType::kInt); strm->Write(it.value()); + } else if (it.type_code() == kDLFloat) { + strm->Write(ConstantType::kFloat); + strm->Write(it.value()); } else { try { strm->Write(ConstantType::kDLDataType); @@ -385,6 +389,12 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) { TVMRetValue cell; cell = value; this->constants.push_back(cell); + } else if (constant_type == ConstantType::kFloat) { + double value; + strm->Read(&value); + TVMRetValue cell; + cell = value; + this->constants.push_back(cell); } else { LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " << ArgTypeCode2Str(constant_type) << " when loading the VM constant pool.";