diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md
new file mode 100644
index 000000000..3b3daf9ef
--- /dev/null
+++ b/comps/finetuning/README.md
@@ -0,0 +1,48 @@
+# LLM Fine-tuning Microservice
+
+LLM Fine-tuning microservice involves adapting a base model to a specific task or dataset to improve its performance on that task.
+
+# 🚀1. Start Microservice with Python
+
+## 1.1 Install Requirements
+
+```bash
+pip install -r requirements.txt
+```
+
+## 1.2 Start Finetuning Service with Python Script
+
+### 1.2.1 Start Ray Cluster
+
+TBD
+
+### 1.2.2 Start Finetuning Service
+
+```bash
+export RAY_ADDRESS="ray://${ray_head_ip}:10001"
+python finetuning/finetuning.py
+```
+
+# 🚀2. Consume Finetuning Service
+
+## 2.1 Check Service Status
+
+```bash
+curl http://${your_ip}:8000/v1/health_check\
+  -X GET \
+  -H 'Content-Type: application/json'
+```
+
+## 2.2 Create fine-tuning job
+
+Assuming a training file `file-vGxE9KywnSUkEL6dv9qZxKAF.jsonl` is uploaded, the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model:
+
+```bash
+curl http://${your_ip}:8000/v1/fine_tuning/jobs \
+  -X POST \
+  -H "Content-Type: application/json" \
+  -d '{
+    "training_file": "file-vGxE9KywnSUkEL6dv9qZxKAF.jsonl",
+    "model": "meta-llama/Llama-2-7b-chat-hf"
+  }'
+```
diff --git a/comps/finetuning/datasets/.gitkeep b/comps/finetuning/datasets/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/comps/finetuning/docker/Dockerfile.finetune b/comps/finetuning/docker/Dockerfile.finetune
new file mode 100644
index 000000000..30b9c6171
--- /dev/null
+++ b/comps/finetuning/docker/Dockerfile.finetune
@@ -0,0 +1,22 @@
+# Use the same python version with ray
+FROM python:3.10.14
+
+WORKDIR /root/opea-finetune
+
+RUN --mount=type=cache,target=/var/cache/apt apt-get update -y \
+    && apt-get install -y vim htop net-tools dnsutils \
+    && apt-get clean \
+    && rm -rf /var/lib/apt/lists/*
+
+# COPY ./install-llm-on-ray.sh /tmp/install-llm-on-ray.sh
+# RUN --mount=type=cache,target=/root/.cache/pip /tmp/install-llm-on-ray.sh
+
+COPY ./ .
+
+RUN --mount=type=cache,target=/root/.cache/pip cd ./llm-on-ray && pip install -v -e .[cpu] --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
+
+RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir --upgrade -r requirements.txt
+
+RUN echo 'source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh' >> ~/.bashrc
+
+CMD ["bash", "-c", "./run.sh"]
\ No newline at end of file
diff --git a/comps/finetuning/finetune_runner.py b/comps/finetuning/finetune_runner.py
new file mode 100644
index 000000000..646341d73
--- /dev/null
+++ b/comps/finetuning/finetune_runner.py
@@ -0,0 +1,48 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import argparse
+import time
+import uuid
+from typing import List
+
+from llm_on_ray.finetune.finetune_config import FinetuneConfig
+from pydantic_yaml import parse_yaml_raw_as
+from ray.train.base_trainer import TrainingFailedError
+from ray.tune.callback import Callback
+from ray.tune.experiment import Trial
+from ray.tune.logger import LoggerCallback
+from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
+
+
+class FineTuneCallback(TrainerCallback):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        print("FineTuneCallback:", args, state)
+
+
+def main():
+    parser = argparse.ArgumentParser(description="Runner for llm_on_ray-finetune")
+    parser.add_argument("--config_file", type=str, required=True, default=None)
+    args = parser.parse_args()
+    model_config_file = args.config_file
+
+    with open(model_config_file) as f:
+        finetune_config = parse_yaml_raw_as(FinetuneConfig, f).model_dump()
+
+    callback = FineTuneCallback()
+    finetune_config["Training"]["callbacks"] = [callback]
+
+    from llm_on_ray.finetune.finetune import main as llm_on_ray_finetune_main
+
+    llm_on_ray_finetune_main(finetune_config)
+    # try:
+    #     llm_on_ray_finetune_main(finetune_config)
+    # except TrainingFailedError as e:
+    #     print(e)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/comps/finetuning/finetuning.py b/comps/finetuning/finetuning.py
new file mode 100644
index 000000000..a02833c00
--- /dev/null
+++ b/comps/finetuning/finetuning.py
@@ -0,0 +1,90 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import uvicorn
+from fastapi import BackgroundTasks, Cookie, FastAPI, Form, Header, Response
+from handlers import (
+    handle_cancel_finetuning_job,
+    handle_create_finetuning_jobs,
+    handle_list_finetuning_jobs,
+    handle_retrieve_finetuning_job,
+)
+from models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest
+from pydantic import BaseModel
+
+app = FastAPI()
+
+
+@app.post("/v1/fine_tuning/jobs", response_model=FineTuningJob)
+def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks):
+    return handle_create_finetuning_jobs(request, background_tasks)
+    # return {
+    #     "object": "fine_tuning.job",
+    #     "id": "ftjob-abc123",
+    #     "model": "davinci-002",
+    #     "created_at": 1692661014,
+    #     "finished_at": 1692661190,
+    #     "fine_tuned_model": "ft:davinci-002:my-org:custom_suffix:7q8mpxmy",
+    #     "organization_id": "org-123",
+    #     "result_files": ["file-abc123"],
+    #     "status": "succeeded",
+    #     "validation_file": None,
+    #     "training_file": "file-abc123",
+    #     "hyperparameters": {
+    #         "n_epochs": 4,
+    #         "batch_size": 1,
+    #         "learning_rate_multiplier": 1.0,
+    #     },
+    #     "trained_tokens": 5768,
+    #     "integrations": [],
+    #     "seed": 0,
+    #     "estimated_finish": 0,
+    # }
+
+
+@app.get("/v1/fine_tuning/jobs", response_model=FineTuningJobList)
+def list_finetuning_jobs():
+    return handle_list_finetuning_jobs()
+    # return {
+    #     "object": "list",
+    #     "data": [
+    #         {
+    #     "object": "fine_tuning.job",
+    #     "id": "ftjob-abc123",
+    #     "model": "davinci-002",
+    #     "created_at": 1692661014,
+    #     "finished_at": 1692661190,
+    #     "fine_tuned_model": "ft:davinci-002:my-org:custom_suffix:7q8mpxmy",
+    #     "organization_id": "org-123",
+    #     "result_files": ["file-abc123"],
+    #     "status": "succeeded",
+    #     "training_file": "file-abc123",
+    #     "hyperparameters": {
+    #         "n_epochs": 4,
+    #         "batch_size": 1,
+    #         "learning_rate_multiplier": 1.0,
+    #     },
+    #     "trained_tokens": 5768,
+    #     "integrations": [],
+    #     "seed": 0,
+    #     "estimated_finish": 0,
+    # },
+    #     ],
+    #     "has_more": True,
+    # }
+
+
+@app.get("/v1/fine_tuning/jobs/{fine_tuning_job_id}", response_model=FineTuningJob)
+def retrieve_finetuning_job(fine_tuning_job_id):
+    job = handle_retrieve_finetuning_job(fine_tuning_job_id)
+    return job
+
+
+@app.post("/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel", response_model=FineTuningJob)
+def cancel_finetuning_job(fine_tuning_job_id):
+    job = handle_cancel_finetuning_job(fine_tuning_job_id)
+    return job
+
+
+if __name__ == "__main__":
+    uvicorn.run(app, host="0.0.0.0", port=8000)
diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py
new file mode 100644
index 000000000..a874369f9
--- /dev/null
+++ b/comps/finetuning/handlers.py
@@ -0,0 +1,140 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+import random
+import time
+import uuid
+from typing import Any, Dict, List, Set
+
+from fastapi import BackgroundTasks, HTTPException
+from llm_on_ray.finetune.finetune import main
+from llm_on_ray.finetune.finetune_config import FinetuneConfig
+from models import FineTuningJob, FineTuningJobEvent, FineTuningJobList, FineTuningJobsRequest
+from pydantic_yaml import parse_yaml_raw_as, to_yaml_file
+from ray.job_submission import JobSubmissionClient
+from ray.train.base_trainer import TrainingFailedError
+from ray.tune.logger import LoggerCallback
+
+MODEL_CONFIG_FILE_MAP = {
+    "meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml",
+    "mistralai/Mistral-7B-v0.1": "./models/mistral-7b-v0.1.yaml",
+}
+
+DATASET_BASE_PATH = "datasets"
+
+FineTuningJobID = str
+CHECK_JOB_STATUS_INTERVAL = 5  # Check every 5 secs
+
+global ray_client
+ray_client: JobSubmissionClient = None
+
+running_finetuning_jobs: Dict[FineTuningJobID, FineTuningJob] = {}
+finetuning_job_to_ray_job: Dict[FineTuningJobID, str] = {}
+
+
+# Add a background task to periodicly update job status
+def update_job_status(job_id: FineTuningJobID):
+    while True:
+        job_status = ray_client.get_job_status(finetuning_job_to_ray_job[job_id])
+        status = str(job_status).lower()
+        # Ray status "stopped" is OpenAI status "cancelled"
+        status = "cancelled" if status == "stopped" else status
+        print(f"Status of job {job_id} is '{status}'")
+        running_finetuning_jobs[job_id].status = status
+        if status == "finished" or status == "cancelled" or status == "failed":
+            break
+        time.sleep(CHECK_JOB_STATUS_INTERVAL)
+
+
+def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks):
+    base_model = request.model
+    train_file = request.training_file
+    train_file_path = os.path.join(DATASET_BASE_PATH, train_file)
+
+    model_config_file = MODEL_CONFIG_FILE_MAP.get(base_model)
+    if not model_config_file:
+        raise HTTPException(status_code=404, detail=f"Base model '{base_model}' not supported!")
+
+    if not os.path.exists(train_file_path):
+        raise HTTPException(status_code=404, detail=f"Training file '{train_file}' not found!")
+
+    with open(model_config_file) as f:
+        finetune_config = parse_yaml_raw_as(FinetuneConfig, f)
+
+    finetune_config.Dataset.train_file = train_file_path
+
+    job = FineTuningJob(
+        id=f"ft-job-{uuid.uuid4()}",
+        model=base_model,
+        created_at=int(time.time()),
+        training_file=train_file,
+        hyperparameters={
+            "n_epochs": finetune_config.Training.epochs,
+            "batch_size": finetune_config.Training.batch_size,
+            "learning_rate_multiplier": finetune_config.Training.learning_rate,
+        },
+        status="running",
+        # TODO: Add seed in finetune config
+        seed=random.randint(0, 1000),
+    )
+
+    finetune_config_file = f"jobs/{job.id}.yaml"
+    to_yaml_file(finetune_config_file, finetune_config)
+
+    global ray_client
+    ray_client = JobSubmissionClient() if ray_client is None else ray_client
+
+    ray_job_id = ray_client.submit_job(
+        # Entrypoint shell command to execute
+        entrypoint=f"python finetune_runner.py --config_file {finetune_config_file}",
+        # Path to the local directory that contains the script.py file
+        runtime_env={"working_dir": "./"},
+    )
+    print(f"Submitted Ray job: {ray_job_id} ...")
+
+    running_finetuning_jobs[job.id] = job
+    finetuning_job_to_ray_job[job.id] = ray_job_id
+
+    background_tasks.add_task(update_job_status, job.id)
+
+    return job
+
+
+def handle_list_finetuning_jobs():
+    finetuning_jobs_list = FineTuningJobList(data=list(running_finetuning_jobs.values()), has_more=False)
+
+    return finetuning_jobs_list
+
+
+def handle_retrieve_finetuning_job(fine_tuning_job_id):
+    job = running_finetuning_jobs.get(fine_tuning_job_id)
+    if job is None:
+        raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!")
+    return job
+
+
+def handle_cancel_finetuning_job(fine_tuning_job_id):
+    ray_job_id = finetuning_job_to_ray_job.get(fine_tuning_job_id)
+    if ray_job_id is None:
+        raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!")
+
+    global ray_client
+    ray_client = JobSubmissionClient() if ray_client is None else ray_client
+    ray_client.stop_job(ray_job_id)
+
+    job = running_finetuning_jobs.get(fine_tuning_job_id)
+    job.status = "cancelled"
+    return job
+
+
+# def cancel_all_jobs():
+#     global ray_client
+#     ray_client = JobSubmissionClient() if ray_client is None else ray_client
+#     # stop all jobs
+#     for job_id in finetuning_job_to_ray_job.values():
+#         ray_client.stop_job(job_id)
+
+#     for job_id in running_finetuning_jobs:
+#         running_finetuning_jobs[job_id].status = "cancelled"
+#     return running_finetuning_jobs
diff --git a/comps/finetuning/jobs/.gitkeep b/comps/finetuning/jobs/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/comps/finetuning/llm_on_ray/common/common.py b/comps/finetuning/llm_on_ray/common/common.py
new file mode 100644
index 000000000..87f74096c
--- /dev/null
+++ b/comps/finetuning/llm_on_ray/common/common.py
@@ -0,0 +1,38 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import glob
+import importlib
+import os
+
+from llm_on_ray.common.logging import logger
+
+
+def import_all_modules(basedir, prefix=None):
+    all_py_files = glob.glob(basedir + "/*.py")
+    modules = [os.path.basename(f) for f in all_py_files]
+
+    for module in modules:
+        if not module.startswith("_"):
+            module = module.rstrip(".py")
+            if prefix is None:
+                module_name = module
+            else:
+                module_name = f"{prefix}.{module}"
+            try:
+                importlib.import_module(module_name)
+            except Exception:
+                logger.warning(f"import {module_name} error", exc_info=True)
diff --git a/comps/finetuning/llm_on_ray/common/logging.py b/comps/finetuning/llm_on_ray/common/logging.py
new file mode 100644
index 000000000..6d3f6ae80
--- /dev/null
+++ b/comps/finetuning/llm_on_ray/common/logging.py
@@ -0,0 +1,67 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import functools
+import logging
+import logging.config
+import traceback
+
+__all__ = ["logger", "get_logger"]
+
+use_accelerate_log = False
+logger_name = "common"
+
+logging_config = {
+    "version": 1,
+    "loggers": {
+        "root": {"level": "INFO", "handlers": ["consoleHandler"]},
+        "common": {
+            "level": "INFO",
+            "handlers": ["consoleHandler"],
+            "qualname": "common",
+            "propagate": 0,
+        },
+    },
+    "handlers": {
+        "consoleHandler": {
+            "class": "logging.StreamHandler",
+            "level": "INFO",
+            "formatter": "standardFormatter",
+        },
+    },
+    "formatters": {
+        "standardFormatter": {
+            "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+            "datefmt": "",
+        }
+    },
+}
+
+if logging_config is not None:
+    try:
+        logging.config.dictConfig(logging_config)
+    except Exception:
+        traceback.print_exc()
+        exit(1)
+
+if use_accelerate_log:
+    import accelerate
+
+    get_logger = functools.partial(accelerate.logging.get_logger, name=logger_name)
+else:
+    get_logger = functools.partial(logging.getLogger, name=logger_name)
+
+logger = get_logger()
diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py
new file mode 100644
index 000000000..ada3d3275
--- /dev/null
+++ b/comps/finetuning/llm_on_ray/finetune/finetune.py
@@ -0,0 +1,410 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#!/usr/bin/env python
+
+import argparse
+import os
+import sys
+from importlib import util
+from typing import Any, Dict, Optional, Union
+
+import ray
+import torch
+import transformers
+from llm_on_ray import common
+from llm_on_ray.finetune.finetune_config import FinetuneConfig
+from pydantic_yaml import parse_yaml_raw_as
+from ray.air import FailureConfig, RunConfig
+from ray.air.config import ScalingConfig
+from ray.train.torch import TorchTrainer
+
+use_habana = False
+if util.find_spec("habana_frameworks") is not None:
+    from optimum.habana.utils import set_seed
+
+    use_habana = True
+else:
+    from accelerate.utils import is_xpu_available, set_seed
+
+    use_habana = False
+
+
+def get_accelerate_environment_variable(config: Dict[str, Any]) -> dict:
+    device = config["Training"]["device"]
+    accelerate_mode = config["Training"]["accelerate_mode"]
+    mixed_precision = config["Training"]["mixed_precision"]
+    mode_env_vars = {
+        "cpu": {
+            "DDP": {
+                "ACCELERATE_USE_CPU": "true",
+                "ACCELERATE_USE_IPEX": "true",
+                "ACCELERATE_MIXED_PRECISION": mixed_precision,
+            }
+        },
+        "gpu": {
+            "DDP": {
+                "ACCELERATE_USE_CPU": "false",
+                "ACCELERATE_USE_XPU": "true",
+                "ACCELERATE_USE_IPEX": "true",
+                "ACCELERATE_MIXED_PRECISION": mixed_precision,
+            },
+            "FSDP": {
+                "ACCELERATE_USE_CPU": "false",
+                "ACCELERATE_USE_XPU": "true",
+                "ACCELERATE_USE_IPEX": "true",
+                "ACCELERATE_USE_FSDP": "true",
+                "FSDP_SHARDING_STRATEGY": "1",
+                "FSDP_OFFLOAD_PARAMS": "false",
+                "FSDP_AUTO_WRAP_POLICY": "NO_WRAP",
+                "FSDP_BACKWARD_PREFETCH": "BACKWARD_PRE",
+                "FSDP_STATE_DICT_TYPE": "SHARDED_STATE_DICT",
+                "FSDP_FORWARD_PREFETCH": "false",
+                "FSDP_USE_ORIG_PARAMS": "false",
+                "FSDP_SYNC_MODULE_STATES": "true",
+                "ACCELERATE_MIXED_PRECISION": mixed_precision,
+            },
+            "DEEPSPEED": {
+                "ACCELERATE_USE_CPU": "false",
+                "ACCELERATE_USE_XPU": "true",
+                "ACCELERATE_USE_IPEX": "true",
+                "ACCELERATE_USE_DEEPSPEED": "true",
+                "ACCELERATE_MIXED_PRECISION": mixed_precision,
+            },
+        },
+        "hpu": {
+            "DDP": {
+                "ACCELERATE_USE_CPU": "false",
+                "ACCELERATE_USE_XPU": "false",
+                "ACCELERATE_USE_IPEX": "false",
+                "ACCELERATE_MIXED_PRECISION": mixed_precision,
+            },
+            "DEEPSPEED": {
+                "ACCELERATE_USE_CPU": "false",
+                "ACCELERATE_USE_XPU": "false",
+                "ACCELERATE_USE_IPEX": "false",
+                "ACCELERATE_USE_DEEPSPEED": "true",
+                "ACCELERATE_MIXED_PRECISION": mixed_precision,
+            },
+        },
+    }
+    if device not in mode_env_vars or accelerate_mode not in mode_env_vars[device]:
+        supported_mode_info = ""
+        for k in mode_env_vars.keys():
+            supported_mode_info += k + ":["
+            for m in mode_env_vars[k]:
+                supported_mode_info += m + ","
+            supported_mode_info = supported_mode_info[:-1]
+            supported_mode_info += "],"
+        supported_mode_info = supported_mode_info[:-1]
+
+        raise ValueError(
+            f"device {device} and accelerate mode {accelerate_mode} not supported. supported device and accelerate mode is {supported_mode_info}"
+        )
+    return mode_env_vars[device][accelerate_mode]
+
+
+def convert_to_training_args(cls, config):
+    device = config["Training"]["device"]
+    accelerate_mode = config["Training"]["accelerate_mode"]
+    save_strategy = config["General"]["save_strategy"]
+
+    args = {
+        "output_dir": config["General"]["output_dir"],
+        "resume_from_checkpoint": config["General"]["resume_from_checkpoint"],
+        "gradient_checkpointing": config["General"]["enable_gradient_checkpointing"],
+        "save_strategy": save_strategy if save_strategy != "False" else "no",
+        "bf16": config["Training"]["mixed_precision"] == "bf16",
+        "num_train_epochs": config["Training"]["epochs"],
+        "per_device_train_batch_size": config["Training"]["batch_size"],
+        "per_device_eval_batch_size": config["Training"]["batch_size"],
+        "optim": config["Training"]["optimizer"],
+        "learning_rate": config["Training"]["learning_rate"],
+        "logging_steps": config["Training"]["logging_steps"],
+        "lr_scheduler_type": config["Training"]["lr_scheduler"],
+        "weight_decay": config["Training"]["weight_decay"],
+        "gradient_accumulation_steps": config["Training"]["gradient_accumulation_steps"],
+    }
+
+    # set attr max_steps
+    if config["Training"]["max_train_steps"] is not None:
+        args.update({"max_steps": config["Training"]["max_train_steps"]})
+
+    # set attr for device cpu
+    if device == "cpu":
+        if hasattr(cls, "use_cpu"):
+            args.update({"use_cpu": True})
+        if hasattr(cls, "no_cuda"):
+            args.update({"no_cuda": True})
+        args.update({"use_ipex": True})
+
+    # set attr 'deepspeed'
+    if accelerate_mode == "DEEPSPEED":
+        args.update({"deepspeed": config["Training"]["deepspeed_config_file"]})
+
+    # set attr for FSDP
+    # if accelerate_mode == "FSDP":
+    #     args.updatwe({})
+
+    # set attr for Intel Gaudi
+    if device == "hpu":
+        args.update({"use_habana": True})
+        args.update({"use_lazy_mode": config["Training"]["hpu_execution_mode"] == "lazy"})
+        args.update({"pipelining_fwd_bwd": True})
+
+    return cls(**args)
+
+
+def get_device_environment_variable(device):
+    if device == "hpu":
+        return {
+            "HABANA_VISIBLE_DEVICES": "all",
+            "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES": "true",
+        }
+    return {}
+
+
+def convert_dtype(dtype: str) -> Optional[torch.dtype]:
+    supported_dtypes = {
+        "fp16": torch.float16,
+        "bf16": torch.bfloat16,
+        "no": None,
+    }
+    return supported_dtypes[dtype]
+
+
+def train_func(config: Dict[str, Any]):
+    os.chdir(config["cwd"])
+
+    device = config["Training"]["device"]
+
+    base_model = config["General"]["base_model"]
+    if config["General"].get("tokenizer_name") is not None:
+        tokenizer_name = config["General"].get("tokenizer_name")
+    else:
+        tokenizer_name = base_model
+    dataset_file = config["Dataset"]["train_file"]
+
+    seed = config["Training"].get("seed")
+    if seed is not None:
+        set_seed(seed)
+
+    tokenizer = common.tokenizer.Tokenizer.registory.get("HuggingFaceTokenizer")()(
+        config={
+            "name": tokenizer_name,
+            "config": config["General"]["config"],
+        }
+    )
+
+    datasets = common.dataset.Dataset.registory.get("HuggingfaceDataset")()(
+        config={
+            "name": dataset_file,
+            "validation_file": config["Dataset"]["validation_file"],
+            "validation_split_percentage": config["Dataset"]["validation_split_percentage"],
+        }
+    )
+
+    dataprocesser = common.dataprocesser.DataProcesser.registory.get("GeneralProcesser")(
+        config={
+            "per_device_train_batch_size": config["Training"]["batch_size"],
+            "per_device_eval_batch_size": config["Training"]["batch_size"],
+            "preprocessing_num_workers": config["Dataset"].get("preprocessing_num_workers", 1),
+            "max_length": config["Dataset"].get("max_length", 512),
+            "group": config["Dataset"].get("group", True),
+            "block_size": config["Dataset"].get("block_size", 512),
+            "shuffle": config["Dataset"].get("shuffle", False),
+        }
+    )
+    tokenized_datasets = dataprocesser.tokenize_dataset(tokenizer, datasets)
+
+    model = common.model.Model.registory.get("HuggingFaceModelForCausalLM")()(
+        config={
+            "name": base_model,
+            "dtype": convert_dtype(config["Training"].get("mixed_precision", "no")),
+            "device": torch.device(device),
+            "config": config["General"]["config"],
+            "enable_gradient_checkpointing": config["General"].get("enable_gradient_checkpointing", False),
+            "lora_config": config["General"].get("lora_config", None),
+        }
+    )
+
+    data_collator = common.dataprocesser.general_processer.DataCollatorForCompletionOnlyLM(
+        tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
+    )
+
+    if device in ["cpu", "gpu"]:
+        from transformers import Trainer, TrainingArguments
+
+        training_args = convert_to_training_args(TrainingArguments, config)
+        trainer = Trainer(
+            model=model,
+            args=training_args,
+            train_dataset=tokenized_datasets["train"],
+            eval_dataset=tokenized_datasets["validation"] if tokenized_datasets.get("validation") is not None else None,
+            tokenizer=tokenizer,
+            data_collator=data_collator,
+        )
+
+        common.logger.info("train start")
+        trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+        trainer.save_model()
+        common.logger.info("train finish")
+    elif device in ["hpu"]:
+        from optimum.habana import GaudiConfig
+        from optimum.habana.transformers import GaudiTrainer, GaudiTrainingArguments
+
+        # If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config
+        if config["general"].get("gaudi_config_name") is not None:
+            gaudi_config = GaudiConfig.from_pretrained(
+                config["general"].get("gaudi_config_name"),
+            )
+        else:
+            gaudi_config = GaudiConfig()
+            gaudi_config.use_fused_adam = True
+            gaudi_config.use_fused_clip_norm = True
+        training_args = convert_to_training_args(GaudiTrainingArguments, config)
+        trainer = GaudiTrainer(
+            model=model,
+            args=training_args,
+            gaudi_config=gaudi_config,
+            train_dataset=tokenized_datasets["train"],
+            eval_dataset=tokenized_datasets["validation"] if tokenized_datasets.get("validation") is not None else None,
+            tokenizer=tokenizer,
+            data_collator=data_collator,
+        )
+
+        common.logger.info("train start")
+        trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+        trainer.save_model()
+        common.logger.info("train finish")
+
+
+def get_finetune_config():
+    parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
+    parser.add_argument(
+        "--config_file",
+        type=str,
+        required=True,
+        default=None,
+        help="The name of the dataset to use (via the datasets library).",
+    )
+
+    # Print help if no arguments were provided
+    if len(sys.argv) == 1:
+        parser.print_help(sys.stderr)
+        sys.exit(1)
+
+    args = parser.parse_args()
+    config_file = args.config_file
+
+    with open(config_file) as f:
+        finetune_config = parse_yaml_raw_as(FinetuneConfig, f)
+    return finetune_config.dict()
+
+
+def main(external_config=None):
+    if not external_config:
+        config = get_finetune_config()
+    else:
+        config = external_config
+
+    config["cwd"] = os.getcwd()
+
+    num_training_workers = config["Training"].get("num_training_workers")
+    resources_per_worker = config["Training"].get("resources_per_worker")
+
+    if config["Training"].get("accelerate_mode", None) is None:
+        config["Training"]["accelerate_mode"] = "DDP"  # will use DDP to accelerate if no method specified
+
+    ccl_worker_count = 1
+    device = config["Training"]["device"]
+    if device != "cpu":
+        ccl_worker_count = num_training_workers
+
+    if not ray.is_initialized():
+        runtime_env = {
+            "env_vars": {
+                "OMP_NUM_THREADS": str(resources_per_worker["CPU"]),
+                "CCL_ZE_IPC_EXCHANGE": "sockets",
+                "CCL_WORKER_COUNT": str(ccl_worker_count),
+                "CCL_LOG_LEVEL": "info",
+                "WORLD_SIZE": str(num_training_workers),
+                "FI_PROVIDER": "tcp",
+            }
+        }
+
+        if config["General"]["gpt_base_model"] is True:
+            runtime_env["pip"] = ["transformers==4.26.0"]
+
+        if device == "gpu":
+            num_cpus = resources_per_worker["CPU"] * num_training_workers + 1  # additional 1 for head worker
+            ray.init(num_cpus=num_cpus, runtime_env=runtime_env)
+        else:
+            ray.init(runtime_env=runtime_env)
+
+    common.logger.info(f"ray available resources = {ray.available_resources()}")
+    use_gpu = True if device == "gpu" else False
+    scaling_config = ScalingConfig(
+        num_workers=num_training_workers,
+        use_gpu=use_gpu,
+        resources_per_worker=resources_per_worker,
+        placement_strategy="SPREAD",
+    )
+
+    # if try to use Intel GPU, convert device to 'xpu'
+    # due to accelerate internal use 'xpu' represent Intel GPU
+    if device == "gpu" and is_xpu_available():
+        device = "xpu"
+
+    if config.get("torch_config", None) is None:
+        backend = None
+        if device == "cpu" or device == "xpu" or device == "gpu":
+            backend = "ccl"
+        elif device == "hpu":
+            backend = "hccl"
+        torch_config = common.TorchConfig(backend=backend, device=device)
+    else:
+        customer_torch_config = config.get("torch_config")
+        torch_config = common.TorchConfig(**customer_torch_config, device=device)
+
+    if config.get("failure_config", None) is None:
+        failure_config = FailureConfig()
+    else:
+        customer_failure_config = config.get("failure_config")
+        failure_config = FailureConfig(**customer_failure_config)
+
+    if config.get("run_config", None) is None:
+        run_config = RunConfig(failure_config=failure_config)
+    else:
+        customer_run_config = config.get("run_config")
+        if customer_run_config.get("failure_config", None) is None:
+            customer_run_config["failure_config"] = failure_config
+        run_config = RunConfig(**customer_run_config)
+
+    trainer = TorchTrainer(
+        train_func,
+        train_loop_config=config,
+        scaling_config=scaling_config,
+        torch_config=torch_config,
+        run_config=run_config,
+    )
+    results = trainer.fit()
+    if external_config is not None:
+        return results
+
+
+if __name__ == "__main__":
+    main()
diff --git a/comps/finetuning/llm_on_ray/finetune/finetune_config.py b/comps/finetuning/llm_on_ray/finetune/finetune_config.py
new file mode 100644
index 000000000..91d2fa0ca
--- /dev/null
+++ b/comps/finetuning/llm_on_ray/finetune/finetune_config.py
@@ -0,0 +1,157 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List, Optional
+
+from pydantic import BaseModel, validator
+
+PRECISION_BF16 = "bf16"
+PRECISION_FP16 = "fp16"
+PRECISION_NO = "no"
+
+DEVICE_CPU = "cpu"
+DEVICE_HPU = "hpu"
+DEVICE_GPU = "gpu"
+
+ACCELERATE_STRATEGY_DDP = "DDP"
+ACCELERATE_STRATEGY_FSDP = "FSDP"
+ACCELERATE_STRATEGY_DEEPSPEED = "DEEPSPEED"
+
+
+class GeneralConfig(BaseModel):
+    trust_remote_code: bool
+    use_auth_token: Optional[str]
+
+
+class LoraConfig(BaseModel):
+    task_type: str
+    r: int
+    lora_alpha: int
+    lora_dropout: float
+    target_modules: Optional[List[str]] = None
+
+
+class DeltatunerConfig(BaseModel):
+    algo: str
+    denas: bool
+    best_model_structure: str
+
+
+class General(BaseModel):
+    base_model: str
+    tokenizer_name: Optional[str] = None
+    gaudi_config_name: Optional[str] = None
+    gpt_base_model: bool
+    output_dir: str
+    resume_from_checkpoint: Optional[str] = None
+    save_strategy: str = "no"
+    config: GeneralConfig
+    lora_config: Optional[LoraConfig] = None
+    deltatuner_config: Optional[DeltatunerConfig] = None
+    enable_gradient_checkpointing: bool = False
+
+
+class Dataset(BaseModel):
+    train_file: str
+    validation_file: Optional[str]
+    validation_split_percentage: int
+    max_length: int = 512
+    group: bool = True
+    block_size: int = 512
+    shuffle: bool = False
+
+
+class RayResourceConfig(BaseModel):
+    CPU: int
+    GPU: int = 0
+    HPU: int = 0
+
+
+class Training(BaseModel):
+    optimizer: str
+    batch_size: int
+    epochs: int
+    max_train_steps: Optional[int] = None
+    learning_rate: float
+    lr_scheduler: str
+    weight_decay: float
+    device: str = DEVICE_CPU
+    hpu_execution_mode: str = "lazy"
+    num_training_workers: int
+    resources_per_worker: RayResourceConfig
+    accelerate_mode: str = ACCELERATE_STRATEGY_DDP
+    mixed_precision: str = PRECISION_NO
+    gradient_accumulation_steps: int = 1
+    logging_steps: int = 10
+    deepspeed_config_file: str = ""
+
+    @validator("device")
+    def check_device(cls, v: str):
+        # will convert to lower case
+        if v:
+            assert v.lower() in [DEVICE_CPU, DEVICE_GPU, DEVICE_HPU]
+        return v.lower()
+
+    @validator("hpu_execution_mode")
+    def check_hpu_execution_mode(cls, v: str):
+        if v:
+            assert v in ["lazy", "eager", "eager.compile"]
+        return v
+
+    @validator("accelerate_mode")
+    def check_accelerate_mode(cls, v: str):
+        if v:
+            assert v in [
+                ACCELERATE_STRATEGY_DDP,
+                ACCELERATE_STRATEGY_FSDP,
+                ACCELERATE_STRATEGY_DEEPSPEED,
+            ]
+        return v
+
+    @validator("mixed_precision")
+    def check_mixed_precision(cls, v: str):
+        if v:
+            assert v in [PRECISION_BF16, PRECISION_FP16, PRECISION_NO]
+        return v
+
+    @validator("logging_steps")
+    def check_logging_steps(cls, v: int):
+        assert v > 0
+        return v
+
+    # @model_validator(mode='after')
+    # def check_device_and_accelerate_mode(self) -> "Training":
+    #     dev = self.device
+    #     res = self.resources_per_worker
+    #     mode = self.accelerate_mode
+    #     if dev == "CPU":
+    #         if res.GPU is not None and res.GPU > 0:
+    #             raise ValueError("Please not specified GPU resource when use CPU only in Ray.")
+    #         if mode != "CPU_DDP":
+    #             raise ValueError("Please specified CPU related accelerate mode when use CPU only in Ray.")
+    #     elif dev == "GPU":
+    #         if res.GPU is None or res.GPU == 0:
+    #             raise ValueError("Please specified GPU resource when use GPU to fine tune in Ray.")
+    #         if mode not in ["GPU_DDP", "GPU_FSDP"]:
+    #             raise ValueError("Please speicifed GPU related accelerate mode when use GPU to fine tune in Ray.")
+
+    #     return self
+
+
+class FinetuneConfig(BaseModel):
+    General: General
+    Dataset: Dataset
+    Training: Training
diff --git a/comps/finetuning/models.py b/comps/finetuning/models.py
new file mode 100644
index 000000000..f6757364d
--- /dev/null
+++ b/comps/finetuning/models.py
@@ -0,0 +1,53 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from datetime import datetime
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+
+class FineTuningJobsRequest(BaseModel):
+    training_file: str
+    model: str
+
+
+class Hyperparameters(BaseModel):
+    n_epochs: int
+    batch_size: int
+    learning_rate_multiplier: float
+
+
+class FineTuningJob(BaseModel):
+    object: str = "fine_tuning.job"  # Set as constant
+    id: str
+    model: str
+    created_at: int
+    finished_at: int = None
+    fine_tuned_model: str = None
+    organization_id: str = None
+    result_files: List[str] = None
+    status: str
+    validation_file: str = None
+    training_file: str
+    hyperparameters: Hyperparameters
+    trained_tokens: int = None
+    integrations: List[str] = []  # Empty list by default
+    seed: int
+    estimated_finish: int = 0  # Set default value to 0
+
+
+class FineTuningJobList(BaseModel):
+    object: str = "list"  # Set as constant
+    data: List[FineTuningJob]
+    has_more: bool
+
+
+class FineTuningJobEvent(BaseModel):
+    object: str = "fine_tuning.job.event"  # Set as constant
+    id: str
+    created_at: int
+    level: str
+    message: str
+    data: None = None  # No data expected for this event type, set to None
+    type: str = "message"  # Default event type is "message"
diff --git a/comps/finetuning/models/llama-2-7b-chat-hf.yaml b/comps/finetuning/models/llama-2-7b-chat-hf.yaml
new file mode 100644
index 000000000..ab62383d2
--- /dev/null
+++ b/comps/finetuning/models/llama-2-7b-chat-hf.yaml
@@ -0,0 +1,40 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+General:
+  base_model: meta-llama/Llama-2-7b-chat-hf
+  gpt_base_model: false
+  output_dir: /tmp/llm-ray/output
+  save_strategy: no
+  config:
+    trust_remote_code: false
+    use_auth_token: null
+  lora_config:
+    task_type: CAUSAL_LM
+    r: 8
+    lora_alpha: 32
+    lora_dropout: 0.1
+    target_modules:
+      - q_proj
+      - v_proj
+  enable_gradient_checkpointing: false
+Dataset:
+  train_file: examples/data/sample_finetune_data_small.jsonl
+  group: false
+  validation_file: null
+  validation_split_percentage: 5
+Training:
+  optimizer: adamw_torch
+  batch_size: 2
+  epochs: 3
+  learning_rate: 1.0e-05
+  lr_scheduler: linear
+  weight_decay: 0.0
+  mixed_precision: bf16
+  device: cpu
+  num_training_workers: 2
+  resources_per_worker:
+    CPU: 32
+  accelerate_mode: DDP
+  gradient_accumulation_steps: 1
+  logging_steps: 10
diff --git a/comps/finetuning/models/mistral-7b-v0.1.yaml b/comps/finetuning/models/mistral-7b-v0.1.yaml
new file mode 100644
index 000000000..29d05de93
--- /dev/null
+++ b/comps/finetuning/models/mistral-7b-v0.1.yaml
@@ -0,0 +1,45 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+General:
+  base_model: mistralai/Mistral-7B-v0.1
+  gpt_base_model: false
+  output_dir: /tmp/llm-ray/output
+  save_strategy: no
+  config:
+    trust_remote_code: false
+    use_auth_token: null
+  lora_config:
+    task_type: CAUSAL_LM
+    r: 8
+    lora_alpha: 32
+    lora_dropout: 0.1
+    target_modules:
+      - q_proj
+      - k_proj
+      - v_proj
+      - o_proj
+      - gate_proj
+      - up_proj
+      - down_proj
+      - lm_head
+  enable_gradient_checkpointing: false
+Dataset:
+  train_file: examples/data/sample_finetune_data_small.jsonl
+  validation_file: null
+  validation_split_percentage: 5
+Training:
+  optimizer: adamw_torch
+  batch_size: 2
+  epochs: 3
+  learning_rate: 1.0e-05
+  lr_scheduler: linear
+  weight_decay: 0.0
+  mixed_precision: bf16
+  device: cpu
+  num_training_workers: 2
+  resources_per_worker:
+    CPU: 32
+  accelerate_mode: DDP
+  gradient_accumulation_steps: 1
+  logging_steps: 10
diff --git a/comps/finetuning/requirements.txt b/comps/finetuning/requirements.txt
new file mode 100644
index 000000000..cefb56399
--- /dev/null
+++ b/comps/finetuning/requirements.txt
@@ -0,0 +1,3 @@
+fastapi
+pydantic
+uvicorn