Skip to content

Commit

Permalink
[OpenCL] Reduced ifdef checks
Browse files Browse the repository at this point in the history
Reduced and clubbed some ifdef checks

Signed-off-by: Debadri Samaddar <[email protected]>
  • Loading branch information
s-debadri committed Mar 6, 2024
1 parent 9ef8f3a commit 61dedf3
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 34 deletions.
30 changes: 15 additions & 15 deletions nntrainer/layers/layer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,21 +813,6 @@ class RunLayerContext {
opencl::ContextManager &context_inst_ = opencl::ContextManager::GetInstance();
opencl::Kernel kernel_;

/**
* @brief set the compute engine for this node
* @param compute engine: (CPU/GPU)
*/
void setComputeEngine(const ml::train::LayerComputeEngine &compute_engine =
ml::train::LayerComputeEngine::CPU) {
this->compute_engine = compute_engine;
}

/**
* @brief get the compute engine for this node
* @return ompute engine: (CPU/GPU)
*/
ml::train::LayerComputeEngine getComputeEngine() { return compute_engine; }

/**
* @brief create OpenCl kernel
* @param kernel implementation string
Expand All @@ -846,6 +831,21 @@ class RunLayerContext {
}
#endif

/**
* @brief set the compute engine for this node
* @param compute engine: (CPU/GPU)
*/
void setComputeEngine(const ml::train::LayerComputeEngine &compute_engine =
ml::train::LayerComputeEngine::CPU) {
this->compute_engine = compute_engine;
}

/**
* @brief get the compute engine for this node
* @return ompute engine: (CPU/GPU)
*/
ml::train::LayerComputeEngine getComputeEngine() { return compute_engine; }

private:
std::tuple<props::Name, props::Trainable> props; /**< props of the layer */
float loss; /**< loss of the layer */
Expand Down
5 changes: 1 addition & 4 deletions nntrainer/layers/layer_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,10 @@ createLayerNode(std::unique_ptr<nntrainer::Layer> &&layer,
auto lnode = std::make_unique<LayerNode>(std::move(layer));

lnode->setProperty(properties);
#ifdef ENABLE_OPENCL

if (compute_engine == ml::train::LayerComputeEngine::GPU) {
lnode->setComputeEngine(compute_engine);
}
#endif

return lnode;
}
Expand Down Expand Up @@ -267,12 +266,10 @@ void LayerNode::setOutputConnection(unsigned nth, const std::string &name,
con = std::make_unique<Connection>(name, index);
}

#ifdef ENABLE_OPENCL
void LayerNode::setComputeEngine(
const ml::train::LayerComputeEngine &compute_engine) {
run_context->setComputeEngine(compute_engine);
}
#endif

const std::string LayerNode::getName() const noexcept {
auto &name = std::get<props::Name>(*layer_node_props);
Expand Down
2 changes: 0 additions & 2 deletions nntrainer/layers/layer_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,12 @@ class LayerNode final : public ml::train::Layer, public GraphNode {
void setOutputConnection(unsigned nth, const std::string &name,
unsigned index);

#ifdef ENABLE_OPENCL
/**
* @brief set the compute engine for this node
* @param compute engine (CPU/GPU)
*/
void setComputeEngine(const ml::train::LayerComputeEngine &compute_engine =
ml::train::LayerComputeEngine::CPU);
#endif

/**
* @brief Get the input connections for this node
Expand Down
26 changes: 13 additions & 13 deletions nntrainer/tensor/cl_operations/cl_sgemv.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ class GpuCLSgemv : public nntrainer::opencl::GpuCLOpInterface {
})";

public:
/**
* @brief Function to set buffers and kernel arguments for SGEMV
*
* @tparam T
* @param matAdata
* @param vecXdata
* @param vecYdata
* @param alpha
* @param beta
* @param dim1
* @param dim2
* @return T*
*/
/**
* @brief Function to set buffers and kernel arguments for SGEMV
*
* @tparam T
* @param matAdata
* @param vecXdata
* @param vecYdata
* @param alpha
* @param beta
* @param dim1
* @param dim2
* @return T*
*/
template <typename T>
T *cLSgemv(const T *matAdata, const T *vecXdata, T *vecYdata, T alpha, T beta,
unsigned int dim1, unsigned int dim2);
Expand Down

0 comments on commit 61dedf3

Please sign in to comment.