Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

A detailed design document about this will be created and is worked on separately (Thread-safety in MXNet).

  • Informative errors

  • Documentation and versioning of the API

...

 

// MultiplyOperation.h

class Tensor;

class MultiplyOperation {
public:
  void setA(const Tensor &a);
  void setB(const Tensor &b);

  Tensor getResult() const;

private:
  class MultiplyOperationPrivate;

  std::unique_ptr<MultiplyOperationPrivate> p;
};

// MultiplyOperation.cpp

#include <MultiplyOperation.h>

void MultiplyOperation::setA(const Tensor &a) {
  p->a = a;
}

void MultiplyOperation::setB(const Tensor &b) {
  p->b = b;
}

Tensor MultiplyOperation::getResult() const {
  return p->getResultImpl();
}

class MultiplyOperationPrivate {
public:
  Tensor a, b;

  Tensor getResultImpl() const;
};

// MultiplyOperation_CUDA.cpp

Tensor MultiplyOperationPrivate::getResultImpl() const {
  return CUDA_CALL(multiply(a, b));
}

// MultiplyOperation_MKL.cpp

Tensor MultiplyOperationPrivate::getResultImpl() const {
  return mkl_multiply(a, b);
}

 

Semantic versioned API with no additional dependencies

...