CLOSE
Updated on 28 Aug, 202525 mins read 164 views

The Template Method Design Pattern is a behavioral design pattern that defines the skeleton of an algorithm in a base class but allows subclasses to override specific steps of the algorithm without changing its structure.

Intent

  • Define the framework of an algorithm in a base class.
  • Allow subclasses to provide specific implementations for one or more steps of the algorithm.

Key Concepts

  1. Base Class (Template):
    • Contains the template method that defines the sequence of steps in an algorithm.
    • Implements some of the steps directly and defines placeholders (abstract or virtual methods) for steps that need to be customized.
  2. Concrete Subclasses:
    • Implement the abstract or virtual steps of the algorithm.
    • Can override certain steps while adhering to the overall structure.

When to Use

  • When multiple classes share the same overall behavior but differ in specific steps.
  • To enforce a common structure for algorithms while allowing flexibility in implementation.
  • When you want to promote code reuse and avoid duplicating common logic.

Problem (Without Template Pattern)

Suppose you are building a system where different types of data parsers ( CSV, XML, JSON) need to:

  1. Open a file
  2. Read data
  3. Parse data
  4. Close the file

Without a template, each parser repeats the same logic for open and close, duplicating the code.

class CSVParser {
public:
    void parseFile() {
        cout << "Opening file\n";
        cout << "Reading CSV data\n";
        cout << "Parsing CSV data\n";
        cout << "Closing file\n";
    }
};

class XMLParser {
public:
    void parseFile() {
        cout << "Opening file\n";
        cout << "Reading XML data\n";
        cout << "Parsing XML data\n";
        cout << "Closing file\n";
    }
};

Problem:

  • Duplicate steps (open, close)
  • Hard to maintain (if open/close changes, update all parsers).

Solution (With Template Method Pattern)

We put the common steps in a base class and allow subclasses to define the variable steps.

#include <iostream>
using namespace std;

// Abstract base class
class DataParser {
public:
    // Template Method: defines the skeleton
    void parseFile() {
        openFile();
        readData();
        parseData();   // Step left to subclasses
        closeFile();
    }

    virtual ~DataParser() = default;

protected:
    void openFile() { cout << "Opening file\n"; }
    void closeFile() { cout << "Closing file\n"; }

    virtual void readData() = 0;   // abstract step
    virtual void parseData() = 0;  // abstract step
};

// Concrete Class: CSV Parser
class CSVParser : public DataParser {
protected:
    void readData() override { cout << "Reading CSV data\n"; }
    void parseData() override { cout << "Parsing CSV data\n"; }
};

// Concrete Class: XML Parser
class XMLParser : public DataParser {
protected:
    void readData() override { cout << "Reading XML data\n"; }
    void parseData() override { cout << "Parsing XML data\n"; }
};

// Client
int main() {
    DataParser* parser1 = new CSVParser();
    parser1->parseFile();

    cout << "----\n";

    DataParser* parser2 = new XMLParser();
    parser2->parseFile();

    delete parser1;
    delete parser2;
}

Output:

Opening file
Reading CSV data
Parsing CSV data
Closing file
----
Opening file
Reading XML data
Parsing XML data
Closing file

Example:


#include <iostream>
#include <string>

using namespace std;

// ───────────────────────────────────────────────────────────
// 1. Base class defining the template method
// ───────────────────────────────────────────────────────────
class ModelTrainer {
public:
    // The template method — final so subclasses can’t change the sequence
    void trainPipeline(const string& dataPath) {
        loadData(dataPath);
        preprocessData();
        trainModel();      // subclass-specific
        evaluateModel();   // subclass-specific
        saveModel();       // subclass-specific or default
    }

protected:
    void loadData(const string& path) {
        cout << "[Common] Loading dataset from " << path << "\n";
        // e.g., read CSV, images, etc.
    }

    virtual void preprocessData() {
        cout << "[Common] Splitting into train/test and normalizing\n";
    }

    virtual void trainModel() = 0;
    virtual void evaluateModel() = 0;

    // Provide a default save, but subclasses can override if needed
    virtual void saveModel() {
        cout << "[Common] Saving model to disk as default format\n";
    }
};

// ───────────────────────────────────────────────────────────
// 2. Concrete subclass: Neural Network
// ───────────────────────────────────────────────────────────
class NeuralNetworkTrainer : public ModelTrainer {
protected:
    void trainModel() override {
        cout << "[NeuralNet] Training Neural Network for 100 epochs\n";
        // pseudo-code: forward/backward passes, gradient descent...
    }
    void evaluateModel() override {
        cout << "[NeuralNet] Evaluating accuracy and loss on validation set\n";
    }
    void saveModel() override {
        cout << "[NeuralNet] Serializing network weights to .h5 file\n";
    }
};

// ───────────────────────────────────────────────────────────
// 3. Concrete subclass: Decision Tree
// ───────────────────────────────────────────────────────────
class DecisionTreeTrainer : public ModelTrainer {
protected:
    // Use the default preprocessData() (train/test split + normalize)

    void trainModel() override {
        cout << "[DecisionTree] Building decision tree with max_depth=5\n";
        // pseudo-code: recursive splitting on features...
    }
    void evaluateModel() override {
        cout << "[DecisionTree] Computing classification report (precision/recall)\n";
    }
    // use the default saveModel()
};

// ───────────────────────────────────────────────────────────
// 4. Usage
// ───────────────────────────────────────────────────────────
int main() {
    cout << "=== Neural Network Training ===\n";
    ModelTrainer* nnTrainer = new NeuralNetworkTrainer();
    nnTrainer->trainPipeline("data/images/");

    cout << "\n=== Decision Tree Training ===\n";
    ModelTrainer* dtTrainer = new DecisionTreeTrainer();
    dtTrainer->trainPipeline("data/iris.csv");

    return 0;
}

Leave a comment

Your email address will not be published. Required fields are marked *