diff --git a/src/python/example/main.py b/src/python/example/main.py new file mode 100644 index 000000000..48a36f953 --- /dev/null +++ b/src/python/example/main.py @@ -0,0 +1,84 @@ +#!D:\project\transfer_xbox\python\tools\python.exe +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from logging import raiseExceptions +import time +import argparse +import numpy as np +import torch +# torch.manual_seed(0) +import torch.nn as nn +import torch.nn.functional as F + +from nnfusion.executor import Executor +from nnfusion.session import generate_sample +from nnfusion.data_format import cast_pytorch_tensor, cast_hlsl_tensor, HLSLTensor + + +def inference(nnf_model_path, total_iter): + assert total_iter >= 1 + executor = Executor(nnf_model_path) + input_dict, output_dict = {}, {} + if executor.host_mode: + # host mode leverage pytorch tensor as storage + for input in executor.get_inputs(): + input_dict[input.name] = cast_pytorch_tensor(generate_sample(input)) + for output in executor.get_outputs(): + output_dict[output.name] = cast_pytorch_tensor(generate_sample(output)) + else: + if executor.device_type == 0: + # cuda device + for input in executor.get_inputs(): + input_dict[input.name] = cast_pytorch_tensor(generate_sample(input, "cuda")) + for output in executor.get_outputs(): + output_dict[output.name] = cast_pytorch_tensor(generate_sample(output, "cuda")) + elif executor.device_type == 3: + # hlsl device + for input in executor.get_inputs(): + input_dict[input.name] = cast_hlsl_tensor(HLSLTensor(generate_sample(input))) + for output in executor.get_outputs(): + output_dict[output.name] = cast_hlsl_tensor(HLSLTensor(generate_sample(output))) + else: + raise Exception("only support device kernel_entry on cuda/hlsl backend.") + + + # warm up + for _ in range(5): + executor(input_dict, output_dict) + for k, v in output_dict.items(): + print(f"{k} = {v.reference}") + + # evaluate + print(f"Begin evaluation of {total_iter} iters") + start = time.time() + perf_list = [] + for _ in range(total_iter): + start_i = time.time() + executor(input_dict, output_dict) + end_i = time.time() + #print(end_i - start_i) + perf_list.append(end_i - start_i) + end = time.time() + + latency_ms = np.array(perf_list) * 1000 + batch_size = list(input_dict.values())[0].shape[0] + print(f"average_latency = {np.mean(latency_ms)} ms") + print(f"latency_50 = {np.percentile(latency_ms, 50)} ms") + print(f"latency_75 = {np.percentile(latency_ms, 75)} ms") + print(f"latency_90 = {np.percentile(latency_ms, 90)} ms") + print(f"latency_95 = {np.percentile(latency_ms, 95)} ms") + print(f"latency_99 = {np.percentile(latency_ms, 99)} ms") + print(f"throughput = {batch_size * (1000.0 / np.mean(latency_ms))} sample/s") + print(f"total elaspe {end - start} s") + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--nnf_model_path', type=str) + parser.add_argument('--total_iter', type=int, default=1) + args = parser.parse_args() + inference(args.nnf_model_path, args.total_iter) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/python/nnfusion/data_format.py b/src/python/nnfusion/data_format.py index 93ce52904..5b7f11a70 100644 --- a/src/python/nnfusion/data_format.py +++ b/src/python/nnfusion/data_format.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. import ctypes +from numpy import dtype +import torch from . import dtypes @@ -46,6 +48,67 @@ def dtype(self): def reference(self): return self._reference +class HLSLTensor(object): + antares_lib = None + + @classmethod + def init_antares_lib(cls, antares_dll_path): + if cls.antares_lib is None: + # cls.antares_lib = ctypes.cdll.LoadLibrary(r"D:\project\nnfusion_rt_pow\nnfusion_rt\dxcompute_codegen\Direct3DWinNN_seperate_dll\x64\Release\antares.dll") + cls.antares_lib = ctypes.cdll.LoadLibrary(antares_dll_path) + # alloc + cls.antares_lib.dxMemAlloc.argtypes = [ctypes.c_uint64] + cls.antares_lib.dxMemAlloc.restype = ctypes.c_void_p + # free + cls.antares_lib.dxMemFree.argtypes = [ctypes.c_void_p] + cls.antares_lib.dxMemFree.restype = ctypes.c_int32 + # H2D + cls.antares_lib.dxMemcpyHtoDAsync.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint64, ctypes.c_void_p] + cls.antares_lib.dxMemcpyHtoDAsync.restype = ctypes.c_int32 + # D2H + cls.antares_lib.dxMemcpyDtoHAsync.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint64, ctypes.c_void_p] + cls.antares_lib.dxMemcpyDtoHAsync.restype = ctypes.c_int32 + # Sync + cls.antares_lib.dxStreamSynchronize.argtypes = [ctypes.c_void_p] + cls.antares_lib.dxStreamSynchronize.restype = ctypes.c_int32 + return + + def __init__(self, pytorch_tensor) -> None: + if self.antares_lib is None: + raise Exception("Please init antares lib firstly(e.g. creating a executor instance antomatically init antares lib") + pytorch_tensor = pytorch_tensor.contiguous() + self.shape = pytorch_tensor.shape + self.pt_type = str(pytorch_tensor.dtype).split(".")[-1] + self.dtype = dtypes.str2type[self.pt_type].type_str + num_element = pytorch_tensor.numel() + element_size = pytorch_tensor.element_size() + self.size = num_element * element_size + self.pointer = self.antares_lib.dxMemAlloc(self.size) + self.antares_lib.dxMemcpyHtoDAsync(self.pointer, ctypes.cast(pytorch_tensor.storage().data_ptr(), ctypes.c_void_p), self.size, None) + self.antares_lib.dxStreamSynchronize(None) + + + def __del__(self): + if hasattr(self, "pointer") and self.pointer: + self.antares_lib.dxMemFree(self.pointer) + self.pointer == ctypes.c_void_p(None) + + def __str__(self): + return self.to_pytorch_tensor().__str__() + + def to_pytorch_tensor(self): + res = torch.empty(self.shape, dtype=dtypes.str2type[self.dtype].torch_type) + self.antares_lib.dxMemcpyDtoHAsync(ctypes.cast(res.storage().data_ptr(), ctypes.c_void_p), self.pointer, self.size, None) + self.antares_lib.dxStreamSynchronize(None) + return res + +def cast_hlsl_tensor(hlsl_tensor): + pointer_type = ctypes.POINTER(dtypes.str2type[hlsl_tensor.dtype].c_type) + pointer = ctypes.cast(hlsl_tensor.pointer, pointer_type) + shape = hlsl_tensor.shape + dtype = hlsl_tensor.dtype + reference = hlsl_tensor + return DataFormat(pointer, pointer_type, shape, dtype, reference) def cast_pytorch_tensor(pytorch_tensor): if not pytorch_tensor.is_contiguous(): diff --git a/src/python/nnfusion/dtypes.py b/src/python/nnfusion/dtypes.py index 1632b60b5..911c1b192 100644 --- a/src/python/nnfusion/dtypes.py +++ b/src/python/nnfusion/dtypes.py @@ -39,9 +39,9 @@ "uint8": TypeObject._make(["uint8", ctypes.c_uint8, torch.uint8, numpy.uint8]), "uint16": - TypeObject._make(["uint8", ctypes.c_uint16, None, numpy.uint16]), + TypeObject._make(["uint16", ctypes.c_uint16, None, numpy.uint16]), "uint32": - TypeObject._make(["uint8", ctypes.c_uint32, None, numpy.uint32]), + TypeObject._make(["uint32", ctypes.c_uint32, None, numpy.uint32]), "uint64": - TypeObject._make(["uint8", ctypes.c_uint64, None, numpy.uint64]), + TypeObject._make(["uint64", ctypes.c_uint64, None, numpy.uint64]), } diff --git a/src/python/nnfusion/executor.py b/src/python/nnfusion/executor.py index 1fec938bd..79fb374fd 100644 --- a/src/python/nnfusion/executor.py +++ b/src/python/nnfusion/executor.py @@ -4,10 +4,9 @@ import json import os import platform - import torch -from .data_format import cast_pytorch_tensor +from .data_format import HLSLTensor, cast_pytorch_tensor from .description import IODescription from .utils import cd @@ -98,6 +97,8 @@ def __init__(self, nnf_rt_dir, device=None): # prepare init/free/kernel_entry self.init_flag = False + if os.path.exists(os.path.join(nnf_rt_dir, "antares.dll")): + HLSLTensor.init_antares_lib(os.path.join(nnf_rt_dir, "antares.dll")) # dxil.dll and dxcompiler.dll must be manually imported if os.path.exists(os.path.join(nnf_rt_dir, "dxil.dll")): ctypes.cdll.LoadLibrary(os.path.join(nnf_rt_dir, "dxil.dll")) @@ -106,8 +107,10 @@ def __init__(self, nnf_rt_dir, device=None): self.libnnf = ctypes.cdll.LoadLibrary(self.libnnf_path) if hasattr(self.libnnf, "kernel_entry_host"): self.kernel_entry = self.libnnf.kernel_entry_host + self.host_mode = True elif hasattr(self.libnnf, "kernel_entry"): self.kernel_entry = self.libnnf.kernel_entry + self.host_mode = False else: raise Exception("No kernel_entry found in nnfusion_rt") device_type = self.get_device_type() @@ -180,18 +183,7 @@ def __call__(self, *args, **kwargs): # self.feed_tensors(*args, **kwargs) self.feed_data(*args, **kwargs) - def feed_data(self, inputs, outputs, strict=True): - """ - Execute the kernel_entry in nnf runtime - - Parameters: - inputs: a dict from name to nnf DataFormat - outputs: a dict from name to nnf DataFormat - strict: False if allow unused inputs/outputs - - Returns: - None - """ + def _dict_to_pointer_list(self, inputs, outputs, strict=True): signature = [None] * (len(self.input_descs) + len(self.output_descs)) params = [None] * (len(self.input_descs) + len(self.output_descs)) for name, data_format in inputs.items(): @@ -223,6 +215,21 @@ def feed_data(self, inputs, outputs, strict=True): else: if strict: raise Exception(f"Unused output {name}") + return signature, params + + def feed_data(self, inputs, outputs, strict=True): + """ + Execute the kernel_entry in nnf runtime + + Parameters: + inputs: a dict from name to nnf DataFormat + outputs: a dict from name to nnf DataFormat + strict: False if allow unused inputs/outputs + + Returns: + None + """ + signature, params = self._dict_to_pointer_list(inputs, outputs, strict=strict) self.feed_pointers(signature, params) def feed_pointers(self, signature, params):