from __future__ import print_function
import argparse
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.utils.data.distributed
import horovod.torch as hvd
import math
import random
import numpy as np

# Training settings
parser = argparse.ArgumentParser(description='PyTorch GPT Text Generation Example')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('--test-batch-size', type=int, default=16, metavar='N',
                    help='input batch size for testing (default: 16)')
parser.add_argument('--epochs', type=int, default=3, metavar='N',
                    help='number of epochs to train (default: 3)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
                    help='learning rate (default: 0.0001)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
                    help='use fp16 compression during allreduce')
parser.add_argument('--use-adasum', action='store_true', default=False,
                    help='use adasum algorithm to do reduction')
parser.add_argument('--seq-len', type=int, default=128, metavar='N',
                    help='sequence length (default: 128)')
parser.add_argument('--vocab-size', type=int, default=5000, metavar='N',
                    help='vocabulary size (default: 5000)')

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        # Linear transformations and split into heads
        Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention = F.softmax(scores, dim=-1)
        context = torch.matmul(attention, V)
        
        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.w_o(context)
        
        return output

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention
        attn_output = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed forward
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

class GPTModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=6, seq_len=128, d_ff=1024):
        super(GPTModel, self).__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.seq_len = seq_len
        
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(seq_len, d_model)
        
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)
        ])
        
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, input_ids, targets=None):
        batch_size, seq_len = input_ids.size()
        
        # Token and position embeddings
        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
        token_emb = self.token_embedding(input_ids)
        pos_emb = self.position_embedding(positions)
        
        x = self.dropout(token_emb + pos_emb)
        
        # Causal mask for autoregressive generation
        mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).unsqueeze(0).unsqueeze(0)
        
        # Transformer blocks
        for block in self.transformer_blocks:
            x = block(x, mask)
        
        x = self.ln_f(x)
        logits = self.head(x)
        
        if targets is not None:
            # Shift targets for next token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = targets[..., 1:].contiguous()
            loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), 
                                 shift_labels.view(-1), ignore_index=-1)
            return logits, loss
        
        return logits

class TextDataset(Dataset):
    def __init__(self, vocab_size, seq_len, num_samples=10000):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples
        
        # Generate synthetic text data
        self.data = []
        for _ in range(num_samples):
            # Create sequences with some patterns to learn
            sequence = []
            for i in range(seq_len):
                if i < 3:  # Start with special pattern
                    token = i + 1
                elif i % 10 == 0:  # Periodic pattern
                    token = 100
                else:
                    token = random.randint(1, vocab_size - 1)
                sequence.append(token)
            self.data.append(sequence)
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        sequence = torch.tensor(self.data[idx], dtype=torch.long)
        return sequence, sequence  # Input and target are the same for next token prediction

def train(epoch):
    model.train()
    train_sampler.set_epoch(epoch)
    total_loss = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        
        optimizer.zero_grad()
        logits, loss = model(data, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_sampler),
                100. * batch_idx / len(train_loader), loss.item()))

def metric_average(val, name):
    tensor = torch.tensor(val)
    avg_tensor = hvd.allreduce(tensor, name=name)
    return avg_tensor.item()

def test():
    model.eval()
    test_loss = 0.
    
    with torch.no_grad():
        for data, target in test_loader:
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            
            logits, loss = model(data, target)
            test_loss += loss.item()
    
    test_loss /= len(test_sampler)
    test_loss = metric_average(test_loss, 'avg_test_loss')
    
    if hvd.rank() == 0:
        print(f'\nTest set: Average loss: {test_loss:.4f}\n')
        
        # Generate some text samples
        generate_text_sample()

def generate_text_sample():
    """Generate a sample text sequence"""
    model.eval()
    with torch.no_grad():
        # Start with a seed sequence
        seed = torch.tensor([[1, 2, 3]], dtype=torch.long)
        if args.cuda:
            seed = seed.cuda()
        
        generated = seed.clone()
        
        # Generate next 20 tokens
        for _ in range(20):
            if generated.size(1) >= args.seq_len:
                break
                
            logits = model(generated)
            next_token_logits = logits[0, -1, :]
            next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), 1)
            generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
        
        print(f"Generated sequence: {generated[0].cpu().numpy().tolist()}")

if __name__ == '__main__':
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # Horovod: initialize library.
    hvd.init()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    if args.cuda:
        # Horovod: pin GPU to local rank.
        torch.cuda.set_device(hvd.local_rank())
        torch.cuda.manual_seed(args.seed)

    # Horovod: limit # of CPU threads to be used per worker.
    torch.set_num_threads(1)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if (kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and
            mp._supports_context and 'forkserver' in mp.get_all_start_methods()):
        kwargs['multiprocessing_context'] = 'forkserver'

    # Create datasets
    train_dataset = TextDataset(args.vocab_size, args.seq_len, num_samples=5000)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs)

    test_dataset = TextDataset(args.vocab_size, args.seq_len, num_samples=1000)
    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.test_batch_size, sampler=test_sampler, **kwargs)

    # Create model
    model = GPTModel(vocab_size=args.vocab_size, seq_len=args.seq_len)
    
    # Learning rate scaling
    lr_scaler = hvd.size() if not args.use_adasum else 1

    if args.cuda:
        model.cuda()
        if args.use_adasum and hvd.nccl_built():
            lr_scaler = hvd.local_size()

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=args.lr * lr_scaler, weight_decay=0.01)

    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Horovod: compression algorithm.
    compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

    # Horovod: wrap optimizer with DistributedOptimizer.
    optimizer = hvd.DistributedOptimizer(optimizer,
                                         named_parameters=model.named_parameters(),
                                         compression=compression,
                                         op=hvd.Adasum if args.use_adasum else hvd.Average)

    # Training loop
    for epoch in range(1, args.epochs + 1):
        train(epoch)
        test()