본문 바로가기

Algorithm

[C++][백준] 문자열 집합 - TRIE(14425)

문제

총 N개의 문자열로 이루어진 집합 S가 주어진다.

입력으로 주어지는 M개의 문자열 중에서 집합 S에 포함되어 있는 것이 총 몇 개인지 구하는 프로그램을 작성하시오.

 

입력

첫째 줄에 문자열의 개수 N과 M (1 ≤ N ≤ 10,000, 1 ≤ M ≤ 10,000)이 주어진다.

다음 N개의 줄에는 집합 S에 포함되어 있는 문자열들이 주어진다.

다음 M개의 줄에는 검사해야 하는 문자열들이 주어진다.

입력으로 주어지는 문자열은 알파벳 소문자로만 이루어져 있으며, 길이는 500을 넘지 않는다. 집합 S에 같은 문자열이 여러 번 주어지는 경우는 없다.

 

출력

첫째 줄에 M개의 문자열 중에 총 몇 개가 집합 S에 포함되어 있는지 출력한다.

 

풀이

TRIE 자료구조를 사용하여 문제를 풀었습니다.

LIST, 배열 등의 자료구조를 사용하는 경우 집합 S에 검사할 문자열이 포함되어있는지 확인 하기 위해 O(N), M번 검사해야 하기 때문에 O(N) * O(M) 시간복잡도를 갖습니다.

그에 비해 TRIE 자료구조를 사용하는 경우 상수 시간으로 집합 S에 검사할 문자열이 포함되어있는지 확인할 수 있기 때문에 O(M) 시간복잡도를 갖습니다.

 

TRIE 구조체를 만들어줍니다.

struct TRIE {
	bool finished;
	TRIE* node[26];

	TRIE() {
		finished = false;
		for (int i = 0; i < 26; i++) {
			node[i] = NULL;
		}
	}
}

 

TRIE 객체를 생성하고, 해당 객체에 접근할 수 있는 포인터를 만들어줍니다.

TRIE* root = new TRIE();

N개의 문자열을 TRIE에 삽입합니다.

for (int i = 0; i < N; i++) {
    string input;
    cin >> input;

    root->insert(input.c_str());
}
   

root 포인터가 가리키는 객체의 insert 함수를 호출합니다.

 

insert 함수를 아래와 같습니다.

void insert(const char* str) {
    if (*str == NULL) {
     finished = true;
     return;
    }
    int next = *str - 'a';
    if (node[next] == NULL) {
     node[next] = new TRIE();
    }
    node[next]->insert(str + 1);
}

삽입할 문자가 NULL이라면 문자열이 끝났다는 의미로 finished 변수를 true로 셋팅하고

그렇지 않은 경우 삽입할 문자가 이미 다음 노드로 가지고 있는지 없는지를 판별하여 계속 insert 를 진행합니다.

 

M번의 Find를 수행합니다.(TRIE에서 검사할 문자열이 있는지 확인합니다.)

for (int i = 0; i < M; i++) {
    string input;
    cin >> input;

    if (root->find(input.c_str())) {
     answer++;
    }
}

 

find 함수는 아래와 같습니다.

bool find(const char* str) {
    if (*str == NULL) {
        if (finished == true) return true;
        else return false;
    }

    int next = *str - 'a';
    if (node[next] == NULL) return false;
    return node[next]->find(str + 1);
}

탐색할 문자열이 끝난 경우 해당 노드에서 finished 변수가 true인지 체크합니다.

끝나지 않은 경우 다음 문자열이 존재하는지 존재하지 않는지에 따라 계속 find를 진행합니다.

 

전체 코드는 아래와 같습니다.

#include <iostream>
#include <string>
using namespace std;

int N, M;

struct TRIE {
	bool finished;
	TRIE* node[26];

	TRIE() {
		finished = false;
		for (int i = 0; i < 26; i++) {
			node[i] = NULL;
		}
	}

	void insert(const char* str) {
		if (*str == NULL) {
			finished = true;
			return;
		}
		int next = *str - 'a';
		if (node[next] == NULL) {
			node[next] = new TRIE();
		}
		node[next]->insert(str + 1);
	}

	bool find(const char* str) {
		if (*str == NULL) {
			if (finished == true) return true;
			else return false;
		}

		int next = *str - 'a';
		if (node[next] == NULL) return false;
		return node[next]->find(str + 1);
	}
};

int main() {
	int answer = 0;
	TRIE* root = new TRIE();
	cin >> N >> M;

	for (int i = 0; i < N; i++) {
		string input;
		cin >> input;

		root->insert(input.c_str());
	}

	for (int i = 0; i < M; i++) {
		string input;
		cin >> input;

		if (root->find(input.c_str())) {
			answer++;
		}
	}

	cout << answer;

	return 0;
}