diff --git a/extensions/api/script.py b/extensions/api/script.py index 20562cc..6726d61 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler): prompt_lines.pop(0) prompt = '\n'.join(prompt_lines) + generate_params = { + 'max_new_tokens': int(body.get('max_length', 200)), + 'do_sample': bool(body.get('do_sample', True)), + 'temperature': float(body.get('temperature', 0.5)), + 'top_p': float(body.get('top_p', 1)), + 'typical_p': float(body.get('typical', 1)), + 'repetition_penalty': float(body.get('rep_pen', 1.1)), + 'encoder_repetition_penalty': 1, + 'top_k': int(body.get('top_k', 0)), + 'min_length': int(body.get('min_length', 0)), + 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)), + 'num_beams': int(body.get('num_beams',1)), + 'penalty_alpha': float(body.get('penalty_alpha', 0)), + 'length_penalty': float(body.get('length_penalty', 1)), + 'early_stopping': bool(body.get('early_stopping', False)), + 'seed': int(body.get('seed', -1)), + } generator = generate_reply( - question = prompt, - max_new_tokens = int(body.get('max_length', 200)), - do_sample=bool(body.get('do_sample', True)), - temperature=float(body.get('temperature', 0.5)), - top_p=float(body.get('top_p', 1)), - typical_p=float(body.get('typical', 1)), - repetition_penalty=float(body.get('rep_pen', 1.1)), - encoder_repetition_penalty=1, - top_k=int(body.get('top_k', 0)), - min_length=int(body.get('min_length', 0)), - no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)), - num_beams=int(body.get('num_beams',1)), - penalty_alpha=float(body.get('penalty_alpha', 0)), - length_penalty=float(body.get('length_penalty', 1)), - early_stopping=bool(body.get('early_stopping', False)), - seed=int(body.get('seed', -1)), + prompt, + generate_params, stopping_strings=body.get('stopping_strings', []), )