Add LLaMA 8-bit support
This commit is contained in:
@@ -88,12 +88,20 @@ def load_model(model_name):
|
||||
|
||||
# LLaMA model (not on HuggingFace)
|
||||
elif shared.is_LLaMA:
|
||||
import modules.LLaMA
|
||||
from modules.LLaMA import LLaMAModel
|
||||
if shared.args.load_in_8bit:
|
||||
import modules.LLaMA_8bit
|
||||
from modules.LLaMA_8bit import LLaMAModel_8bit
|
||||
|
||||
model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
|
||||
model = LLaMAModel_8bit.from_pretrained(Path(f'models/{model_name}'))
|
||||
|
||||
return model, None
|
||||
return model, None
|
||||
else:
|
||||
import modules.LLaMA
|
||||
from modules.LLaMA import LLaMAModel
|
||||
|
||||
model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
|
||||
|
||||
return model, None
|
||||
|
||||
# Custom
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user