Skip to content

Commit

Permalink
Update for xla 0.6.0.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Nov 30, 2023
1 parent 5db7c29 commit 0dccbab
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Experimentation using the xla compiler from rust

Pre-compiled binaries for the xla library can be downloaded from the
[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.5.1).
[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.6.0).
These should be extracted at the root of this repository, resulting
in a `xla_extension` subdirectory being created, the currently supported version
is 0.5.1.
is 0.6.0.

For a linux platform, this can be done via:
```bash
wget https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget https://github.com/elixir-nx/xla/releases/download/v0.6.0/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
```

Expand Down
2 changes: 2 additions & 0 deletions src/wrappers/pjrt_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ impl PjRtClient {
Ok(Self(Rc::new(PjRtClientInternal(ptr))))
}

/*
/// A TPU client.
pub fn tpu(max_inflight_computations: usize) -> Result<Self> {
let mut ptr: c_lib::pjrt_client = std::ptr::null_mut();
Expand All @@ -39,6 +40,7 @@ impl PjRtClient {
super::handle_status(status)?;
Ok(Self(Rc::new(PjRtClientInternal(ptr))))
}
*/

fn ptr(&self) -> c_lib::pjrt_client {
self.0 .0
Expand Down
2 changes: 2 additions & 0 deletions xla_rs/xla_rs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ status pjrt_gpu_client_create(pjrt_client *output, double memory_fraction,
return nullptr;
}

/*
status pjrt_tpu_client_create(pjrt_client *output,
int max_inflight_computations) {
ASSIGN_OR_RETURN_STATUS(client, xla::GetTpuClient(max_inflight_computations));
*output = new std::shared_ptr(std::move(client));
return nullptr;
}
*/

int pjrt_client_device_count(pjrt_client c) { return (*c)->device_count(); }

Expand Down
3 changes: 1 addition & 2 deletions xla_rs/xla_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/pjrt/tpu_client.h"
#include "xla/service/hlo_parser.h"
#include "xla/shape_util.h"
#include "xla/statusor.h"
Expand Down Expand Up @@ -52,7 +51,7 @@ typedef struct _hlo_module_proto *hlo_module_proto;

status pjrt_cpu_client_create(pjrt_client *);
status pjrt_gpu_client_create(pjrt_client *, double, bool);
status pjrt_tpu_client_create(pjrt_client *, int);
// status pjrt_tpu_client_create(pjrt_client *, int);
void pjrt_client_free(pjrt_client);
int pjrt_client_device_count(pjrt_client);
int pjrt_client_addressable_device_count(pjrt_client);
Expand Down

0 comments on commit 0dccbab

Please sign in to comment.