Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Help converting ONNX to TensorRT with graphsurgeon and HPS plugin #449

Closed
dmac opened this issue May 23, 2024 · 4 comments
Closed
Assignees
Labels
question Further information is requested

Comments

@dmac
Copy link

dmac commented May 23, 2024

Hi,

I'm attempting to convert an existing ONNX model with embeddings into a TensorRT model that uses HPS to perform the embedding lookup. I have extracted the embeddings from the ONNX model into files on disk in the HPS format, then used graphsurgeon to replace the embedding Gather node with an HPS_TRT node. Then I use trtexec to convert the new ONNX model into a TRT model and get this error:

2024/05/23 05:56:51 [stderr] [05/23/2024-05:56:50] [I] [TRT] Successfully created plugin: HPS_TRT
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.594][ERROR][RK0][main]: Expression check failed!
2024/05/23 05:56:51 [stderr] 	File: /hugectr/hps_trt/hps_plugin/hps_plugin.cpp:74
2024/05/23 05:56:51 [stderr] 	Function: getOutputDimensions
2024/05/23 05:56:51 [stderr] 	Expression: inputs[0].nbDims == 2
2024/05/23 05:56:51 [stderr] 	Hint: The dimensions of inputs[0] should be 2

I think I am probably not performing the graph surgery correctly, but am having trouble figuring out how to swap in the HPS_TRT node into the existing model.

Here is the relevant section from the original model. Not pictured is the categorical input feeding into this section. The upper Gather node selects a single categorical feature and sends it to the lower Gather node which does the embedding lookup.

After performing the graph surgery to replace the lower Gather node (the embedding lookup) with the HPS_TRT node, here is the relevant section in the new model. Here are details for the upper Gather node and the lower HPS_TRT node.

The specific changes I make using graphsurgeon are:

  1. Create a new HPS_TRT node and set its inputs and outputs to the inputs and outputs of the embedding Gather node.
  2. Change the input and output shapes to include the batch size and embedding vector size as documented here.
  3. Remove the outputs from the old embedding Gather node (so it gets removed from the graph).
Here is the code I'm using to do this.
import numpy as np
import onnx
from onnx import shape_inference
import onnx_graphsurgeon as gs
import sys

model = onnx.load_model("model.onnx")
graph = gs.import_onnx(model)

input = next(n for n in graph.nodes if n.name == "/input_layer/Gather_77")
assert len(input.outputs) == 1
assert len(input.outputs[0].outputs) == 1
old = input.outputs[0].outputs[0]
embedding = next(t for t in old.inputs if t.name.endswith("embedding.weight"))
in_var = input.outputs[0]
in_var.shape = ("batch_size", 1)
in_var.dtype = np.int32
out_var = old.outputs[0]
out_var.shape = ("batch_size", 1, 512)
out_var.dtype = np.float32
new = gs.Node(
    name=old.name,
    op="HPS_TRT",
    attrs={
        "ps_config_file": f"hps.json\0",
        "model_name": "mymodel\0",
        "table_id": 0,
        "emb_vec_size": embedding.shape[1],
    },
    inputs=[in_var],
    outputs=[out_var],
)
graph.nodes.append(new)
old.outputs.clear()

graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "model_hps.onnx")
Here is the contents of my hps.json configuration file, if it's relevant, but these are mostly initial values just to try to get something working.
{
  "supportlonglong": false,
  "models": [{
    "model": "mymodel",
    "sparse_files":["embeddings/input_layer.pretrained_embeddings.0.layers.embedding.weight.496328x512.tritonhps"],
    "deployed_device_list": [0],
    "maxnum_catfeature_query_per_table_per_sample": [1],
    "embedding_vecsize_per_table": [512],
    "default_value_for_each_table": [0]
  }]
}
Here is the full HCTR output when running trtexec.
2024/05/23 05:56:51 [stderr] =====================================================HPS Parse====================================================
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: fuse_embedding_table is not specified using default: 0
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: max_batch_size is not specified using default: 64
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: gpucache is not specified using default: 1
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: hit_rate_threshold is not specified using default: 0.9
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: gpucacheper is not specified using default: 0.2
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: dense_file is not specified using default: 
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: num_of_worker_buffer_in_pool is not specified using default: 1
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: cache_refresh_percentage_per_iteration is not specified using default: 0
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: refresh_delay is not specified using default: 0
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: refresh_interval is not specified using default: 0
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: use_static_table is not specified using default: 0
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: use_context_stream is not specified using default: 1
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: use_hctr_cache_implementation is not specified using default: 1
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: HPS plugin uses context stream for model mymodel: True
2024/05/23 05:56:51 [stderr] ====================================================HPS Create====================================================
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: Creating HashMap CPU database backend...
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][DEBUG][RK0][main]: Created blank database backend in local memory!
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: Volatile DB: initial cache rate = 1
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.074][DEBUG][RK0][main]: Created raw model loader in local memory!
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.581][INFO][RK0][main]: Table: hps_et.mymodel.sparse_embedding0; cached 496328 / 496328 embeddings in volatile database (HashMapBackend); load: 496328 / 18446744073709551615 (0.00%).
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.585][DEBUG][RK0][main]: Real-time subscribers created!
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.585][INFO][RK0][main]: Creating embedding cache in device 0.
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: Model name: mymodel
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: Max batch size: 64
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: Fuse embedding tables: False
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: Number of embedding tables: 1
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 0.200000
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: Embedding cache type: dynamic
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: Use I64 input key: False
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: Configured cache hit rate threshold: 0.900000
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: The size of thread pool: 16
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: The size of worker memory pool: 1
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: The size of refresh memory pool: 1
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.586][INFO][RK0][main]: The refresh percentage : 0.000000
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.594][INFO][RK0][main]: LookupSession i64_input_key: False
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.594][INFO][RK0][main]: Creating lookup session for mymodel on device: 0
2024/05/23 05:56:51 [stderr] [05/23/2024-05:56:50] [I] [TRT] Successfully created plugin: HPS_TRT
2024/05/23 05:56:51 [stderr] [HCTR][05:56:50.594][ERROR][RK0][main]: Expression check failed!
2024/05/23 05:56:51 [stderr] 	File: /hugectr/hps_trt/hps_plugin/hps_plugin.cpp:74
2024/05/23 05:56:51 [stderr] 	Function: getOutputDimensions
2024/05/23 05:56:51 [stderr] 	Expression: inputs[0].nbDims == 2
2024/05/23 05:56:51 [stderr] 	Hint: The dimensions of inputs[0] should be 2
2024/05/23 05:56:51 Process exited with exit status 1

What's confusing to me is that I'm getting the assertion failure inputs[0].nbDims == 2, even though the input to the HPS_TRT node is tensor: int32[batch_size, 1]. I suspect I am not performing the graph surgery correctly.

