Overview

Jamba models support several quantization techniques:

  • FP8 Quantization: 8-bit floating point weights for reduced memory footprint and efficient deployment
  • ExpertsInt8: Innovative quantization for MoE models in vLLM deployment
  • 8-bit Quantization: Using BitsAndBytesConfig for training and inference

FP8 Quantization (vLLM)

These models leverage pre-quantized FP8 weights, significantly reducing storage requirements and memory footprint while not compromising output quality.

FP8 quantization requires Hopper architecture GPUs such as NVIDIA H100 and NVIDIA H200.

Pre-quantized Model Weights

Prerequisites

pip install vllm>=0.6.5,<=0.8.5.post1

Implementation

1

Load Pre-quantized FP8 Model

from vllm import LLM, SamplingParams

llm = LLM(
    model="ai21labs/AI21-Jamba-Mini-1.7-FP8",
    max_model_len=100*1024,
)
2

Generate Text

sampling_params = SamplingParams(
    temperature=0.4,
    top_p=1.0,
    max_tokens=100
)

prompts = ["Explain the advantages of FP8 quantization:"]
outputs = llm.generate(prompts, sampling_params)

print(outputs[0].outputs[0].text)

Pre-quantized FP8 models require no additional quantization parameters since the weights are already quantized.

ExpertsInt8 Quantization (vLLM)

ExpertsInt8 is an innovative and efficient quantization technique developed specifically for Mixture of Experts (MoE) models deployed in vLLM, including Jamba models. This technique enables:

  • Jamba Mini 1.7: Deploy on a single 80GB GPU
  • Jamba Large 1.7: Deploy on a single node of 8x 80GB GPUs

Prerequisites

pip install vllm>=0.6.5,<=0.8.5.post1

Implementation

1

Load Model with ExpertsInt8

from vllm import LLM

llm = LLM(
    model="ai21labs/AI21-Jamba-Mini-1.7",
    max_model_len=100*1024,
    quantization="experts_int8"  # Enable ExpertsInt8 quantization
)
2

Generate Text

from vllm import SamplingParams

sampling_params = SamplingParams(
    temperature=0.4, 
    top_p=0.95, 
    max_tokens=100
)

# Generate text
prompts = ["Explain the benefits of model quantization:"]
outputs = llm.generate(prompts, sampling_params)

print(outputs[0].outputs[0].text)

With ExpertsInt8 quantization, you can fit prompts up to 100K tokens on a single 80GB A100 GPU with Jamba Mini.

8-bit Quantization (Hugging Face)

With 8-bit quantization using BitsAndBytesConfig, it is possible to fit up to 140K sequence length on a single 80GB GPU.

Prerequisites

pip install transformers torch bitsandbytes accelerate

Implementation

1

Configure 8-bit Quantization

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,  # Enable 8-bit quantization
    llm_int8_skip_modules=["mamba"]  # Exclude Mamba blocks to preserve quality
)
2

Load Model with Quantization

model = AutoModelForCausalLM.from_pretrained(
    "ai21labs/AI21-Jamba-Mini-1.7",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    quantization_config=quantization_config
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
3

Run Inference

messages = [
    {"role": "system", "content": "You are a helpful AI assistant."},
    {"role": "user", "content": "What are the advantages of 8-bit quantization?"}
]

input_ids = tokenizer.apply_chat_template(
    messages, 
    add_generation_prompt=True, 
    return_tensors='pt'
).to(model.device)

with torch.no_grad():
    outputs = model.generate(
        input_ids,
        max_new_tokens=200,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

To maintain model quality, we recommend excluding Mamba blocks from quantization using llm_int8_skip_modules=["mamba"].

Overview

Jamba models support several quantization techniques:

  • FP8 Quantization: 8-bit floating point weights for reduced memory footprint and efficient deployment
  • ExpertsInt8: Innovative quantization for MoE models in vLLM deployment
  • 8-bit Quantization: Using BitsAndBytesConfig for training and inference

FP8 Quantization (vLLM)

These models leverage pre-quantized FP8 weights, significantly reducing storage requirements and memory footprint while not compromising output quality.

FP8 quantization requires Hopper architecture GPUs such as NVIDIA H100 and NVIDIA H200.

Pre-quantized Model Weights

Prerequisites

pip install vllm>=0.6.5,<=0.8.5.post1

Implementation

1

Load Pre-quantized FP8 Model

from vllm import LLM, SamplingParams

llm = LLM(
    model="ai21labs/AI21-Jamba-Mini-1.7-FP8",
    max_model_len=100*1024,
)
2

Generate Text

sampling_params = SamplingParams(
    temperature=0.4,
    top_p=1.0,
    max_tokens=100
)

prompts = ["Explain the advantages of FP8 quantization:"]
outputs = llm.generate(prompts, sampling_params)

print(outputs[0].outputs[0].text)

Pre-quantized FP8 models require no additional quantization parameters since the weights are already quantized.

ExpertsInt8 Quantization (vLLM)

ExpertsInt8 is an innovative and efficient quantization technique developed specifically for Mixture of Experts (MoE) models deployed in vLLM, including Jamba models. This technique enables:

  • Jamba Mini 1.7: Deploy on a single 80GB GPU
  • Jamba Large 1.7: Deploy on a single node of 8x 80GB GPUs

Prerequisites

pip install vllm>=0.6.5,<=0.8.5.post1

Implementation

1

Load Model with ExpertsInt8

from vllm import LLM

llm = LLM(
    model="ai21labs/AI21-Jamba-Mini-1.7",
    max_model_len=100*1024,
    quantization="experts_int8"  # Enable ExpertsInt8 quantization
)
2

Generate Text

from vllm import SamplingParams

sampling_params = SamplingParams(
    temperature=0.4, 
    top_p=0.95, 
    max_tokens=100
)

# Generate text
prompts = ["Explain the benefits of model quantization:"]
outputs = llm.generate(prompts, sampling_params)

print(outputs[0].outputs[0].text)

With ExpertsInt8 quantization, you can fit prompts up to 100K tokens on a single 80GB A100 GPU with Jamba Mini.

8-bit Quantization (Hugging Face)

With 8-bit quantization using BitsAndBytesConfig, it is possible to fit up to 140K sequence length on a single 80GB GPU.

Prerequisites

pip install transformers torch bitsandbytes accelerate

Implementation

1

Configure 8-bit Quantization

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,  # Enable 8-bit quantization
    llm_int8_skip_modules=["mamba"]  # Exclude Mamba blocks to preserve quality
)
2

Load Model with Quantization

model = AutoModelForCausalLM.from_pretrained(
    "ai21labs/AI21-Jamba-Mini-1.7",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    quantization_config=quantization_config
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
3

Run Inference

messages = [
    {"role": "system", "content": "You are a helpful AI assistant."},
    {"role": "user", "content": "What are the advantages of 8-bit quantization?"}
]

input_ids = tokenizer.apply_chat_template(
    messages, 
    add_generation_prompt=True, 
    return_tensors='pt'
).to(model.device)

with torch.no_grad():
    outputs = model.generate(
        input_ids,
        max_new_tokens=200,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

To maintain model quality, we recommend excluding Mamba blocks from quantization using llm_int8_skip_modules=["mamba"].