1
0
This commit is contained in:
liushuang 2025-10-15 10:53:17 +08:00
parent 25e718d9a6
commit a27b0aa2c6

View File

@ -74,7 +74,7 @@ from pydantic import BaseModel
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
# Guided decoding by JSON using Pydantic schema # 定义结构化输出 schema
class CarType(str, Enum): class CarType(str, Enum):
sedan = "sedan" sedan = "sedan"
suv = "SUV" suv = "SUV"
@ -86,25 +86,52 @@ class CarDescription(BaseModel):
model: str model: str
car_type: CarType car_type: CarType
# 获取 JSON schema
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
# guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams(guided_decoding={}) # 设置 prompt
prompt_json = ( prompt = (
"Generate a JSON with the brand, model and car_type of" "Generate a JSON with the brand, model and car_type of "
"the most iconic car from the 90's" "the most iconic car from the 90's"
) )
def format_output(title: str, output: str): def format_output(title: str, output: str):
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}") print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
return outputs[0].outputs[0].text
def main(): def main():
llm = LLM(model="qwen", max_model_len=100) # 1. 初始化本地 LLM加载本地模型文件
json_output = generate_output(prompt_json, sampling_params_json, llm) llm = LLM(
format_output("Guided decoding by JSON", json_output) model="/home/ss/vllm-py12/qwen3-06b", # 指向你的本地模型路径
max_model_len=1024,
enable_prefix_caching=True,
gpu_memory_utilization=0.9,
)
# 2. 构造一个无效的 guided_decoding没有任何有效字段
# 这将导致 get_structured_output_key() 中 raise ValueError
guided_decoding_invalid = GuidedDecodingParams(
json=None,
json_object=False,
regex=None,
choice=None,
grammar=None,
structural_tag=None
)
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=512,
guided_decoding=guided_decoding_invalid # ✅ 传入但无有效字段
)
# 3. 生成输出(预期会触发 ValueError
try:
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
format_output("Output", generated_text)
except Exception as e:
print(f"Caught expected error: {e}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()