Is it valid to perform step (2) above, where I manually change the existing inputs and outputs to include the batch size and embedding vector size? Will the upper Gather node know to pass the batch size down to the HPS_TRT node, or does it not work that way? Do I instead need to be setting max_batch_size = 0 in this situation? (When I tried that I got Runtime error: Invalid value for unique_op capacity and haven't yet debugged further.)

Please let me know if I'm not thinking about this correctly or if there's more information I could provide to illuminate what's going wrong.

Thanks!

@dmac dmac added the question Further information is requested label May 23, 2024
@yingcanw
Copy link
Collaborator

Thanks for the feedback, @KingsleyLiu-NV could you help check this out?

@LiuJieShane
Copy link

LiuJieShane commented May 27, 2024

Hi @dmac , thanks for reporting this issue. Can you please share the minimal code to reproduce this? I think we need these files:

  • The surgeon onnx model with HPS_TRT plugin
  • The HPS JSON file
  • The embedding table in HPS format
  • The script to build TRT engine

You don't need to provide the original onnx model and embedding table. Instead, you can provide the onnx model only with the plugin node that triggers the assertion error, and provide a mini embedding table to help us reproduce.

@KingsleyLiu-NV
Copy link
Collaborator

KingsleyLiu-NV commented May 28, 2024

Hi @dmac ,

I create an ONNX model has the same layers as yours:

Screenshot 2024-05-28 100110

And then use your script to do graph surgery to use HPS TRT plugin:

image

After that, I use the follow script to build TRT engine:

import tensorrt as trt
import ctypes

plugin_lib_name = "/usr/local/hps_trt/lib/libhps_plugin.so"
handle = ctypes.CDLL(plugin_lib_name, mode=ctypes.RTLD_GLOBAL)

TRT_LOGGER = trt.Logger(trt.Logger.INFO)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

def build_engine_from_onnx(onnx_model_path):
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as builder_config:        
        model = open(onnx_model_path, 'rb')
        parser.parse(model.read())

        profile = builder.create_optimization_profile()        
        profile.set_shape("keys", (1, 26), (1024, 26), (1024, 26))    
        profile.set_shape("numerical_features", (1, 13), (1024, 13), (1024, 13))
        builder_config.add_optimization_profile(profile)
        engine = builder.build_serialized_network(network, builder_config)
        return engine

serialized_engine = build_engine_from_onnx("dlrm_tf_with_hps.onnx")
with open("dlrm_tf_with_hps.trt", "wb") as fout:
    fout.write(serialized_engine)
print("Successfully build the TensorRT engine")

Here is the log of successful building:

[05/28/2024-01:58:08] [TRT] [I] [MemUsageChange] Init CUDA: CPU +18, GPU +0, now: CPU 26, GPU 423 (MiB)
[05/28/2024-01:58:14] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +1667, GPU +310, now: CPU 1770, GPU 733 (MiB)
[05/28/2024-01:58:14] [TRT] [W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[05/28/2024-01:58:14] [TRT] [I] No importer registered for op: HPS_TRT. Attempting to import as plugin.
[05/28/2024-01:58:14] [TRT] [I] Searching for plugin: HPS_TRT, plugin_version: 1, plugin_namespace:
=====================================================HPS Parse====================================================
[HCTR][01:58:14.412][INFO][RK0][main]: fuse_embedding_table is not specified using default: 0
[HCTR][01:58:14.412][INFO][RK0][main]: dense_file is not specified using default:
[HCTR][01:58:14.412][INFO][RK0][main]: num_of_refresher_buffer_in_pool is not specified using default: 1
[HCTR][01:58:14.412][INFO][RK0][main]: maxnum_des_feature_per_sample is not specified using default: 26
[HCTR][01:58:14.412][INFO][RK0][main]: refresh_delay is not specified using default: 0
[HCTR][01:58:14.412][INFO][RK0][main]: refresh_interval is not specified using default: 0
[HCTR][01:58:14.412][INFO][RK0][main]: use_static_table is not specified using default: 0
[HCTR][01:58:14.412][INFO][RK0][main]: use_context_stream is not specified using default: 1
[HCTR][01:58:14.412][INFO][RK0][main]: use_hctr_cache_implementation is not specified using default: 1
[HCTR][01:58:14.412][INFO][RK0][main]: thread_pool_size is not specified using default: 16
[HCTR][01:58:14.412][INFO][RK0][main]: init_ec is not specified using default: 1
[HCTR][01:58:14.412][INFO][RK0][main]: enable_pagelock is not specified using default: 0
[HCTR][01:58:14.412][INFO][RK0][main]: fp8_quant is not specified using default: 0
[HCTR][01:58:14.412][INFO][RK0][main]: HPS plugin uses context stream for model dlrm: True
====================================================HPS Create====================================================
[HCTR][01:58:14.413][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][01:58:14.413][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][01:58:14.413][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][01:58:14.413][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][01:58:14.413][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][01:58:15.068][INFO][RK0][main]: Table: hps_et.dlrm.sparse_embedding0; cached 2600 / 2600 embeddings in volatile database (HashMapBackend); load: 2600 / 18446744073709551615 (0.00%).
[HCTR][01:58:15.071][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][01:58:15.071][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][01:58:15.093][INFO][RK0][main]: Model name: dlrm
[HCTR][01:58:15.093][INFO][RK0][main]: Max batch size: 1024
[HCTR][01:58:15.093][INFO][RK0][main]: Fuse embedding tables: False
[HCTR][01:58:15.093][INFO][RK0][main]: Number of embedding tables: 1
[HCTR][01:58:15.093][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][01:58:15.093][INFO][RK0][main]: Embedding cache type: dynamic
[HCTR][01:58:15.093][INFO][RK0][main]: Use I64 input key: False
[HCTR][01:58:15.093][INFO][RK0][main]: Configured cache hit rate threshold: 1.000000
[HCTR][01:58:15.093][INFO][RK0][main]: The size of thread pool: 256
[HCTR][01:58:15.093][INFO][RK0][main]: The size of worker memory pool: 3
[HCTR][01:58:15.093][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][01:58:15.093][INFO][RK0][main]: The refresh percentage : 0.200000
[HCTR][01:58:15.109][INFO][RK0][main]: Initialize the embedding cache by by inserting the same size model file with embedding cache from beginning
[HCTR][01:58:15.109][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][01:58:15.109][INFO][RK0][main]: EC initialization on device 0 for hps_et.dlrm.sparse_embedding0
[HCTR][01:58:15.178][INFO][RK0][main]: Initialize the embedding table 0 for iteration 0 with number of 512 keys.
[HCTR][01:58:15.248][INFO][RK0][main]: Initialize the embedding table 0 for iteration 1 with number of 512 keys.
[HCTR][01:58:15.316][INFO][RK0][main]: Initialize the embedding table 0 for iteration 2 with number of 512 keys.
[HCTR][01:58:15.385][INFO][RK0][main]: Initialize the embedding table 0 for iteration 3 with number of 512 keys.
[HCTR][01:58:15.452][INFO][RK0][main]: Initialize the embedding table 0 for iteration 4 with number of 512 keys.
[HCTR][01:58:15.473][INFO][RK0][main]: LookupSession i64_input_key: False
[HCTR][01:58:15.473][INFO][RK0][main]: Creating lookup session for dlrm on device: 0
[05/28/2024-01:58:15] [TRT] [I] Successfully created plugin: HPS_TRT
[05/28/2024-01:58:15] [TRT] [I] Graph optimization time: 0.00860746 seconds.
[05/28/2024-01:58:15] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +6, GPU +8, now: CPU 5950, GPU 833 (MiB)
[05/28/2024-01:58:15] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +2, GPU +10, now: CPU 5952, GPU 843 (MiB)
[05/28/2024-01:58:15] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[05/28/2024-01:58:17] [TRT] [I] Detected 2 inputs and 1 output network tensors.
[05/28/2024-01:58:17] [TRT] [I] Total Host Persistent Memory: 7088
[05/28/2024-01:58:17] [TRT] [I] Total Device Persistent Memory: 0
[05/28/2024-01:58:17] [TRT] [I] Total Scratch Memory: 0
[05/28/2024-01:58:17] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 4 MiB
[05/28/2024-01:58:17] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 6 steps to complete.
[05/28/2024-01:58:17] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.0212ms to assign 3 blocks to 6 nodes requiring 1102336 bytes.
[05/28/2024-01:58:17] [TRT] [I] Total Activation Memory: 1102336
[05/28/2024-01:58:17] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 6138, GPU 931 (MiB)
[05/28/2024-01:58:17] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +10, now: CPU 6138, GPU 941 (MiB)
[05/28/2024-01:58:17] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +4, now: CPU 0, GPU 4 (MiB)
Successfully build the TensorRT engine

Maybe there is something wrong with your ONNX model after graph surgery. Can you please check if the output tensor of /input_layer/Gather_77 in your model is really a two-dimensional tensor of shape (batch_size, 1)?

The annotation tensor: int32[batch_size, 1] in ONNX visualization may not reflect he actual tensor shape during runtime. You can run your original ONNX model via onnxruntime to check the tensor shape in related layers.

@KingsleyLiu-NV KingsleyLiu-NV self-assigned this May 28, 2024
@dmac
Copy link
Author

dmac commented May 29, 2024

@KingsleyLiu-NV Thanks for taking a look at this, your response helped put me on the right track.

There were two issues I needed to fix:

  1. The /input_layer/Gather_77 was using a scalar for its indices input. I changed that from 100 to [100] so it would output the correct shape that the HPS_TRT node expects.
  2. I needed to add a Reshape node between HPS_TRT and Concat to reshape ["batch_size", 1, 512] to ["batch_size", 512].

After making these adjustments to my script I was able to succesfully convert it to TensorRT. Thanks again for your help!

For posterity, this is the script I ended up with.
import numpy as np
import onnx
import onnx_graphsurgeon as gs

model = onnx.load_model("model.onnx")
graph = gs.import_onnx(model)

def replace_embedding_node(gather_node_name: str):
    # Change the Gather input shape from i to [i] to match the expected input
    # shape for the HPS_TRT node.
    gather = next(n for n in graph.nodes if n.name == gather_node_name)
    indices = next(t for t in gather.inputs if "Constant" in t.name).inputs[0]
    indices.attrs["value"].values = np.array([indices.attrs["value"].values])

    # Replace the inline weights embedding lookup with HPS_TRT.
    embedding = gather.outputs[0].outputs[0]
    weights = next(t for t in embedding.inputs if t.name.endswith("embedding.weight"))
    concat_in = embedding.outputs[0]
    hps_out = gs.Variable(
        name=embedding.name+".HPS_output_0",
        dtype=np.float32,
        shape=("batch_size", 1, weights.shape[1]),
    )
    hps = gs.Node(
        name=embedding.name+".HPS",
        op="HPS_TRT",
        attrs={
            "ps_config_file": f"hps.json\0",
            "model_name": "mymodel\0",
            "table_id": 0,
            "emb_vec_size": weights.shape[1],
        },
        inputs=[gather.outputs[0]],
        outputs=[hps_out],
    )

    # Reshape the HPS_TRT output from ["batch_size", 1, N] to ["batch_size", N]
    # to match the expected shape of the Concat inputs.
    new_shape = gs.Constant(
        name=hps.name+"/Reshape_input_shape",
        values=np.array([-1, weights.shape[1]]),
    )
    reshape = gs.Node(
        name=hps.name+"/Reshape",
        op="Reshape",
        inputs=[hps_out, new_shape],
        outputs=[concat_in],
    )

    graph.nodes.append(hps)
    graph.nodes.append(reshape)
    embedding.outputs.clear()

replace_embedding_node("/input_layer/Gather_77")
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "model_hps.onnx")

@dmac dmac closed this as completed May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants