add
This commit is contained in:
parent
25e718d9a6
commit
a27b0aa2c6
49
20251014.md
49
20251014.md
@ -74,7 +74,7 @@ from pydantic import BaseModel
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
# Guided decoding by JSON using Pydantic schema
|
||||
# 定义结构化输出 schema
|
||||
class CarType(str, Enum):
|
||||
sedan = "sedan"
|
||||
suv = "SUV"
|
||||
@ -86,10 +86,11 @@ class CarDescription(BaseModel):
|
||||
model: str
|
||||
car_type: CarType
|
||||
|
||||
# 获取 JSON schema
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
# guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
|
||||
sampling_params_json = SamplingParams(guided_decoding={})
|
||||
prompt_json = (
|
||||
|
||||
# 设置 prompt
|
||||
prompt = (
|
||||
"Generate a JSON with the brand, model and car_type of "
|
||||
"the most iconic car from the 90's"
|
||||
)
|
||||
@ -97,14 +98,40 @@ prompt_json = (
|
||||
def format_output(title: str, output: str):
|
||||
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
|
||||
|
||||
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
return outputs[0].outputs[0].text
|
||||
|
||||
def main():
|
||||
llm = LLM(model="qwen", max_model_len=100)
|
||||
json_output = generate_output(prompt_json, sampling_params_json, llm)
|
||||
format_output("Guided decoding by JSON", json_output)
|
||||
# 1. 初始化本地 LLM,加载本地模型文件
|
||||
llm = LLM(
|
||||
model="/home/ss/vllm-py12/qwen3-06b", # 指向你的本地模型路径
|
||||
max_model_len=1024,
|
||||
enable_prefix_caching=True,
|
||||
gpu_memory_utilization=0.9,
|
||||
)
|
||||
|
||||
# 2. 构造一个无效的 guided_decoding:没有任何有效字段
|
||||
# 这将导致 get_structured_output_key() 中 raise ValueError
|
||||
guided_decoding_invalid = GuidedDecodingParams(
|
||||
json=None,
|
||||
json_object=False,
|
||||
regex=None,
|
||||
choice=None,
|
||||
grammar=None,
|
||||
structural_tag=None
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=512,
|
||||
guided_decoding=guided_decoding_invalid # ✅ 传入但无有效字段
|
||||
)
|
||||
|
||||
# 3. 生成输出(预期会触发 ValueError)
|
||||
try:
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
format_output("Output", generated_text)
|
||||
except Exception as e:
|
||||
print(f"Caught expected error: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user